Implement Cain with scdaling incorporated;

Removing scaling from ScaledLinear, ScaledConv1d, etc.
This commit is contained in:
Daniel Povey 2022-05-20 13:36:01 +08:00
parent 8fd9e64fdf
commit abe5abb688
2 changed files with 123 additions and 354 deletions

View File

@ -22,6 +22,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding
def _ntuple(n): def _ntuple(n):
@ -154,10 +155,8 @@ class BasicNorm(torch.nn.Module):
class ScaledLinear(nn.Linear): class ScaledLinear(nn.Linear):
""" """
A modified version of nn.Linear where the parameters are scaled before A modified version of nn.Linear that gives an easy way to set the
use, via: default initial parameter scale.
weight = self.weight * self.weight_scale.exp()
bias = self.bias * self.bias_scale.exp()
Args: Args:
Accepts the standard args and kwargs that nn.Linear accepts Accepts the standard args and kwargs that nn.Linear accepts
@ -168,59 +167,26 @@ class ScaledLinear(nn.Linear):
(affects the initialization of weight_scale and bias_scale). (affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is Another option, if you want to do something like this, is
to re-initialize the parameters. to re-initialize the parameters.
initial_speed: this affects how fast the parameter will
learn near the start of training; you can set it to a
value less than one if you suspect that a module
is contributing to instability near the start of training.
Nnote: regardless of the use of this option, it's best to
use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
""" """
def __init__( def __init__(
self, self,
*args, *args,
initial_scale: float = 1.0, initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs **kwargs
): ):
super(ScaledLinear, self).__init__(*args, **kwargs) super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad(): with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log() self.weight[:] *= initial_scale
if self.bias is not None:
self.bias[:] *= initial_scale
def get_weight(self): def get_weight(self): # not needed any more but kept for back compatibility
return self.weight * self.weight_scale.exp() return self.weight
def get_bias(self): def get_bias(self):
if self.bias is None or self.bias_scale is None: return self.bias
return None
return self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(
input, self.get_weight(), self.get_bias()
)
class ScaledConv1d(nn.Conv1d): class ScaledConv1d(nn.Conv1d):
@ -229,66 +195,20 @@ class ScaledConv1d(nn.Conv1d):
self, self,
*args, *args,
initial_scale: float = 1.0, initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs **kwargs
): ):
super(ScaledConv1d, self).__init__(*args, **kwargs) super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad(): with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log() self.weight[:] *= initial_scale
if self.bias is not None:
self.bias[:] *= initial_scale
def get_weight(self):
return self.weight * self.weight_scale.exp()
def get_bias(self): def get_weight(self): # TODO: delete
bias = self.bias return self.weight
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
def forward(self, input: Tensor) -> Tensor: def get_bias(self): # TODO: delete
F = torch.nn.functional return self.bias
if self.padding_mode != "zeros":
return F.conv1d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
self.get_weight(),
self.get_bias(),
self.stride,
(0,),
self.dilation,
self.groups,
)
return F.conv1d(
input,
self.get_weight(),
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
class ScaledConv2d(nn.Conv2d): class ScaledConv2d(nn.Conv2d):
@ -297,70 +217,20 @@ class ScaledConv2d(nn.Conv2d):
self, self,
*args, *args,
initial_scale: float = 1.0, initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs **kwargs
): ):
super(ScaledConv2d, self).__init__(*args, **kwargs) super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad(): with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log() self.weight[:] *= initial_scale
if self.bias is not None:
self.bias[:] *= initial_scale
def get_weight(self): def get_weight(self):
return self.weight * self.weight_scale.exp() return self.weight
def get_bias(self): def get_bias(self):
# see https://github.com/pytorch/pytorch/issues/24135 return self.bias
bias = self.bias
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
def _conv_forward(self, input, weight):
F = torch.nn.functional
if self.padding_mode != "zeros":
return F.conv2d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
weight,
self.get_bias(),
self.stride,
(0, 0),
self.dilation,
self.groups,
)
return F.conv2d(
input,
weight,
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.get_weight())
class ActivationBalancer(torch.nn.Module): class ActivationBalancer(torch.nn.Module):
@ -464,179 +334,6 @@ class DoubleSwish(torch.nn.Module):
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module):
r"""This is a modified version of nn.Embedding that introduces a learnable scale
on the parameters. Note: due to how we initialize it, it's best used with
schedulers like Noam that have a warmup period.
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
See Notes for more details regarding sparse gradients.
initial_speed (float, optional): This affects how fast the parameter will
learn near the start of training; you can set it to a value less than
one if you suspect that a module is contributing to instability near
the start of training. Nnote: regardless of the use of this option,
it's best to use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)`
Shape:
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
.. note::
Keep in mind that only a limited number of optimizers support
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
.. note::
With :attr:`padding_idx` set, the embedding vector at
:attr:`padding_idx` is initialized to all zeros. However, note that this
vector can be modified afterwards, e.g., using a customized
initialization method, and thus changing the vector used to pad the
output. The gradient for this vector from :class:`~torch.nn.Embedding`
is always zero.
Examples::
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
"""
__constants__ = [
"num_embeddings",
"embedding_dim",
"padding_idx",
"scale_grad_by_freq",
"sparse",
]
num_embeddings: int
embedding_dim: int
padding_idx: int
scale_grad_by_freq: bool
weight: Tensor
sparse: bool
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
initial_speed: float = 1.0,
) -> None:
super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters(initial_speed)
def reset_parameters(self, initial_speed: float = 1.0) -> None:
std = 0.1 / initial_speed
nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
scale = self.scale.exp()
if input.numel() < self.num_embeddings:
return (
F.embedding(
input,
self.weight,
self.padding_idx,
None,
2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq,
self.sparse,
)
* scale
)
else:
return F.embedding(
input,
self.weight * scale,
self.padding_idx,
None,
2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq,
self.sparse,
)
def extra_repr(self) -> str:
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
if self.padding_idx is not None:
s += ", padding_idx={padding_idx}"
if self.scale_grad_by_freq is not False:
s += ", scale_grad_by_freq={scale_grad_by_freq}"
if self.sparse is not False:
s += ", sparse=True"
return s.format(**self.__dict__)
def _test_activation_balancer_sign(): def _test_activation_balancer_sign():

