Remove rank-1 dims, meaning where size==numel(), from processing.

This commit is contained in:
Daniel Povey 2022-07-09 13:36:48 +08:00
parent 2fc9eb9789
commit ffeef4ede4

View File

@ -178,8 +178,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1:
# if size == p.numel(), is_one_axis == True above
if size == 1 or size == p.numel():
continue
@ -320,7 +319,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
(step + lr_update_period))
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1:
if size == 1 or size == p.numel():
continue
# M is the parameter matrix, of shape (-1, size) where the -1 covers
@ -363,7 +362,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
for dim in range(ndim):
size = p.shape[dim]
if size == 1:
if size == 1 or size == p.numel():
continue
param_cov = state[f"param_cov_{dim}"]
U, S, V = _svd(param_cov)
@ -472,7 +471,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# grad_cov_{dim}, which are for periodically updating the
# learning-rate matrices.
size = grad.shape[dim]
if size == 1:
if size == 1 or size == p.numel():
continue
if step % grad_cov_period == 0 or step < grad_cov_period:
grad_cov = state[f"grad_cov_{dim}"]
@ -559,8 +558,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
co-ordinates); if False, the reverse. They differ by a transpose,
not an inverse, and do not make a round trip.
"""
numel = x.numel()
for dim in range(x.ndim):
if x.shape[dim] == 1:
size = x.shape[dim]
if size == 1 or size == numel:
continue
Q = state[f"Q_{dim}"]
if forward:
@ -575,236 +576,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
x = x.transpose(-1, dim)
return x
def _store_grad_stats(self,
grad: Tensor,
state: dict,
beta2: float) -> None:
"""
Accumulate some stats for each dimension of the tensor, about the covariance of gradients.
The stats are just summed, not decayed, as the scalar factor will be normalized out later.
"""
ndim = grad.ndim
kwargs = {'device': grad.device, 'dtype': grad.dtype}
for dim in range(ndim):
size = grad.shape[dim]
if size == 1:
continue
grad_cov = state[f"grad_cov_{dim}"]
g = grad.transpose(dim, -1)
g = g.reshape(-1, g.shape[-1])
# We don't care about the scalar factor, we will normalize when we
# use it.
grad_cov.mul_(beta2).add_(torch.matmul(g.t(), g))
def _estimate_projections(self,
p: Tensor,
state: dict,
param_eps: float,
param_rel_eps: float,
param_rel_max: float,
param_reverse_cutoff: float,
beta3: float,
param_pow: float,
grad_pow: float,
grad_min_rand: float,
) -> None:
"""
Estimate the per-dimension projections proj_0, proj_1 and so on.
These projections, when applied with forward==True by _change_coordinates(),
first rotate into a space
where the optimal learning-rate-scale is diagonalized (see self._estimate_proj);
then scale by the parameter scale. When applied with forward == False
they first scale by the parameter scale and then rotate back into canonical
co-ordinates. (Thus, applying with forward==True and then forward==False
is not a round trip).
Args:
p: The parameter for which we are etimating the projections. Supplied
because we need this to estimate parameter covariance matrices in
each of its tensor directions.
state: The state dict for this parameter tensor
param_eps: Small epsilon representing a minimum parameter magnitude;
once the estimated parameter rms for some element of p gets below
this value, we'll treat it as having this value.
param_rel_eps: Another constraint on minimum parameter rms, relative to
sqrt(overall_param_rms), where overall_param_rms is the sqrt of the
parameter variance averaged over the entire tensor.
beta3: discounting parameter for param_cov, applied once every estimate_period.
param_pow: param_pow==1.0 means fully normalizing for parameter scale;
setting to a smaller value closer to 0.0 means partial normalization.
grad_min_rand: minimum proportion of random tensor to mix with the
gradient covariance matrix.
"""
kwargs = {'device':p.device, 'dtype':p.dtype}
ndim = p.ndim
for dim in range(ndim):
size = p.shape[dim]
if size == 1:
continue
proj = state[f"proj_{dim}"]
if proj.ndim == 2:
# This dimension gets the full-covariance treatment.
count = state[f"grad_cov_count_{dim}"]
assert count != 0
grad_cov = state[f"grad_cov_{dim}"] / count
del state[f"grad_cov_{dim}"] # save memory
#self._randomize_lowrank_cov(grad_cov, count, p.shape,
# grad_min_rand)
cur_param_cov = self._get_param_cov(p, dim)
cur_param_cov = cur_param_cov / (cur_param_cov.diag().mean() + 1.0e-20) # normalize size..
param_cov = state[f"param_cov_{dim}"]
param_cov.mul_(beta3).add_(cur_param_cov, alpha=(1-beta3))
# We want the orthogonal matrix U that diagonalizes P, where
# P is the SPD matrix such that P G^{param_pow} P^T == C^{param_pow},
# where G == grad_cov and C == param_cov.
# .. this is the same as the U that diagonalizes a different P
# such that P G P^T == C^{param_pow/(2-param_pow)},
# since the P's are related by taking-to-a-power.
num_samples = p.numel() // size
reverse_cutoff = (param_reverse_cutoff if num_samples > size//4 else 1.0e+10)
P = self._estimate_proj(grad_cov,
param_cov,
param_pow / grad_pow,
param_rel_eps,
param_rel_max,
reverse_cutoff)
# The only thing we want from P is the basis that diagonalizes
# it, i.e. if we do the symmetric svd P = U S U^T, we can
# interpret the shape of U as (canonical_coordinate,
# diagonalized_coordinate)
U, S, _ = P.svd()
# `proj` will be of shape (diagonalized_coordinate, canonical_coordinate)
# Later we will modify `proj` to incorporate the parameter scale.
proj[:] = U.t()
else:
# This dimension will be treated diagonally. For now just set `proj` to
# all ones, later we will modify it in _estimate_params_scale()
proj.fill_(1.0)
self._estimate_param_scale(p, state, param_eps,
param_pow, param_rel_eps, param_rel_max,
param_reverse_cutoff)
def _multiply_by_lr(self,
grad: Tensor,
state: dict,
forward: bool) -> Tensor:
"""
Multiply the grad by a positive semi-definite matrix for each dimension,
interpreted as a non-scalar learning-rate factor.
"""
cur_grad = grad
for dim in range(grad.ndim):
size = grad.shape[dim]
if size == 1:
continue
M = state[f"lr_{dim}"]
if M.ndim == 1:
# We are treating this dimension diagonally because it is too big.
M = M.reshape(size, *([1] * (ndim-1-dim)))
cur_grad = cur_grad * M
else:
cur_grad = self._multiply_on_dim(cur_grad, M, dim)
return cur_grad
def _multiply_by_grad_cov(self,
grad: Tensor,
state: dict,
forward: bool) -> Tensor:
"""
Multiply the grad by a positive semi-definite matrix for each dimension,
interpreted as a non-scalar learning-rate factor.
"""
cur_grad = grad
for dim in range(grad.ndim):
size = grad.shape[dim]
if size == 1:
continue
M = state[f"grad_cov_{dim}"]
if M.ndim == 1:
# We are treating this dimension diagonally because it is too big.
M = M.reshape(size, *([1] * (ndim-1-dim)))
# Dividing by (M.mean() + 1.0e-20) is just arbitrary normalization,
# to avoid getting very large or very small outputs.
cur_grad = cur_grad * (M / (M.mean() + 1.0e-20))
else:
# Dividing by (M.diag().mean() + 1.0e-20) is just arbitrary normalization,
# to avoid getting very large or very small outputs. We only need the result
# up to an arbitrary scalar factor.
cur_grad = self._multiply_on_dim(cur_grad, M, dim) / (M.diag().mean() + 1.0e-20)
return cur_grad
def _change_coordinates(self,
grad: Tensor,
state: dict,
forward: bool) -> Tensor:
"""
Multiply the grad by a matrix for each dimension, interpreted as a non-orthogonal
basis change.
If forward == True, `grad` is interpreted as: gradient -> gradient in new basis;
if forward == False, it is interpreted as:
parameter-change in new basis -> parameter-change in canonical basis.
These differ by a transpose (this is not the same as inversion as the matrix is not
orthogonal). Therefore the co-ordinate change applied both forward and backward
does not represent a "round trip".
We know at this point that `grad` has more than one non-trivial dimension/axis
(i.e. >1 dimension with size != 1).
Args:
p: the Parameter that the grad is for; may be needed if we re-estimating
the learning-rate matrix
grad: the gradient to precondition (will not be modified by this function)
state: the state-dict where will look up the projections.
Returns:
The co-ordinate-changed grad or parameter change
"""
cur_grad = grad
ndim = grad.ndim
step = state["step"]
for dim in range(ndim):
size = grad.shape[dim]
if size == 1:
continue
proj = state[f"proj_{dim}"]
if proj.ndim == 1:
# We are treating this dimension diagonally because it is too big.
proj = proj.reshape(size, *([1] * (ndim-1-dim)))
cur_grad = cur_grad * proj
else:
cur_grad = self._multiply_on_dim(cur_grad, proj, dim, forward)
return cur_grad
def _multiply_on_dim(self, x: Tensor, M: Tensor, dim: int) -> Tensor:
"""
Matrix-multiplies x by symmetric matrix `M` on x's dimension numbered `dim`
Args:
x: Tensor to multiply, of arbitrary shape
M: symmetric matrix to multiply x by, of shape
(x.shape[dim], x.shape[dim]).
"""
# Note: can use either M or M.t() here because M is symmetric
return torch.matmul(x.transpose(-1, dim), M.t())
def _smooth_param_rms(self,
group: dict,
rms: Tensor,
@ -836,176 +607,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
return rms / new_mean
def _estimate_proj(self,
grad_cov: Tensor,
param_cov: Tensor,
param_pow: float,
param_rel_eps: float,
param_rel_max: float,
param_reverse_cutoff: float) -> Tensor:
"""
Return a symmetric positive definite matrix P such that
(P grad_cov P^T == param_cov^{param_pow}),
i.e. a descent direction that makes the covariance of steps look
like the covariance of parameter. This will be normalizes so
that the mean of its diagonal is 1.0 (since we'll have a separate
such projection per dim of the tensor, and we don't want to count
the scalar factor many times).
Args:
grad_cov: Covariance of gradients, a symmetric-positive-definite
matrix [except for roundoff!] of shape (size, size)
param_cov: Covariance of parameters, a symmetric-positive-definite
matrix [except for roundoff!] of shape (size, size)
param_pow: Power that we take "param_cov" to, between 0 and 1.
Normally it will be 1, and not all the comments mention it.
Returns:
A matrix P such that (alpha P grad_cov P^T == param_cov^param_pow)
"""
# symmetrize grad_cov and param_cov. The projection P only cares about the
# relative sizes, so doubling both of them does not matter.
G = grad_cov + grad_cov.t()
C = param_cov + param_cov.t()
U, S, V = C.svd()
# Using G == grad_cov and C == param_cov, and not writing transpose because
# all these quantities are symmetric, we can use:
#
# P = C^{0.5} (C^{0.5} G C^{0.5})^{-0.5} C^{0.5}
#
# You can verify fairly easily that P G P == C: multiply on left and right
# by C^{-0.5} on each side and you can see that both sides equal I.
#
# We can reduce the amount of roundoff by, after decomposing C
# with SVD as (U S U^T), writing C^{0.5} = U Q = Q^T U^T, with Q = S^0.5 U^T.
#
# Then we can write P as follows:
# P = Q^T U^T (U Q G Q^T U^T)^{-0.5} U Q,
# and we easily show that for orthogonal U, (U X U^T)^{-0.5} = U X^{-0.5} U^T, so we
# can do this and cancel U^T U = U U^T = I, giving:
# P = Q^T (Q G Q^T)^{-0.5} Q.
# because C is symmetric, C == U S U^T, we can ignore V.
# S_sqrt is S.sqrt() in the limit where param_pow == 1.0,
# param_rel_eps=0, param_rel_max=inf
S_smoothed = self._smooth_param_diag_var(S, param_pow,
param_rel_eps, param_rel_max,
param_reverse_cutoff)
S_sqrt = S_smoothed ** 0.5
Q = (U * S_sqrt).t()
X = torch.matmul(Q, torch.matmul(G, Q.t()))
U2, S2, _ = X.svd()
# Because X^{-0.5} = U2 S2^{-0.5} U2^T, we have
# P = Q^T U2 S2^{-0.5} U2^T Q
# = Y Y^T, where
# Y = Q^T U2 S2^{-0.25}
S2_pow = (S2 + 1.0e-35) ** -0.25
Y = torch.matmul(Q.t(), U2) * S2_pow
P = torch.matmul(Y, Y.t())
if random.random() < 1.0: #0.025:
# TODO: remove this testing code.
assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric.
try:
P = 0.5 * (P + P.t())
_,s,_ = P.svd()
logging.info(f"Min,max eig of P: {s.min()},{s.max()}")
except:
pass
# testing... note, this is only true modulo "eps"
C_check = torch.matmul(torch.matmul(P, G), P)
C_smoothed = torch.matmul(Q.t(), Q)
# C_check should equal C_smoothed
C_diff = C_check - C_smoothed
# Roundoff can cause significant differences, so use a fairly large
# threshold of 0.001. We may increase this later or even remove the check.
if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean():
logging.info(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
return P
def reset_speedup(self):
"""
Reset the step_period so that we'll do a step each time, until the next
time
"""
for group in self.param_groups:
group["speedup_step"] = -group["speedup_first_step"]
for p in group["params"]:
state = self.state[p]
state["step_period"] = 1
def _recalibrate_speedup(self, group: dict):
param_prods = []
speedup = group["lr_for_speedup"] / group["lr"]
if speedup > 10.0:
speedup = 10.0
if speedup <= 1.0:
for p in group["params"]:
state = self.state[p]
state["step_period"] = 1
return
_, beta2, _ = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
# We ignore the part of the parameter change that comes from the scale
# (see 'scale_deriv' in the code)
state = self.state[p]
step = state["step"]
# the sqrt() is just a heuristic, it seems to give periods that work better.
param_prods.append(state["grad_times_step"].sqrt())
param_prods = torch.stack(param_prods).to('cpu')
# TEMP
param_prods = param_prods.sqrt()
avg_update_freq = 1.0 / speedup
param_freqs = param_prods * (avg_update_freq / param_prods.mean())
# Don't delay more than 1 / (3*speedup) steps for any parameter, to
# ensure that no parameter waits super-long for an update. Note: with
# beta=0.1, this would mean an average delay of about 30*speedup steps
# before the gradient makes it to the parameter.
min_freq = 1.0 / (3 * speedup)
# get the mean after applying limits of [min_freq..1] for the frequencies of
# update
param_freqs_mean = param_freqs.clamp(min=min_freq, max=1.0).mean()
# renormalize after applying min and max limits
param_freqs = param_freqs * (avg_update_freq / param_freqs_mean)
# ...and apply the limits again.
param_freqs = param_freqs.clamp(min=min_freq, max=1.0)
param_periods = (1.0 / param_freqs).to(torch.int32) # .to(torch.int32) rounds down
actual_speedup = 1.0 / (1.0 / param_periods).mean()
param_prods = param_prods.tolist()
param_freqs = param_freqs.tolist()
param_periods = param_periods.tolist()
logging.info(f"PrAdam._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
print_info = random.random() < 0.01
i = 0
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]
state["step_period"] = param_periods[i]
if print_info:
logging.info(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.")
i += 1
# Do nothing more, as yet.
def _svd(x: Tensor):