mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
In the middle of reworking for new idea
This commit is contained in:
parent
41368f6b63
commit
e9e2a85c95
@ -48,7 +48,10 @@ class LearnedGradient(Optimizer):
|
||||
learning the scale on the parameters (we'll keep it >= this size)
|
||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll keep it <= this size)
|
||||
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrix.
|
||||
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
||||
orthogonalize_period: How often we re-orthogonalize the gradient covariance; this
|
||||
gets multiplied by lr_est_period, i.e. we orthogonalize less frequently
|
||||
than we update the learning-rate matrixes.
|
||||
max_step_scale: Value that determines limit on how large steps can be per element,
|
||||
intended to prevent instability during early steps before the learning-rate
|
||||
matrices are well learned.
|
||||
@ -58,12 +61,14 @@ class LearnedGradient(Optimizer):
|
||||
params,
|
||||
lr=3e-02,
|
||||
size_lr_scale=0.1,
|
||||
meta_lr_scale=1.0,
|
||||
meta_lr_scale=0.1,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1.0e-08,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=2.0,
|
||||
lr_est_period=10,
|
||||
max_lr_eig=4.0,
|
||||
orthogonalize_period=4,
|
||||
max_step_scale=2.0
|
||||
):
|
||||
|
||||
@ -77,6 +82,8 @@ class LearnedGradient(Optimizer):
|
||||
param_min_rms=param_min_rms,
|
||||
param_max_rms=param_max_rms,
|
||||
lr_est_period=lr_est_period,
|
||||
max_lr_eig=max_lr_eig,
|
||||
orthogonalize_period=orthogonalize_period,
|
||||
max_step_scale=max_step_scale,
|
||||
)
|
||||
|
||||
@ -104,16 +111,15 @@ class LearnedGradient(Optimizer):
|
||||
meta_lr_scale = group["meta_lr_scale"]
|
||||
size_lr = lr * group["size_lr_scale"]
|
||||
beta1, beta2 = group["betas"]
|
||||
max_lr_eig = group["max_lr_eig"]
|
||||
eps = group["eps"]
|
||||
size_update_period = 4
|
||||
param_min_eps = group["param_min_rms"]
|
||||
param_max_eps = group["param_max_rms"]
|
||||
lr_est_period = group["lr_est_period"]
|
||||
orthogonalize_period = group["orthogonalize_period"]
|
||||
max_step_scale = group["max_step_scale"]
|
||||
|
||||
|
||||
small_tensor_cutoff = 5 # cutoff below which we use Adam
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
@ -134,60 +140,69 @@ class LearnedGradient(Optimizer):
|
||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||
|
||||
|
||||
if p.numel() > 1:
|
||||
# we learn the scale on the parameters in a way that
|
||||
# also emulates Eve. scale_exp_avg_sq is the
|
||||
# moving-average square of the gradient w.r.t. this
|
||||
# scalar scale.
|
||||
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
state["scale_exp_avg"] = torch.zeros((), **kwargs)
|
||||
# for speed, we do the scale update only periodically,
|
||||
# and meanwhile we store the derivs w.r.t. the grad scale
|
||||
# as a list
|
||||
state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
|
||||
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
# 'delta' implements conventional momentum. There are
|
||||
# several different kinds of update going on, so rather than
|
||||
# compute "exp_avg" like in Adam, we store and decay a
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
if p.numel() > small_tensor_cutoff:
|
||||
# moving_grad is a moving average of the raw gradient.
|
||||
# We do the accumulation on the input side because it
|
||||
# makes it easier to get grads w.r.t. the learning rate
|
||||
# matrices.
|
||||
|
||||
# "scale" is just a per-tensor scalar that determines the
|
||||
# rate of learning (in terms of the rms value of the change
|
||||
# in parameter)
|
||||
if p.numel() > 1:
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
param_rms = (p**2).mean().sqrt().add_(eps)
|
||||
state["param_rms"] = param_rms
|
||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq
|
||||
state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs)
|
||||
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1 or size == p.numel():
|
||||
if size == 1:
|
||||
# if size == p.numel(), is_one_axis == True above
|
||||
continue
|
||||
|
||||
state[f"lr_{dim}"] = torch.eye(size, **kwargs)
|
||||
# Some parts of the update -- the parts where we are
|
||||
# doing to update lr_{dim} will require grad. these
|
||||
# parts will be done inside torch.enable_grad().
|
||||
# The rest of this function is inside no_grad(), so
|
||||
# it will ignore the requires_grad == True.
|
||||
state[f"lr_{dim}"].requires_grad = True
|
||||
# exp_avg_sq is in the projected space.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
|
||||
# proj_{dim} will be the learning-rate matrix for this
|
||||
# dimension, times an orthogonal matrix that we'll periodically
|
||||
# re-estimate when we "reorthogonalize" (this orthogonalizes
|
||||
# the gradient covariance).
|
||||
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
|
||||
|
||||
# proj_grad_{dim} is the gradient w.r.t. proj_{grad}
|
||||
# (technically w.r.t a matrix to the left of
|
||||
# proj_{grad}, we'll project it later); we accumulate
|
||||
# this for `lr_est_period` steps, and then update
|
||||
# proj_{dim} and zero this matrix
|
||||
state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||
# This is needed when we re-orthogonalize, and to compute the
|
||||
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
||||
|
||||
# grad_cov_{dim} is the covariance of gradients on this axis,
|
||||
# treating all other axes as as a batch axis. This is needed
|
||||
# only for purposes of computing the scalar factor on the
|
||||
# learning rate; and, as a result of this, also contributes something
|
||||
# to the gradient w.r.t. f"lr_{dim}", which is one reason
|
||||
# why we allocate a variable to keep track of its moving average
|
||||
# instead of just using a temporary and smoothing the scalar factor.
|
||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
else:
|
||||
# size <= small_tensor_cutoff: do a form of Adam
|
||||
state["exp_avg_sq"] = torch.zeros_like(p,
|
||||
mamory_format=preserve_memory_format)
|
||||
|
||||
# proj_grad_var_{dim} is the variance of gradients of the thing we're
|
||||
# going to multiply proj_{dim} by (between it and the "underlying
|
||||
# parameter matrix"), aggregated over each of its 2 dimensions
|
||||
# and decayed over time with beta1 each time we do a step on proj_{dim},
|
||||
# e.g. every lr_est_period steps. It's a factored representation
|
||||
# of something like exp_avg_sq, used for estimating proj_{dim}; we avoid
|
||||
# storing an exp_avg_sq for proj_{dim}, to save GPU memory.
|
||||
state[f"proj_grad_var_{dim}"] = torch.ones(2, size, **kwargs)
|
||||
|
||||
|
||||
step = state["step"]
|
||||
@ -206,14 +221,14 @@ class LearnedGradient(Optimizer):
|
||||
beta1, beta2, step, size_lr,
|
||||
param_min_rms, param_max_rms)
|
||||
|
||||
if numel <= small_tensor_cutoff:
|
||||
if numel == 1:
|
||||
# For parameters with very few elements we just use a form
|
||||
# of Adam with a scale factor to reflect the overall
|
||||
# parameter rms.
|
||||
self._step_small(beta1, beta2, eps, lr, p, grad, state)
|
||||
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
|
||||
else:
|
||||
if step % lr_est_period == 0:
|
||||
self._update_lrs(group, grad, state)
|
||||
self._update_lrs(group, p, state)
|
||||
self._step(group, p, grad, state)
|
||||
|
||||
state["step"] = step + 1
|
||||
@ -289,28 +304,129 @@ class LearnedGradient(Optimizer):
|
||||
|
||||
def _update_lrs(self,
|
||||
group: dict,
|
||||
grad: Tensor,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Updates the learning-rate matrices state["lr_{dim}"].
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
grad: current gradient for the parameter that we are updating
|
||||
p: parameter matrix that we are updating. The learning rate matrices
|
||||
are actually factors of p, so p itself will change when we change
|
||||
them.
|
||||
state: state dict for the current parameter
|
||||
"""
|
||||
# meta_lr is the learning rate for the learning rate matrix.
|
||||
meta_lr = group["lr"] * group["meta_lr_scale"]
|
||||
# meta_lr is the learning rate for the learning rate matrix. Including the factor
|
||||
# (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant
|
||||
# to the lr_est_period (close to convergence).
|
||||
meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5)
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
moving_g = state["exp_avg"]
|
||||
|
||||
ndim = p.ndim
|
||||
|
||||
# Update grad_cov.
|
||||
self._store_grad_stats(grad, state, beta2)
|
||||
for dim in range(ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
|
||||
p_t = p.transpose(dim, -1)
|
||||
this_p = p_t.reshape(-1, size)
|
||||
|
||||
# proj_grad is a summed gradient, indexed [grad_idx][param_idx] (don't worry
|
||||
# if this notation about the indexing is unclear, just keep reading).
|
||||
#
|
||||
# Suppose M is our parameter matrix p reshaped to (size, -1) with the -1 being
|
||||
# all other dims treated as a batch dim, and M_grad is the loss derivative w.r.t. M,
|
||||
#
|
||||
# Then proj_grad is the accumulated value of M_grad M^T, which will be of shape
|
||||
# (size, size). If we view the parameter matrix as actually (proj M) with
|
||||
# proj being numerically equal to I, i.e. as
|
||||
# matmul(proj, M), then instead of the loss being tr(M_grad^T M), we can write
|
||||
# it as tr(M_grad^T proj M), which can also be written as tr(M M_grad^T proj)
|
||||
# = tr( (M_grad M^T)^T proj), where we can write M_grad M^T == proj_grad.
|
||||
#
|
||||
# Now, proj_grad is convenient to compute but it's not in a very normalized
|
||||
# space, we want to normalize it first. We're going to view M as
|
||||
# M = T N,
|
||||
# where T is our proj_{dim} matrix (see the code), and N == T^{-1} M.
|
||||
# It's more convenient to have the projection between T and N, i.e. have proj2,
|
||||
# again numerically equal to I, so
|
||||
# M == T proj2 N.
|
||||
# [Be careful because when we use T as a superscript it means transpose).
|
||||
# We want the derivative w.r.t. proj2.
|
||||
# So suppose the loss is
|
||||
# tr(M_grad^T M) == tr(M_grad^T T proj2 N) == tr(N M_grad^T T proj2)
|
||||
# == tr( (T^T M_grad N^T)^T proj2 )
|
||||
# == tr( (T^T M_grad N^T)^T proj2 )
|
||||
# and using N == T^{-1} M, so N^T = M^T T^{-T},
|
||||
# == tr( (T^T M_grad M^T T^{-T})^T proj2 )
|
||||
# so we can view the thing in parentheses in the above expression as proj2_grad, i.e.
|
||||
# proj2_grad = T^T M_grad M^T T^{-T}
|
||||
# proj2_grad = T^T proj_grad T^{-T}
|
||||
#
|
||||
# note: p = state[f"proj_{dim}"] actually corresponds to T^T, since it has a
|
||||
# shape interpretable as (diagonalized_index, canonical_index).
|
||||
#
|
||||
# So proj2_grad = p proj_grad p^{-1} (eq.1)
|
||||
#
|
||||
# After computing proj2_grad we compute/update a factorized representation of
|
||||
# its variance, factorized over its 2 axes, and do an Adam-type update
|
||||
# for it, except without momentum at this level (momentum will be applied
|
||||
# later). This gives us proj2_delta, according to the
|
||||
# equation: proj2 = I + proj2_delta.
|
||||
#
|
||||
# This is the point at which we control the magnitude of the parameter change.
|
||||
#
|
||||
# We now have to update the factor T using proj2, as in:
|
||||
# T <-- T (I + proj2_delta)
|
||||
# T <-- T + T proj2_delta
|
||||
# and since p == T^T, this becomes:
|
||||
# p <-- p + proj2_delta^T p
|
||||
# or equivalently:
|
||||
# p_delta = proj2_delta^T p (eq.2)
|
||||
#
|
||||
# We also need to update M itself; this is easiest done by computing proj_delta
|
||||
# (which is the deviation of proj from I). The relationship between proj and proj2
|
||||
# is given by equating
|
||||
# proj T N = T proj2 N,
|
||||
# proj T = T proj2,
|
||||
# proj = T proj2 T^{-1}
|
||||
# proj = p^T proj2 p^{-T} .
|
||||
# we can apply the above formula directly to proj2_delta, since the "I" part just
|
||||
# becomes a no-op, i.e.
|
||||
# proj_delta = p^T proj2_delta p^{-T}.
|
||||
# ... and to turn this back into a delta on M, we'll do:
|
||||
# M_delta = proj_delta M.
|
||||
|
||||
|
||||
proj_grad = state[f"proj_grad_{dim}"]
|
||||
p = state[f"proj_{dim}"]
|
||||
|
||||
# torch.linalg.solve(A, B) returns X such that AX=B,
|
||||
# meaning X = A^{-1} B. We want to compute X = proj_grad p^{-1},
|
||||
# so X^T = p^{-T} proj_grad^T.
|
||||
X = torch.linalg.solve(p.t(), proj_grad.t()).t()
|
||||
assert torch.allclose(proj_grad, torch.matmul(X, p)) # TEMP.
|
||||
|
||||
# See (eq.1), the next line implements proj_grad = p proj_grad p^{-1}
|
||||
proj2_grad = torch.matmul(p, X)
|
||||
|
||||
# for now just estimate it from proj2_grad, later we will try aggregating it
|
||||
# across iterations. This is like exp_avg_sq in Adam.
|
||||
proj2_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
||||
|
||||
denom = proj2_grad_var.sqrt()
|
||||
|
||||
alpha = -meta_lr * (1-beta1)
|
||||
|
||||
# we'll take into account alpha later on.
|
||||
proj2_delta = proj2_grad_var / denom
|
||||
|
||||
p_delta = torch.matmul(proj2_delta.t(), p)
|
||||
|
||||
|
||||
|
||||
with torch.enable_grad():
|
||||
|
||||
# To update the learning rate matrices, we are asking the question, "What if,
|
||||
# on the previous step, we had used a slightly different learning-rate matrix?
|
||||
@ -345,7 +461,7 @@ class LearnedGradient(Optimizer):
|
||||
# [something] as it contributes no gradient, becomes -alpha *
|
||||
# (moving_d_norm * grad).sum(), and ignoring the scale alpha, that
|
||||
# is -(moving_d_norm * grad).sum(), so we call this neg_loss as it's the
|
||||
# negative of hte iss.
|
||||
# negative of the loss.
|
||||
neg_loss = (grad * moving_d_norm).sum()
|
||||
|
||||
# after the following, there will grads in state[f"lr_{dim}"]
|
||||
@ -357,6 +473,32 @@ class LearnedGradient(Optimizer):
|
||||
if p.shape[dim] != 1:
|
||||
self._update_lr(state[f"lr_{dim}"], meta_lr)
|
||||
|
||||
def _get_matrix_var(self,
|
||||
x: Tensor,
|
||||
eps: float) -> Tensor:
|
||||
"""
|
||||
Estimate a variance of elements of x, which must be a tensor with exactly
|
||||
2 diimensions, neither of which can have size == 1. This is done using
|
||||
a structured representation, so the variance is a product over the 2
|
||||
dimensions (otherwise, it wouldn't be possible to estimate a variance from
|
||||
a single sample). Eventually we may do this via aggregation across
|
||||
time as well; this approach seems to give a high variance estimate even
|
||||
for largish dimension, e.g. if x.shape == (100,100) the 95th percentile
|
||||
is from about 0.6 to 1.6.
|
||||
|
||||
Args:
|
||||
x: tensor to estimate the variance of elements of. Require x.ndim == 2,
|
||||
neither size must be 1.
|
||||
eps: Epsilon to avoid zero variance: epsilon is dimensionally a rms value.
|
||||
"""
|
||||
assert x.ndim == 2 and 1 not in list(x.shape) # TODO: remove?
|
||||
|
||||
var0 = (x**2).mean(dim=1, keepdim=True) + eps*eps
|
||||
var1 = ((x / var0.sqrt()) ** 2).mean(dim=0, keepdim=True) + eps*eps
|
||||
# esetimate var0 one more time.
|
||||
var0 = ((x / var1.sqrt()) ** 2).mean(dim=1, keepdim=True) + eps*eps
|
||||
return var1 * var0
|
||||
|
||||
|
||||
|
||||
def _update_lr(self,
|
||||
@ -421,7 +563,7 @@ class LearnedGradient(Optimizer):
|
||||
state: dict):
|
||||
"""
|
||||
This function does the core update of self._step, in the case where the tensor
|
||||
has more than about 5 elements. It multiplies the moving-average gradient tensor by a
|
||||
has more than 1 elements. It multiplies the moving-average gradient tensor by a
|
||||
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
|
||||
and takes a step in that direction.
|
||||
|
||||
@ -434,55 +576,60 @@ class LearnedGradient(Optimizer):
|
||||
This function modifies p.
|
||||
"""
|
||||
lr = group["lr"]
|
||||
beta1 = group["betas"][0]
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
max_step_scale = group["max_step_scale"]
|
||||
|
||||
moving_g = state["exp_avg"]
|
||||
# include the current grad in moving_g
|
||||
moving_g.mul_(beta1).add(grad, alpha=(1-beta1))
|
||||
|
||||
# get the parameter delta (up to a scalar factor) by multiplying by a
|
||||
# learning-rate matrix for each dimension.
|
||||
moving_d = self._multiply_by_lr(moving_g)
|
||||
for dim in range(grad.ndim):
|
||||
size = grad.shape[size]
|
||||
# accumulate some stats for learning the projections
|
||||
proj_grad = state[f"proj_grad_{dim}"]
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
this_p = p.transpose(-1, dim).reshape(-1, size)
|
||||
this_g = g.transpose(-1, dim).reshape(-1, size)
|
||||
proj_grad.add_(torch.matmul(this_g.t(), this_p))
|
||||
# could perhaps accumulate grad_cov less frequently.
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||
|
||||
# moving_d contains a moving average of parameter change contributions
|
||||
# from multiple iterations with weights: (1-beta) * (1 + beta + beta^2 +
|
||||
# beta^3, etc.), which sum to 1 for large step. Suppose the rms value
|
||||
# of each individual gradient were r. Assuming the terms are
|
||||
# independent, with rms=r, the variance of the elements of moving_d
|
||||
# would be r^2 * (1-beta)^2 * (1 + beta^2 + beta^4 + ... ) which equals
|
||||
# r^2 * (1-beta)^2 / (1-beta^2), so the measured rms of moving_d would
|
||||
# be moving_d_rms = individual_d_rms * (1-beta) / sqrt(1-beta^2), so
|
||||
# individual_d_rms = moving_d_rms * sqrt(1-beta^2) / (1-beta) ... this
|
||||
# would be the root-mean-square value of the individual deltas, if we
|
||||
# were to proces them individually (if they were independent), and this
|
||||
# is what we want to normalize by for greater independence of the beta
|
||||
# value, i.e. so beta only affects momentum and not the overall speed
|
||||
moving_d_rms = (moving_d**2).mean().sqrt() + eps
|
||||
grad = self._project(grad, forward=True, state)
|
||||
|
||||
# e.g. if beta1==0.98, this formula gives a factor of 9.95, reflecting
|
||||
# that the individual rms's of the un-scaled parameter deltas without
|
||||
# momentum are larger than their moving average. see the justification
|
||||
# in a comment above.
|
||||
d_rms = moving_d_rms * (((1 - beta1 * beta1) ** 0.5) / (1 - beta1))
|
||||
param_rms = state["param_rms"]
|
||||
alpha = -lr * param_rms / d_rms
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=(1-beta2))
|
||||
|
||||
# apply max_step_scale to limit magnitude of change on individual
|
||||
# elements, this should help stability.
|
||||
limit = moving_d_rms * max_step_scale
|
||||
moving_d_clamped = moving_d.clamp_(min=-limit, max=limit)
|
||||
step = state["step"]
|
||||
bias_correction2 = 1 - beta2 ** (step + 1)
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||
denom = exp_avg_sq.sqrt() + eps
|
||||
|
||||
if random.random() < 0.001:
|
||||
clamp_diff_rms = ((moving_d_clamped - moving_d) ** 2).mean().sqrt()
|
||||
rel_clamp_diff = clamp_diff_rms / d_rms
|
||||
logging.info(f"Clamping difference is {clamp_diff_rms.item():.3g} vs. {d_rms.item():.3g}: "
|
||||
f"relative diff {rel_clamp_diff.item():.3g}, shape={tuple(p.shape)}")
|
||||
grad = grad / denom
|
||||
|
||||
p.add_(moving_d_clamped, alpha=alpha)
|
||||
# and project back..
|
||||
grad = self._project(grad, forward=False, state)
|
||||
|
||||
def _step_small(self,
|
||||
|
||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
|
||||
alpha=1-beta2)
|
||||
|
||||
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
|
||||
|
||||
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
|
||||
|
||||
alpha = state["param_rms"] * (1-beta1) * lr / denom
|
||||
|
||||
delta = state["delta"]
|
||||
delta.mul_(beta1).add_(grad, alpha=alpha)
|
||||
|
||||
param.add_(delta)
|
||||
|
||||
|
||||
def _step_scalar(self,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
eps: float,
|
||||
@ -491,11 +638,11 @@ class LearnedGradient(Optimizer):
|
||||
grad: Tensor,
|
||||
state: dict):
|
||||
"""
|
||||
A form of the core update for tensors with a small number of elements, e.g. <5. This is
|
||||
A form of the core update for tensors with a small number of elements, e.g. scalars. This is
|
||||
Adam where, if the numel() > 1, the learning rate is proportional to the parameter rms value.
|
||||
"""
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=1-beta2)
|
||||
@ -505,11 +652,40 @@ class LearnedGradient(Optimizer):
|
||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
|
||||
alpha = -lr / denom
|
||||
if p.numel() > 1:
|
||||
alpha = alpha * state["param_rms"]
|
||||
# for now not supporting p.numel() > 1.
|
||||
#if p.numel() > 1:
|
||||
# alpha = alpha * state["param_rms"]
|
||||
p.add_(exp_avg, alpha=alpha)
|
||||
|
||||
|
||||
def _project(self,
|
||||
x: Tensor,
|
||||
forward: bool,
|
||||
state: dict) -> Tensor:
|
||||
"""
|
||||
Multiply a tensor x by proj_{dim} on each of its dimensions.
|
||||
If forward == True, converts x from canonical to preconditioned/diagonalized
|
||||
co-ordinates; if forward == False, does the reverse.
|
||||
|
||||
Args:
|
||||
x: The tensor to project, e.g. a gradient or a parameter change.
|
||||
forward: if True, go in the forward directin (from canonical to diagnonalized
|
||||
co-ordinates; if False, the reverse. They differ by a transpose,
|
||||
not an inverse, and do not make a round trip.
|
||||
"""
|
||||
for dim in range(x.ndim):
|
||||
if x.shape[dim] == 1:
|
||||
continue
|
||||
proj = state[f"proj_{dim}"]
|
||||
if forward:
|
||||
# dimension 0 of `proj` is diagonalized co-ordinate.
|
||||
proj = proj.t()
|
||||
# TODO: could possibly somehow force the output format to be unchanged.
|
||||
x = x.transpose(-1, dim)
|
||||
x = torch.matmul(x, proj)
|
||||
x = x.transpose(-1, dim)
|
||||
return x
|
||||
|
||||
def _store_grad_stats(self,
|
||||
grad: Tensor,
|
||||
state: dict,
|
||||
|
Loading…
x
Reference in New Issue
Block a user