mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Draft of the new method..
This commit is contained in:
parent
e9e2a85c95
commit
26815d177f
@ -48,9 +48,11 @@ class LearnedGradient(Optimizer):
|
|||||||
learning the scale on the parameters (we'll keep it >= this size)
|
learning the scale on the parameters (we'll keep it >= this size)
|
||||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||||
learning the scale on the parameters (we'll keep it <= this size)
|
learning the scale on the parameters (we'll keep it <= this size)
|
||||||
|
lr_mat_min: Minimum singular value of learning-rate matrices Q.
|
||||||
|
lr_mat_max: Maximum singular value of learning-rate matrices Q.
|
||||||
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
||||||
orthogonalize_period: How often we re-orthogonalize the gradient covariance; this
|
diagonalize_period: How often we re-diagonalize the gradient covariance; this
|
||||||
gets multiplied by lr_est_period, i.e. we orthogonalize less frequently
|
gets multiplied by lr_est_period, i.e. we diagonalize less frequently
|
||||||
than we update the learning-rate matrixes.
|
than we update the learning-rate matrixes.
|
||||||
max_step_scale: Value that determines limit on how large steps can be per element,
|
max_step_scale: Value that determines limit on how large steps can be per element,
|
||||||
intended to prevent instability during early steps before the learning-rate
|
intended to prevent instability during early steps before the learning-rate
|
||||||
@ -64,11 +66,13 @@ class LearnedGradient(Optimizer):
|
|||||||
meta_lr_scale=0.1,
|
meta_lr_scale=0.1,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
|
size_update_period=4,
|
||||||
param_min_rms=1.0e-05,
|
param_min_rms=1.0e-05,
|
||||||
param_max_rms=2.0,
|
param_max_rms=2.0,
|
||||||
|
lr_mat_min=0.01,
|
||||||
|
lr_mat_max=4.0,
|
||||||
lr_est_period=10,
|
lr_est_period=10,
|
||||||
max_lr_eig=4.0,
|
diagonalize_period=4,
|
||||||
orthogonalize_period=4,
|
|
||||||
max_step_scale=2.0
|
max_step_scale=2.0
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -79,11 +83,13 @@ class LearnedGradient(Optimizer):
|
|||||||
meta_lr_scale=meta_lr_scale,
|
meta_lr_scale=meta_lr_scale,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
|
size_update_period=size_update_period,
|
||||||
param_min_rms=param_min_rms,
|
param_min_rms=param_min_rms,
|
||||||
param_max_rms=param_max_rms,
|
param_max_rms=param_max_rms,
|
||||||
|
lr_mat_min=lr_mat_min,
|
||||||
|
lr_mat_max=lr_mat_max,
|
||||||
lr_est_period=lr_est_period,
|
lr_est_period=lr_est_period,
|
||||||
max_lr_eig=max_lr_eig,
|
diagonalize_period=diagonalize_period,
|
||||||
orthogonalize_period=orthogonalize_period,
|
|
||||||
max_step_scale=max_step_scale,
|
max_step_scale=max_step_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -112,12 +118,13 @@ class LearnedGradient(Optimizer):
|
|||||||
size_lr = lr * group["size_lr_scale"]
|
size_lr = lr * group["size_lr_scale"]
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
max_lr_eig = group["max_lr_eig"]
|
max_lr_eig = group["max_lr_eig"]
|
||||||
|
min_lr_eig = group["max_lr_eig"]
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
size_update_period = 4
|
size_update_period = 4
|
||||||
param_min_eps = group["param_min_rms"]
|
param_min_rms = group["param_min_rms"]
|
||||||
param_max_eps = group["param_max_rms"]
|
param_max_rms = group["param_max_rms"]
|
||||||
lr_est_period = group["lr_est_period"]
|
lr_est_period = group["lr_est_period"]
|
||||||
orthogonalize_period = group["orthogonalize_period"]
|
diagonalize_period = group["diagonalize_period"]
|
||||||
max_step_scale = group["max_step_scale"]
|
max_step_scale = group["max_step_scale"]
|
||||||
|
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
@ -158,39 +165,52 @@ class LearnedGradient(Optimizer):
|
|||||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq
|
# scalar exp_avg_sq. not the same as scale_exp_avg_sq
|
||||||
state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs)
|
state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs)
|
||||||
|
|
||||||
|
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
||||||
|
# set zero_step.
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
state["zero_step"] = 0
|
||||||
|
|
||||||
for dim in range(p.ndim):
|
for dim in range(p.ndim):
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
if size == 1:
|
if size == 1:
|
||||||
# if size == p.numel(), is_one_axis == True above
|
# if size == p.numel(), is_one_axis == True above
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# exp_avg_sq is in the projected space.
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(
|
|
||||||
p, memory_format=torch.preserve_format
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Q_{dim}, which we will call q in this comment, will be the learning-rate matrix for this
|
||||||
# proj_{dim} will be the learning-rate matrix for this
|
|
||||||
# dimension, times an orthogonal matrix that we'll periodically
|
# dimension, times an orthogonal matrix that we'll periodically
|
||||||
# re-estimate when we "reorthogonalize" (this orthogonalizes
|
# re-estimate when we "rediagonalize" (this diagonalizes
|
||||||
# the gradient covariance).
|
# the gradient covariance).
|
||||||
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
|
# If our parameter matrix M is of shape (..., size),
|
||||||
|
# we'll view M as being M == torch.matmul(N, q)
|
||||||
|
# so p can be interpreted as having shape
|
||||||
|
# (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate]
|
||||||
|
# by "diagonalize" we mean that we diagonalize the gradient covariance.
|
||||||
|
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
|
||||||
|
|
||||||
# proj_grad_{dim} is the gradient w.r.t. proj_{grad}
|
# proj_grad_{dim} is the gradient w.r.t. `proj`,
|
||||||
# (technically w.r.t a matrix to the left of
|
# which is a matrix we introduce to the right of p, i.e.
|
||||||
# proj_{grad}, we'll project it later); we accumulate
|
# instead of M == torch.matmul(N, p) we view it as
|
||||||
# this for `lr_est_period` steps, and then update
|
# M == torch.matmul(torch.matmul(N, torch.matmul(p, proj)))
|
||||||
# proj_{dim} and zero this matrix
|
# or equivalently M == torch.matmul(M_underlying, proj)
|
||||||
|
# where M_underlying is numerically equal to M, and proj is numerically
|
||||||
|
# equal to I, but they are viewed as separate variables.
|
||||||
|
#
|
||||||
|
# proj_grad_{dim} will be set to the sum of
|
||||||
|
# torch.matmul(M^T, M_grad) over a number of steps (`lr_est_period` steps),
|
||||||
|
# and will be zeroed after we call self_update_lrs().
|
||||||
state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs)
|
state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||||
|
|
||||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||||
# This is needed when we re-orthogonalize, and to compute the
|
# This is needed when we re-diagonalize, and to compute the
|
||||||
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
||||||
|
|
||||||
# only for purposes of computing the scalar factor on the
|
# only for purposes of computing the scalar factor on the
|
||||||
# learning rate; and, as a result of this, also contributes something
|
# learning rate; and, as a result of this, also contributes something
|
||||||
# to the gradient w.r.t. f"lr_{dim}", which is one reason
|
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
||||||
# why we allocate a variable to keep track of its moving average
|
# why we allocate a variable to keep track of its moving average
|
||||||
# instead of just using a temporary and smoothing the scalar factor.
|
# instead of just using a temporary and smoothing the scalar factor.
|
||||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||||
@ -220,17 +240,21 @@ class LearnedGradient(Optimizer):
|
|||||||
scale_grads, param_rms,
|
scale_grads, param_rms,
|
||||||
beta1, beta2, step, size_lr,
|
beta1, beta2, step, size_lr,
|
||||||
param_min_rms, param_max_rms)
|
param_min_rms, param_max_rms)
|
||||||
|
state["delta"].mul_(beta1)
|
||||||
if numel == 1:
|
if numel == 1:
|
||||||
# For parameters with very few elements we just use a form
|
# For parameters with very few elements we just use a form
|
||||||
# of Adam with a scale factor to reflect the overall
|
# of Adam with a scale factor to reflect the overall
|
||||||
# parameter rms.
|
# parameter rms. Updates delta.
|
||||||
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
|
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if step % lr_est_period == 0:
|
if step % lr_est_period == 0:
|
||||||
self._update_lrs(group, p, state)
|
self._update_lrs(group, p, state)
|
||||||
|
if step % (lr_est_period * diagonalize_period) == 0:
|
||||||
|
self._diagonalize_lrs(group, p, state)
|
||||||
|
self._zero_exp_avg_sq()
|
||||||
self._step(group, p, grad, state)
|
self._step(group, p, grad, state)
|
||||||
|
p.add_(delta)
|
||||||
state["step"] = step + 1
|
state["step"] = step + 1
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
@ -238,6 +262,7 @@ class LearnedGradient(Optimizer):
|
|||||||
def _size_update(self,
|
def _size_update(self,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict,
|
state: dict,
|
||||||
|
scale_grad: Tensor,
|
||||||
beta1: float,
|
beta1: float,
|
||||||
beta2: float,
|
beta2: float,
|
||||||
size_lr: float,
|
size_lr: float,
|
||||||
@ -258,7 +283,6 @@ class LearnedGradient(Optimizer):
|
|||||||
tensor p (for enforcing param_min_rms and param_max_rms)
|
tensor p (for enforcing param_min_rms and param_max_rms)
|
||||||
beta1: Beta value for decaying gradient stats
|
beta1: Beta value for decaying gradient stats
|
||||||
beta2: Beta value for decaying gradient-squared stats
|
beta2: Beta value for decaying gradient-squared stats
|
||||||
step: The step index
|
|
||||||
size_lr: The learning rate to apply to `scale`
|
size_lr: The learning rate to apply to `scale`
|
||||||
param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this
|
param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this
|
||||||
param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this
|
param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this
|
||||||
@ -268,19 +292,14 @@ class LearnedGradient(Optimizer):
|
|||||||
# products (p * grad).sum() for the `size_update_period` most recent steps
|
# products (p * grad).sum() for the `size_update_period` most recent steps
|
||||||
size_update_period, = scale_grads.shape
|
size_update_period, = scale_grads.shape
|
||||||
|
|
||||||
# correct beta1 and beta2 for the size update period: we will have
|
# correct beta2 for the size update period: we will have
|
||||||
# faster decay at this level. The same for the learning rate.
|
# faster decay at this level.
|
||||||
beta1_corr = beta1 ** size_update_period
|
|
||||||
beta2_corr = beta2 ** size_update_period
|
beta2_corr = beta2 ** size_update_period
|
||||||
size_lr_corr = size_lr * size_update_period
|
|
||||||
|
|
||||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||||
(scale_grads ** 2).mean(), value=1-beta2_corr)
|
(scale_grads ** 2).mean(), value=1-beta2_corr)
|
||||||
|
|
||||||
scale_exp_avg = state["scale_exp_avg"]
|
|
||||||
scale_exp_avg.mul_(beta1_corr).add_(
|
|
||||||
scale_grads.mean(), value=1-beta1_corr)
|
|
||||||
|
|
||||||
# The 1st time we reach here is when size_step == 1.
|
# The 1st time we reach here is when size_step == 1.
|
||||||
size_step = (step + 1) // size_update_period
|
size_step = (step + 1) // size_update_period
|
||||||
@ -291,15 +310,17 @@ class LearnedGradient(Optimizer):
|
|||||||
eps = 1.0e-10 # this value is not critical.
|
eps = 1.0e-10 # this value is not critical.
|
||||||
denom = scale_exp_avg_sq.sqrt() + eps
|
denom = scale_exp_avg_sq.sqrt() + eps
|
||||||
|
|
||||||
scale_step = -size_lr_corr * (bias_correction2 ** 0.5) * scale_exp_avg / denom
|
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
|
||||||
|
|
||||||
is_too_small = (param_rms < param_min_rms)
|
is_too_small = (param_rms < param_min_rms)
|
||||||
is_too_large = (param_rms > param_max_rms)
|
is_too_large = (param_rms > param_max_rms)
|
||||||
|
|
||||||
scale_step.masked_fill_(is_too_small, size_lr_corr)
|
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
|
||||||
scale_step.masked_fill_(is_too_large, -size_lr_corr)
|
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
||||||
|
|
||||||
p.add_(p, alpha=scale_step)
|
delta = state["delta"]
|
||||||
|
# the factor of (1-beta1) relates to momentum.
|
||||||
|
delta.add_(p, alpha=(1-beta1) * scale_step)
|
||||||
|
|
||||||
|
|
||||||
def _update_lrs(self,
|
def _update_lrs(self,
|
||||||
@ -307,7 +328,7 @@ class LearnedGradient(Optimizer):
|
|||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict) -> None:
|
state: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Updates the learning-rate matrices state["lr_{dim}"].
|
Updates the learning-rate matrices state["Q_{dim}"].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: dict to look up configuration values
|
group: dict to look up configuration values
|
||||||
@ -322,7 +343,7 @@ class LearnedGradient(Optimizer):
|
|||||||
meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5)
|
meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5)
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
moving_g = state["exp_avg"]
|
delta = state["delta"]
|
||||||
ndim = p.ndim
|
ndim = p.ndim
|
||||||
|
|
||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
@ -330,148 +351,247 @@ class LearnedGradient(Optimizer):
|
|||||||
if size == 1:
|
if size == 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
p_t = p.transpose(dim, -1)
|
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
||||||
this_p = p_t.reshape(-1, size)
|
# all other dimensions of the tensor.
|
||||||
|
M_full = p.transpose(dim, -1)
|
||||||
|
M = M.reshape(-1, size)
|
||||||
|
|
||||||
# proj_grad is a summed gradient, indexed [grad_idx][param_idx] (don't worry
|
# proj_grad is a summed gradient, of shape (size, size),
|
||||||
# if this notation about the indexing is unclear, just keep reading).
|
# indexed [old_param_idx][new_param_idx] (the point is,
|
||||||
|
# param_idx corresponds to the index of the current M, new_param_idx
|
||||||
|
# corresponds to the index of the M after a small change).
|
||||||
#
|
#
|
||||||
# Suppose M is our parameter matrix p reshaped to (size, -1) with the -1 being
|
# Suppose M is our parameter matrix p reshaped to (-1, size) with
|
||||||
# all other dims treated as a batch dim, and M_grad is the loss derivative w.r.t. M,
|
# the -1 being all other dims of the tensor treated as a batch dim
|
||||||
|
# and "size" is the size of whatever dimension we are working on
|
||||||
|
# right now. And suppose M_grad is the loss derivative w.r.t. M
|
||||||
|
# (for compatibility with things like Torch, we use a convention
|
||||||
|
# where all loss derivatives/grads have the same shape as the things they
|
||||||
|
# are derivative with respect to).
|
||||||
#
|
#
|
||||||
# Then proj_grad is the accumulated value of M_grad M^T, which will be of shape
|
# proj_grad is the accumulated value, on the last few steps, of M M_grad^T,
|
||||||
# (size, size). If we view the parameter matrix as actually (proj M) with
|
# which will be of shape (size, size). For purposes of our update, if we view
|
||||||
# proj being numerically equal to I, i.e. as
|
# the parameter matrix M as (M_underlying proj) with
|
||||||
# matmul(proj, M), then instead of the loss being tr(M_grad^T M), we can write
|
# proj being numerically equal to I but treated as a variable, i.e. as
|
||||||
# it as tr(M_grad^T proj M), which can also be written as tr(M M_grad^T proj)
|
# matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write
|
||||||
# = tr( (M_grad M^T)^T proj), where we can write M_grad M^T == proj_grad.
|
# it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj)
|
||||||
|
# where proj_grad == M_underlying^T M_grad.
|
||||||
#
|
#
|
||||||
# Now, proj_grad is convenient to compute but it's not in a very normalized
|
# Now, proj_grad is convenient to compute but it's not in a very normalized
|
||||||
# space, we want to normalize it first. We're going to view M as
|
# space, we want to normalize it before doing gradient descent on it.
|
||||||
# M = T N,
|
# We're going to view M as
|
||||||
# where T is our proj_{dim} matrix (see the code), and N == T^{-1} M.
|
# M == N Q,
|
||||||
# It's more convenient to have the projection between T and N, i.e. have proj2,
|
# where Q is our Q_{dim} "learning-rate" matrix, of shape (size, size) indexed
|
||||||
# again numerically equal to I, so
|
# [diagonalized_index, canonical_index] (see the code), and N == M Q^{-1}.
|
||||||
# M == T proj2 N.
|
# It's more convenient to do gradient descent on a `proj` matrix that's between N and Q,
|
||||||
# [Be careful because when we use T as a superscript it means transpose).
|
# i.e. define `proj2`, also numerically equal to I but treated as a variable,
|
||||||
# We want the derivative w.r.t. proj2.
|
# and have:
|
||||||
# So suppose the loss is
|
# M == N proj2 Q
|
||||||
# tr(M_grad^T M) == tr(M_grad^T T proj2 N) == tr(N M_grad^T T proj2)
|
# We want to obtain the derivative w.r.t. proj2.
|
||||||
# == tr( (T^T M_grad N^T)^T proj2 )
|
# So the linearized pseudo-loss is
|
||||||
# == tr( (T^T M_grad N^T)^T proj2 )
|
# tr(M_grad^T M) == tr(M_grad^T N proj2 Q) == tr(Q M_grad^T N proj2),
|
||||||
# and using N == T^{-1} M, so N^T = M^T T^{-T},
|
# which we can view as tr(proj2_grad^T proj2), where
|
||||||
# == tr( (T^T M_grad M^T T^{-T})^T proj2 )
|
# proj2_grad == (Q M_grad^T N)^T == N^T M_grad Q^T == Q^{-T} M^T M_grad Q^T
|
||||||
# so we can view the thing in parentheses in the above expression as proj2_grad, i.e.
|
# proj2_grad = Q^{-T} proj_grad Q^T (eq.1)
|
||||||
# proj2_grad = T^T M_grad M^T T^{-T}
|
|
||||||
# proj2_grad = T^T proj_grad T^{-T}
|
|
||||||
#
|
|
||||||
# note: p = state[f"proj_{dim}"] actually corresponds to T^T, since it has a
|
|
||||||
# shape interpretable as (diagonalized_index, canonical_index).
|
|
||||||
#
|
|
||||||
# So proj2_grad = p proj_grad p^{-1} (eq.1)
|
|
||||||
#
|
#
|
||||||
# After computing proj2_grad we compute/update a factorized representation of
|
# After computing proj2_grad we compute/update a factorized representation of
|
||||||
# its variance, factorized over its 2 axes, and do an Adam-type update
|
# its variance, factorized over its 2 axes, and do an Adam-type update
|
||||||
# for it, except without momentum at this level (momentum will be applied
|
# for it, except without momentum at this level (momentum will be applied
|
||||||
# later). This gives us proj2_delta, according to the
|
# later). This gives us proj2_delta, according to the
|
||||||
# equation: proj2 = I + proj2_delta.
|
# equation:
|
||||||
|
# proj2 = I + proj2_delta
|
||||||
|
# (again: proj2_delta is derived from proj2_grad using an Adam-style update).
|
||||||
#
|
#
|
||||||
# This is the point at which we control the magnitude of the parameter change.
|
# We now have to update the "learning-rate matrix" Q using proj2, as in
|
||||||
#
|
# Q := proj2 Q
|
||||||
# We now have to update the factor T using proj2, as in:
|
# Q := (proj2_delta + I) Q, or:
|
||||||
# T <-- T (I + proj2_delta)
|
# Q += proj2_delta Q (eq.2)
|
||||||
# T <-- T + T proj2_delta
|
|
||||||
# and since p == T^T, this becomes:
|
|
||||||
# p <-- p + proj2_delta^T p
|
|
||||||
# or equivalently:
|
|
||||||
# p_delta = proj2_delta^T p (eq.2)
|
|
||||||
#
|
#
|
||||||
# We also need to update M itself; this is easiest done by computing proj_delta
|
# We also need to update M itself; this is easiest done by computing proj_delta
|
||||||
# (which is the deviation of proj from I). The relationship between proj and proj2
|
# (which is the difference proj from I). The relationship between proj and proj2
|
||||||
# is given by equating
|
# is given by equating
|
||||||
# proj T N = T proj2 N,
|
# M == N proj2 Q == N Q proj,
|
||||||
# proj T = T proj2,
|
# so proj2 Q == Q proj
|
||||||
# proj = T proj2 T^{-1}
|
# proj == Q^{-1} proj2 Q
|
||||||
# proj = p^T proj2 p^{-T} .
|
# and subtracting out the constant "I" part,
|
||||||
# we can apply the above formula directly to proj2_delta, since the "I" part just
|
# proj_delta == Q^{-1} proj2_delta Q (eq.3)
|
||||||
# becomes a no-op, i.e.
|
# ... so the change in Q is given by:
|
||||||
# proj_delta = p^T proj2_delta p^{-T}.
|
# Q := Q (I + proj_delta),
|
||||||
# ... and to turn this back into a delta on M, we'll do:
|
# Q += Q proj_delta
|
||||||
# M_delta = proj_delta M.
|
# Q_delta = Q proj_delta
|
||||||
|
# and looking at how we compute proj_delta (eq.3), we notice the right-hand subexpression
|
||||||
|
# is the sam as p_delta, i.e.:
|
||||||
|
# Q_delta = proj2_delta Q (eq.4)
|
||||||
|
#
|
||||||
|
# and then because we view the parameter matrix as
|
||||||
|
# "M proj",
|
||||||
|
# with proj being numerically equal to I but learned here,
|
||||||
|
# it becomes M (I + proj_delta) = M + M proj_delta.
|
||||||
|
# So the update to M is:
|
||||||
|
# M += M proj_delta
|
||||||
|
# or equivalently:
|
||||||
|
# M_delta = M proj_delta (eq.5)
|
||||||
proj_grad = state[f"proj_grad_{dim}"]
|
proj_grad = state[f"proj_grad_{dim}"]
|
||||||
p = state[f"proj_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
|
|
||||||
# torch.linalg.solve(A, B) returns X such that AX=B,
|
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||||
# meaning X = A^{-1} B. We want to compute X = proj_grad p^{-1},
|
# torch.linalg.solve(A, B) returns A^{-1} B.
|
||||||
# so X^T = p^{-T} proj_grad^T.
|
Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad)
|
||||||
X = torch.linalg.solve(p.t(), proj_grad.t()).t()
|
assert torch.allclose(proj_grad, torch.matmul(Q.t(), Q_invt_proj_grad)) # TODO: remove
|
||||||
assert torch.allclose(proj_grad, torch.matmul(X, p)) # TEMP.
|
|
||||||
|
|
||||||
# See (eq.1), the next line implements proj_grad = p proj_grad p^{-1}
|
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||||
proj2_grad = torch.matmul(p, X)
|
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
|
||||||
|
|
||||||
# for now just estimate it from proj2_grad, later we will try aggregating it
|
# for now just estimate proj2_grad_var from proj2_grad; later, we
|
||||||
# across iterations. This is like exp_avg_sq in Adam.
|
# may try aggregating it across iterations. This is like exp_avg_sq
|
||||||
|
# in Adam.
|
||||||
proj2_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
proj2_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
||||||
|
|
||||||
denom = proj2_grad_var.sqrt()
|
denom = proj2_grad_var.sqrt()
|
||||||
|
|
||||||
alpha = -meta_lr * (1-beta1)
|
# we'll take into account the factor of -meta_lr * (1-beta1) later on;
|
||||||
|
# actually proj2_delta is the negative of a parameter change.
|
||||||
# we'll take into account alpha later on.
|
|
||||||
proj2_delta = proj2_grad_var / denom
|
proj2_delta = proj2_grad_var / denom
|
||||||
|
|
||||||
p_delta = torch.matmul(proj2_delta.t(), p)
|
# See (eq.4), Q_delta = proj2_delta q
|
||||||
|
Q_delta = torch.matmul(proj2_delta, Q)
|
||||||
|
|
||||||
|
# See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta
|
||||||
|
proj_delta = torch.linalg.solve(Q, Q_delta)
|
||||||
|
|
||||||
|
M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta
|
||||||
|
|
||||||
|
alpha = -meta_lr * (1-beta1)
|
||||||
|
|
||||||
# To update the learning rate matrices, we are asking the question, "What if,
|
M_delta_full = M_delta.reshape(M_full.shape)
|
||||||
# on the previous step, we had used a slightly different learning-rate matrix?
|
|
||||||
# How would that have affected the loss function on the current step?"
|
|
||||||
moving_d = self._multiply_by_lr(moving_g)
|
|
||||||
|
|
||||||
# moving_d_rms is an rms value of moving_d computed via inner
|
this_delta = M_delta_full.transpose(dim, -1)
|
||||||
# products with the grad, which is a basis-independent way of
|
# we'l be applying momentum on delta, so multiply by (1-beta1) to
|
||||||
# computing the rms value. you have to view this as the rms times
|
# get the total magnitude of the change right. The 2 changes below
|
||||||
# an arbitrary multiplicative factor. in the "real" update we don't
|
# are the final changes, the only 2 we make in this loop that have
|
||||||
# actually do the length normalization like this because it's slow,
|
# side effects.
|
||||||
# we do it a different, simpler way; but here we do it this way
|
delta.add_(this_delta, alpha=-meta_lr*(1-beta1))
|
||||||
# because we are computing gradients and it's going to make it more
|
# there is no momentum on Q.
|
||||||
# sensitive to the norm used. If we used the regular norm on moving_d, it
|
Q.add_(Q_delta, alpha=-meta_lr)
|
||||||
# might concentrate all the movement in "don't-care" dimensions in
|
|
||||||
# parameter space, assuming such dimensions exist.
|
|
||||||
moving_d_rms = (self._multiply_by_grad_cov(moving_d) * moving_d).mean().sqrt()
|
|
||||||
|
|
||||||
# moving_d_norm is the parameter "delta" moving_d, length-normaized but still with
|
def _zero_exp_avg_sq(self,
|
||||||
# an arbitrary multiplicative constant. We don't care about this constant
|
state: dict) -> None:
|
||||||
# because we are going to normalize the grad length when we do the update;
|
"""
|
||||||
# but it's important to have normalized by dividing by moving_d_rms because it
|
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
|
||||||
# makes the gradient "scale-independent" in the sense that it is zero in
|
"""
|
||||||
# the direction of increasing or decreasing the parameter scale.
|
state["exp_avg_sq"].zero_()
|
||||||
moving_d_norm = moving_d / moving_d_rms
|
state["zero_step"] = state["step"]
|
||||||
|
|
||||||
# view the current parameter p as [something] - alpha *
|
def _diagonalize_lrs(self,
|
||||||
# moving_d_norm, where alpha contains the learning rate and other
|
group: dict,
|
||||||
# factors, but we don't care about the numerical value of alpha
|
p: Tensor,
|
||||||
# because we're going to divide by the gradient norm anyway. The
|
state: dict) -> None:
|
||||||
# (pseudo-) loss function would be (p * grad).sum(), which ignoring
|
"""
|
||||||
# [something] as it contributes no gradient, becomes -alpha *
|
Diagonalizes the learning-rate matrices state["Q_{dim}"], and also
|
||||||
# (moving_d_norm * grad).sum(), and ignoring the scale alpha, that
|
applies the min- and max- singular value constraints. This Q
|
||||||
# is -(moving_d_norm * grad).sum(), so we call this neg_loss as it's the
|
is applied to the parameter matrix M such that we interpret
|
||||||
# negative of the loss.
|
M == N Q.
|
||||||
neg_loss = (grad * moving_d_norm).sum()
|
By "diagonalize", we mean to left-miltiply Q by an orthogonal matrix,
|
||||||
|
Q := U Q,
|
||||||
|
so that the covariance of gradients w.r.t. N, measured
|
||||||
|
over the 2nd index of N while treating its 1st index as a batch index,
|
||||||
|
is diagonal.
|
||||||
|
|
||||||
# after the following, there will grads in state[f"lr_{dim}"]
|
Before doing this, also limits the singular values of Q so that the
|
||||||
# that are the negative of loss function grads, i.e. we want to go
|
max,min eig are lr_min_eig,lr_max_eig, and then renormalizing so the
|
||||||
# in the forward grad direction.
|
mean is 1.0 and limiting again.
|
||||||
neg_loss.backward()
|
|
||||||
|
Args:
|
||||||
|
group: dict to look up configuration values
|
||||||
|
p: parameter matrix that we are updating. The learning rate matrices
|
||||||
|
are actually factors of p, so p itself will change when we change
|
||||||
|
them.
|
||||||
|
state: state dict for the current parameter
|
||||||
|
"""
|
||||||
|
# meta_lr is the learning rate for the learning rate matrix. Including the factor
|
||||||
|
# (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant
|
||||||
|
# to the lr_est_period (close to convergence).
|
||||||
|
|
||||||
|
lr_mat_min = group["lr_mat_min"]
|
||||||
|
lr_mat_max = group["lr_mat_max"]
|
||||||
|
|
||||||
|
ndim = p.ndim
|
||||||
|
|
||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
if p.shape[dim] != 1:
|
size = p.shape[dim]
|
||||||
self._update_lr(state[f"lr_{dim}"], meta_lr)
|
if size == 1:
|
||||||
|
continue
|
||||||
|
Q = state[f"Q_{dim}"]
|
||||||
|
|
||||||
|
if True:
|
||||||
|
# This block limits the singular values of Q.
|
||||||
|
try:
|
||||||
|
U,S,V = Q.svd()
|
||||||
|
except:
|
||||||
|
# if SVD fails for some reason we can rotate Q by something arbitrary on the
|
||||||
|
# left, as we're going to rotate here anyway, and try again.
|
||||||
|
logging.warn("Error doing SVD on Q. Trying again after rotation.")
|
||||||
|
U,_,_ = torch.randn_like(Q).svd() # Create random orthogonal matrix.
|
||||||
|
Q[:] = torch.matmul(U, Q)
|
||||||
|
U,S,V = Q.svd()
|
||||||
|
S_new = S.clamp(lr_mat_min, lr_mat_max)
|
||||||
|
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
|
||||||
|
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
|
||||||
|
# Reconstruct Q with the modified S.
|
||||||
|
Q[:] = torch.matmul(U * S, V.t())
|
||||||
|
if random.random() < 0.1:
|
||||||
|
subsample = max(1, S.numel() // 20)
|
||||||
|
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
|
||||||
|
|
||||||
|
if True:
|
||||||
|
# This block does the actual diagonalization.
|
||||||
|
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
|
||||||
|
# the -1 represents all other tensor dims treated as a batch dimension.
|
||||||
|
# M_grad is the same shape as M. We could write a pseudo-loss as
|
||||||
|
# loss = tr(M_grad^T M)
|
||||||
|
# Because we can decompose M as M == N Q, we can write:
|
||||||
|
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
|
||||||
|
# so we can write this at tr(N_grad^T N),
|
||||||
|
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
|
||||||
|
# Now,
|
||||||
|
# grad_cov == M_grad^T M_grad,
|
||||||
|
# decaying-averaged over minibatches; this is of shape (size,size).
|
||||||
|
# Using N_grad = M_grad Q^T, we can write:
|
||||||
|
# N_grad_cov = Q M_grad^T M_grad Q^T
|
||||||
|
# = Q grad_cov Q^T
|
||||||
|
# (note: this makes sense because the 1st index of Q is the diagonalized
|
||||||
|
# index).
|
||||||
|
|
||||||
|
grad_cov = state[f"grad_cov_{dim}"]
|
||||||
|
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q))
|
||||||
|
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
|
||||||
|
U, S, V = N_grad_cov.svd()
|
||||||
|
|
||||||
|
# N_grad_cov is SPD, so
|
||||||
|
# N_grad_cov = U S U^T.
|
||||||
|
|
||||||
|
# Now, we can diagonalize N_grad_cov with:
|
||||||
|
# U^T N_grad_cov U == S.
|
||||||
|
# N_grad_cov is a sum of N_grad^T N_grad.
|
||||||
|
# So U^T N_grad^T N_grad U is diagonal.
|
||||||
|
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
|
||||||
|
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
|
||||||
|
#
|
||||||
|
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
|
||||||
|
# as tr(hat_N_grad hat_N), where:
|
||||||
|
# hat_N_grad = N_grad U
|
||||||
|
# hat_N = N U
|
||||||
|
# (hat_N means \hat{N}, or N with a hat on it).
|
||||||
|
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
|
||||||
|
# hat_N will be diagonalized. We also modify Q to hat_Q when
|
||||||
|
# we modify hat_N, to keep the product M = N Q = N U U^T Q = hat_N hat_Q
|
||||||
|
# This can be done by setting
|
||||||
|
# hat_Q = U^T Q (eq.10)
|
||||||
|
#
|
||||||
|
# This is the only thing we have to do, as N is implicit
|
||||||
|
# and not materialized at this point.
|
||||||
|
Q[:] = torch.matmul(U.t(), Q)
|
||||||
|
|
||||||
|
|
||||||
def _get_matrix_var(self,
|
def _get_matrix_var(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -581,14 +701,18 @@ class LearnedGradient(Optimizer):
|
|||||||
|
|
||||||
|
|
||||||
for dim in range(grad.ndim):
|
for dim in range(grad.ndim):
|
||||||
|
# This block accumulates the statistics proj_grad_{dim} and
|
||||||
|
# grad_cov_{dim}, which are for periodically updating the
|
||||||
|
# learning-rate matrices.
|
||||||
size = grad.shape[size]
|
size = grad.shape[size]
|
||||||
# accumulate some stats for learning the projections
|
# accumulate some stats for learning the projections
|
||||||
proj_grad = state[f"proj_grad_{dim}"]
|
proj_grad = state[f"proj_grad_{dim}"]
|
||||||
grad_cov = state[f"grad_cov_{dim}"]
|
grad_cov = state[f"grad_cov_{dim}"]
|
||||||
this_p = p.transpose(-1, dim).reshape(-1, size)
|
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||||
this_g = g.transpose(-1, dim).reshape(-1, size)
|
this_g = g.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||||
proj_grad.add_(torch.matmul(this_g.t(), this_p))
|
proj_grad.add_(torch.matmul(this_m.t(), this_g))
|
||||||
# could perhaps accumulate grad_cov less frequently.
|
# could perhaps accumulate grad_cov less frequently; it's only
|
||||||
|
# needed when we rediagonalize which is not that common.
|
||||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||||
|
|
||||||
grad = self._project(grad, forward=True, state)
|
grad = self._project(grad, forward=True, state)
|
||||||
@ -624,7 +748,7 @@ class LearnedGradient(Optimizer):
|
|||||||
alpha = state["param_rms"] * (1-beta1) * lr / denom
|
alpha = state["param_rms"] * (1-beta1) * lr / denom
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
delta.mul_(beta1).add_(grad, alpha=alpha)
|
delta.add_(grad, alpha=alpha)
|
||||||
|
|
||||||
param.add_(delta)
|
param.add_(delta)
|
||||||
|
|
||||||
@ -638,12 +762,11 @@ class LearnedGradient(Optimizer):
|
|||||||
grad: Tensor,
|
grad: Tensor,
|
||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
A form of the core update for tensors with a small number of elements, e.g. scalars. This is
|
A form of the core update for tensors with a small number of elements,
|
||||||
Adam where, if the numel() > 1, the learning rate is proportional to the parameter rms value.
|
e.g. scalars. This is Adam where, if the numel() > 1, the learning rate
|
||||||
|
is proportional to the parameter rms value.
|
||||||
"""
|
"""
|
||||||
exp_avg = state["exp_avg"]
|
|
||||||
exp_avg_sq = state["scalar_exp_avg_sq"]
|
exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
|
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
value=1-beta2)
|
value=1-beta2)
|
||||||
|
|
||||||
@ -651,11 +774,10 @@ class LearnedGradient(Optimizer):
|
|||||||
# slower update at the start will help stability anyway.
|
# slower update at the start will help stability anyway.
|
||||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||||
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
|
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
|
||||||
alpha = -lr / denom
|
alpha = -lr * (1-beta1) / denom
|
||||||
# for now not supporting p.numel() > 1.
|
|
||||||
#if p.numel() > 1:
|
delta = state["delta"]
|
||||||
# alpha = alpha * state["param_rms"]
|
delta.add_(grad / denom, alpha=alpha)
|
||||||
p.add_(exp_avg, alpha=alpha)
|
|
||||||
|
|
||||||
|
|
||||||
def _project(self,
|
def _project(self,
|
||||||
@ -670,19 +792,19 @@ class LearnedGradient(Optimizer):
|
|||||||
Args:
|
Args:
|
||||||
x: The tensor to project, e.g. a gradient or a parameter change.
|
x: The tensor to project, e.g. a gradient or a parameter change.
|
||||||
forward: if True, go in the forward directin (from canonical to diagnonalized
|
forward: if True, go in the forward directin (from canonical to diagnonalized
|
||||||
co-ordinates; if False, the reverse. They differ by a transpose,
|
co-ordinates); if False, the reverse. They differ by a transpose,
|
||||||
not an inverse, and do not make a round trip.
|
not an inverse, and do not make a round trip.
|
||||||
"""
|
"""
|
||||||
for dim in range(x.ndim):
|
for dim in range(x.ndim):
|
||||||
if x.shape[dim] == 1:
|
if x.shape[dim] == 1:
|
||||||
continue
|
continue
|
||||||
proj = state[f"proj_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
if forward:
|
if not forward:
|
||||||
# dimension 0 of `proj` is diagonalized co-ordinate.
|
# Q is indexed [canonical_index, diagonalized_index]
|
||||||
proj = proj.t()
|
Q = Q.t()
|
||||||
# TODO: could possibly somehow force the output format to be unchanged.
|
# TODO: could possibly somehow force the output format to be unchanged.
|
||||||
x = x.transpose(-1, dim)
|
x = x.transpose(-1, dim)
|
||||||
x = torch.matmul(x, proj)
|
x = torch.matmul(x, Q)
|
||||||
x = x.transpose(-1, dim)
|
x = x.transpose(-1, dim)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user