mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
Draft of the new method..
This commit is contained in:
parent
e9e2a85c95
commit
26815d177f
@ -48,9 +48,11 @@ class LearnedGradient(Optimizer):
|
||||
learning the scale on the parameters (we'll keep it >= this size)
|
||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll keep it <= this size)
|
||||
lr_mat_min: Minimum singular value of learning-rate matrices Q.
|
||||
lr_mat_max: Maximum singular value of learning-rate matrices Q.
|
||||
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
||||
orthogonalize_period: How often we re-orthogonalize the gradient covariance; this
|
||||
gets multiplied by lr_est_period, i.e. we orthogonalize less frequently
|
||||
diagonalize_period: How often we re-diagonalize the gradient covariance; this
|
||||
gets multiplied by lr_est_period, i.e. we diagonalize less frequently
|
||||
than we update the learning-rate matrixes.
|
||||
max_step_scale: Value that determines limit on how large steps can be per element,
|
||||
intended to prevent instability during early steps before the learning-rate
|
||||
@ -64,11 +66,13 @@ class LearnedGradient(Optimizer):
|
||||
meta_lr_scale=0.1,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1.0e-08,
|
||||
size_update_period=4,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=2.0,
|
||||
lr_mat_min=0.01,
|
||||
lr_mat_max=4.0,
|
||||
lr_est_period=10,
|
||||
max_lr_eig=4.0,
|
||||
orthogonalize_period=4,
|
||||
diagonalize_period=4,
|
||||
max_step_scale=2.0
|
||||
):
|
||||
|
||||
@ -79,11 +83,13 @@ class LearnedGradient(Optimizer):
|
||||
meta_lr_scale=meta_lr_scale,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
size_update_period=size_update_period,
|
||||
param_min_rms=param_min_rms,
|
||||
param_max_rms=param_max_rms,
|
||||
lr_mat_min=lr_mat_min,
|
||||
lr_mat_max=lr_mat_max,
|
||||
lr_est_period=lr_est_period,
|
||||
max_lr_eig=max_lr_eig,
|
||||
orthogonalize_period=orthogonalize_period,
|
||||
diagonalize_period=diagonalize_period,
|
||||
max_step_scale=max_step_scale,
|
||||
)
|
||||
|
||||
@ -112,12 +118,13 @@ class LearnedGradient(Optimizer):
|
||||
size_lr = lr * group["size_lr_scale"]
|
||||
beta1, beta2 = group["betas"]
|
||||
max_lr_eig = group["max_lr_eig"]
|
||||
min_lr_eig = group["max_lr_eig"]
|
||||
eps = group["eps"]
|
||||
size_update_period = 4
|
||||
param_min_eps = group["param_min_rms"]
|
||||
param_max_eps = group["param_max_rms"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
lr_est_period = group["lr_est_period"]
|
||||
orthogonalize_period = group["orthogonalize_period"]
|
||||
diagonalize_period = group["diagonalize_period"]
|
||||
max_step_scale = group["max_step_scale"]
|
||||
|
||||
for p in group["params"]:
|
||||
@ -158,39 +165,52 @@ class LearnedGradient(Optimizer):
|
||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq
|
||||
state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs)
|
||||
|
||||
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
||||
# set zero_step.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
state["zero_step"] = 0
|
||||
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
# if size == p.numel(), is_one_axis == True above
|
||||
continue
|
||||
|
||||
# exp_avg_sq is in the projected space.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
|
||||
# proj_{dim} will be the learning-rate matrix for this
|
||||
# Q_{dim}, which we will call q in this comment, will be the learning-rate matrix for this
|
||||
# dimension, times an orthogonal matrix that we'll periodically
|
||||
# re-estimate when we "reorthogonalize" (this orthogonalizes
|
||||
# re-estimate when we "rediagonalize" (this diagonalizes
|
||||
# the gradient covariance).
|
||||
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
|
||||
# If our parameter matrix M is of shape (..., size),
|
||||
# we'll view M as being M == torch.matmul(N, q)
|
||||
# so p can be interpreted as having shape
|
||||
# (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate]
|
||||
# by "diagonalize" we mean that we diagonalize the gradient covariance.
|
||||
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
|
||||
|
||||
# proj_grad_{dim} is the gradient w.r.t. proj_{grad}
|
||||
# (technically w.r.t a matrix to the left of
|
||||
# proj_{grad}, we'll project it later); we accumulate
|
||||
# this for `lr_est_period` steps, and then update
|
||||
# proj_{dim} and zero this matrix
|
||||
# proj_grad_{dim} is the gradient w.r.t. `proj`,
|
||||
# which is a matrix we introduce to the right of p, i.e.
|
||||
# instead of M == torch.matmul(N, p) we view it as
|
||||
# M == torch.matmul(torch.matmul(N, torch.matmul(p, proj)))
|
||||
# or equivalently M == torch.matmul(M_underlying, proj)
|
||||
# where M_underlying is numerically equal to M, and proj is numerically
|
||||
# equal to I, but they are viewed as separate variables.
|
||||
#
|
||||
# proj_grad_{dim} will be set to the sum of
|
||||
# torch.matmul(M^T, M_grad) over a number of steps (`lr_est_period` steps),
|
||||
# and will be zeroed after we call self_update_lrs().
|
||||
state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||
# This is needed when we re-orthogonalize, and to compute the
|
||||
# This is needed when we re-diagonalize, and to compute the
|
||||
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
||||
|
||||
# only for purposes of computing the scalar factor on the
|
||||
# learning rate; and, as a result of this, also contributes something
|
||||
# to the gradient w.r.t. f"lr_{dim}", which is one reason
|
||||
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
||||
# why we allocate a variable to keep track of its moving average
|
||||
# instead of just using a temporary and smoothing the scalar factor.
|
||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
@ -220,17 +240,21 @@ class LearnedGradient(Optimizer):
|
||||
scale_grads, param_rms,
|
||||
beta1, beta2, step, size_lr,
|
||||
param_min_rms, param_max_rms)
|
||||
|
||||
state["delta"].mul_(beta1)
|
||||
if numel == 1:
|
||||
# For parameters with very few elements we just use a form
|
||||
# of Adam with a scale factor to reflect the overall
|
||||
# parameter rms.
|
||||
# parameter rms. Updates delta.
|
||||
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
|
||||
else:
|
||||
|
||||
if step % lr_est_period == 0:
|
||||
self._update_lrs(group, p, state)
|
||||
if step % (lr_est_period * diagonalize_period) == 0:
|
||||
self._diagonalize_lrs(group, p, state)
|
||||
self._zero_exp_avg_sq()
|
||||
self._step(group, p, grad, state)
|
||||
|
||||
p.add_(delta)
|
||||
state["step"] = step + 1
|
||||
|
||||
return loss
|
||||
@ -238,6 +262,7 @@ class LearnedGradient(Optimizer):
|
||||
def _size_update(self,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
scale_grad: Tensor,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
size_lr: float,
|
||||
@ -258,7 +283,6 @@ class LearnedGradient(Optimizer):
|
||||
tensor p (for enforcing param_min_rms and param_max_rms)
|
||||
beta1: Beta value for decaying gradient stats
|
||||
beta2: Beta value for decaying gradient-squared stats
|
||||
step: The step index
|
||||
size_lr: The learning rate to apply to `scale`
|
||||
param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this
|
||||
param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this
|
||||
@ -268,19 +292,14 @@ class LearnedGradient(Optimizer):
|
||||
# products (p * grad).sum() for the `size_update_period` most recent steps
|
||||
size_update_period, = scale_grads.shape
|
||||
|
||||
# correct beta1 and beta2 for the size update period: we will have
|
||||
# faster decay at this level. The same for the learning rate.
|
||||
beta1_corr = beta1 ** size_update_period
|
||||
# correct beta2 for the size update period: we will have
|
||||
# faster decay at this level.
|
||||
beta2_corr = beta2 ** size_update_period
|
||||
size_lr_corr = size_lr * size_update_period
|
||||
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads ** 2).mean(), value=1-beta2_corr)
|
||||
|
||||
scale_exp_avg = state["scale_exp_avg"]
|
||||
scale_exp_avg.mul_(beta1_corr).add_(
|
||||
scale_grads.mean(), value=1-beta1_corr)
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
size_step = (step + 1) // size_update_period
|
||||
@ -291,15 +310,17 @@ class LearnedGradient(Optimizer):
|
||||
eps = 1.0e-10 # this value is not critical.
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = -size_lr_corr * (bias_correction2 ** 0.5) * scale_exp_avg / denom
|
||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
|
||||
|
||||
is_too_small = (param_rms < param_min_rms)
|
||||
is_too_large = (param_rms > param_max_rms)
|
||||
|
||||
scale_step.masked_fill_(is_too_small, size_lr_corr)
|
||||
scale_step.masked_fill_(is_too_large, -size_lr_corr)
|
||||
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
|
||||
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
||||
|
||||
p.add_(p, alpha=scale_step)
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
delta.add_(p, alpha=(1-beta1) * scale_step)
|
||||
|
||||
|
||||
def _update_lrs(self,
|
||||
@ -307,7 +328,7 @@ class LearnedGradient(Optimizer):
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Updates the learning-rate matrices state["lr_{dim}"].
|
||||
Updates the learning-rate matrices state["Q_{dim}"].
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
@ -322,7 +343,7 @@ class LearnedGradient(Optimizer):
|
||||
meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5)
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
moving_g = state["exp_avg"]
|
||||
delta = state["delta"]
|
||||
ndim = p.ndim
|
||||
|
||||
for dim in range(ndim):
|
||||
@ -330,148 +351,247 @@ class LearnedGradient(Optimizer):
|
||||
if size == 1:
|
||||
continue
|
||||
|
||||
p_t = p.transpose(dim, -1)
|
||||
this_p = p_t.reshape(-1, size)
|
||||
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
||||
# all other dimensions of the tensor.
|
||||
M_full = p.transpose(dim, -1)
|
||||
M = M.reshape(-1, size)
|
||||
|
||||
# proj_grad is a summed gradient, indexed [grad_idx][param_idx] (don't worry
|
||||
# if this notation about the indexing is unclear, just keep reading).
|
||||
# proj_grad is a summed gradient, of shape (size, size),
|
||||
# indexed [old_param_idx][new_param_idx] (the point is,
|
||||
# param_idx corresponds to the index of the current M, new_param_idx
|
||||
# corresponds to the index of the M after a small change).
|
||||
#
|
||||
# Suppose M is our parameter matrix p reshaped to (size, -1) with the -1 being
|
||||
# all other dims treated as a batch dim, and M_grad is the loss derivative w.r.t. M,
|
||||
# Suppose M is our parameter matrix p reshaped to (-1, size) with
|
||||
# the -1 being all other dims of the tensor treated as a batch dim
|
||||
# and "size" is the size of whatever dimension we are working on
|
||||
# right now. And suppose M_grad is the loss derivative w.r.t. M
|
||||
# (for compatibility with things like Torch, we use a convention
|
||||
# where all loss derivatives/grads have the same shape as the things they
|
||||
# are derivative with respect to).
|
||||
#
|
||||
# Then proj_grad is the accumulated value of M_grad M^T, which will be of shape
|
||||
# (size, size). If we view the parameter matrix as actually (proj M) with
|
||||
# proj being numerically equal to I, i.e. as
|
||||
# matmul(proj, M), then instead of the loss being tr(M_grad^T M), we can write
|
||||
# it as tr(M_grad^T proj M), which can also be written as tr(M M_grad^T proj)
|
||||
# = tr( (M_grad M^T)^T proj), where we can write M_grad M^T == proj_grad.
|
||||
# proj_grad is the accumulated value, on the last few steps, of M M_grad^T,
|
||||
# which will be of shape (size, size). For purposes of our update, if we view
|
||||
# the parameter matrix M as (M_underlying proj) with
|
||||
# proj being numerically equal to I but treated as a variable, i.e. as
|
||||
# matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write
|
||||
# it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj)
|
||||
# where proj_grad == M_underlying^T M_grad.
|
||||
#
|
||||
# Now, proj_grad is convenient to compute but it's not in a very normalized
|
||||
# space, we want to normalize it first. We're going to view M as
|
||||
# M = T N,
|
||||
# where T is our proj_{dim} matrix (see the code), and N == T^{-1} M.
|
||||
# It's more convenient to have the projection between T and N, i.e. have proj2,
|
||||
# again numerically equal to I, so
|
||||
# M == T proj2 N.
|
||||
# [Be careful because when we use T as a superscript it means transpose).
|
||||
# We want the derivative w.r.t. proj2.
|
||||
# So suppose the loss is
|
||||
# tr(M_grad^T M) == tr(M_grad^T T proj2 N) == tr(N M_grad^T T proj2)
|
||||
# == tr( (T^T M_grad N^T)^T proj2 )
|
||||
# == tr( (T^T M_grad N^T)^T proj2 )
|
||||
# and using N == T^{-1} M, so N^T = M^T T^{-T},
|
||||
# == tr( (T^T M_grad M^T T^{-T})^T proj2 )
|
||||
# so we can view the thing in parentheses in the above expression as proj2_grad, i.e.
|
||||
# proj2_grad = T^T M_grad M^T T^{-T}
|
||||
# proj2_grad = T^T proj_grad T^{-T}
|
||||
#
|
||||
# note: p = state[f"proj_{dim}"] actually corresponds to T^T, since it has a
|
||||
# shape interpretable as (diagonalized_index, canonical_index).
|
||||
#
|
||||
# So proj2_grad = p proj_grad p^{-1} (eq.1)
|
||||
# space, we want to normalize it before doing gradient descent on it.
|
||||
# We're going to view M as
|
||||
# M == N Q,
|
||||
# where Q is our Q_{dim} "learning-rate" matrix, of shape (size, size) indexed
|
||||
# [diagonalized_index, canonical_index] (see the code), and N == M Q^{-1}.
|
||||
# It's more convenient to do gradient descent on a `proj` matrix that's between N and Q,
|
||||
# i.e. define `proj2`, also numerically equal to I but treated as a variable,
|
||||
# and have:
|
||||
# M == N proj2 Q
|
||||
# We want to obtain the derivative w.r.t. proj2.
|
||||
# So the linearized pseudo-loss is
|
||||
# tr(M_grad^T M) == tr(M_grad^T N proj2 Q) == tr(Q M_grad^T N proj2),
|
||||
# which we can view as tr(proj2_grad^T proj2), where
|
||||
# proj2_grad == (Q M_grad^T N)^T == N^T M_grad Q^T == Q^{-T} M^T M_grad Q^T
|
||||
# proj2_grad = Q^{-T} proj_grad Q^T (eq.1)
|
||||
#
|
||||
# After computing proj2_grad we compute/update a factorized representation of
|
||||
# its variance, factorized over its 2 axes, and do an Adam-type update
|
||||
# for it, except without momentum at this level (momentum will be applied
|
||||
# later). This gives us proj2_delta, according to the
|
||||
# equation: proj2 = I + proj2_delta.
|
||||
# equation:
|
||||
# proj2 = I + proj2_delta
|
||||
# (again: proj2_delta is derived from proj2_grad using an Adam-style update).
|
||||
#
|
||||
# This is the point at which we control the magnitude of the parameter change.
|
||||
#
|
||||
# We now have to update the factor T using proj2, as in:
|
||||
# T <-- T (I + proj2_delta)
|
||||
# T <-- T + T proj2_delta
|
||||
# and since p == T^T, this becomes:
|
||||
# p <-- p + proj2_delta^T p
|
||||
# or equivalently:
|
||||
# p_delta = proj2_delta^T p (eq.2)
|
||||
# We now have to update the "learning-rate matrix" Q using proj2, as in
|
||||
# Q := proj2 Q
|
||||
# Q := (proj2_delta + I) Q, or:
|
||||
# Q += proj2_delta Q (eq.2)
|
||||
#
|
||||
# We also need to update M itself; this is easiest done by computing proj_delta
|
||||
# (which is the deviation of proj from I). The relationship between proj and proj2
|
||||
# (which is the difference proj from I). The relationship between proj and proj2
|
||||
# is given by equating
|
||||
# proj T N = T proj2 N,
|
||||
# proj T = T proj2,
|
||||
# proj = T proj2 T^{-1}
|
||||
# proj = p^T proj2 p^{-T} .
|
||||
# we can apply the above formula directly to proj2_delta, since the "I" part just
|
||||
# becomes a no-op, i.e.
|
||||
# proj_delta = p^T proj2_delta p^{-T}.
|
||||
# ... and to turn this back into a delta on M, we'll do:
|
||||
# M_delta = proj_delta M.
|
||||
|
||||
|
||||
# M == N proj2 Q == N Q proj,
|
||||
# so proj2 Q == Q proj
|
||||
# proj == Q^{-1} proj2 Q
|
||||
# and subtracting out the constant "I" part,
|
||||
# proj_delta == Q^{-1} proj2_delta Q (eq.3)
|
||||
# ... so the change in Q is given by:
|
||||
# Q := Q (I + proj_delta),
|
||||
# Q += Q proj_delta
|
||||
# Q_delta = Q proj_delta
|
||||
# and looking at how we compute proj_delta (eq.3), we notice the right-hand subexpression
|
||||
# is the sam as p_delta, i.e.:
|
||||
# Q_delta = proj2_delta Q (eq.4)
|
||||
#
|
||||
# and then because we view the parameter matrix as
|
||||
# "M proj",
|
||||
# with proj being numerically equal to I but learned here,
|
||||
# it becomes M (I + proj_delta) = M + M proj_delta.
|
||||
# So the update to M is:
|
||||
# M += M proj_delta
|
||||
# or equivalently:
|
||||
# M_delta = M proj_delta (eq.5)
|
||||
proj_grad = state[f"proj_grad_{dim}"]
|
||||
p = state[f"proj_{dim}"]
|
||||
Q = state[f"Q_{dim}"]
|
||||
|
||||
# torch.linalg.solve(A, B) returns X such that AX=B,
|
||||
# meaning X = A^{-1} B. We want to compute X = proj_grad p^{-1},
|
||||
# so X^T = p^{-T} proj_grad^T.
|
||||
X = torch.linalg.solve(p.t(), proj_grad.t()).t()
|
||||
assert torch.allclose(proj_grad, torch.matmul(X, p)) # TEMP.
|
||||
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||
# torch.linalg.solve(A, B) returns A^{-1} B.
|
||||
Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad)
|
||||
assert torch.allclose(proj_grad, torch.matmul(Q.t(), Q_invt_proj_grad)) # TODO: remove
|
||||
|
||||
# See (eq.1), the next line implements proj_grad = p proj_grad p^{-1}
|
||||
proj2_grad = torch.matmul(p, X)
|
||||
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
|
||||
|
||||
# for now just estimate it from proj2_grad, later we will try aggregating it
|
||||
# across iterations. This is like exp_avg_sq in Adam.
|
||||
# for now just estimate proj2_grad_var from proj2_grad; later, we
|
||||
# may try aggregating it across iterations. This is like exp_avg_sq
|
||||
# in Adam.
|
||||
proj2_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
||||
|
||||
denom = proj2_grad_var.sqrt()
|
||||
|
||||
alpha = -meta_lr * (1-beta1)
|
||||
|
||||
# we'll take into account alpha later on.
|
||||
# we'll take into account the factor of -meta_lr * (1-beta1) later on;
|
||||
# actually proj2_delta is the negative of a parameter change.
|
||||
proj2_delta = proj2_grad_var / denom
|
||||
|
||||
p_delta = torch.matmul(proj2_delta.t(), p)
|
||||
# See (eq.4), Q_delta = proj2_delta q
|
||||
Q_delta = torch.matmul(proj2_delta, Q)
|
||||
|
||||
# See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta
|
||||
proj_delta = torch.linalg.solve(Q, Q_delta)
|
||||
|
||||
M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta
|
||||
|
||||
alpha = -meta_lr * (1-beta1)
|
||||
|
||||
# To update the learning rate matrices, we are asking the question, "What if,
|
||||
# on the previous step, we had used a slightly different learning-rate matrix?
|
||||
# How would that have affected the loss function on the current step?"
|
||||
moving_d = self._multiply_by_lr(moving_g)
|
||||
M_delta_full = M_delta.reshape(M_full.shape)
|
||||
|
||||
# moving_d_rms is an rms value of moving_d computed via inner
|
||||
# products with the grad, which is a basis-independent way of
|
||||
# computing the rms value. you have to view this as the rms times
|
||||
# an arbitrary multiplicative factor. in the "real" update we don't
|
||||
# actually do the length normalization like this because it's slow,
|
||||
# we do it a different, simpler way; but here we do it this way
|
||||
# because we are computing gradients and it's going to make it more
|
||||
# sensitive to the norm used. If we used the regular norm on moving_d, it
|
||||
# might concentrate all the movement in "don't-care" dimensions in
|
||||
# parameter space, assuming such dimensions exist.
|
||||
moving_d_rms = (self._multiply_by_grad_cov(moving_d) * moving_d).mean().sqrt()
|
||||
this_delta = M_delta_full.transpose(dim, -1)
|
||||
# we'l be applying momentum on delta, so multiply by (1-beta1) to
|
||||
# get the total magnitude of the change right. The 2 changes below
|
||||
# are the final changes, the only 2 we make in this loop that have
|
||||
# side effects.
|
||||
delta.add_(this_delta, alpha=-meta_lr*(1-beta1))
|
||||
# there is no momentum on Q.
|
||||
Q.add_(Q_delta, alpha=-meta_lr)
|
||||
|
||||
# moving_d_norm is the parameter "delta" moving_d, length-normaized but still with
|
||||
# an arbitrary multiplicative constant. We don't care about this constant
|
||||
# because we are going to normalize the grad length when we do the update;
|
||||
# but it's important to have normalized by dividing by moving_d_rms because it
|
||||
# makes the gradient "scale-independent" in the sense that it is zero in
|
||||
# the direction of increasing or decreasing the parameter scale.
|
||||
moving_d_norm = moving_d / moving_d_rms
|
||||
def _zero_exp_avg_sq(self,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
|
||||
"""
|
||||
state["exp_avg_sq"].zero_()
|
||||
state["zero_step"] = state["step"]
|
||||
|
||||
# view the current parameter p as [something] - alpha *
|
||||
# moving_d_norm, where alpha contains the learning rate and other
|
||||
# factors, but we don't care about the numerical value of alpha
|
||||
# because we're going to divide by the gradient norm anyway. The
|
||||
# (pseudo-) loss function would be (p * grad).sum(), which ignoring
|
||||
# [something] as it contributes no gradient, becomes -alpha *
|
||||
# (moving_d_norm * grad).sum(), and ignoring the scale alpha, that
|
||||
# is -(moving_d_norm * grad).sum(), so we call this neg_loss as it's the
|
||||
# negative of the loss.
|
||||
neg_loss = (grad * moving_d_norm).sum()
|
||||
def _diagonalize_lrs(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Diagonalizes the learning-rate matrices state["Q_{dim}"], and also
|
||||
applies the min- and max- singular value constraints. This Q
|
||||
is applied to the parameter matrix M such that we interpret
|
||||
M == N Q.
|
||||
By "diagonalize", we mean to left-miltiply Q by an orthogonal matrix,
|
||||
Q := U Q,
|
||||
so that the covariance of gradients w.r.t. N, measured
|
||||
over the 2nd index of N while treating its 1st index as a batch index,
|
||||
is diagonal.
|
||||
|
||||
# after the following, there will grads in state[f"lr_{dim}"]
|
||||
# that are the negative of loss function grads, i.e. we want to go
|
||||
# in the forward grad direction.
|
||||
neg_loss.backward()
|
||||
Before doing this, also limits the singular values of Q so that the
|
||||
max,min eig are lr_min_eig,lr_max_eig, and then renormalizing so the
|
||||
mean is 1.0 and limiting again.
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
p: parameter matrix that we are updating. The learning rate matrices
|
||||
are actually factors of p, so p itself will change when we change
|
||||
them.
|
||||
state: state dict for the current parameter
|
||||
"""
|
||||
# meta_lr is the learning rate for the learning rate matrix. Including the factor
|
||||
# (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant
|
||||
# to the lr_est_period (close to convergence).
|
||||
|
||||
lr_mat_min = group["lr_mat_min"]
|
||||
lr_mat_max = group["lr_mat_max"]
|
||||
|
||||
ndim = p.ndim
|
||||
|
||||
for dim in range(ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
Q = state[f"Q_{dim}"]
|
||||
|
||||
if True:
|
||||
# This block limits the singular values of Q.
|
||||
try:
|
||||
U,S,V = Q.svd()
|
||||
except:
|
||||
# if SVD fails for some reason we can rotate Q by something arbitrary on the
|
||||
# left, as we're going to rotate here anyway, and try again.
|
||||
logging.warn("Error doing SVD on Q. Trying again after rotation.")
|
||||
U,_,_ = torch.randn_like(Q).svd() # Create random orthogonal matrix.
|
||||
Q[:] = torch.matmul(U, Q)
|
||||
U,S,V = Q.svd()
|
||||
S_new = S.clamp(lr_mat_min, lr_mat_max)
|
||||
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
|
||||
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
|
||||
# Reconstruct Q with the modified S.
|
||||
Q[:] = torch.matmul(U * S, V.t())
|
||||
if random.random() < 0.1:
|
||||
subsample = max(1, S.numel() // 20)
|
||||
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
|
||||
|
||||
if True:
|
||||
# This block does the actual diagonalization.
|
||||
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
|
||||
# the -1 represents all other tensor dims treated as a batch dimension.
|
||||
# M_grad is the same shape as M. We could write a pseudo-loss as
|
||||
# loss = tr(M_grad^T M)
|
||||
# Because we can decompose M as M == N Q, we can write:
|
||||
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
|
||||
# so we can write this at tr(N_grad^T N),
|
||||
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
|
||||
# Now,
|
||||
# grad_cov == M_grad^T M_grad,
|
||||
# decaying-averaged over minibatches; this is of shape (size,size).
|
||||
# Using N_grad = M_grad Q^T, we can write:
|
||||
# N_grad_cov = Q M_grad^T M_grad Q^T
|
||||
# = Q grad_cov Q^T
|
||||
# (note: this makes sense because the 1st index of Q is the diagonalized
|
||||
# index).
|
||||
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q))
|
||||
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
|
||||
U, S, V = N_grad_cov.svd()
|
||||
|
||||
# N_grad_cov is SPD, so
|
||||
# N_grad_cov = U S U^T.
|
||||
|
||||
# Now, we can diagonalize N_grad_cov with:
|
||||
# U^T N_grad_cov U == S.
|
||||
# N_grad_cov is a sum of N_grad^T N_grad.
|
||||
# So U^T N_grad^T N_grad U is diagonal.
|
||||
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
|
||||
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
|
||||
#
|
||||
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
|
||||
# as tr(hat_N_grad hat_N), where:
|
||||
# hat_N_grad = N_grad U
|
||||
# hat_N = N U
|
||||
# (hat_N means \hat{N}, or N with a hat on it).
|
||||
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
|
||||
# hat_N will be diagonalized. We also modify Q to hat_Q when
|
||||
# we modify hat_N, to keep the product M = N Q = N U U^T Q = hat_N hat_Q
|
||||
# This can be done by setting
|
||||
# hat_Q = U^T Q (eq.10)
|
||||
#
|
||||
# This is the only thing we have to do, as N is implicit
|
||||
# and not materialized at this point.
|
||||
Q[:] = torch.matmul(U.t(), Q)
|
||||
|
||||
for dim in range(ndim):
|
||||
if p.shape[dim] != 1:
|
||||
self._update_lr(state[f"lr_{dim}"], meta_lr)
|
||||
|
||||
def _get_matrix_var(self,
|
||||
x: Tensor,
|
||||
@ -581,14 +701,18 @@ class LearnedGradient(Optimizer):
|
||||
|
||||
|
||||
for dim in range(grad.ndim):
|
||||
# This block accumulates the statistics proj_grad_{dim} and
|
||||
# grad_cov_{dim}, which are for periodically updating the
|
||||
# learning-rate matrices.
|
||||
size = grad.shape[size]
|
||||
# accumulate some stats for learning the projections
|
||||
proj_grad = state[f"proj_grad_{dim}"]
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
this_p = p.transpose(-1, dim).reshape(-1, size)
|
||||
this_g = g.transpose(-1, dim).reshape(-1, size)
|
||||
proj_grad.add_(torch.matmul(this_g.t(), this_p))
|
||||
# could perhaps accumulate grad_cov less frequently.
|
||||
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||
this_g = g.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||
proj_grad.add_(torch.matmul(this_m.t(), this_g))
|
||||
# could perhaps accumulate grad_cov less frequently; it's only
|
||||
# needed when we rediagonalize which is not that common.
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||
|
||||
grad = self._project(grad, forward=True, state)
|
||||
@ -624,26 +748,25 @@ class LearnedGradient(Optimizer):
|
||||
alpha = state["param_rms"] * (1-beta1) * lr / denom
|
||||
|
||||
delta = state["delta"]
|
||||
delta.mul_(beta1).add_(grad, alpha=alpha)
|
||||
delta.add_(grad, alpha=alpha)
|
||||
|
||||
param.add_(delta)
|
||||
|
||||
|
||||
def _step_scalar(self,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
eps: float,
|
||||
lr: float,
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
state: dict):
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
eps: float,
|
||||
lr: float,
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
state: dict):
|
||||
"""
|
||||
A form of the core update for tensors with a small number of elements, e.g. scalars. This is
|
||||
Adam where, if the numel() > 1, the learning rate is proportional to the parameter rms value.
|
||||
A form of the core update for tensors with a small number of elements,
|
||||
e.g. scalars. This is Adam where, if the numel() > 1, the learning rate
|
||||
is proportional to the parameter rms value.
|
||||
"""
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=1-beta2)
|
||||
|
||||
@ -651,11 +774,10 @@ class LearnedGradient(Optimizer):
|
||||
# slower update at the start will help stability anyway.
|
||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
|
||||
alpha = -lr / denom
|
||||
# for now not supporting p.numel() > 1.
|
||||
#if p.numel() > 1:
|
||||
# alpha = alpha * state["param_rms"]
|
||||
p.add_(exp_avg, alpha=alpha)
|
||||
alpha = -lr * (1-beta1) / denom
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad / denom, alpha=alpha)
|
||||
|
||||
|
||||
def _project(self,
|
||||
@ -670,19 +792,19 @@ class LearnedGradient(Optimizer):
|
||||
Args:
|
||||
x: The tensor to project, e.g. a gradient or a parameter change.
|
||||
forward: if True, go in the forward directin (from canonical to diagnonalized
|
||||
co-ordinates; if False, the reverse. They differ by a transpose,
|
||||
co-ordinates); if False, the reverse. They differ by a transpose,
|
||||
not an inverse, and do not make a round trip.
|
||||
"""
|
||||
for dim in range(x.ndim):
|
||||
if x.shape[dim] == 1:
|
||||
continue
|
||||
proj = state[f"proj_{dim}"]
|
||||
if forward:
|
||||
# dimension 0 of `proj` is diagonalized co-ordinate.
|
||||
proj = proj.t()
|
||||
Q = state[f"Q_{dim}"]
|
||||
if not forward:
|
||||
# Q is indexed [canonical_index, diagonalized_index]
|
||||
Q = Q.t()
|
||||
# TODO: could possibly somehow force the output format to be unchanged.
|
||||
x = x.transpose(-1, dim)
|
||||
x = torch.matmul(x, proj)
|
||||
x = torch.matmul(x, Q)
|
||||
x = x.transpose(-1, dim)
|
||||
return x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user