Some bug fixes.. seems to be working.

This commit is contained in:
Daniel Povey 2022-06-17 12:04:21 +08:00
parent 827a37c7fc
commit 9e92d13a33

View File

@ -94,10 +94,6 @@ class NeutralGradient(Optimizer):
raise ValueError( raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1]) "Invalid beta parameter at index 1: {}".format(betas[1])
) )
if not 0.0 < betas[2] < 1.0:
raise ValueError(
"Invalid beta parameter at index 2: {}".format(betas[2])
)
if not 0 < scale_speed < 1.0: if not 0 < scale_speed < 1.0:
raise ValueError("Invalid scale_speed value: {}".format(scale_speed)) raise ValueError("Invalid scale_speed value: {}".format(scale_speed))
if not 0.0 <= grad_eps < 0.1: if not 0.0 <= grad_eps < 0.1:
@ -199,10 +195,10 @@ class NeutralGradient(Optimizer):
if not is_one_axis: if not is_one_axis:
# each parameter has a different random time, modulo estimate_period, # each parameter has a different random time, modulo estimate_period,
# to re-estimate the projections. "steps_this_period" will # to re-estimate the projections. "step_within_period" will increase by
# be reset to 0 when it reaches esetimate_period. # 1 on each step, and will be reset to 0 when it reaches estimate_period.
state["steps_this_period"] = random.random(0, state["step_within_period"] = random.randint(0,
estimate_period-stats_steps) estimate_period-stats_steps)
used_scale = False used_scale = False
for dim in range(p.ndim): for dim in range(p.ndim):
@ -212,7 +208,7 @@ class NeutralGradient(Optimizer):
continue continue
elif size > max_fullcov_size: elif size > max_fullcov_size:
# diagonal only... # diagonal only...
state[f"proj_{dim}"] = torch.ones(size, **kwargs) * param_rms state[f"proj_{dim}"] = torch.ones(size, **kwargs)
else: else:
state[f"proj_{dim}"] = torch.eye(size, **kwargs) state[f"proj_{dim}"] = torch.eye(size, **kwargs)
if not used_scale: if not used_scale:
@ -279,12 +275,12 @@ class NeutralGradient(Optimizer):
else: else:
# The full update. # The full update.
step_within_period = state["step_within_period"] step_within_period = state["step_within_period"]
if step_within_period == estimate_step: if step_within_period == estimate_period:
self._estimate_projections(p, state, param_eps, param_rel_eps, param_pow) self._estimate_projections(p, state, param_eps, param_rel_eps, param_pow)
state["step_within_period"] = 0 state["step_within_period"] = 0
if step_within_period >= estimate_period - stats_steps: if step_within_period >= estimate_period - stats_steps:
self._store_grad_stats(grad, state) self._store_grad_stats(grad, state, max_fullcov_size)
cur_grad = grad cur_grad = grad
@ -299,6 +295,10 @@ class NeutralGradient(Optimizer):
# stats when we changed the co-ordinates. # stats when we changed the co-ordinates.
bias_correction2 = 1 - beta2 ** (min(step, step_within_period) + 1) bias_correction2 = 1 - beta2 ** (min(step, step_within_period) + 1)
cur_grad = cur_grad / (exp_avg_sq.sqrt() + grad_eps)
if bias_correction2 < 0.99:
cur_grad *= bias_correction2
cur_grad = self._change_coordinates(cur_grad, state, forward=False) cur_grad = self._change_coordinates(cur_grad, state, forward=False)
if random.random() < 0.004: if random.random() < 0.004:
@ -316,7 +316,7 @@ class NeutralGradient(Optimizer):
# is from gradient descent. # is from gradient descent.
prod = (grad*cur_grad).mean() prod = (grad*cur_grad).mean()
cos_angle = prod / ((grad**2).mean() * (cur_grad**2).mean()).sqrt() cos_angle = prod / ((grad**2).mean() * (cur_grad**2).mean()).sqrt()
if random.random() < 0.01 or cos_angle < 0.01: if random.random() < 0.04 or cos_angle < 0.01:
print(f"cos_angle = {cos_angle}, shape={grad.shape}") print(f"cos_angle = {cos_angle}, shape={grad.shape}")
alpha = -lr * (1-beta1) alpha = -lr * (1-beta1)
@ -417,22 +417,22 @@ class NeutralGradient(Optimizer):
size = p.shape[dim] size = p.shape[dim]
if size == 1: if size == 1:
continue continue
count = state[f"grad_cov_count_{dim}"]
assert count != 0 # we can find a way to deal with this case later,
# if it happens.
grad_cov = state[f"grad_cov_{dim}"] / count
del state[f"grad_cov_{dim}"] # save memory
proj = state[f"proj_{dim}"] proj = state[f"proj_{dim}"]
if proj.ndim == 2: if proj.ndim == 2:
# This dimension gets the full-covariance treatment. # This dimension gets the full-covariance treatment.
count = state[f"grad_cov_count_{dim}"]
assert count != 0
grad_cov = state[f"grad_cov_{dim}"] / count
del state[f"grad_cov_{dim}"] # save memory
self._randomize_lowrank_cov(grad_cov, count) self._randomize_lowrank_cov(grad_cov, count)
param_cov = self._get_param_cov(p, dim) param_cov = self._get_param_cov(p, dim)
# P is the SPD matrix such that P G P^T == C^{param_pow}, # P is the SPD matrix such that P G P^T == C^{param_pow},
# where G == grad_cov and C == param_cov. # where G == grad_cov and C == param_cov.
P = self._estimate_proj(grad_cov_smoothed, P = self._estimate_proj(grad_cov,
param_cov_smoothed, param_cov,
param_pow) param_pow)
# The only thing we want from P is the basis that diagonalizes # The only thing we want from P is the basis that diagonalizes
# it, i.e. if we do the symmetric svd P = U S U^T, we can # it, i.e. if we do the symmetric svd P = U S U^T, we can
@ -470,13 +470,12 @@ class NeutralGradient(Optimizer):
# Rotate `p` to the diagonalize basis. At this point there is no scaling, # Rotate `p` to the diagonalize basis. At this point there is no scaling,
# just an orthogonal transformation; this function is going to add the # just an orthogonal transformation; this function is going to add the
# scaling to state[f"proj_{dim}"] # scaling to state[f"proj_{dim}"]
rotated_p = self._change_coordinates(rotated_p, state, forward=True) rotated_p = self._change_coordinates(p, state, forward=True)
params_sq = rotated_p**2 params_sq = rotated_p**2
params_sq.add_(param_eps*param_eps + params_sq.add_(param_eps*param_eps +
param_rel_eps*param_rel_eps * params_sq.mean()) param_rel_eps*param_rel_eps * params_sq.mean())
param_var = torch.ones_like(p)
for _ in range(3): for _ in range(3):
# Iterate 3 times, this should be enough to converge. # Iterate 3 times, this should be enough to converge.
@ -486,13 +485,16 @@ class NeutralGradient(Optimizer):
continue continue
# p will have at least one non-trivial dim. # p will have at least one non-trivial dim.
other_dims = [ i for i in range(p.ndim) if i != dim ] other_dims = [ i for i in range(p.ndim) if i != dim ]
this_var = param_var.mean(dim=other_dims, keepdim=True) # Compute diagonal variance along this dimension
param_var = param_var / this_var # (this is after normalizing previous dimensions' variance)
this_var = params_sq.mean(dim=other_dims, keepdim=True)
params_sq = params_sq / this_var
this_var = this_var.reshape(size) this_var = this_var.reshape(size)
this_scale = (this_var ** (param_pow * 0.5)).reshape(size) this_scale = (this_var ** (param_pow * 0.5)).reshape(size)
proj = state[f"proj_{dim}"] proj = state[f"proj_{dim}"]
#print(f"iter={_}, dim={dim}, this_scale = {this_scale}")
if proj.ndim == 1: if proj.ndim == 1:
proj *= this_scale proj *= this_scale
else: else:
@ -591,6 +593,7 @@ class NeutralGradient(Optimizer):
# To be confident in our estimate of the covariance, we want `rank` (which # To be confident in our estimate of the covariance, we want `rank` (which
# actually represents the number of outer products added together) # actually represents the number of outer products added together)
# to be at least `required_rank`; otherwise we'll add a random component. # to be at least `required_rank`; otherwise we'll add a random component.
size = cov.shape[0]
required_rank = int(size * 1.2) + 1 required_rank = int(size * 1.2) + 1
if rank < required_rank: if rank < required_rank:
@ -598,7 +601,7 @@ class NeutralGradient(Optimizer):
# params are exactly zero or we sample a zero matrix; the exact value is not # params are exactly zero or we sample a zero matrix; the exact value is not
# going to affect the performance. # going to affect the performance.
eps = 1.0e-20 eps = 1.0e-20
param_cov_scale = param_cov.diag().mean() + 1.0e-20 cov_scale = cov.diag().mean() + 1.0e-20
# The following formula assumes that the "missing" outer products are # The following formula assumes that the "missing" outer products are
# expected to be smaller than the ones that we do have, i.e. 0.2 the size. # expected to be smaller than the ones that we do have, i.e. 0.2 the size.
@ -606,12 +609,12 @@ class NeutralGradient(Optimizer):
# i.e. how big it is relative to the trace of the existing matrix. # i.e. how big it is relative to the trace of the existing matrix.
missing_rank = (required_rank - rank) missing_rank = (required_rank - rank)
rand_scale = 0.2 * missing_rank / rank rand_scale = 0.2 * missing_rank / rank
R = torch.randn(size, size) R = torch.randn(size, size, device=cov.device, dtype=cov.dtype)
R = torch.matmul(C, C.t()) # positive semidefinite random matrix R = torch.matmul(R, R.t()) # positive semidefinite random matrix
R_scale = R.diag().mean() + 1.0e-20 R_scale = R.diag().mean() + 1.0e-20
if random.random() < 0.02: if random.random() < 0.02:
print(f"required_rank={required_rank}, rank={rank}, size={size}, R_scale={R_scale}") print(f"required_rank={required_rank}, rank={rank}, size={size}, R_scale={R_scale}")
param_cov.add_(R, alpha=rand_scale * param_cov_scale / R_scale) cov.add_(R, alpha=rand_scale * cov_scale / R_scale)
def _multiply_on_dim(self, x: Tensor, proj: Tensor, dim: int, def _multiply_on_dim(self, x: Tensor, proj: Tensor, dim: int,
@ -1356,8 +1359,10 @@ def _test_eve_cain():
if iter == 0: optim = Eve(m.parameters(), lr=0.003) if iter == 0: optim = Eve(m.parameters(), lr=0.003)
elif iter == 1: optim = Cain(m.parameters(), lr=0.03) elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
elif iter == 2: optim = NeutralGradient(m.parameters(), lr=0.03, max_size=10) elif iter == 2: optim = NeutralGradient(m.parameters(), lr=0.03, max_fullcov_size=10,
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.03, max_size=1000) estimate_period=500, stats_steps=100)
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.03, max_fullcov_size=1000,
estimate_period=500, stats_steps=100)
scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False) scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False)
start = timeit.default_timer() start = timeit.default_timer()
@ -1375,7 +1380,10 @@ def _test_eve_cain():
for n, (x,y) in enumerate(train_pairs): for n, (x,y) in enumerate(train_pairs):
y_out = m(x) y_out = m(x)
loss = ((y_out - y)**2).mean() * 100.0 loss = ((y_out - y)**2).mean() * 100.0
avg_loss = 0.95 * avg_loss + 0.05 * loss.item() if epoch == 0 and n == 0:
avg_loss = loss.item()
else:
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
if n == 0 and epoch % 10 == 0: if n == 0 and epoch % 10 == 0:
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()