Another bug fix, regarding Q being transposed.
This commit is contained in:
parent
ad2e698fc3
commit
fb36712e6b
@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
|
|||||||
params,
|
params,
|
||||||
lr=3e-02,
|
lr=3e-02,
|
||||||
size_lr_scale=0.1,
|
size_lr_scale=0.1,
|
||||||
meta_lr_scale=0.1,
|
meta_lr_scale=0.2,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
size_update_period=1,
|
size_update_period=1,
|
||||||
@ -183,7 +183,7 @@ class LearnedGradient(Optimizer):
|
|||||||
# re-estimate when we "rediagonalize" (this diagonalizes
|
# re-estimate when we "rediagonalize" (this diagonalizes
|
||||||
# the gradient covariance).
|
# the gradient covariance).
|
||||||
# If our parameter matrix M is of shape (..., size),
|
# If our parameter matrix M is of shape (..., size),
|
||||||
# we'll view M as being M == torch.matmul(N, q)
|
# we'll view M as being M == torch.matmul(N, Q)
|
||||||
# so p can be interpreted as having shape
|
# so p can be interpreted as having shape
|
||||||
# (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate]
|
# (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate]
|
||||||
# by "diagonalize" we mean that we diagonalize the gradient covariance.
|
# by "diagonalize" we mean that we diagonalize the gradient covariance.
|
||||||
@ -255,6 +255,7 @@ class LearnedGradient(Optimizer):
|
|||||||
if step % lr_est_period == 0:
|
if step % lr_est_period == 0:
|
||||||
self._update_lrs(group, p, state)
|
self._update_lrs(group, p, state)
|
||||||
if step % (lr_est_period * diagonalize_period) == 0:
|
if step % (lr_est_period * diagonalize_period) == 0:
|
||||||
|
logging.info("Diagonalizing")
|
||||||
self._diagonalize_lrs(group, p, state)
|
self._diagonalize_lrs(group, p, state)
|
||||||
self._zero_exp_avg_sq(state)
|
self._zero_exp_avg_sq(state)
|
||||||
if True:
|
if True:
|
||||||
@ -555,8 +556,8 @@ class LearnedGradient(Optimizer):
|
|||||||
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
|
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
|
||||||
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
|
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
|
||||||
# Reconstruct Q with the modified S.
|
# Reconstruct Q with the modified S.
|
||||||
Q[:] = torch.matmul(U * S, V.t())
|
Q[:] = torch.matmul(U * S_new, V.t())
|
||||||
if random.random() < 0.1:
|
if random.random() < 0.03:
|
||||||
subsample = max(1, S.numel() // 20)
|
subsample = max(1, S.numel() // 20)
|
||||||
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
|
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
|
||||||
|
|
||||||
@ -580,7 +581,7 @@ class LearnedGradient(Optimizer):
|
|||||||
# index).
|
# index).
|
||||||
|
|
||||||
grad_cov = state[f"grad_cov_{dim}"]
|
grad_cov = state[f"grad_cov_{dim}"]
|
||||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q))
|
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
|
||||||
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
|
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
|
||||||
U, S, V = N_grad_cov.svd()
|
U, S, V = N_grad_cov.svd()
|
||||||
|
|
||||||
@ -590,7 +591,7 @@ class LearnedGradient(Optimizer):
|
|||||||
# Now, we can diagonalize N_grad_cov with:
|
# Now, we can diagonalize N_grad_cov with:
|
||||||
# U^T N_grad_cov U == S.
|
# U^T N_grad_cov U == S.
|
||||||
# N_grad_cov is a sum of N_grad^T N_grad.
|
# N_grad_cov is a sum of N_grad^T N_grad.
|
||||||
# So U^T N_grad^T N_grad U is diagonal.
|
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
|
||||||
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
|
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
|
||||||
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
|
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
|
||||||
#
|
#
|
||||||
@ -601,7 +602,8 @@ class LearnedGradient(Optimizer):
|
|||||||
# (hat_N means \hat{N}, or N with a hat on it).
|
# (hat_N means \hat{N}, or N with a hat on it).
|
||||||
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
|
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
|
||||||
# hat_N will be diagonalized. We also modify Q to hat_Q when
|
# hat_N will be diagonalized. We also modify Q to hat_Q when
|
||||||
# we modify hat_N, to keep the product M = N Q = N U U^T Q = hat_N hat_Q
|
# we modify hat_N, to keep the product M unchanged:
|
||||||
|
# M = N Q = N U U^T Q = hat_N hat_Q
|
||||||
# This can be done by setting
|
# This can be done by setting
|
||||||
# hat_Q = U^T Q (eq.10)
|
# hat_Q = U^T Q (eq.10)
|
||||||
#
|
#
|
||||||
@ -814,8 +816,10 @@ class LearnedGradient(Optimizer):
|
|||||||
if x.shape[dim] == 1:
|
if x.shape[dim] == 1:
|
||||||
continue
|
continue
|
||||||
Q = state[f"Q_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
if not forward:
|
if forward:
|
||||||
# Q is indexed [canonical_index, diagonalized_index]
|
# Q is indexed [diagonalized_index, canonical_index]; in the forward
|
||||||
|
# direction we want to change canonical to diagonalized index, so have
|
||||||
|
# to transpose.
|
||||||
Q = Q.t()
|
Q = Q.t()
|
||||||
# TODO: could possibly somehow force the output format to be unchanged.
|
# TODO: could possibly somehow force the output format to be unchanged.
|
||||||
x = x.transpose(-1, dim)
|
x = x.transpose(-1, dim)
|
||||||
@ -1156,7 +1160,7 @@ class LearnedGradient(Optimizer):
|
|||||||
try:
|
try:
|
||||||
P = 0.5 * (P + P.t())
|
P = 0.5 * (P + P.t())
|
||||||
_,s,_ = P.svd()
|
_,s,_ = P.svd()
|
||||||
print(f"Min,max eig of P: {s.min()},{s.max()}")
|
logging.info(f"Min,max eig of P: {s.min()},{s.max()}")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
# testing... note, this is only true modulo "eps"
|
# testing... note, this is only true modulo "eps"
|
||||||
@ -1167,7 +1171,7 @@ class LearnedGradient(Optimizer):
|
|||||||
# Roundoff can cause significant differences, so use a fairly large
|
# Roundoff can cause significant differences, so use a fairly large
|
||||||
# threshold of 0.001. We may increase this later or even remove the check.
|
# threshold of 0.001. We may increase this later or even remove the check.
|
||||||
if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean():
|
if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean():
|
||||||
print(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
|
logging.info(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
|
||||||
|
|
||||||
return P
|
return P
|
||||||
|
|
||||||
@ -1545,7 +1549,7 @@ class Cain(Optimizer):
|
|||||||
var_factor = var_factor.mean(dim=dims, keepdim=True)
|
var_factor = var_factor.mean(dim=dims, keepdim=True)
|
||||||
|
|
||||||
#if random.random() < 0.01:
|
#if random.random() < 0.01:
|
||||||
# print(f"shape={p.shape}, var_factor={var_factor}")
|
# logging.info(f"shape={p.shape}, var_factor={var_factor}")
|
||||||
param_rms.mul_(var_factor.sqrt())
|
param_rms.mul_(var_factor.sqrt())
|
||||||
|
|
||||||
|
|
||||||
@ -1648,7 +1652,7 @@ class LRScheduler(object):
|
|||||||
def print_lr(self, is_verbose, group, lr):
|
def print_lr(self, is_verbose, group, lr):
|
||||||
"""Display the current learning rate."""
|
"""Display the current learning rate."""
|
||||||
if is_verbose:
|
if is_verbose:
|
||||||
print(
|
logging.info(
|
||||||
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||||
f" of group {group} to {lr:.4e}."
|
f" of group {group} to {lr:.4e}."
|
||||||
)
|
)
|
||||||
@ -1797,7 +1801,7 @@ class Eve(Optimizer):
|
|||||||
|
|
||||||
if random.random() < 0.0005:
|
if random.random() < 0.0005:
|
||||||
step = (exp_avg/denom) * step_size
|
step = (exp_avg/denom) * step_size
|
||||||
print(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
|
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
|
||||||
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
@ -1863,8 +1867,8 @@ def _test_eden():
|
|||||||
scheduler.step_batch()
|
scheduler.step_batch()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
|
|
||||||
print("last lr = ", scheduler.get_last_lr())
|
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
||||||
print("state dict = ", scheduler.state_dict())
|
logging.info(f"state dict = {scheduler.state_dict()}")
|
||||||
|
|
||||||
|
|
||||||
def _test_eve_cain():
|
def _test_eve_cain():
|
||||||
@ -1873,7 +1877,7 @@ def _test_eve_cain():
|
|||||||
E = 100
|
E = 100
|
||||||
B = 4
|
B = 4
|
||||||
T = 2
|
T = 2
|
||||||
print("in test_eve_cain")
|
logging.info("in test_eve_cain")
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
@ -1921,7 +1925,8 @@ def _test_eve_cain():
|
|||||||
avg_loss = loss.item()
|
avg_loss = loss.item()
|
||||||
else:
|
else:
|
||||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||||
if n == 0 and epoch % 5 == 0:
|
#if n == 0 and epoch % 5 == 0:
|
||||||
|
if True:
|
||||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||||
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||||
@ -1931,7 +1936,7 @@ def _test_eve_cain():
|
|||||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||||
lr = scheduler.get_last_lr()[0]
|
lr = scheduler.get_last_lr()[0]
|
||||||
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||||
loss.log().backward()
|
loss.log().backward()
|
||||||
optim.step()
|
optim.step()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
@ -1940,24 +1945,24 @@ def _test_eve_cain():
|
|||||||
#diagnostic.print_diagnostics()
|
#diagnostic.print_diagnostics()
|
||||||
|
|
||||||
stop = timeit.default_timer()
|
stop = timeit.default_timer()
|
||||||
print(f"Iter={iter}, Time taken: {stop - start}")
|
logging.info(f"Iter={iter}, Time taken: {stop - start}")
|
||||||
|
|
||||||
print("last lr = ", scheduler.get_last_lr())
|
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
||||||
#print("state dict = ", scheduler.state_dict())
|
#logging.info("state dict = ", scheduler.state_dict())
|
||||||
#print("optim state_dict = ", optim.state_dict())
|
#logging.info("optim state_dict = ", optim.state_dict())
|
||||||
print("input_magnitudes = ", input_magnitudes)
|
logging.info(f"input_magnitudes = {input_magnitudes}")
|
||||||
print("output_magnitudes = ", output_magnitudes)
|
logging.info(f"output_magnitudes = {output_magnitudes}")
|
||||||
|
|
||||||
def stddev(x):
|
def stddev(x):
|
||||||
return ((x-x.mean())**2).mean().sqrt()
|
return ((x-x.mean())**2).mean().sqrt()
|
||||||
print("Un-normalized input col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
|
logging.info(f"Un-normalized input col magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=0).sqrt().log())}")
|
||||||
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
|
logging.info(f"Normalized input col magnitudes log-stddev: {stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())}")
|
||||||
|
|
||||||
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
logging.info(f"Un-normalized 0-output row magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=1).sqrt().log())}")
|
||||||
print("Un-normalized 2-input col magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=0).sqrt().log()))
|
logging.info("Un-normalized 2-input col magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=0).sqrt().log())}")
|
||||||
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=1).sqrt().log()))
|
logging.info("Un-normalized 2-output row magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=1).sqrt().log())}")
|
||||||
|
|
||||||
print("Normalized output row magnitudes log-stddev: ", stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log()))
|
logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user