mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Some bug fixes.. seems to be working.
This commit is contained in:
parent
827a37c7fc
commit
9e92d13a33
@ -94,10 +94,6 @@ class NeutralGradient(Optimizer):
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||
)
|
||||
if not 0.0 < betas[2] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 2: {}".format(betas[2])
|
||||
)
|
||||
if not 0 < scale_speed < 1.0:
|
||||
raise ValueError("Invalid scale_speed value: {}".format(scale_speed))
|
||||
if not 0.0 <= grad_eps < 0.1:
|
||||
@ -199,10 +195,10 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
if not is_one_axis:
|
||||
# each parameter has a different random time, modulo estimate_period,
|
||||
# to re-estimate the projections. "steps_this_period" will
|
||||
# be reset to 0 when it reaches esetimate_period.
|
||||
state["steps_this_period"] = random.random(0,
|
||||
estimate_period-stats_steps)
|
||||
# to re-estimate the projections. "step_within_period" will increase by
|
||||
# 1 on each step, and will be reset to 0 when it reaches estimate_period.
|
||||
state["step_within_period"] = random.randint(0,
|
||||
estimate_period-stats_steps)
|
||||
|
||||
used_scale = False
|
||||
for dim in range(p.ndim):
|
||||
@ -212,7 +208,7 @@ class NeutralGradient(Optimizer):
|
||||
continue
|
||||
elif size > max_fullcov_size:
|
||||
# diagonal only...
|
||||
state[f"proj_{dim}"] = torch.ones(size, **kwargs) * param_rms
|
||||
state[f"proj_{dim}"] = torch.ones(size, **kwargs)
|
||||
else:
|
||||
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
|
||||
if not used_scale:
|
||||
@ -279,12 +275,12 @@ class NeutralGradient(Optimizer):
|
||||
else:
|
||||
# The full update.
|
||||
step_within_period = state["step_within_period"]
|
||||
if step_within_period == estimate_step:
|
||||
if step_within_period == estimate_period:
|
||||
self._estimate_projections(p, state, param_eps, param_rel_eps, param_pow)
|
||||
state["step_within_period"] = 0
|
||||
|
||||
if step_within_period >= estimate_period - stats_steps:
|
||||
self._store_grad_stats(grad, state)
|
||||
self._store_grad_stats(grad, state, max_fullcov_size)
|
||||
|
||||
cur_grad = grad
|
||||
|
||||
@ -299,6 +295,10 @@ class NeutralGradient(Optimizer):
|
||||
# stats when we changed the co-ordinates.
|
||||
bias_correction2 = 1 - beta2 ** (min(step, step_within_period) + 1)
|
||||
|
||||
cur_grad = cur_grad / (exp_avg_sq.sqrt() + grad_eps)
|
||||
if bias_correction2 < 0.99:
|
||||
cur_grad *= bias_correction2
|
||||
|
||||
cur_grad = self._change_coordinates(cur_grad, state, forward=False)
|
||||
|
||||
if random.random() < 0.004:
|
||||
@ -316,7 +316,7 @@ class NeutralGradient(Optimizer):
|
||||
# is from gradient descent.
|
||||
prod = (grad*cur_grad).mean()
|
||||
cos_angle = prod / ((grad**2).mean() * (cur_grad**2).mean()).sqrt()
|
||||
if random.random() < 0.01 or cos_angle < 0.01:
|
||||
if random.random() < 0.04 or cos_angle < 0.01:
|
||||
print(f"cos_angle = {cos_angle}, shape={grad.shape}")
|
||||
|
||||
alpha = -lr * (1-beta1)
|
||||
@ -417,22 +417,22 @@ class NeutralGradient(Optimizer):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
count = state[f"grad_cov_count_{dim}"]
|
||||
assert count != 0 # we can find a way to deal with this case later,
|
||||
# if it happens.
|
||||
grad_cov = state[f"grad_cov_{dim}"] / count
|
||||
del state[f"grad_cov_{dim}"] # save memory
|
||||
proj = state[f"proj_{dim}"]
|
||||
|
||||
if proj.ndim == 2:
|
||||
# This dimension gets the full-covariance treatment.
|
||||
count = state[f"grad_cov_count_{dim}"]
|
||||
assert count != 0
|
||||
grad_cov = state[f"grad_cov_{dim}"] / count
|
||||
del state[f"grad_cov_{dim}"] # save memory
|
||||
|
||||
self._randomize_lowrank_cov(grad_cov, count)
|
||||
param_cov = self._get_param_cov(p, dim)
|
||||
|
||||
# P is the SPD matrix such that P G P^T == C^{param_pow},
|
||||
# where G == grad_cov and C == param_cov.
|
||||
P = self._estimate_proj(grad_cov_smoothed,
|
||||
param_cov_smoothed,
|
||||
P = self._estimate_proj(grad_cov,
|
||||
param_cov,
|
||||
param_pow)
|
||||
# The only thing we want from P is the basis that diagonalizes
|
||||
# it, i.e. if we do the symmetric svd P = U S U^T, we can
|
||||
@ -470,13 +470,12 @@ class NeutralGradient(Optimizer):
|
||||
# Rotate `p` to the diagonalize basis. At this point there is no scaling,
|
||||
# just an orthogonal transformation; this function is going to add the
|
||||
# scaling to state[f"proj_{dim}"]
|
||||
rotated_p = self._change_coordinates(rotated_p, state, forward=True)
|
||||
rotated_p = self._change_coordinates(p, state, forward=True)
|
||||
|
||||
params_sq = rotated_p**2
|
||||
params_sq.add_(param_eps*param_eps +
|
||||
param_rel_eps*param_rel_eps * params_sq.mean())
|
||||
|
||||
param_var = torch.ones_like(p)
|
||||
|
||||
for _ in range(3):
|
||||
# Iterate 3 times, this should be enough to converge.
|
||||
@ -486,13 +485,16 @@ class NeutralGradient(Optimizer):
|
||||
continue
|
||||
# p will have at least one non-trivial dim.
|
||||
other_dims = [ i for i in range(p.ndim) if i != dim ]
|
||||
this_var = param_var.mean(dim=other_dims, keepdim=True)
|
||||
param_var = param_var / this_var
|
||||
# Compute diagonal variance along this dimension
|
||||
# (this is after normalizing previous dimensions' variance)
|
||||
this_var = params_sq.mean(dim=other_dims, keepdim=True)
|
||||
params_sq = params_sq / this_var
|
||||
|
||||
this_var = this_var.reshape(size)
|
||||
this_scale = (this_var ** (param_pow * 0.5)).reshape(size)
|
||||
proj = state[f"proj_{dim}"]
|
||||
|
||||
#print(f"iter={_}, dim={dim}, this_scale = {this_scale}")
|
||||
if proj.ndim == 1:
|
||||
proj *= this_scale
|
||||
else:
|
||||
@ -591,6 +593,7 @@ class NeutralGradient(Optimizer):
|
||||
# To be confident in our estimate of the covariance, we want `rank` (which
|
||||
# actually represents the number of outer products added together)
|
||||
# to be at least `required_rank`; otherwise we'll add a random component.
|
||||
size = cov.shape[0]
|
||||
required_rank = int(size * 1.2) + 1
|
||||
|
||||
if rank < required_rank:
|
||||
@ -598,7 +601,7 @@ class NeutralGradient(Optimizer):
|
||||
# params are exactly zero or we sample a zero matrix; the exact value is not
|
||||
# going to affect the performance.
|
||||
eps = 1.0e-20
|
||||
param_cov_scale = param_cov.diag().mean() + 1.0e-20
|
||||
cov_scale = cov.diag().mean() + 1.0e-20
|
||||
|
||||
# The following formula assumes that the "missing" outer products are
|
||||
# expected to be smaller than the ones that we do have, i.e. 0.2 the size.
|
||||
@ -606,12 +609,12 @@ class NeutralGradient(Optimizer):
|
||||
# i.e. how big it is relative to the trace of the existing matrix.
|
||||
missing_rank = (required_rank - rank)
|
||||
rand_scale = 0.2 * missing_rank / rank
|
||||
R = torch.randn(size, size)
|
||||
R = torch.matmul(C, C.t()) # positive semidefinite random matrix
|
||||
R = torch.randn(size, size, device=cov.device, dtype=cov.dtype)
|
||||
R = torch.matmul(R, R.t()) # positive semidefinite random matrix
|
||||
R_scale = R.diag().mean() + 1.0e-20
|
||||
if random.random() < 0.02:
|
||||
print(f"required_rank={required_rank}, rank={rank}, size={size}, R_scale={R_scale}")
|
||||
param_cov.add_(R, alpha=rand_scale * param_cov_scale / R_scale)
|
||||
cov.add_(R, alpha=rand_scale * cov_scale / R_scale)
|
||||
|
||||
|
||||
def _multiply_on_dim(self, x: Tensor, proj: Tensor, dim: int,
|
||||
@ -1356,8 +1359,10 @@ def _test_eve_cain():
|
||||
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
|
||||
elif iter == 2: optim = NeutralGradient(m.parameters(), lr=0.03, max_size=10)
|
||||
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.03, max_size=1000)
|
||||
elif iter == 2: optim = NeutralGradient(m.parameters(), lr=0.03, max_fullcov_size=10,
|
||||
estimate_period=500, stats_steps=100)
|
||||
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.03, max_fullcov_size=1000,
|
||||
estimate_period=500, stats_steps=100)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
@ -1375,7 +1380,10 @@ def _test_eve_cain():
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||
if epoch == 0 and n == 0:
|
||||
avg_loss = loss.item()
|
||||
else:
|
||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||
if n == 0 and epoch % 10 == 0:
|
||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||
|
Loading…
x
Reference in New Issue
Block a user