mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Implement Cain with scdaling incorporated;
Removing scaling from ScaledLinear, ScaledConv1d, etc.
This commit is contained in:
parent
8fd9e64fdf
commit
abe5abb688
@ -22,6 +22,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.nn import Embedding as ScaledEmbedding
|
||||||
|
|
||||||
|
|
||||||
def _ntuple(n):
|
def _ntuple(n):
|
||||||
@ -154,10 +155,8 @@ class BasicNorm(torch.nn.Module):
|
|||||||
|
|
||||||
class ScaledLinear(nn.Linear):
|
class ScaledLinear(nn.Linear):
|
||||||
"""
|
"""
|
||||||
A modified version of nn.Linear where the parameters are scaled before
|
A modified version of nn.Linear that gives an easy way to set the
|
||||||
use, via:
|
default initial parameter scale.
|
||||||
weight = self.weight * self.weight_scale.exp()
|
|
||||||
bias = self.bias * self.bias_scale.exp()
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
Accepts the standard args and kwargs that nn.Linear accepts
|
Accepts the standard args and kwargs that nn.Linear accepts
|
||||||
@ -168,59 +167,26 @@ class ScaledLinear(nn.Linear):
|
|||||||
(affects the initialization of weight_scale and bias_scale).
|
(affects the initialization of weight_scale and bias_scale).
|
||||||
Another option, if you want to do something like this, is
|
Another option, if you want to do something like this, is
|
||||||
to re-initialize the parameters.
|
to re-initialize the parameters.
|
||||||
initial_speed: this affects how fast the parameter will
|
|
||||||
learn near the start of training; you can set it to a
|
|
||||||
value less than one if you suspect that a module
|
|
||||||
is contributing to instability near the start of training.
|
|
||||||
Nnote: regardless of the use of this option, it's best to
|
|
||||||
use schedulers like Noam that have a warm-up period.
|
|
||||||
Alternatively you can set it to more than 1 if you want it to
|
|
||||||
initially train faster. Must be greater than 0.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
initial_scale: float = 1.0,
|
initial_scale: float = 1.0,
|
||||||
initial_speed: float = 1.0,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias_scale", None)
|
|
||||||
|
|
||||||
self._reset_parameters(
|
|
||||||
initial_speed
|
|
||||||
) # Overrides the reset_parameters in nn.Linear
|
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
|
||||||
std = 0.1 / initial_speed
|
|
||||||
a = (3 ** 0.5) * std
|
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
|
||||||
if self.bias is not None:
|
|
||||||
nn.init.constant_(self.bias, 0.0)
|
|
||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
|
||||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight[:] *= initial_scale
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias[:] *= initial_scale
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self): # not needed any more but kept for back compatibility
|
||||||
return self.weight * self.weight_scale.exp()
|
return self.weight
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
if self.bias is None or self.bias_scale is None:
|
return self.bias
|
||||||
return None
|
|
||||||
|
|
||||||
return self.bias * self.bias_scale.exp()
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
|
||||||
return torch.nn.functional.linear(
|
|
||||||
input, self.get_weight(), self.get_bias()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledConv1d(nn.Conv1d):
|
class ScaledConv1d(nn.Conv1d):
|
||||||
@ -229,66 +195,20 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
initial_scale: float = 1.0,
|
initial_scale: float = 1.0,
|
||||||
initial_speed: float = 1.0,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias_scale", None)
|
|
||||||
self._reset_parameters(
|
|
||||||
initial_speed
|
|
||||||
) # Overrides the reset_parameters in base class
|
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
|
||||||
std = 0.1 / initial_speed
|
|
||||||
a = (3 ** 0.5) * std
|
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
|
||||||
if self.bias is not None:
|
|
||||||
nn.init.constant_(self.bias, 0.0)
|
|
||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
|
||||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight[:] *= initial_scale
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias[:] *= initial_scale
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
return self.weight * self.weight_scale.exp()
|
|
||||||
|
|
||||||
def get_bias(self):
|
def get_weight(self): # TODO: delete
|
||||||
bias = self.bias
|
return self.weight
|
||||||
bias_scale = self.bias_scale
|
|
||||||
if bias is None or bias_scale is None:
|
|
||||||
return None
|
|
||||||
return bias * bias_scale.exp()
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def get_bias(self): # TODO: delete
|
||||||
F = torch.nn.functional
|
return self.bias
|
||||||
if self.padding_mode != "zeros":
|
|
||||||
return F.conv1d(
|
|
||||||
F.pad(
|
|
||||||
input,
|
|
||||||
self._reversed_padding_repeated_twice,
|
|
||||||
mode=self.padding_mode,
|
|
||||||
),
|
|
||||||
self.get_weight(),
|
|
||||||
self.get_bias(),
|
|
||||||
self.stride,
|
|
||||||
(0,),
|
|
||||||
self.dilation,
|
|
||||||
self.groups,
|
|
||||||
)
|
|
||||||
return F.conv1d(
|
|
||||||
input,
|
|
||||||
self.get_weight(),
|
|
||||||
self.get_bias(),
|
|
||||||
self.stride,
|
|
||||||
self.padding,
|
|
||||||
self.dilation,
|
|
||||||
self.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledConv2d(nn.Conv2d):
|
class ScaledConv2d(nn.Conv2d):
|
||||||
@ -297,70 +217,20 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
initial_scale: float = 1.0,
|
initial_scale: float = 1.0,
|
||||||
initial_speed: float = 1.0,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias_scale", None)
|
|
||||||
self._reset_parameters(
|
|
||||||
initial_speed
|
|
||||||
) # Overrides the reset_parameters in base class
|
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
|
||||||
std = 0.1 / initial_speed
|
|
||||||
a = (3 ** 0.5) * std
|
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
|
||||||
if self.bias is not None:
|
|
||||||
nn.init.constant_(self.bias, 0.0)
|
|
||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
|
||||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight[:] *= initial_scale
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias[:] *= initial_scale
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.weight * self.weight_scale.exp()
|
return self.weight
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
# see https://github.com/pytorch/pytorch/issues/24135
|
return self.bias
|
||||||
bias = self.bias
|
|
||||||
bias_scale = self.bias_scale
|
|
||||||
if bias is None or bias_scale is None:
|
|
||||||
return None
|
|
||||||
return bias * bias_scale.exp()
|
|
||||||
|
|
||||||
def _conv_forward(self, input, weight):
|
|
||||||
F = torch.nn.functional
|
|
||||||
if self.padding_mode != "zeros":
|
|
||||||
return F.conv2d(
|
|
||||||
F.pad(
|
|
||||||
input,
|
|
||||||
self._reversed_padding_repeated_twice,
|
|
||||||
mode=self.padding_mode,
|
|
||||||
),
|
|
||||||
weight,
|
|
||||||
self.get_bias(),
|
|
||||||
self.stride,
|
|
||||||
(0, 0),
|
|
||||||
self.dilation,
|
|
||||||
self.groups,
|
|
||||||
)
|
|
||||||
return F.conv2d(
|
|
||||||
input,
|
|
||||||
weight,
|
|
||||||
self.get_bias(),
|
|
||||||
self.stride,
|
|
||||||
self.padding,
|
|
||||||
self.dilation,
|
|
||||||
self.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
|
||||||
return self._conv_forward(input, self.get_weight())
|
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancer(torch.nn.Module):
|
class ActivationBalancer(torch.nn.Module):
|
||||||
@ -464,179 +334,6 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class ScaledEmbedding(nn.Module):
|
|
||||||
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
|
||||||
on the parameters. Note: due to how we initialize it, it's best used with
|
|
||||||
schedulers like Noam that have a warmup period.
|
|
||||||
|
|
||||||
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
|
|
||||||
|
|
||||||
This module is often used to store word embeddings and retrieve them using indices.
|
|
||||||
The input to the module is a list of indices, and the output is the corresponding
|
|
||||||
word embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_embeddings (int): size of the dictionary of embeddings
|
|
||||||
embedding_dim (int): the size of each embedding vector
|
|
||||||
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
|
||||||
(initialized to zeros) whenever it encounters the index.
|
|
||||||
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
|
||||||
is renormalized to have norm :attr:`max_norm`.
|
|
||||||
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
|
||||||
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
|
||||||
the words in the mini-batch. Default ``False``.
|
|
||||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
|
||||||
See Notes for more details regarding sparse gradients.
|
|
||||||
|
|
||||||
initial_speed (float, optional): This affects how fast the parameter will
|
|
||||||
learn near the start of training; you can set it to a value less than
|
|
||||||
one if you suspect that a module is contributing to instability near
|
|
||||||
the start of training. Nnote: regardless of the use of this option,
|
|
||||||
it's best to use schedulers like Noam that have a warm-up period.
|
|
||||||
Alternatively you can set it to more than 1 if you want it to
|
|
||||||
initially train faster. Must be greater than 0.
|
|
||||||
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
|
||||||
initialized from :math:`\mathcal{N}(0, 1)`
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
|
||||||
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
Keep in mind that only a limited number of optimizers support
|
|
||||||
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
|
||||||
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
With :attr:`padding_idx` set, the embedding vector at
|
|
||||||
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
|
||||||
vector can be modified afterwards, e.g., using a customized
|
|
||||||
initialization method, and thus changing the vector used to pad the
|
|
||||||
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
|
||||||
is always zero.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> # an Embedding module containing 10 tensors of size 3
|
|
||||||
>>> embedding = nn.Embedding(10, 3)
|
|
||||||
>>> # a batch of 2 samples of 4 indices each
|
|
||||||
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
|
||||||
>>> embedding(input)
|
|
||||||
tensor([[[-0.0251, -1.6902, 0.7172],
|
|
||||||
[-0.6431, 0.0748, 0.6969],
|
|
||||||
[ 1.4970, 1.3448, -0.9685],
|
|
||||||
[-0.3677, -2.7265, -0.1685]],
|
|
||||||
|
|
||||||
[[ 1.4970, 1.3448, -0.9685],
|
|
||||||
[ 0.4362, -0.4004, 0.9400],
|
|
||||||
[-0.6431, 0.0748, 0.6969],
|
|
||||||
[ 0.9124, -2.3616, 1.1151]]])
|
|
||||||
|
|
||||||
|
|
||||||
>>> # example with padding_idx
|
|
||||||
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
|
||||||
>>> input = torch.LongTensor([[0,2,0,5]])
|
|
||||||
>>> embedding(input)
|
|
||||||
tensor([[[ 0.0000, 0.0000, 0.0000],
|
|
||||||
[ 0.1535, -2.0309, 0.9315],
|
|
||||||
[ 0.0000, 0.0000, 0.0000],
|
|
||||||
[-0.1655, 0.9897, 0.0635]]])
|
|
||||||
|
|
||||||
"""
|
|
||||||
__constants__ = [
|
|
||||||
"num_embeddings",
|
|
||||||
"embedding_dim",
|
|
||||||
"padding_idx",
|
|
||||||
"scale_grad_by_freq",
|
|
||||||
"sparse",
|
|
||||||
]
|
|
||||||
|
|
||||||
num_embeddings: int
|
|
||||||
embedding_dim: int
|
|
||||||
padding_idx: int
|
|
||||||
scale_grad_by_freq: bool
|
|
||||||
weight: Tensor
|
|
||||||
sparse: bool
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_embeddings: int,
|
|
||||||
embedding_dim: int,
|
|
||||||
padding_idx: Optional[int] = None,
|
|
||||||
scale_grad_by_freq: bool = False,
|
|
||||||
sparse: bool = False,
|
|
||||||
initial_speed: float = 1.0,
|
|
||||||
) -> None:
|
|
||||||
super(ScaledEmbedding, self).__init__()
|
|
||||||
self.num_embeddings = num_embeddings
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
if padding_idx is not None:
|
|
||||||
if padding_idx > 0:
|
|
||||||
assert (
|
|
||||||
padding_idx < self.num_embeddings
|
|
||||||
), "Padding_idx must be within num_embeddings"
|
|
||||||
elif padding_idx < 0:
|
|
||||||
assert (
|
|
||||||
padding_idx >= -self.num_embeddings
|
|
||||||
), "Padding_idx must be within num_embeddings"
|
|
||||||
padding_idx = self.num_embeddings + padding_idx
|
|
||||||
self.padding_idx = padding_idx
|
|
||||||
self.scale_grad_by_freq = scale_grad_by_freq
|
|
||||||
|
|
||||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
|
||||||
self.sparse = sparse
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
|
||||||
self.reset_parameters(initial_speed)
|
|
||||||
|
|
||||||
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
|
||||||
std = 0.1 / initial_speed
|
|
||||||
nn.init.normal_(self.weight, std=std)
|
|
||||||
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
|
|
||||||
|
|
||||||
if self.padding_idx is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
self.weight[self.padding_idx].fill_(0)
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
|
||||||
F = torch.nn.functional
|
|
||||||
scale = self.scale.exp()
|
|
||||||
if input.numel() < self.num_embeddings:
|
|
||||||
return (
|
|
||||||
F.embedding(
|
|
||||||
input,
|
|
||||||
self.weight,
|
|
||||||
self.padding_idx,
|
|
||||||
None,
|
|
||||||
2.0, # None, 2.0 relate to normalization
|
|
||||||
self.scale_grad_by_freq,
|
|
||||||
self.sparse,
|
|
||||||
)
|
|
||||||
* scale
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return F.embedding(
|
|
||||||
input,
|
|
||||||
self.weight * scale,
|
|
||||||
self.padding_idx,
|
|
||||||
None,
|
|
||||||
2.0, # None, 2.0 relates to normalization
|
|
||||||
self.scale_grad_by_freq,
|
|
||||||
self.sparse,
|
|
||||||
)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
|
||||||
if self.padding_idx is not None:
|
|
||||||
s += ", padding_idx={padding_idx}"
|
|
||||||
if self.scale_grad_by_freq is not False:
|
|
||||||
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
|
||||||
if self.sparse is not False:
|
|
||||||
s += ", sparse=True"
|
|
||||||
return s.format(**self.__dict__)
|
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_sign():
|
def _test_activation_balancer_sign():
|
||||||
|
@ -205,9 +205,16 @@ class Abel(Optimizer):
|
|||||||
eps (float, optional): term added to the denominator to improve
|
eps (float, optional): term added to the denominator to improve
|
||||||
numerical stability (default: 1e-08).
|
numerical stability (default: 1e-08).
|
||||||
target_rms (float, optional): target root-mean-square value of
|
target_rms (float, optional): target root-mean-square value of
|
||||||
x_norm in factorization (conceptually). Actually this now just becomes
|
parameters, when we normalize them (only conceptually!)..
|
||||||
a factor in the learning rate, we may remove it at some point.
|
actually this now just becomes
|
||||||
|
a factor in the learning rate for non-scalar parameters.
|
||||||
|
rms_eps (float, optional): epsilon value such that when parameter
|
||||||
|
rms value goes below this, we stop learning slower for that parameter,
|
||||||
|
i.e. we treat the rms as if it were rms_eps.
|
||||||
|
rms_max (float, optional): maximum root-mean-square value for
|
||||||
|
non-scalar parameters, such that we don't let the parameter
|
||||||
|
values exceed this; we'll shrink any parameter matrices that become
|
||||||
|
larger than this.
|
||||||
|
|
||||||
.. _Adam\: A Method for Stochastic Optimization:
|
.. _Adam\: A Method for Stochastic Optimization:
|
||||||
https://arxiv.org/abs/1412.6980
|
https://arxiv.org/abs/1412.6980
|
||||||
@ -438,6 +445,9 @@ class Cain(Optimizer):
|
|||||||
co-ordinates every so often, just in the optimizer, to separate big/small
|
co-ordinates every so often, just in the optimizer, to separate big/small
|
||||||
directions.
|
directions.
|
||||||
|
|
||||||
|
This version of Cain absorbs the learned scales of the parameter matrices into
|
||||||
|
the optimizer.
|
||||||
|
|
||||||
Eve is a modified version of AdamW with a special
|
Eve is a modified version of AdamW with a special
|
||||||
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||||
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
||||||
@ -456,12 +466,17 @@ class Cain(Optimizer):
|
|||||||
running averages of gradient and its square (default: (0.9, 0.999))
|
running averages of gradient and its square (default: (0.9, 0.999))
|
||||||
eps (float, optional): term added to the denominator to improve
|
eps (float, optional): term added to the denominator to improve
|
||||||
numerical stability (default: 1e-8)
|
numerical stability (default: 1e-8)
|
||||||
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
|
||||||
this value means that the weight would decay significantly after
|
|
||||||
about 3k minibatches. Is not multiplied by learning rate, but
|
|
||||||
is conditional on RMS-value of parameter being > target_rms.
|
|
||||||
target_rms (float, optional): target root-mean-square value of
|
target_rms (float, optional): target root-mean-square value of
|
||||||
parameters, if they fall below this we will stop applying weight decay.
|
parameters, if they fall below this we will stop applying weight decay.
|
||||||
|
rms_eps (float, optional): epsilon value such that when parameter
|
||||||
|
rms value goes below this, we stop learning slower for that parameter,
|
||||||
|
i.e. we treat the rms as if it were rms_eps.
|
||||||
|
rms_max (float, optional): maximum root-mean-square value for
|
||||||
|
non-scalar parameters, such that we don't let the parameter
|
||||||
|
values exceed this; we'll shrink any parameter matrices that become
|
||||||
|
larger than this.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
.. _Adam\: A Method for Stochastic Optimization:
|
.. _Adam\: A Method for Stochastic Optimization:
|
||||||
@ -478,8 +493,9 @@ class Cain(Optimizer):
|
|||||||
lr=1e-3,
|
lr=1e-3,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
eps=1e-8,
|
eps=1e-8,
|
||||||
weight_decay=1e-3,
|
|
||||||
target_rms=0.1,
|
target_rms=0.1,
|
||||||
|
rms_eps=1.0e-05,
|
||||||
|
rms_max=10.0,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
@ -494,18 +510,20 @@ class Cain(Optimizer):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||||
)
|
)
|
||||||
if not 0 <= weight_decay <= 0.1:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid weight_decay value: {}".format(weight_decay)
|
|
||||||
)
|
|
||||||
if not 0 < target_rms <= 10.0:
|
if not 0 < target_rms <= 10.0:
|
||||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||||
|
if not 0.1 < rms_max <= 1000.0:
|
||||||
|
raise ValueError("Invalid rms_max value: {}".format(rms_max))
|
||||||
|
if not 0.0 < rms_eps < 0.1:
|
||||||
|
raise ValueError("Invalid rms_eps value: {}".format(rms_eps))
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
weight_decay=weight_decay,
|
|
||||||
target_rms=target_rms,
|
target_rms=target_rms,
|
||||||
|
rms_eps=rms_eps,
|
||||||
|
rms_max=rms_max,
|
||||||
)
|
)
|
||||||
super(Cain, self).__init__(params, defaults)
|
super(Cain, self).__init__(params, defaults)
|
||||||
|
|
||||||
@ -526,6 +544,12 @@ class Cain(Optimizer):
|
|||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
|
beta1, beta2 = group["betas"]
|
||||||
|
lr = group["lr"]
|
||||||
|
target_rms = group["target_rms"]
|
||||||
|
rms_eps = group["rms_eps"]
|
||||||
|
rms_max = group["rms_max"]
|
||||||
|
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
@ -550,14 +574,61 @@ class Cain(Optimizer):
|
|||||||
state["exp_avg_sq"] = torch.zeros_like(
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
p, memory_format=torch.preserve_format
|
p, memory_format=torch.preserve_format
|
||||||
)
|
)
|
||||||
|
if p.numel() > 1:
|
||||||
|
# we keep track of the RMS value of the parameters to
|
||||||
|
# help set the learning rate... this emulates Eve.
|
||||||
|
state["param_rms"] = (p**2).mean().sqrt()
|
||||||
|
# we learn the scale on the parameters in a way
|
||||||
|
# that also emulates Eve.
|
||||||
|
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
|
||||||
|
dtype=p.dtype)
|
||||||
|
state["scale_exp_avg"] = torch.zeros((), device=p.device,
|
||||||
|
dtype=p.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
step = state["step"]
|
||||||
|
|
||||||
|
if p.numel() > 1:
|
||||||
|
# This part learns the scale on the parameters.
|
||||||
|
# scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e.
|
||||||
|
# if we imagine: p = underlying_param * scale.exp(), this is the derivative
|
||||||
|
# w.r.t. scale (it does not actually matter what value `scale` currently has).
|
||||||
|
scale_deriv = (p * grad).sum()
|
||||||
|
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||||
|
scale_exp_avg = state["scale_exp_avg"]
|
||||||
|
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
|
||||||
|
value=1 - beta2)
|
||||||
|
|
||||||
|
scale_exp_avg.mul_(beta1).add_(scale_deriv, alpha=(1-beta1))
|
||||||
|
scale_bias_correction2 = 1 - beta2 ** step
|
||||||
|
|
||||||
|
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
|
||||||
|
|
||||||
|
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
|
||||||
|
|
||||||
|
# scale_delta is going to be the change in log-scale, i.e. if
|
||||||
|
# p = underlying_param * scale.exp(),
|
||||||
|
# delta is the change in `scale`.
|
||||||
|
scale_delta = scale_alpha * (scale_exp_avg / scale_denom)
|
||||||
|
|
||||||
|
# the following is equivalent to:
|
||||||
|
# if torch.all(state["param_rms"] > rms_max):
|
||||||
|
# delta = -1.0e-03
|
||||||
|
# which means: if the parameter root-mean-square value goes above the
|
||||||
|
# specified limit, start to decay the parameters (and ignore the computed
|
||||||
|
# delta).
|
||||||
|
scale_delta = torch.where(state["param_rms"] > rms_max,
|
||||||
|
torch.tensor(-1.0e-03, device=scale_delta.device,
|
||||||
|
dtype=scale_delta.dtype),
|
||||||
|
scale_delta)
|
||||||
|
|
||||||
|
p.mul_(1 + scale_delta)
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
|
|
||||||
beta1, beta2 = group["betas"]
|
|
||||||
|
|
||||||
# forget bias_correction1. We use normal momentum. exp_avg really
|
# forget bias_correction1. We use normal momentum. exp_avg really
|
||||||
# just stores the moving-average gradient step.
|
# just stores the moving-average gradient step.
|
||||||
bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(state["step"])
|
bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(step)
|
||||||
|
|
||||||
grad = self._change_coordinates(grad, state, forward=True)
|
grad = self._change_coordinates(grad, state, forward=True)
|
||||||
|
|
||||||
@ -566,23 +637,22 @@ class Cain(Optimizer):
|
|||||||
denom = (exp_avg_sq.sqrt()).add_(group["eps"])
|
denom = (exp_avg_sq.sqrt()).add_(group["eps"])
|
||||||
|
|
||||||
|
|
||||||
step_size = group["lr"] # forget bias_correction1
|
|
||||||
target_rms = group["target_rms"]
|
|
||||||
weight_decay = group["weight_decay"]
|
|
||||||
|
|
||||||
this_delta = grad / denom
|
this_delta = grad / denom
|
||||||
this_delta = self._change_coordinates(this_delta, state, forward=False)
|
this_delta = self._change_coordinates(this_delta, state, forward=False)
|
||||||
|
|
||||||
# treat/repurpose exp_avg as moving-average of `this_delta`
|
|
||||||
exp_avg.mul_(beta1).add_(this_delta, alpha=-step_size*(1-beta1)*(bias_correction2 ** 0.5))
|
|
||||||
|
|
||||||
|
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
|
||||||
if p.numel() > 1:
|
if p.numel() > 1:
|
||||||
# avoid applying this weight-decay on "scaling factors"
|
if step % 10 == 0:
|
||||||
# (which are scalar).
|
state["param_rms"].fill_((p**2).mean().sqrt())
|
||||||
is_above_target_rms = p.norm() > (
|
# imagine param was normalized to rms=target_rms and we stored the
|
||||||
target_rms * (p.numel() ** 0.5)
|
# scale as a separate scalar. This is how fast we'd be learning.
|
||||||
)
|
alpha = (alpha / target_rms) * state["param_rms"].clamp(min=rms_eps)
|
||||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
|
||||||
|
# treat/repurpose exp_avg as moving-average of `this_delta`
|
||||||
|
# don't use alpha=alpha as I don't think
|
||||||
|
exp_avg.mul_(beta1).add_(this_delta, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
p.add_(exp_avg)
|
p.add_(exp_avg)
|
||||||
state["step"] += 1
|
state["step"] += 1
|
||||||
@ -1059,7 +1129,7 @@ def _test_eve_cain():
|
|||||||
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||||
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||||
|
|
||||||
for iter in [0,1]:
|
for iter in [1,0]:
|
||||||
fix_random_seed(42)
|
fix_random_seed(42)
|
||||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||||
m = torch.nn.Sequential(Linear(E, 200),
|
m = torch.nn.Sequential(Linear(E, 200),
|
||||||
@ -1071,7 +1141,7 @@ def _test_eve_cain():
|
|||||||
|
|
||||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||||
else: optim = Cain(m.parameters(), lr=0.003)
|
else: optim = Cain(m.parameters(), lr=0.003)
|
||||||
scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False)
|
scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False)
|
||||||
|
|
||||||
start = timeit.default_timer()
|
start = timeit.default_timer()
|
||||||
for epoch in range(150):
|
for epoch in range(150):
|
||||||
@ -1121,5 +1191,7 @@ def _test_eve_cain():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
_test_eve_cain()
|
_test_eve_cain()
|
||||||
#_test_eden()
|
#_test_eden()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user