Implement Cain with scdaling incorporated;

Removing scaling from ScaledLinear, ScaledConv1d, etc.
This commit is contained in:
Daniel Povey 2022-05-20 13:36:01 +08:00
parent 8fd9e64fdf
commit abe5abb688
2 changed files with 123 additions and 354 deletions

View File

@ -22,6 +22,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding
def _ntuple(n):
@ -154,10 +155,8 @@ class BasicNorm(torch.nn.Module):
class ScaledLinear(nn.Linear):
"""
A modified version of nn.Linear where the parameters are scaled before
use, via:
weight = self.weight * self.weight_scale.exp()
bias = self.bias * self.bias_scale.exp()
A modified version of nn.Linear that gives an easy way to set the
default initial parameter scale.
Args:
Accepts the standard args and kwargs that nn.Linear accepts
@ -168,59 +167,26 @@ class ScaledLinear(nn.Linear):
(affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is
to re-initialize the parameters.
initial_speed: this affects how fast the parameter will
learn near the start of training; you can set it to a
value less than one if you suspect that a module
is contributing to instability near the start of training.
Nnote: regardless of the use of this option, it's best to
use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
"""
def __init__(
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
self.weight[:] *= initial_scale
if self.bias is not None:
self.bias[:] *= initial_scale
def get_weight(self):
return self.weight * self.weight_scale.exp()
def get_weight(self): # not needed any more but kept for back compatibility
return self.weight
def get_bias(self):
if self.bias is None or self.bias_scale is None:
return None
return self.bias
return self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(
input, self.get_weight(), self.get_bias()
)
class ScaledConv1d(nn.Conv1d):
@ -229,66 +195,20 @@ class ScaledConv1d(nn.Conv1d):
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
self.weight[:] *= initial_scale
if self.bias is not None:
self.bias[:] *= initial_scale
def get_weight(self):
return self.weight * self.weight_scale.exp()
def get_bias(self):
bias = self.bias
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
def get_weight(self): # TODO: delete
return self.weight
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
if self.padding_mode != "zeros":
return F.conv1d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
self.get_weight(),
self.get_bias(),
self.stride,
(0,),
self.dilation,
self.groups,
)
return F.conv1d(
input,
self.get_weight(),
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
def get_bias(self): # TODO: delete
return self.bias
class ScaledConv2d(nn.Conv2d):
@ -297,70 +217,20 @@ class ScaledConv2d(nn.Conv2d):
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
self.weight[:] *= initial_scale
if self.bias is not None:
self.bias[:] *= initial_scale
def get_weight(self):
return self.weight * self.weight_scale.exp()
return self.weight
def get_bias(self):
# see https://github.com/pytorch/pytorch/issues/24135
bias = self.bias
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
return self.bias
def _conv_forward(self, input, weight):
F = torch.nn.functional
if self.padding_mode != "zeros":
return F.conv2d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
weight,
self.get_bias(),
self.stride,
(0, 0),
self.dilation,
self.groups,
)
return F.conv2d(
input,
weight,
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.get_weight())
class ActivationBalancer(torch.nn.Module):
@ -464,179 +334,6 @@ class DoubleSwish(torch.nn.Module):
return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module):
r"""This is a modified version of nn.Embedding that introduces a learnable scale
on the parameters. Note: due to how we initialize it, it's best used with
schedulers like Noam that have a warmup period.
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
See Notes for more details regarding sparse gradients.
initial_speed (float, optional): This affects how fast the parameter will
learn near the start of training; you can set it to a value less than
one if you suspect that a module is contributing to instability near
the start of training. Nnote: regardless of the use of this option,
it's best to use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)`
Shape:
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
.. note::
Keep in mind that only a limited number of optimizers support
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
.. note::
With :attr:`padding_idx` set, the embedding vector at
:attr:`padding_idx` is initialized to all zeros. However, note that this
vector can be modified afterwards, e.g., using a customized
initialization method, and thus changing the vector used to pad the
output. The gradient for this vector from :class:`~torch.nn.Embedding`
is always zero.
Examples::
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
"""
__constants__ = [
"num_embeddings",
"embedding_dim",
"padding_idx",
"scale_grad_by_freq",
"sparse",
]
num_embeddings: int
embedding_dim: int
padding_idx: int
scale_grad_by_freq: bool
weight: Tensor
sparse: bool
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
initial_speed: float = 1.0,
) -> None:
super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters(initial_speed)
def reset_parameters(self, initial_speed: float = 1.0) -> None:
std = 0.1 / initial_speed
nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
scale = self.scale.exp()
if input.numel() < self.num_embeddings:
return (
F.embedding(
input,
self.weight,
self.padding_idx,
None,
2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq,
self.sparse,
)
* scale
)
else:
return F.embedding(
input,
self.weight * scale,
self.padding_idx,
None,
2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq,
self.sparse,
)
def extra_repr(self) -> str:
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
if self.padding_idx is not None:
s += ", padding_idx={padding_idx}"
if self.scale_grad_by_freq is not False:
s += ", scale_grad_by_freq={scale_grad_by_freq}"
if self.sparse is not False:
s += ", sparse=True"
return s.format(**self.__dict__)
def _test_activation_balancer_sign():

View File

