Reworked how inverse is done, fixed bug in _apply_min_max_with_metric, regarding how M is normalized.

This commit is contained in:
Daniel Povey 2022-08-04 09:46:14 +08:00
parent 766bf69a98
commit dc9133227f

View File

@ -124,12 +124,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
(4) is for smoothing the grad covariance used for (2)
cov_pow: This was mainly added for development and experimentation purposes;
it allows you to smooth the parameter covariance matrices at the
stages (1), (2), (3) of smoothing mentioned above, and also
the gradient covariance matrix used in stage (2) of smoothing. If the
1st 3 values are not 1.0 it will cause a measurable speed penalty because it
requires SVD. Recommend to leave all these at 1.0.
eps: A general-purpose epsilon to prevent division by zero
denom_rel_eps: An epsilon value used to keep the elements of denominator (exp_avg_grad_sqrt.sqrt())
not too small relative to the mean value; this will limit the speed of update
@ -163,9 +157,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr=3e-02,
betas=(0.9, 0.98),
size_lr_scale=0.1,
cov_min=(0.025, 0.0025, 0.1, 0.0001),
cov_min=(0.025, 0.05, 0.1, 0.0001),
cov_max=(5.0, 20.0, 3.5, 40.0),
cov_pow=(1.0, 1.0, 1.0, 1.0),
param_rms_smooth0=0.4,
param_rms_smooth1=0.2,
eps=1.0e-08,
@ -189,7 +182,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size_lr_scale=size_lr_scale,
cov_min=cov_min,
cov_max=cov_max,
cov_pow=cov_pow,
param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1,
betas=betas,
@ -763,8 +755,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
P = self._smooth_cov(P,
max(smooth, group["cov_min"][0]),
group["cov_max"][0],
group["cov_pow"][0])
group["cov_max"][0])
#P = P.clone()
#P_diag = _diag(P)
#P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
@ -775,11 +766,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
#G_diag *= 1.005 # ensure invertible.
G = self._smooth_cov(G,
group["cov_min"][3],
group["cov_max"][3],
group["cov_pow"][3])
group["cov_max"][3])
G_inv = self._safe_inverse(G, group["cov_min"][3])
P = self._apply_min_max_with_metric(P, G,
P = self._apply_min_max_with_metric(P, G, G_inv,
group["cov_min"][1],
group["cov_max"][1])
@ -787,36 +778,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# Apply a 3rd round of smoothing in the canonical basis.
P = self._smooth_cov(P,
group["cov_min"][2],
group["cov_max"][2],
group["cov_pow"][2])
group["cov_max"][2])
return P
def _smooth_cov(self,
X: Tensor,
min_eig: float,
max_eig: float,
power: float = 1.0) -> Tensor:
max_eig: float) -> Tensor:
"""
Returns a `smoothed` version of a symmetric positive definite covariance matrix
[with block-diagonal structure, in a batch]. This is done without SVD (which
can be very slow).
The eigenvalues L will be transformed as:
Returns a `smoothed` version of a nonzero symmetric positive semidefinite covariance matrix
(actually a batch of such covariance matrices, each one with block-diagonal structure).
This is done without SVD or eigenvalue decomposition or inversion, since those things
can be very slow).
L = L + min_eig * L.mean() + eps
L /= L.mean()
L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
L = L ** power # need SVD for this, will get rid of the requirement later.
L /= L.mean()
The eigenvalues L will be transformed in a way roughly approximating the following:
# Note on approximation functions like x^0.75 for smallish x: on wolframalpha, type:
# plot x^0.75 and 0.05 + (1.1x - 0.18 x^2 + 0.02 x^3) for x from 0 to 10
# [this starts to diverge after 5 or so]
L /= (L**2).mean().sqrt() # normalize so RMS eigenvalue is 1.0
L = 1 / (1/L + 1/max_eig) # apply soft-min with max_eig as the maximum
L /= (1.0 - min_eig) * L.mean()
L += min_eig
Args:
min_eig: minimum allowed eigenvalue of returned X
max_eig: maximum allowed eigenvalue of returned X
power: power to take eigenvalues to
X: the batch of symmetric positive definite tensors we are smoothing;
X: the batch of symmetric positive semidefinite tensors we are smoothing;
of shape (batch_size, num_blocks, block_size, block_size)
"""
eps = 1.0e-20
@ -824,7 +809,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size = X.shape[1] * X.shape[3]
def rms_eig(Y):
# rms of eigenvalues, or spectral 2-norm.
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt()
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size + eps).sqrt()
X /= rms_eig(X)
# eig_ceil is the maximum possible eigenvalue that X could possibly
@ -860,16 +845,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# normalize to have rms eig == 1.
X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric.
diag = _diag(X).unsqueeze(-1) # Aliased with X
diag += (_mean(diag, exclude_dims=[0], keepdim=True) * min_eig + eps)
X /= rms_eig(X)
def mean_eig(Y):
return _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
X *= ((1.0 - min_eig) / (mean_eig(X) + eps))
_diag(X).add_(min_eig)
return X
def _apply_min_max_with_metric(self,
X: Tensor,
M: Tensor,
M_inverse: Tensor,
min_eig: float,
max_eig: float) -> Tensor:
"""
@ -882,6 +869,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
Args:
X: Batch of block-diagonal nonzero positive definite matrices to have max eigenvalue applied
M: Batch of positive definite block-diagonal matrices to use as metrics w.r.t. X.
M_inverse: If min_eig != 0.0, the inverse of M should be supplied by the user;
this is expected to have about the same value that M.inverse() would have,
but is supplied separately as the the calling code inverts in a safer-than-normal
way based on knowledge of the minimum eigenvalue of M.
min_eig: The minimum eigenvalue of the returned Tensor, which will have average
eigenvalue equal to 1.0 for each block-diagonal matrix.
max_eig: The maximum allowed eigenvalue of the returned Tensor; this is expressed
as a multiple of the root-mean-square (rms) eigenvalue of X, and we
later normalize and apply min_eig, so the returned tensor may have
eigenvalues larger than this.
"""
X = X.clone()
# size of the block-diagonal matrix..
@ -894,7 +891,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def rms_eig(Y):
# rms of eigenvalues, or spectral 2-norm.
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt()
return (_sum(Y*Y.transpose(2,3), exclude_dims=[0], keepdim=True) / size).sqrt()
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max.
X /= rms_eig(torch.matmul(X, M))
@ -931,12 +928,79 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
if min_eig != 0.0:
X /= mean_eig(X)
X = X * (1.0-min_eig) + min_eig * M.inverse()
mean_XM_eig = _sum(X * M.transpose(2, 3), exclude_dims=[0], keepdim=True) / size
X /= mean_XM_eig
X = X * (1.0-min_eig) + min_eig * M_inverse
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
return X
def _safe_inverse(self,
X: Tensor,
min_eig: float,
tolerance: float = 1.2):
"""
Returns X.inverse(), but in a safe way that limits the possible amount that numerical
roundoff could make the result inaccurate.
Args:
X: batch of positive definite tensors to be inverted, of shape (*, size, size).
Mathematically, X must be symmetric and positive definite (SPD), but this
routine allows it to have negative eigenvalues and just assumes that
any eigenvalues <min_eig should actually equal min_eig.
min_eig: the minimum possible eigenvalue that X is known mathematically to have
(e.g. 0.1).
tolerance: A tolerance that should be >1.0; it controls the tradeoff betweeen
redundant computation and accuracy.
"""
try:
size = X.shape[-1]
I = torch.eye(size, dtype=X.dtype, device=X.device)
if True:
X_inv = X.inverse()
else:
C = X.cholesky() # X = C C^T
I = I.reshape(*([1]*(C.ndim-2)), size, size).expand(*X.shape)
# note, X^{-1} == C^{-T} C^{-1}
# next line does: X_inv = C^{-1} X_inv (== C^{-1})
X_inv = torch.triangular_solve(I, C, upper=False, transpose=False)[0]
# next line does: X_inv = C^{-T} X_inv (== C^{-T} C^{-1})
X_inv = torch.triangular_solve(X_inv, C, upper=False, transpose=True)[0]
# Make sure X_inv is exactly symmetric (the input is assumed symmetric)
X_inv = 0.5 * (X_inv + X_inv.transpose(-2, -1))
# OK, X_inv is positive definite and we require that its largest eigenvalue
# should be < (1/min_eig) * tolerance. That requires that
# I * tolerance - X_inv should be positive definite.
X_inv_check = ((tolerance / min_eig) * I - X_inv)
# now X_inv_check == I * tolerance - X_inv.
# We don't need the result of X_inv_check.cholesky() below, we just execute the
# statement as a way of checking that X_inv_check is positive definite.
# we also check that X_inv has no negative eigenvalues, because if any of the
# original X's eigenvalues had been negative, they would show up as negative
# eigenvalues in the answer.
torch.stack([X_inv, X_inv_check]).cholesky()
return X_inv
except RuntimeError as e:
# _svd(X) is an alternative to X.svd() that retries in response to
# failures caused by SVD implementation bugs.
U,S,V = _svd(X)
S_sign = (2*((U * V).sum(dim=-2) > 0) - 1)
# the method with S_with_sign below is a quick but not 100%
# guaranteed to find the eigenvalues of a SPD matrix (might not work
# if there are tied eigenvalues); this does not matter as we are only
# using it to print diagnostics.
S_with_sign = S * S_sign
logging.warning(f"Caught error in _safe_inverse() (this is normal). Shape={X.shape}, minimum eigenvalue "
f"is {S_with_sign.min().item()} vs. {min_eig} with tolerance {tolerance}; exception was: {e}")
S_inv = 1.0 / S.clamp(min=min_eig)
X_inv = torch.matmul(U * S_inv.unsqueeze(-2), U.transpose(-2, -1))
# Make sure X_inv is exactly symmetric (the input is assumed symmetric)
X_inv = 0.5 * (X_inv + X_inv.transpose(-2, -1))
return X_inv
def _update_grad_cov(self,
group: dict,
@ -1949,14 +2013,38 @@ def _test_svd():
X2 = torch.matmul(U*S, V.t())
assert torch.allclose(X2, X, atol=0.001)
def _test_safe_inverse():
for dim in [10, 20, 100]:
a = torch.randn(2, 3, dim, dim//2)
a = torch.matmul(a, a.transpose(-2, -1))
min_eig = 0.001
_diag(a).add_(min_eig)
for tolerance in [0.9, 0.99, 1.0, 1.01, 1.1]:
self = None
inv = PrAdam._safe_inverse(self, a, min_eig, tolerance)
prod = (torch.matmul(inv, a))
err = prod - torch.eye(dim)
err = torch.sum((err**2), dim=(2,3)).sqrt()
logging.info(f"_test_safe_inverse(): dim={dim}, tolerance={tolerance}, err={err}")
def _test_smooth_cov():
b = (torch.arange(10)*0.5).exp().diag()
b = b.unsqueeze(0).unsqueeze(0)
c = PrAdam._smooth_cov(None, b, 0.0, 10.0)
logging.info(f"c[noinv] = {_diag(c)}")
c = PrAdam._smooth_cov(None, b, 0.0, 10.0, 1.00001)
logging.info(f"c[svd] = {_diag(c)}")
def _test_apply_min_max():
dim = 50
X = torch.randn(1, 1, dim, dim*2)
X = torch.matmul(X, X.transpose(2, 3))
M = torch.randn(1, 1, dim, dim*2)
M = torch.matmul(M, M.transpose(2, 3))
for min_eig in [0.0, 0.1]:
for max_eig in [5.0, 10.0]:
for Xscale in [1.0, 2.0]:
for Mscale in [1.0, 2.0]:
self = None
thisM = M * Mscale
thisMinv = thisM.inverse()
Y = PrAdam._apply_min_max_with_metric(self, X*Xscale, thisM, thisMinv,
min_eig, max_eig)
Y /= _mean(_diag(Y), exclude_dims=[0], keepdim=True)
logging.info(f"min_eig={min_eig},max_eig={max_eig},Xscale={Xscale},Mscale={Mscale}, rms of Y = {(Y**2).mean().item()}")
if __name__ == "__main__":
torch.set_num_threads(1)
@ -1964,8 +2052,9 @@ if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
import subprocess
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
_test_smooth_cov()
logging.info(s)
_test_safe_inverse()
_test_apply_min_max()
#_test_svd()
import sys
if len(sys.argv) > 1: