mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
In the middle of reworking for new idea
This commit is contained in:
parent
41368f6b63
commit
e9e2a85c95
@ -48,7 +48,10 @@ class LearnedGradient(Optimizer):
|
|||||||
learning the scale on the parameters (we'll keep it >= this size)
|
learning the scale on the parameters (we'll keep it >= this size)
|
||||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||||
learning the scale on the parameters (we'll keep it <= this size)
|
learning the scale on the parameters (we'll keep it <= this size)
|
||||||
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrix.
|
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
||||||
|
orthogonalize_period: How often we re-orthogonalize the gradient covariance; this
|
||||||
|
gets multiplied by lr_est_period, i.e. we orthogonalize less frequently
|
||||||
|
than we update the learning-rate matrixes.
|
||||||
max_step_scale: Value that determines limit on how large steps can be per element,
|
max_step_scale: Value that determines limit on how large steps can be per element,
|
||||||
intended to prevent instability during early steps before the learning-rate
|
intended to prevent instability during early steps before the learning-rate
|
||||||
matrices are well learned.
|
matrices are well learned.
|
||||||
@ -58,12 +61,14 @@ class LearnedGradient(Optimizer):
|
|||||||
params,
|
params,
|
||||||
lr=3e-02,
|
lr=3e-02,
|
||||||
size_lr_scale=0.1,
|
size_lr_scale=0.1,
|
||||||
meta_lr_scale=1.0,
|
meta_lr_scale=0.1,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
param_min_rms=1.0e-05,
|
param_min_rms=1.0e-05,
|
||||||
param_max_rms=2.0,
|
param_max_rms=2.0,
|
||||||
lr_est_period=10,
|
lr_est_period=10,
|
||||||
|
max_lr_eig=4.0,
|
||||||
|
orthogonalize_period=4,
|
||||||
max_step_scale=2.0
|
max_step_scale=2.0
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -77,6 +82,8 @@ class LearnedGradient(Optimizer):
|
|||||||
param_min_rms=param_min_rms,
|
param_min_rms=param_min_rms,
|
||||||
param_max_rms=param_max_rms,
|
param_max_rms=param_max_rms,
|
||||||
lr_est_period=lr_est_period,
|
lr_est_period=lr_est_period,
|
||||||
|
max_lr_eig=max_lr_eig,
|
||||||
|
orthogonalize_period=orthogonalize_period,
|
||||||
max_step_scale=max_step_scale,
|
max_step_scale=max_step_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -104,16 +111,15 @@ class LearnedGradient(Optimizer):
|
|||||||
meta_lr_scale = group["meta_lr_scale"]
|
meta_lr_scale = group["meta_lr_scale"]
|
||||||
size_lr = lr * group["size_lr_scale"]
|
size_lr = lr * group["size_lr_scale"]
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
|
max_lr_eig = group["max_lr_eig"]
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
size_update_period = 4
|
size_update_period = 4
|
||||||
param_min_eps = group["param_min_rms"]
|
param_min_eps = group["param_min_rms"]
|
||||||
param_max_eps = group["param_max_rms"]
|
param_max_eps = group["param_max_rms"]
|
||||||
lr_est_period = group["lr_est_period"]
|
lr_est_period = group["lr_est_period"]
|
||||||
|
orthogonalize_period = group["orthogonalize_period"]
|
||||||
max_step_scale = group["max_step_scale"]
|
max_step_scale = group["max_step_scale"]
|
||||||
|
|
||||||
|
|
||||||
small_tensor_cutoff = 5 # cutoff below which we use Adam
|
|
||||||
|
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
@ -134,60 +140,69 @@ class LearnedGradient(Optimizer):
|
|||||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||||
|
|
||||||
|
|
||||||
if p.numel() > 1:
|
# 'delta' implements conventional momentum. There are
|
||||||
# we learn the scale on the parameters in a way that
|
# several different kinds of update going on, so rather than
|
||||||
# also emulates Eve. scale_exp_avg_sq is the
|
# compute "exp_avg" like in Adam, we store and decay a
|
||||||
# moving-average square of the gradient w.r.t. this
|
# parameter-change "delta", which combines all forms of
|
||||||
# scalar scale.
|
# update. this is equivalent to how it's done in Adam,
|
||||||
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
# except for the first few steps.
|
||||||
state["scale_exp_avg"] = torch.zeros((), **kwargs)
|
state["delta"] = torch.zeros_like(
|
||||||
# for speed, we do the scale update only periodically,
|
|
||||||
# and meanwhile we store the derivs w.r.t. the grad scale
|
|
||||||
# as a list
|
|
||||||
state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
|
|
||||||
|
|
||||||
state["exp_avg"] = torch.zeros_like(
|
|
||||||
p, memory_format=torch.preserve_format
|
p, memory_format=torch.preserve_format
|
||||||
)
|
)
|
||||||
|
|
||||||
if p.numel() > small_tensor_cutoff:
|
if p.numel() > 1:
|
||||||
# moving_grad is a moving average of the raw gradient.
|
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||||
# We do the accumulation on the input side because it
|
# the parameter tensor.
|
||||||
# makes it easier to get grads w.r.t. the learning rate
|
|
||||||
# matrices.
|
|
||||||
|
|
||||||
# "scale" is just a per-tensor scalar that determines the
|
|
||||||
# rate of learning (in terms of the rms value of the change
|
|
||||||
# in parameter)
|
|
||||||
param_rms = (p**2).mean().sqrt().add_(eps)
|
param_rms = (p**2).mean().sqrt().add_(eps)
|
||||||
state["param_rms"] = param_rms
|
state["param_rms"] = param_rms
|
||||||
|
# scalar exp_avg_sq. not the same as scale_exp_avg_sq
|
||||||
|
state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs)
|
||||||
|
|
||||||
for dim in range(p.ndim):
|
for dim in range(p.ndim):
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
if size == 1 or size == p.numel():
|
if size == 1:
|
||||||
# if size == p.numel(), is_one_axis == True above
|
# if size == p.numel(), is_one_axis == True above
|
||||||
continue
|
continue
|
||||||
|
|
||||||
state[f"lr_{dim}"] = torch.eye(size, **kwargs)
|
# exp_avg_sq is in the projected space.
|
||||||
# Some parts of the update -- the parts where we are
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
# doing to update lr_{dim} will require grad. these
|
p, memory_format=torch.preserve_format
|
||||||
# parts will be done inside torch.enable_grad().
|
)
|
||||||
# The rest of this function is inside no_grad(), so
|
|
||||||
# it will ignore the requires_grad == True.
|
|
||||||
state[f"lr_{dim}"].requires_grad = True
|
|
||||||
|
|
||||||
# grad_cov_{dim} is the covariance of gradients on this axis,
|
|
||||||
# treating all other axes as as a batch axis. This is needed
|
# proj_{dim} will be the learning-rate matrix for this
|
||||||
# only for purposes of computing the scalar factor on the
|
# dimension, times an orthogonal matrix that we'll periodically
|
||||||
# learning rate; and, as a result of this, also contributes something
|
# re-estimate when we "reorthogonalize" (this orthogonalizes
|
||||||
# to the gradient w.r.t. f"lr_{dim}", which is one reason
|
# the gradient covariance).
|
||||||
# why we allocate a variable to keep track of its moving average
|
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
|
||||||
# instead of just using a temporary and smoothing the scalar factor.
|
|
||||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
# proj_grad_{dim} is the gradient w.r.t. proj_{grad}
|
||||||
else:
|
# (technically w.r.t a matrix to the left of
|
||||||
# size <= small_tensor_cutoff: do a form of Adam
|
# proj_{grad}, we'll project it later); we accumulate
|
||||||
state["exp_avg_sq"] = torch.zeros_like(p,
|
# this for `lr_est_period` steps, and then update
|
||||||
mamory_format=preserve_memory_format)
|
# proj_{dim} and zero this matrix
|
||||||
|
state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||||
|
|
||||||
|
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||||
|
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||||
|
# This is needed when we re-orthogonalize, and to compute the
|
||||||
|
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
||||||
|
|
||||||
|
# only for purposes of computing the scalar factor on the
|
||||||
|
# learning rate; and, as a result of this, also contributes something
|
||||||
|
# to the gradient w.r.t. f"lr_{dim}", which is one reason
|
||||||
|
# why we allocate a variable to keep track of its moving average
|
||||||
|
# instead of just using a temporary and smoothing the scalar factor.
|
||||||
|
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||||
|
|
||||||
|
# proj_grad_var_{dim} is the variance of gradients of the thing we're
|
||||||
|
# going to multiply proj_{dim} by (between it and the "underlying
|
||||||
|
# parameter matrix"), aggregated over each of its 2 dimensions
|
||||||
|
# and decayed over time with beta1 each time we do a step on proj_{dim},
|
||||||
|
# e.g. every lr_est_period steps. It's a factored representation
|
||||||
|
# of something like exp_avg_sq, used for estimating proj_{dim}; we avoid
|
||||||
|
# storing an exp_avg_sq for proj_{dim}, to save GPU memory.
|
||||||
|
state[f"proj_grad_var_{dim}"] = torch.ones(2, size, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
@ -206,14 +221,14 @@ class LearnedGradient(Optimizer):
|
|||||||
beta1, beta2, step, size_lr,
|
beta1, beta2, step, size_lr,
|
||||||
param_min_rms, param_max_rms)
|
param_min_rms, param_max_rms)
|
||||||
|
|
||||||
if numel <= small_tensor_cutoff:
|
if numel == 1:
|
||||||
# For parameters with very few elements we just use a form
|
# For parameters with very few elements we just use a form
|
||||||
# of Adam with a scale factor to reflect the overall
|
# of Adam with a scale factor to reflect the overall
|
||||||
# parameter rms.
|
# parameter rms.
|
||||||
self._step_small(beta1, beta2, eps, lr, p, grad, state)
|
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
|
||||||
else:
|
else:
|
||||||
if step % lr_est_period == 0:
|
if step % lr_est_period == 0:
|
||||||
self._update_lrs(group, grad, state)
|
self._update_lrs(group, p, state)
|
||||||
self._step(group, p, grad, state)
|
self._step(group, p, grad, state)
|
||||||
|
|
||||||
state["step"] = step + 1
|
state["step"] = step + 1
|
||||||
@ -289,28 +304,129 @@ class LearnedGradient(Optimizer):
|
|||||||
|
|
||||||
def _update_lrs(self,
|
def _update_lrs(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
grad: Tensor,
|
p: Tensor,
|
||||||
state: dict) -> None:
|
state: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Updates the learning-rate matrices state["lr_{dim}"].
|
Updates the learning-rate matrices state["lr_{dim}"].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: dict to look up configuration values
|
group: dict to look up configuration values
|
||||||
grad: current gradient for the parameter that we are updating
|
p: parameter matrix that we are updating. The learning rate matrices
|
||||||
|
are actually factors of p, so p itself will change when we change
|
||||||
|
them.
|
||||||
state: state dict for the current parameter
|
state: state dict for the current parameter
|
||||||
"""
|
"""
|
||||||
# meta_lr is the learning rate for the learning rate matrix.
|
# meta_lr is the learning rate for the learning rate matrix. Including the factor
|
||||||
meta_lr = group["lr"] * group["meta_lr_scale"]
|
# (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant
|
||||||
|
# to the lr_est_period (close to convergence).
|
||||||
|
meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5)
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
moving_g = state["exp_avg"]
|
moving_g = state["exp_avg"]
|
||||||
|
|
||||||
ndim = p.ndim
|
ndim = p.ndim
|
||||||
|
|
||||||
# Update grad_cov.
|
for dim in range(ndim):
|
||||||
self._store_grad_stats(grad, state, beta2)
|
size = p.shape[dim]
|
||||||
|
if size == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
p_t = p.transpose(dim, -1)
|
||||||
|
this_p = p_t.reshape(-1, size)
|
||||||
|
|
||||||
|
# proj_grad is a summed gradient, indexed [grad_idx][param_idx] (don't worry
|
||||||
|
# if this notation about the indexing is unclear, just keep reading).
|
||||||
|
#
|
||||||
|
# Suppose M is our parameter matrix p reshaped to (size, -1) with the -1 being
|
||||||
|
# all other dims treated as a batch dim, and M_grad is the loss derivative w.r.t. M,
|
||||||
|
#
|
||||||
|
# Then proj_grad is the accumulated value of M_grad M^T, which will be of shape
|
||||||
|
# (size, size). If we view the parameter matrix as actually (proj M) with
|
||||||
|
# proj being numerically equal to I, i.e. as
|
||||||
|
# matmul(proj, M), then instead of the loss being tr(M_grad^T M), we can write
|
||||||
|
# it as tr(M_grad^T proj M), which can also be written as tr(M M_grad^T proj)
|
||||||
|
# = tr( (M_grad M^T)^T proj), where we can write M_grad M^T == proj_grad.
|
||||||
|
#
|
||||||
|
# Now, proj_grad is convenient to compute but it's not in a very normalized
|
||||||
|
# space, we want to normalize it first. We're going to view M as
|
||||||
|
# M = T N,
|
||||||
|
# where T is our proj_{dim} matrix (see the code), and N == T^{-1} M.
|
||||||
|
# It's more convenient to have the projection between T and N, i.e. have proj2,
|
||||||
|
# again numerically equal to I, so
|
||||||
|
# M == T proj2 N.
|
||||||
|
# [Be careful because when we use T as a superscript it means transpose).
|
||||||
|
# We want the derivative w.r.t. proj2.
|
||||||
|
# So suppose the loss is
|
||||||
|
# tr(M_grad^T M) == tr(M_grad^T T proj2 N) == tr(N M_grad^T T proj2)
|
||||||
|
# == tr( (T^T M_grad N^T)^T proj2 )
|
||||||
|
# == tr( (T^T M_grad N^T)^T proj2 )
|
||||||
|
# and using N == T^{-1} M, so N^T = M^T T^{-T},
|
||||||
|
# == tr( (T^T M_grad M^T T^{-T})^T proj2 )
|
||||||
|
# so we can view the thing in parentheses in the above expression as proj2_grad, i.e.
|
||||||
|
# proj2_grad = T^T M_grad M^T T^{-T}
|
||||||
|
# proj2_grad = T^T proj_grad T^{-T}
|
||||||
|
#
|
||||||
|
# note: p = state[f"proj_{dim}"] actually corresponds to T^T, since it has a
|
||||||
|
# shape interpretable as (diagonalized_index, canonical_index).
|
||||||
|
#
|
||||||
|
# So proj2_grad = p proj_grad p^{-1} (eq.1)
|
||||||
|
#
|
||||||
|
# After computing proj2_grad we compute/update a factorized representation of
|
||||||
|
# its variance, factorized over its 2 axes, and do an Adam-type update
|
||||||
|
# for it, except without momentum at this level (momentum will be applied
|
||||||
|
# later). This gives us proj2_delta, according to the
|
||||||
|
# equation: proj2 = I + proj2_delta.
|
||||||
|
#
|
||||||
|
# This is the point at which we control the magnitude of the parameter change.
|
||||||
|
#
|
||||||
|
# We now have to update the factor T using proj2, as in:
|
||||||
|
# T <-- T (I + proj2_delta)
|
||||||
|
# T <-- T + T proj2_delta
|
||||||
|
# and since p == T^T, this becomes:
|
||||||
|
# p <-- p + proj2_delta^T p
|
||||||
|
# or equivalently:
|
||||||
|
# p_delta = proj2_delta^T p (eq.2)
|
||||||
|
#
|
||||||
|
# We also need to update M itself; this is easiest done by computing proj_delta
|
||||||
|
# (which is the deviation of proj from I). The relationship between proj and proj2
|
||||||
|
# is given by equating
|
||||||
|
# proj T N = T proj2 N,
|
||||||
|
# proj T = T proj2,
|
||||||
|
# proj = T proj2 T^{-1}
|
||||||
|
# proj = p^T proj2 p^{-T} .
|
||||||
|
# we can apply the above formula directly to proj2_delta, since the "I" part just
|
||||||
|
# becomes a no-op, i.e.
|
||||||
|
# proj_delta = p^T proj2_delta p^{-T}.
|
||||||
|
# ... and to turn this back into a delta on M, we'll do:
|
||||||
|
# M_delta = proj_delta M.
|
||||||
|
|
||||||
|
|
||||||
|
proj_grad = state[f"proj_grad_{dim}"]
|
||||||
|
p = state[f"proj_{dim}"]
|
||||||
|
|
||||||
|
# torch.linalg.solve(A, B) returns X such that AX=B,
|
||||||
|
# meaning X = A^{-1} B. We want to compute X = proj_grad p^{-1},
|
||||||
|
# so X^T = p^{-T} proj_grad^T.
|
||||||
|
X = torch.linalg.solve(p.t(), proj_grad.t()).t()
|
||||||
|
assert torch.allclose(proj_grad, torch.matmul(X, p)) # TEMP.
|
||||||
|
|
||||||
|
# See (eq.1), the next line implements proj_grad = p proj_grad p^{-1}
|
||||||
|
proj2_grad = torch.matmul(p, X)
|
||||||
|
|
||||||
|
# for now just estimate it from proj2_grad, later we will try aggregating it
|
||||||
|
# across iterations. This is like exp_avg_sq in Adam.
|
||||||
|
proj2_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
||||||
|
|
||||||
|
denom = proj2_grad_var.sqrt()
|
||||||
|
|
||||||
|
alpha = -meta_lr * (1-beta1)
|
||||||
|
|
||||||
|
# we'll take into account alpha later on.
|
||||||
|
proj2_delta = proj2_grad_var / denom
|
||||||
|
|
||||||
|
p_delta = torch.matmul(proj2_delta.t(), p)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
with torch.enable_grad():
|
|
||||||
|
|
||||||
# To update the learning rate matrices, we are asking the question, "What if,
|
# To update the learning rate matrices, we are asking the question, "What if,
|
||||||
# on the previous step, we had used a slightly different learning-rate matrix?
|
# on the previous step, we had used a slightly different learning-rate matrix?
|
||||||
@ -345,7 +461,7 @@ class LearnedGradient(Optimizer):
|
|||||||
# [something] as it contributes no gradient, becomes -alpha *
|
# [something] as it contributes no gradient, becomes -alpha *
|
||||||
# (moving_d_norm * grad).sum(), and ignoring the scale alpha, that
|
# (moving_d_norm * grad).sum(), and ignoring the scale alpha, that
|
||||||
# is -(moving_d_norm * grad).sum(), so we call this neg_loss as it's the
|
# is -(moving_d_norm * grad).sum(), so we call this neg_loss as it's the
|
||||||
# negative of hte iss.
|
# negative of the loss.
|
||||||
neg_loss = (grad * moving_d_norm).sum()
|
neg_loss = (grad * moving_d_norm).sum()
|
||||||
|
|
||||||
# after the following, there will grads in state[f"lr_{dim}"]
|
# after the following, there will grads in state[f"lr_{dim}"]
|
||||||
@ -357,6 +473,32 @@ class LearnedGradient(Optimizer):
|
|||||||
if p.shape[dim] != 1:
|
if p.shape[dim] != 1:
|
||||||
self._update_lr(state[f"lr_{dim}"], meta_lr)
|
self._update_lr(state[f"lr_{dim}"], meta_lr)
|
||||||
|
|
||||||
|
def _get_matrix_var(self,
|
||||||
|
x: Tensor,
|
||||||
|
eps: float) -> Tensor:
|
||||||
|
"""
|
||||||
|
Estimate a variance of elements of x, which must be a tensor with exactly
|
||||||
|
2 diimensions, neither of which can have size == 1. This is done using
|
||||||
|
a structured representation, so the variance is a product over the 2
|
||||||
|
dimensions (otherwise, it wouldn't be possible to estimate a variance from
|
||||||
|
a single sample). Eventually we may do this via aggregation across
|
||||||
|
time as well; this approach seems to give a high variance estimate even
|
||||||
|
for largish dimension, e.g. if x.shape == (100,100) the 95th percentile
|
||||||
|
is from about 0.6 to 1.6.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: tensor to estimate the variance of elements of. Require x.ndim == 2,
|
||||||
|
neither size must be 1.
|
||||||
|
eps: Epsilon to avoid zero variance: epsilon is dimensionally a rms value.
|
||||||
|
"""
|
||||||
|
assert x.ndim == 2 and 1 not in list(x.shape) # TODO: remove?
|
||||||
|
|
||||||
|
var0 = (x**2).mean(dim=1, keepdim=True) + eps*eps
|
||||||
|
var1 = ((x / var0.sqrt()) ** 2).mean(dim=0, keepdim=True) + eps*eps
|
||||||
|
# esetimate var0 one more time.
|
||||||
|
var0 = ((x / var1.sqrt()) ** 2).mean(dim=1, keepdim=True) + eps*eps
|
||||||
|
return var1 * var0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _update_lr(self,
|
def _update_lr(self,
|
||||||
@ -421,7 +563,7 @@ class LearnedGradient(Optimizer):
|
|||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
This function does the core update of self._step, in the case where the tensor
|
This function does the core update of self._step, in the case where the tensor
|
||||||
has more than about 5 elements. It multiplies the moving-average gradient tensor by a
|
has more than 1 elements. It multiplies the moving-average gradient tensor by a
|
||||||
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
|
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
|
||||||
and takes a step in that direction.
|
and takes a step in that direction.
|
||||||
|
|
||||||
@ -434,55 +576,60 @@ class LearnedGradient(Optimizer):
|
|||||||
This function modifies p.
|
This function modifies p.
|
||||||
"""
|
"""
|
||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
beta1 = group["betas"][0]
|
beta1, beta2 = group["betas"]
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
max_step_scale = group["max_step_scale"]
|
|
||||||
|
|
||||||
moving_g = state["exp_avg"]
|
|
||||||
# include the current grad in moving_g
|
|
||||||
moving_g.mul_(beta1).add(grad, alpha=(1-beta1))
|
|
||||||
|
|
||||||
# get the parameter delta (up to a scalar factor) by multiplying by a
|
for dim in range(grad.ndim):
|
||||||
# learning-rate matrix for each dimension.
|
size = grad.shape[size]
|
||||||
moving_d = self._multiply_by_lr(moving_g)
|
# accumulate some stats for learning the projections
|
||||||
|
proj_grad = state[f"proj_grad_{dim}"]
|
||||||
|
grad_cov = state[f"grad_cov_{dim}"]
|
||||||
|
this_p = p.transpose(-1, dim).reshape(-1, size)
|
||||||
|
this_g = g.transpose(-1, dim).reshape(-1, size)
|
||||||
|
proj_grad.add_(torch.matmul(this_g.t(), this_p))
|
||||||
|
# could perhaps accumulate grad_cov less frequently.
|
||||||
|
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||||
|
|
||||||
# moving_d contains a moving average of parameter change contributions
|
grad = self._project(grad, forward=True, state)
|
||||||
# from multiple iterations with weights: (1-beta) * (1 + beta + beta^2 +
|
|
||||||
# beta^3, etc.), which sum to 1 for large step. Suppose the rms value
|
|
||||||
# of each individual gradient were r. Assuming the terms are
|
|
||||||
# independent, with rms=r, the variance of the elements of moving_d
|
|
||||||
# would be r^2 * (1-beta)^2 * (1 + beta^2 + beta^4 + ... ) which equals
|
|
||||||
# r^2 * (1-beta)^2 / (1-beta^2), so the measured rms of moving_d would
|
|
||||||
# be moving_d_rms = individual_d_rms * (1-beta) / sqrt(1-beta^2), so
|
|
||||||
# individual_d_rms = moving_d_rms * sqrt(1-beta^2) / (1-beta) ... this
|
|
||||||
# would be the root-mean-square value of the individual deltas, if we
|
|
||||||
# were to proces them individually (if they were independent), and this
|
|
||||||
# is what we want to normalize by for greater independence of the beta
|
|
||||||
# value, i.e. so beta only affects momentum and not the overall speed
|
|
||||||
moving_d_rms = (moving_d**2).mean().sqrt() + eps
|
|
||||||
|
|
||||||
# e.g. if beta1==0.98, this formula gives a factor of 9.95, reflecting
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
# that the individual rms's of the un-scaled parameter deltas without
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
# momentum are larger than their moving average. see the justification
|
value=(1-beta2))
|
||||||
# in a comment above.
|
|
||||||
d_rms = moving_d_rms * (((1 - beta1 * beta1) ** 0.5) / (1 - beta1))
|
|
||||||
param_rms = state["param_rms"]
|
|
||||||
alpha = -lr * param_rms / d_rms
|
|
||||||
|
|
||||||
# apply max_step_scale to limit magnitude of change on individual
|
step = state["step"]
|
||||||
# elements, this should help stability.
|
bias_correction2 = 1 - beta2 ** (step + 1)
|
||||||
limit = moving_d_rms * max_step_scale
|
if bias_correction2 < 0.99:
|
||||||
moving_d_clamped = moving_d.clamp_(min=-limit, max=limit)
|
# note: not in-place.
|
||||||
|
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||||
|
denom = exp_avg_sq.sqrt() + eps
|
||||||
|
|
||||||
if random.random() < 0.001:
|
grad = grad / denom
|
||||||
clamp_diff_rms = ((moving_d_clamped - moving_d) ** 2).mean().sqrt()
|
|
||||||
rel_clamp_diff = clamp_diff_rms / d_rms
|
|
||||||
logging.info(f"Clamping difference is {clamp_diff_rms.item():.3g} vs. {d_rms.item():.3g}: "
|
|
||||||
f"relative diff {rel_clamp_diff.item():.3g}, shape={tuple(p.shape)}")
|
|
||||||
|
|
||||||
p.add_(moving_d_clamped, alpha=alpha)
|
# and project back..
|
||||||
|
grad = self._project(grad, forward=False, state)
|
||||||
|
|
||||||
def _step_small(self,
|
|
||||||
|
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||||
|
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
|
||||||
|
alpha=1-beta2)
|
||||||
|
|
||||||
|
|
||||||
|
if bias_correction2 < 0.99:
|
||||||
|
# note: not in-place.
|
||||||
|
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
|
||||||
|
|
||||||
|
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
|
||||||
|
|
||||||
|
alpha = state["param_rms"] * (1-beta1) * lr / denom
|
||||||
|
|
||||||
|
delta = state["delta"]
|
||||||
|
delta.mul_(beta1).add_(grad, alpha=alpha)
|
||||||
|
|
||||||
|
param.add_(delta)
|
||||||
|
|
||||||
|
|
||||||
|
def _step_scalar(self,
|
||||||
beta1: float,
|
beta1: float,
|
||||||
beta2: float,
|
beta2: float,
|
||||||
eps: float,
|
eps: float,
|
||||||
@ -491,11 +638,11 @@ class LearnedGradient(Optimizer):
|
|||||||
grad: Tensor,
|
grad: Tensor,
|
||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
A form of the core update for tensors with a small number of elements, e.g. <5. This is
|
A form of the core update for tensors with a small number of elements, e.g. scalars. This is
|
||||||
Adam where, if the numel() > 1, the learning rate is proportional to the parameter rms value.
|
Adam where, if the numel() > 1, the learning rate is proportional to the parameter rms value.
|
||||||
"""
|
"""
|
||||||
exp_avg = state["exp_avg"]
|
exp_avg = state["exp_avg"]
|
||||||
exp_avg_sq = state["exp_avg_sq"]
|
exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
|
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
value=1-beta2)
|
value=1-beta2)
|
||||||
@ -505,11 +652,40 @@ class LearnedGradient(Optimizer):
|
|||||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||||
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
|
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
|
||||||
alpha = -lr / denom
|
alpha = -lr / denom
|
||||||
if p.numel() > 1:
|
# for now not supporting p.numel() > 1.
|
||||||
alpha = alpha * state["param_rms"]
|
#if p.numel() > 1:
|
||||||
|
# alpha = alpha * state["param_rms"]
|
||||||
p.add_(exp_avg, alpha=alpha)
|
p.add_(exp_avg, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def _project(self,
|
||||||
|
x: Tensor,
|
||||||
|
forward: bool,
|
||||||
|
state: dict) -> Tensor:
|
||||||
|
"""
|
||||||
|
Multiply a tensor x by proj_{dim} on each of its dimensions.
|
||||||
|
If forward == True, converts x from canonical to preconditioned/diagonalized
|
||||||
|
co-ordinates; if forward == False, does the reverse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The tensor to project, e.g. a gradient or a parameter change.
|
||||||
|
forward: if True, go in the forward directin (from canonical to diagnonalized
|
||||||
|
co-ordinates; if False, the reverse. They differ by a transpose,
|
||||||
|
not an inverse, and do not make a round trip.
|
||||||
|
"""
|
||||||
|
for dim in range(x.ndim):
|
||||||
|
if x.shape[dim] == 1:
|
||||||
|
continue
|
||||||
|
proj = state[f"proj_{dim}"]
|
||||||
|
if forward:
|
||||||
|
# dimension 0 of `proj` is diagonalized co-ordinate.
|
||||||
|
proj = proj.t()
|
||||||
|
# TODO: could possibly somehow force the output format to be unchanged.
|
||||||
|
x = x.transpose(-1, dim)
|
||||||
|
x = torch.matmul(x, proj)
|
||||||
|
x = x.transpose(-1, dim)
|
||||||
|
return x
|
||||||
|
|
||||||
def _store_grad_stats(self,
|
def _store_grad_stats(self,
|
||||||
grad: Tensor,
|
grad: Tensor,
|
||||||
state: dict,
|
state: dict,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user