Merge branch 'pradam_exp1m4_nophase1_noinv' into pradam_exp1m4_nophase1_rework_noinv

This commit is contained in:
Daniel Povey 2022-07-31 01:32:36 -07:00
commit 2042c9862c

View File

@ -856,13 +856,34 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
diag += (mean_eig * min_eig + eps)
cur_diag_mean = mean_eig * (1 + min_eig) + eps
# The following 2 statements will be equivalent to:
# L /= L.mean()
# L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
# if L is the eigenvalues
X_inv = X.inverse()
_diag(X_inv).add_(1. / (max_eig * cur_diag_mean))
X = X_inv.inverse()
X /= cur_diag_mean.unsqueeze(-1)
# OK, now the mean of the diagonal of X is 1 (or less than 1, in
# which case X is extremely tiny).
# eig_ceil is the maximum possible eigenvalue that X could possibly
# have at this time.
eig_ceil = X.shape[-1]
while max_eig <= eig_ceil * 0.5:
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
# increasing from 0..eig_ceil.
logging.info(f"X={X}, eig_ceil={eig_ceil}")
X = X - 0.5/eig_ceil * torch.matmul(X, X)
eig_ceil = 0.5 * eig_ceil
# max_eig > eig_ceil * 0.5
if max_eig < eig_ceil:
# map l -> l - coeff*l*l, if l==eig_ceil, this
# takes us to:
# eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil
# == max_eig
# .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5]
# means that the function is monotonic on inputs from 0 to eig_ceil.
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
X = X - coeff * torch.matmul(X, X)
# Normalize again.
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
return X
@ -1120,7 +1141,7 @@ def _mean(x: Tensor,
# the interface changes in future.
return x.mean(dim=tuple(range(x.ndim)), keepdim=keepdim)
elif x.ndim == 1:
assert exclude_dim == [0] or exclude_dim == [-1]
assert exclude_dims == [0] or exclude_dims == [-1]
return x # if one dim is excluded, there are no dims to mean, and
# x.mean(dim=[]) means all dims so we should not call mean().
exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims]
@ -1806,7 +1827,7 @@ def _test_eve_cain():
fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
hidden_dim = 200
hidden_dim = 300
m = torch.nn.Sequential(Linear(E, hidden_dim),
torch.nn.PReLU(),
Linear(hidden_dim, hidden_dim),
@ -1823,16 +1844,6 @@ def _test_eve_cain():
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=256)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
#TEMP
if iter == 3:
a = torch.randn(5, 10, 5, 5)
a = torch.matmul(a, a.transpose(2, 3))
b1 = optim._smooth_cov(a, 0.1, 4.0, 0.999999999)
b2 = optim._smooth_cov(a, 0.1, 4.0, 1.0)
diff = (b1 - b2)
ratio = (diff**2).sqrt().mean() / (b1**2).sqrt().mean()
logging.info(f"ratio = {ratio}")
assert ratio < 0.01
start = timeit.default_timer()
avg_loss = 0.0
@ -1891,12 +1902,22 @@ def _test_svd():
X2 = torch.matmul(U*S, V.t())
assert torch.allclose(X2, X, atol=0.001)
def _test_smooth_cov():
b = (torch.arange(10)*0.5).exp().diag()
b = b.unsqueeze(0).unsqueeze(0)
c = PrAdam._smooth_cov(None, b, 0.0, 10.0)
logging.info(f"c[noinv] = {_diag(c)}")
c = PrAdam._smooth_cov(None, b, 0.0, 10.0, 1.00001)
logging.info(f"c[svd] = {_diag(c)}")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
logging.getLogger().setLevel(logging.INFO)
import subprocess
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
_test_smooth_cov()
logging.info(s)
#_test_svd()
_test_eve_cain()