Another bug fix, regarding Q being transposed.

This commit is contained in:
Daniel Povey 2022-07-08 05:22:24 +08:00
parent ad2e698fc3
commit fb36712e6b

View File

@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
params,
lr=3e-02,
size_lr_scale=0.1,
meta_lr_scale=0.1,
meta_lr_scale=0.2,
betas=(0.9, 0.98),
eps=1.0e-08,
size_update_period=1,
@ -183,7 +183,7 @@ class LearnedGradient(Optimizer):
# re-estimate when we "rediagonalize" (this diagonalizes
# the gradient covariance).
# If our parameter matrix M is of shape (..., size),
# we'll view M as being M == torch.matmul(N, q)
# we'll view M as being M == torch.matmul(N, Q)
# so p can be interpreted as having shape
# (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate]
# by "diagonalize" we mean that we diagonalize the gradient covariance.
@ -255,6 +255,7 @@ class LearnedGradient(Optimizer):
if step % lr_est_period == 0:
self._update_lrs(group, p, state)
if step % (lr_est_period * diagonalize_period) == 0:
logging.info("Diagonalizing")
self._diagonalize_lrs(group, p, state)
self._zero_exp_avg_sq(state)
if True:
@ -555,8 +556,8 @@ class LearnedGradient(Optimizer):
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
# Reconstruct Q with the modified S.
Q[:] = torch.matmul(U * S, V.t())
if random.random() < 0.1:
Q[:] = torch.matmul(U * S_new, V.t())
if random.random() < 0.03:
subsample = max(1, S.numel() // 20)
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
@ -580,7 +581,7 @@ class LearnedGradient(Optimizer):
# index).
grad_cov = state[f"grad_cov_{dim}"]
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q))
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
U, S, V = N_grad_cov.svd()
@ -590,7 +591,7 @@ class LearnedGradient(Optimizer):
# Now, we can diagonalize N_grad_cov with:
# U^T N_grad_cov U == S.
# N_grad_cov is a sum of N_grad^T N_grad.
# So U^T N_grad^T N_grad U is diagonal.
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
#
@ -601,7 +602,8 @@ class LearnedGradient(Optimizer):
# (hat_N means \hat{N}, or N with a hat on it).
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
# hat_N will be diagonalized. We also modify Q to hat_Q when
# we modify hat_N, to keep the product M = N Q = N U U^T Q = hat_N hat_Q
# we modify hat_N, to keep the product M unchanged:
# M = N Q = N U U^T Q = hat_N hat_Q
# This can be done by setting
# hat_Q = U^T Q (eq.10)
#
@ -814,8 +816,10 @@ class LearnedGradient(Optimizer):
if x.shape[dim] == 1:
continue
Q = state[f"Q_{dim}"]
if not forward:
# Q is indexed [canonical_index, diagonalized_index]
if forward:
# Q is indexed [diagonalized_index, canonical_index]; in the forward
# direction we want to change canonical to diagonalized index, so have
# to transpose.
Q = Q.t()
# TODO: could possibly somehow force the output format to be unchanged.
x = x.transpose(-1, dim)
@ -1156,7 +1160,7 @@ class LearnedGradient(Optimizer):
try:
P = 0.5 * (P + P.t())
_,s,_ = P.svd()
print(f"Min,max eig of P: {s.min()},{s.max()}")
logging.info(f"Min,max eig of P: {s.min()},{s.max()}")
except:
pass
# testing... note, this is only true modulo "eps"
@ -1167,7 +1171,7 @@ class LearnedGradient(Optimizer):
# Roundoff can cause significant differences, so use a fairly large
# threshold of 0.001. We may increase this later or even remove the check.
if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean():
print(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
logging.info(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
return P
@ -1545,7 +1549,7 @@ class Cain(Optimizer):
var_factor = var_factor.mean(dim=dims, keepdim=True)
#if random.random() < 0.01:
# print(f"shape={p.shape}, var_factor={var_factor}")
# logging.info(f"shape={p.shape}, var_factor={var_factor}")
param_rms.mul_(var_factor.sqrt())
@ -1648,7 +1652,7 @@ class LRScheduler(object):
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate."""
if is_verbose:
print(
logging.info(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
@ -1797,7 +1801,7 @@ class Eve(Optimizer):
if random.random() < 0.0005:
step = (exp_avg/denom) * step_size
print(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
return loss
@ -1863,8 +1867,8 @@ def _test_eden():
scheduler.step_batch()
optim.zero_grad()
print("last lr = ", scheduler.get_last_lr())
print("state dict = ", scheduler.state_dict())
logging.info(f"last lr = {scheduler.get_last_lr()}")
logging.info(f"state dict = {scheduler.state_dict()}")
def _test_eve_cain():
@ -1873,7 +1877,7 @@ def _test_eve_cain():
E = 100
B = 4
T = 2
print("in test_eve_cain")
logging.info("in test_eve_cain")
device = torch.device('cuda')
dtype = torch.float32
@ -1921,7 +1925,8 @@ def _test_eve_cain():
avg_loss = loss.item()
else:
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
if n == 0 and epoch % 5 == 0:
#if n == 0 and epoch % 5 == 0:
if True:
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
@ -1931,7 +1936,7 @@ def _test_eve_cain():
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
lr = scheduler.get_last_lr()[0]
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
loss.log().backward()
optim.step()
optim.zero_grad()
@ -1940,24 +1945,24 @@ def _test_eve_cain():
#diagnostic.print_diagnostics()
stop = timeit.default_timer()
print(f"Iter={iter}, Time taken: {stop - start}")
logging.info(f"Iter={iter}, Time taken: {stop - start}")
print("last lr = ", scheduler.get_last_lr())
#print("state dict = ", scheduler.state_dict())
#print("optim state_dict = ", optim.state_dict())
print("input_magnitudes = ", input_magnitudes)
print("output_magnitudes = ", output_magnitudes)
logging.info(f"last lr = {scheduler.get_last_lr()}")
#logging.info("state dict = ", scheduler.state_dict())
#logging.info("optim state_dict = ", optim.state_dict())
logging.info(f"input_magnitudes = {input_magnitudes}")
logging.info(f"output_magnitudes = {output_magnitudes}")
def stddev(x):
return ((x-x.mean())**2).mean().sqrt()
print("Un-normalized input col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
logging.info(f"Un-normalized input col magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=0).sqrt().log())}")
logging.info(f"Normalized input col magnitudes log-stddev: {stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())}")
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
print("Un-normalized 2-input col magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=0).sqrt().log()))
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=1).sqrt().log()))
logging.info(f"Un-normalized 0-output row magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=1).sqrt().log())}")
logging.info("Un-normalized 2-input col magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=0).sqrt().log())}")
logging.info("Un-normalized 2-output row magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=1).sqrt().log())}")
print("Normalized output row magnitudes log-stddev: ", stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log()))
logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}")