Another bug fix, regarding Q being transposed.

This commit is contained in:
Daniel Povey 2022-07-08 05:22:24 +08:00
parent ad2e698fc3
commit fb36712e6b

View File

@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
params, params,
lr=3e-02, lr=3e-02,
size_lr_scale=0.1, size_lr_scale=0.1,
meta_lr_scale=0.1, meta_lr_scale=0.2,
betas=(0.9, 0.98), betas=(0.9, 0.98),
eps=1.0e-08, eps=1.0e-08,
size_update_period=1, size_update_period=1,
@ -183,7 +183,7 @@ class LearnedGradient(Optimizer):
# re-estimate when we "rediagonalize" (this diagonalizes # re-estimate when we "rediagonalize" (this diagonalizes
# the gradient covariance). # the gradient covariance).
# If our parameter matrix M is of shape (..., size), # If our parameter matrix M is of shape (..., size),
# we'll view M as being M == torch.matmul(N, q) # we'll view M as being M == torch.matmul(N, Q)
# so p can be interpreted as having shape # so p can be interpreted as having shape
# (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate] # (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate]
# by "diagonalize" we mean that we diagonalize the gradient covariance. # by "diagonalize" we mean that we diagonalize the gradient covariance.
@ -255,6 +255,7 @@ class LearnedGradient(Optimizer):
if step % lr_est_period == 0: if step % lr_est_period == 0:
self._update_lrs(group, p, state) self._update_lrs(group, p, state)
if step % (lr_est_period * diagonalize_period) == 0: if step % (lr_est_period * diagonalize_period) == 0:
logging.info("Diagonalizing")
self._diagonalize_lrs(group, p, state) self._diagonalize_lrs(group, p, state)
self._zero_exp_avg_sq(state) self._zero_exp_avg_sq(state)
if True: if True:
@ -555,8 +556,8 @@ class LearnedGradient(Optimizer):
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0 S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more. S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
# Reconstruct Q with the modified S. # Reconstruct Q with the modified S.
Q[:] = torch.matmul(U * S, V.t()) Q[:] = torch.matmul(U * S_new, V.t())
if random.random() < 0.1: if random.random() < 0.03:
subsample = max(1, S.numel() // 20) subsample = max(1, S.numel() // 20)
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}") logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
@ -580,7 +581,7 @@ class LearnedGradient(Optimizer):
# index). # index).
grad_cov = state[f"grad_cov_{dim}"] grad_cov = state[f"grad_cov_{dim}"]
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q)) N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
U, S, V = N_grad_cov.svd() U, S, V = N_grad_cov.svd()
@ -590,7 +591,7 @@ class LearnedGradient(Optimizer):
# Now, we can diagonalize N_grad_cov with: # Now, we can diagonalize N_grad_cov with:
# U^T N_grad_cov U == S. # U^T N_grad_cov U == S.
# N_grad_cov is a sum of N_grad^T N_grad. # N_grad_cov is a sum of N_grad^T N_grad.
# So U^T N_grad^T N_grad U is diagonal. # We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad). # The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I, # This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
# #
@ -601,7 +602,8 @@ class LearnedGradient(Optimizer):
# (hat_N means \hat{N}, or N with a hat on it). # (hat_N means \hat{N}, or N with a hat on it).
# So if we interpret hat_N = N U, the gradient covariance w.r.t. # So if we interpret hat_N = N U, the gradient covariance w.r.t.
# hat_N will be diagonalized. We also modify Q to hat_Q when # hat_N will be diagonalized. We also modify Q to hat_Q when
# we modify hat_N, to keep the product M = N Q = N U U^T Q = hat_N hat_Q # we modify hat_N, to keep the product M unchanged:
# M = N Q = N U U^T Q = hat_N hat_Q
# This can be done by setting # This can be done by setting
# hat_Q = U^T Q (eq.10) # hat_Q = U^T Q (eq.10)
# #
@ -814,8 +816,10 @@ class LearnedGradient(Optimizer):
if x.shape[dim] == 1: if x.shape[dim] == 1:
continue continue
Q = state[f"Q_{dim}"] Q = state[f"Q_{dim}"]
if not forward: if forward:
# Q is indexed [canonical_index, diagonalized_index] # Q is indexed [diagonalized_index, canonical_index]; in the forward
# direction we want to change canonical to diagonalized index, so have
# to transpose.
Q = Q.t() Q = Q.t()
# TODO: could possibly somehow force the output format to be unchanged. # TODO: could possibly somehow force the output format to be unchanged.
x = x.transpose(-1, dim) x = x.transpose(-1, dim)
@ -1156,7 +1160,7 @@ class LearnedGradient(Optimizer):
try: try:
P = 0.5 * (P + P.t()) P = 0.5 * (P + P.t())
_,s,_ = P.svd() _,s,_ = P.svd()
print(f"Min,max eig of P: {s.min()},{s.max()}") logging.info(f"Min,max eig of P: {s.min()},{s.max()}")
except: except:
pass pass
# testing... note, this is only true modulo "eps" # testing... note, this is only true modulo "eps"
@ -1167,7 +1171,7 @@ class LearnedGradient(Optimizer):
# Roundoff can cause significant differences, so use a fairly large # Roundoff can cause significant differences, so use a fairly large
# threshold of 0.001. We may increase this later or even remove the check. # threshold of 0.001. We may increase this later or even remove the check.
if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean(): if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean():
print(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}") logging.info(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
return P return P
@ -1545,7 +1549,7 @@ class Cain(Optimizer):
var_factor = var_factor.mean(dim=dims, keepdim=True) var_factor = var_factor.mean(dim=dims, keepdim=True)
#if random.random() < 0.01: #if random.random() < 0.01:
# print(f"shape={p.shape}, var_factor={var_factor}") # logging.info(f"shape={p.shape}, var_factor={var_factor}")
param_rms.mul_(var_factor.sqrt()) param_rms.mul_(var_factor.sqrt())
@ -1648,7 +1652,7 @@ class LRScheduler(object):
def print_lr(self, is_verbose, group, lr): def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate.""" """Display the current learning rate."""
if is_verbose: if is_verbose:
print( logging.info(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}." f" of group {group} to {lr:.4e}."
) )
@ -1797,7 +1801,7 @@ class Eve(Optimizer):
if random.random() < 0.0005: if random.random() < 0.0005:
step = (exp_avg/denom) * step_size step = (exp_avg/denom) * step_size
print(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
return loss return loss
@ -1863,8 +1867,8 @@ def _test_eden():
scheduler.step_batch() scheduler.step_batch()
optim.zero_grad() optim.zero_grad()
print("last lr = ", scheduler.get_last_lr()) logging.info(f"last lr = {scheduler.get_last_lr()}")
print("state dict = ", scheduler.state_dict()) logging.info(f"state dict = {scheduler.state_dict()}")
def _test_eve_cain(): def _test_eve_cain():
@ -1873,7 +1877,7 @@ def _test_eve_cain():
E = 100 E = 100
B = 4 B = 4
T = 2 T = 2
print("in test_eve_cain") logging.info("in test_eve_cain")
device = torch.device('cuda') device = torch.device('cuda')
dtype = torch.float32 dtype = torch.float32
@ -1921,7 +1925,8 @@ def _test_eve_cain():
avg_loss = loss.item() avg_loss = loss.item()
else: else:
avg_loss = 0.95 * avg_loss + 0.05 * loss.item() avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
if n == 0 and epoch % 5 == 0: #if n == 0 and epoch % 5 == 0:
if True:
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
@ -1931,7 +1936,7 @@ def _test_eve_cain():
#scale2 = '%.2e' % (m[2].weight_scale.exp().item()) #scale2 = '%.2e' % (m[2].weight_scale.exp().item())
#scale2b = '%.2e' % (m[2].bias_scale.exp().item()) #scale2b = '%.2e' % (m[2].bias_scale.exp().item())
lr = scheduler.get_last_lr()[0] lr = scheduler.get_last_lr()[0]
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
loss.log().backward() loss.log().backward()
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
@ -1940,24 +1945,24 @@ def _test_eve_cain():
#diagnostic.print_diagnostics() #diagnostic.print_diagnostics()
stop = timeit.default_timer() stop = timeit.default_timer()
print(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"Iter={iter}, Time taken: {stop - start}")
print("last lr = ", scheduler.get_last_lr()) logging.info(f"last lr = {scheduler.get_last_lr()}")
#print("state dict = ", scheduler.state_dict()) #logging.info("state dict = ", scheduler.state_dict())
#print("optim state_dict = ", optim.state_dict()) #logging.info("optim state_dict = ", optim.state_dict())
print("input_magnitudes = ", input_magnitudes) logging.info(f"input_magnitudes = {input_magnitudes}")
print("output_magnitudes = ", output_magnitudes) logging.info(f"output_magnitudes = {output_magnitudes}")
def stddev(x): def stddev(x):
return ((x-x.mean())**2).mean().sqrt() return ((x-x.mean())**2).mean().sqrt()
print("Un-normalized input col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log())) logging.info(f"Un-normalized input col magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=0).sqrt().log())}")
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())) logging.info(f"Normalized input col magnitudes log-stddev: {stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())}")
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log())) logging.info(f"Un-normalized 0-output row magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=1).sqrt().log())}")
print("Un-normalized 2-input col magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=0).sqrt().log())) logging.info("Un-normalized 2-input col magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=0).sqrt().log())}")
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=1).sqrt().log())) logging.info("Un-normalized 2-output row magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=1).sqrt().log())}")
print("Normalized output row magnitudes log-stddev: ", stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())) logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}")