@ -205,9 +205,16 @@ class Abel(Optimizer):
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-08).
target_rms (float, optional): target root-mean-square value of
x_norm in factorization (conceptually). Actually this now just becomes
a factor in the learning rate, we may remove it at some point.
parameters, when we normalize them (only conceptually!)..
actually this now just becomes
a factor in the learning rate for non-scalar parameters.
rms_eps (float, optional): epsilon value such that when parameter
rms value goes below this, we stop learning slower for that parameter,
i.e. we treat the rms as if it were rms_eps.
rms_max (float, optional): maximum root-mean-square value for
non-scalar parameters, such that we don't let the parameter
values exceed this; we'll shrink any parameter matrices that become
larger than this.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
@ -438,6 +445,9 @@ class Cain(Optimizer):
co-ordinates every so often, just in the optimizer, to separate big/small
directions.
This version of Cain absorbs the learned scales of the parameter matrices into
the optimizer.
Eve is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular target_rms (default: 0.1). This is
@ -456,12 +466,17 @@ class Cain(Optimizer):
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
this value means that the weight would decay significantly after
about 3k minibatches. Is not multiplied by learning rate, but
is conditional on RMS-value of parameter being > target_rms.
target_rms (float, optional): target root-mean-square value of
parameters, if they fall below this we will stop applying weight decay.
rms_eps (float, optional): epsilon value such that when parameter
rms value goes below this, we stop learning slower for that parameter,
i.e. we treat the rms as if it were rms_eps.
rms_max (float, optional): maximum root-mean-square value for
non-scalar parameters, such that we don't let the parameter
values exceed this; we'll shrink any parameter matrices that become
larger than this.
.. _Adam\: A Method for Stochastic Optimization:
@ -478,8 +493,9 @@ class Cain(Optimizer):
lr=1e-3,
betas=(0.9, 0.98),
eps=1e-8,
weight_decay=1e-3,
target_rms=0.1,
rms_eps=1.0e-05,
rms_max=10.0,
):
if not 0.0 <= lr:
@ -494,18 +510,20 @@ class Cain(Optimizer):
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
if not 0.1 < rms_max <= 1000.0:
raise ValueError("Invalid rms_max value: {}".format(rms_max))
if not 0.0 < rms_eps < 0.1:
raise ValueError("Invalid rms_eps value: {}".format(rms_eps))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
target_rms=target_rms,
rms_eps=rms_eps,
rms_max=rms_max,
)
super(Cain, self).__init__(params, defaults)
@ -526,6 +544,12 @@ class Cain(Optimizer):
loss = closure()
for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
target_rms = group["target_rms"]
rms_eps = group["rms_eps"]
rms_max = group["rms_max"]
for p in group["params"]:
if p.grad is None:
continue
@ -550,14 +574,61 @@ class Cain(Optimizer):
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if p.numel() > 1:
# we keep track of the RMS value of the parameters to
# help set the learning rate... this emulates Eve.
state["param_rms"] = (p**2).mean().sqrt()
# we learn the scale on the parameters in a way
# that also emulates Eve.
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
dtype=p.dtype)
state["scale_exp_avg"] = torch.zeros((), device=p.device,
dtype=p.dtype)
step = state["step"]
if p.numel() > 1:
# This part learns the scale on the parameters.
# scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e.
# if we imagine: p = underlying_param * scale.exp(), this is the derivative
# w.r.t. scale (it does not actually matter what value `scale` currently has).
scale_deriv = (p * grad).sum()
scale_exp_avg_sq = state["scale_exp_avg_sq"]
scale_exp_avg = state["scale_exp_avg"]
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
value=1 - beta2)
scale_exp_avg.mul_(beta1).add_(scale_deriv, alpha=(1-beta1))
scale_bias_correction2 = 1 - beta2 ** step
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
# scale_delta is going to be the change in log-scale, i.e. if
# p = underlying_param * scale.exp(),
# delta is the change in `scale`.
scale_delta = scale_alpha * (scale_exp_avg / scale_denom)
# the following is equivalent to:
# if torch.all(state["param_rms"] > rms_max):
# delta = -1.0e-03
# which means: if the parameter root-mean-square value goes above the
# specified limit, start to decay the parameters (and ignore the computed
# delta).
scale_delta = torch.where(state["param_rms"] > rms_max,
torch.tensor(-1.0e-03, device=scale_delta.device,
dtype=scale_delta.dtype),
scale_delta)
p.mul_(1 + scale_delta)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
# forget bias_correction1. We use normal momentum. exp_avg really
# just stores the moving-average gradient step.
bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(state["step"])
bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(step)
grad = self._change_coordinates(grad, state, forward=True)
@ -566,23 +637,22 @@ class Cain(Optimizer):
denom = (exp_avg_sq.sqrt()).add_(group["eps"])
step_size = group["lr"] # forget bias_correction1
target_rms = group["target_rms"]
weight_decay = group["weight_decay"]
this_delta = grad / denom
this_delta = self._change_coordinates(this_delta, state, forward=False)
# treat/repurpose exp_avg as moving-average of `this_delta`
exp_avg.mul_(beta1).add_(this_delta, alpha=-step_size*(1-beta1)*(bias_correction2 ** 0.5))
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors"
# (which are scalar).
is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms))
if step % 10 == 0:
state["param_rms"].fill_((p**2).mean().sqrt())
# imagine param was normalized to rms=target_rms and we stored the
# scale as a separate scalar. This is how fast we'd be learning.
alpha = (alpha / target_rms) * state["param_rms"].clamp(min=rms_eps)
# treat/repurpose exp_avg as moving-average of `this_delta`
# don't use alpha=alpha as I don't think
exp_avg.mul_(beta1).add_(this_delta, alpha=alpha)
p.add_(exp_avg)
state["step"] += 1
@ -1059,7 +1129,7 @@ def _test_eve_cain():
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
for iter in [0,1]:
for iter in [1,0]:
fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
m = torch.nn.Sequential(Linear(E, 200),
@ -1071,7 +1141,7 @@ def _test_eve_cain():
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
else: optim = Cain(m.parameters(), lr=0.003)
scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False)
scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False)
start = timeit.default_timer()
for epoch in range(150):
@ -1121,5 +1191,7 @@ def _test_eve_cain():
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_eve_cain()
#_test_eden()