View File

@ -205,9 +205,16 @@ class Abel(Optimizer):
eps (float, optional): term added to the denominator to improve eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-08). numerical stability (default: 1e-08).
target_rms (float, optional): target root-mean-square value of target_rms (float, optional): target root-mean-square value of
x_norm in factorization (conceptually). Actually this now just becomes parameters, when we normalize them (only conceptually!)..
a factor in the learning rate, we may remove it at some point. actually this now just becomes
a factor in the learning rate for non-scalar parameters.
rms_eps (float, optional): epsilon value such that when parameter
rms value goes below this, we stop learning slower for that parameter,
i.e. we treat the rms as if it were rms_eps.
rms_max (float, optional): maximum root-mean-square value for
non-scalar parameters, such that we don't let the parameter
values exceed this; we'll shrink any parameter matrices that become
larger than this.
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
@ -438,6 +445,9 @@ class Cain(Optimizer):
co-ordinates every so often, just in the optimizer, to separate big/small co-ordinates every so often, just in the optimizer, to separate big/small
directions. directions.
This version of Cain absorbs the learned scales of the parameter matrices into
the optimizer.
Eve is a modified version of AdamW with a special Eve is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular target_rms (default: 0.1). This is rms of the parameters approach a particular target_rms (default: 0.1). This is
@ -456,12 +466,17 @@ class Cain(Optimizer):
running averages of gradient and its square (default: (0.9, 0.999)) running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8) numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
this value means that the weight would decay significantly after
about 3k minibatches. Is not multiplied by learning rate, but
is conditional on RMS-value of parameter being > target_rms.
target_rms (float, optional): target root-mean-square value of target_rms (float, optional): target root-mean-square value of
parameters, if they fall below this we will stop applying weight decay. parameters, if they fall below this we will stop applying weight decay.
rms_eps (float, optional): epsilon value such that when parameter
rms value goes below this, we stop learning slower for that parameter,
i.e. we treat the rms as if it were rms_eps.
rms_max (float, optional): maximum root-mean-square value for
non-scalar parameters, such that we don't let the parameter
values exceed this; we'll shrink any parameter matrices that become
larger than this.
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
@ -478,8 +493,9 @@ class Cain(Optimizer):
lr=1e-3, lr=1e-3,
betas=(0.9, 0.98), betas=(0.9, 0.98),
eps=1e-8, eps=1e-8,
weight_decay=1e-3,
target_rms=0.1, target_rms=0.1,
rms_eps=1.0e-05,
rms_max=10.0,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
@ -494,18 +510,20 @@ class Cain(Optimizer):
raise ValueError( raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1]) "Invalid beta parameter at index 1: {}".format(betas[1])
) )
if not 0 <= weight_decay <= 0.1:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0: if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms)) raise ValueError("Invalid target_rms value: {}".format(target_rms))
if not 0.1 < rms_max <= 1000.0:
raise ValueError("Invalid rms_max value: {}".format(rms_max))
if not 0.0 < rms_eps < 0.1:
raise ValueError("Invalid rms_eps value: {}".format(rms_eps))
defaults = dict( defaults = dict(
lr=lr, lr=lr,
betas=betas, betas=betas,
eps=eps, eps=eps,
weight_decay=weight_decay,
target_rms=target_rms, target_rms=target_rms,
rms_eps=rms_eps,
rms_max=rms_max,
) )
super(Cain, self).__init__(params, defaults) super(Cain, self).__init__(params, defaults)
@ -526,6 +544,12 @@ class Cain(Optimizer):
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
target_rms = group["target_rms"]
rms_eps = group["rms_eps"]
rms_max = group["rms_max"]
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
@ -550,14 +574,61 @@ class Cain(Optimizer):
state["exp_avg_sq"] = torch.zeros_like( state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format p, memory_format=torch.preserve_format
) )
if p.numel() > 1:
# we keep track of the RMS value of the parameters to
# help set the learning rate... this emulates Eve.
state["param_rms"] = (p**2).mean().sqrt()
# we learn the scale on the parameters in a way
# that also emulates Eve.
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
dtype=p.dtype)
state["scale_exp_avg"] = torch.zeros((), device=p.device,
dtype=p.dtype)
step = state["step"]
if p.numel() > 1:
# This part learns the scale on the parameters.
# scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e.
# if we imagine: p = underlying_param * scale.exp(), this is the derivative
# w.r.t. scale (it does not actually matter what value `scale` currently has).
scale_deriv = (p * grad).sum()
scale_exp_avg_sq = state["scale_exp_avg_sq"]
scale_exp_avg = state["scale_exp_avg"]
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
value=1 - beta2)
scale_exp_avg.mul_(beta1).add_(scale_deriv, alpha=(1-beta1))
scale_bias_correction2 = 1 - beta2 ** step
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
# scale_delta is going to be the change in log-scale, i.e. if
# p = underlying_param * scale.exp(),
# delta is the change in `scale`.
scale_delta = scale_alpha * (scale_exp_avg / scale_denom)
# the following is equivalent to:
# if torch.all(state["param_rms"] > rms_max):
# delta = -1.0e-03
# which means: if the parameter root-mean-square value goes above the
# specified limit, start to decay the parameters (and ignore the computed
# delta).
scale_delta = torch.where(state["param_rms"] > rms_max,
torch.tensor(-1.0e-03, device=scale_delta.device,
dtype=scale_delta.dtype),
scale_delta)
p.mul_(1 + scale_delta)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
# forget bias_correction1. We use normal momentum. exp_avg really # forget bias_correction1. We use normal momentum. exp_avg really
# just stores the moving-average gradient step. # just stores the moving-average gradient step.
bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(state["step"]) bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(step)
grad = self._change_coordinates(grad, state, forward=True) grad = self._change_coordinates(grad, state, forward=True)
@ -566,23 +637,22 @@ class Cain(Optimizer):
denom = (exp_avg_sq.sqrt()).add_(group["eps"]) denom = (exp_avg_sq.sqrt()).add_(group["eps"])
step_size = group["lr"] # forget bias_correction1
target_rms = group["target_rms"]
weight_decay = group["weight_decay"]
this_delta = grad / denom this_delta = grad / denom
this_delta = self._change_coordinates(this_delta, state, forward=False) this_delta = self._change_coordinates(this_delta, state, forward=False)
# treat/repurpose exp_avg as moving-average of `this_delta`
exp_avg.mul_(beta1).add_(this_delta, alpha=-step_size*(1-beta1)*(bias_correction2 ** 0.5))
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
if p.numel() > 1: if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors" if step % 10 == 0:
# (which are scalar). state["param_rms"].fill_((p**2).mean().sqrt())
is_above_target_rms = p.norm() > ( # imagine param was normalized to rms=target_rms and we stored the
target_rms * (p.numel() ** 0.5) # scale as a separate scalar. This is how fast we'd be learning.
) alpha = (alpha / target_rms) * state["param_rms"].clamp(min=rms_eps)
p.mul_(1 - (weight_decay * is_above_target_rms))
# treat/repurpose exp_avg as moving-average of `this_delta`
# don't use alpha=alpha as I don't think
exp_avg.mul_(beta1).add_(this_delta, alpha=alpha)
p.add_(exp_avg) p.add_(exp_avg)
state["step"] += 1 state["step"] += 1
@ -1059,7 +1129,7 @@ def _test_eve_cain():
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
for iter in [0,1]: for iter in [1,0]:
fix_random_seed(42) fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear Linear = torch.nn.Linear if iter == 0 else ScaledLinear
m = torch.nn.Sequential(Linear(E, 200), m = torch.nn.Sequential(Linear(E, 200),
@ -1071,7 +1141,7 @@ def _test_eve_cain():
if iter == 0: optim = Eve(m.parameters(), lr=0.003) if iter == 0: optim = Eve(m.parameters(), lr=0.003)
else: optim = Cain(m.parameters(), lr=0.003) else: optim = Cain(m.parameters(), lr=0.003)
scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False) scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False)
start = timeit.default_timer() start = timeit.default_timer()
for epoch in range(150): for epoch in range(150):
@ -1121,5 +1191,7 @@ def _test_eve_cain():
if __name__ == "__main__": if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_eve_cain() _test_eve_cain()
#_test_eden() #_test_eden()