mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
This time get it right, last time was starting from wrong base.
This commit is contained in:
commit
117d348f70
@ -52,6 +52,10 @@ class NeutralGradient(Optimizer):
|
||||
decreasing with the parameter norm.
|
||||
param_rel_eps: An epsilon that limits how different different parameter rms estimates can
|
||||
be within the same tensor
|
||||
param_rel_max: Limits how big we allow the parameter variance eigs to be relative to the
|
||||
original mean. Setting this to a not-very-large value like 1 ensures that
|
||||
we only apply the param scaling for smaller-than-average, not larger-than-average,
|
||||
eigenvalues, which limits the dispersion of eigenvalues of the parameter tensors.
|
||||
param_max: To prevent parameter tensors getting too large, we will clip elements to
|
||||
-param_max..param_max.
|
||||
max_fullcov_size: We will only use the full-covariance (not-diagonal) update for
|
||||
@ -85,11 +89,12 @@ class NeutralGradient(Optimizer):
|
||||
self,
|
||||
params,
|
||||
lr=3e-02,
|
||||
betas=(0.9, 0.98),
|
||||
betas=(0.9, 0.98, 0.95),
|
||||
scale_speed=0.1,
|
||||
grad_eps=1e-10,
|
||||
param_eps=1.0e-06,
|
||||
param_rel_eps=1.0e-04,
|
||||
param_rel_max=1.0,
|
||||
param_max_rms=2.0,
|
||||
param_min_rms=1.0e-05,
|
||||
max_fullcov_size=1023,
|
||||
@ -140,6 +145,7 @@ class NeutralGradient(Optimizer):
|
||||
grad_eps=grad_eps,
|
||||
param_eps=param_eps,
|
||||
param_rel_eps=param_rel_eps,
|
||||
param_rel_max=param_rel_max,
|
||||
param_max_rms=param_max_rms,
|
||||
param_min_rms=param_min_rms,
|
||||
max_fullcov_size=max_fullcov_size,
|
||||
@ -174,11 +180,12 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
for group in self.param_groups:
|
||||
lr = group["lr"]
|
||||
beta1, beta2 = group["betas"]
|
||||
beta1, beta2, beta3 = group["betas"]
|
||||
scale_speed = group["scale_speed"]
|
||||
grad_eps = group["grad_eps"]
|
||||
param_eps = group["param_eps"]
|
||||
param_rel_eps = group["param_rel_eps"]
|
||||
param_rel_max = group["param_rel_max"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
max_fullcov_size = group["max_fullcov_size"]
|
||||
@ -271,8 +278,7 @@ class NeutralGradient(Optimizer):
|
||||
state["step_within_period"] = random.randint(0,
|
||||
estimate_period-stats_steps)
|
||||
|
||||
if param_pow != 1.0:
|
||||
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
|
||||
used_scale = False
|
||||
for dim in range(p.ndim):
|
||||
@ -285,6 +291,11 @@ class NeutralGradient(Optimizer):
|
||||
state[f"proj_{dim}"] = torch.ones(size, **kwargs)
|
||||
else:
|
||||
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
|
||||
|
||||
# we will decay the param stats via beta3; they are only
|
||||
# accumulated once every `estimate_period`
|
||||
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
if not used_scale:
|
||||
param_rms = (p**2).mean().sqrt().add_(param_eps)
|
||||
state[f"proj_{dim}"] *= param_rms
|
||||
@ -389,6 +400,7 @@ class NeutralGradient(Optimizer):
|
||||
step_within_period = state["step_within_period"]
|
||||
if step_within_period == estimate_period:
|
||||
self._estimate_projections(p, state, param_eps, param_rel_eps,
|
||||
param_rel_max, beta3,
|
||||
param_pow, grad_pow, grad_min_rand)
|
||||
state["step_within_period"] = 0
|
||||
elif step_within_period >= estimate_period - stats_steps:
|
||||
@ -437,7 +449,8 @@ class NeutralGradient(Optimizer):
|
||||
logging.info(f"cos_angle = {cos_angle}, shape={grad.shape}")
|
||||
|
||||
alpha = -lr * (1-beta1)
|
||||
if param_pow != 1.0 or grad_pow != 1.0:
|
||||
|
||||
if True:
|
||||
# Renormalize scale of cur_grad
|
||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||
scalar_exp_avg_sq.mul_(beta2).add_((cur_grad**2).mean()/n_cached_grads, alpha=(1-beta2))
|
||||
@ -509,6 +522,8 @@ class NeutralGradient(Optimizer):
|
||||
state: dict,
|
||||
param_eps: float,
|
||||
param_rel_eps: float,
|
||||
param_rel_max: float,
|
||||
beta3: float,
|
||||
param_pow: float,
|
||||
grad_pow: float,
|
||||
grad_min_rand: float,
|
||||
@ -523,7 +538,6 @@ class NeutralGradient(Optimizer):
|
||||
co-ordinates. (Thus, applying with forward==True and then forward==False
|
||||
is not a round trip).
|
||||
|
||||
|
||||
Args:
|
||||
p: The parameter for which we are etimating the projections. Supplied
|
||||
because we need this to estimate parameter covariance matrices in
|
||||
@ -535,6 +549,7 @@ class NeutralGradient(Optimizer):
|
||||
param_rel_eps: Another constraint on minimum parameter rms, relative to
|
||||
sqrt(overall_param_rms), where overall_param_rms is the sqrt of the
|
||||
parameter variance averaged over the entire tensor.
|
||||
beta3: discounting parameter for param_cov, applied once every estimate_period.
|
||||
param_pow: param_pow==1.0 means fully normalizing for parameter scale;
|
||||
setting to a smaller value closer to 0.0 means partial normalization.
|
||||
grad_min_rand: minimum proportion of random tensor to mix with the
|
||||
@ -557,10 +572,13 @@ class NeutralGradient(Optimizer):
|
||||
grad_cov = state[f"grad_cov_{dim}"] / count
|
||||
del state[f"grad_cov_{dim}"] # save memory
|
||||
|
||||
self._randomize_lowrank_cov(grad_cov, count, p.shape,
|
||||
grad_min_rand)
|
||||
#self._randomize_lowrank_cov(grad_cov, count, p.shape,
|
||||
# grad_min_rand)
|
||||
|
||||
param_cov = self._get_param_cov(p, dim)
|
||||
cur_param_cov = self._get_param_cov(p, dim)
|
||||
cur_param_cov = cur_param_cov / (cur_param_cov.diag().mean() + 1.0e-20) # normalize size..
|
||||
param_cov = state[f"param_cov_{dim}"]
|
||||
param_cov.mul_(beta3).add_(cur_param_cov, alpha=(1-beta3))
|
||||
|
||||
# We want the orthogonal matrix U that diagonalizes P, where
|
||||
# P is the SPD matrix such that P G^{param_pow} P^T == C^{param_pow},
|
||||
@ -570,7 +588,11 @@ class NeutralGradient(Optimizer):
|
||||
# since the P's are related by taking-to-a-power.
|
||||
P = self._estimate_proj(grad_cov,
|
||||
param_cov,
|
||||
param_pow / grad_pow)
|
||||
param_pow / grad_pow,
|
||||
param_rel_eps,
|
||||
param_rel_max)
|
||||
|
||||
|
||||
# The only thing we want from P is the basis that diagonalizes
|
||||
# it, i.e. if we do the symmetric svd P = U S U^T, we can
|
||||
# interpret the shape of U as (canonical_coordinate,
|
||||
@ -585,7 +607,8 @@ class NeutralGradient(Optimizer):
|
||||
proj.fill_(1.0)
|
||||
|
||||
self._estimate_param_scale(p, state, param_eps,
|
||||
param_rel_eps, param_pow)
|
||||
param_rel_eps, param_rel_max,
|
||||
param_pow)
|
||||
|
||||
|
||||
def _estimate_param_scale(self,
|
||||
@ -593,6 +616,7 @@ class NeutralGradient(Optimizer):
|
||||
state: dict,
|
||||
param_eps: float,
|
||||
param_rel_eps: float,
|
||||
param_rel_max: float,
|
||||
param_pow: float) -> None:
|
||||
"""
|
||||
This is called from _estimate_projections() after suitably "diagonalizing" bases
|
||||
@ -615,8 +639,10 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
|
||||
params_sq_norm = params_sq
|
||||
if param_pow != 1.0:
|
||||
params_sq_partnorm = params_sq
|
||||
|
||||
# param_diag_vars is the diagonals of the parameter variances
|
||||
# on each tensor dim, with an arbitrary scalar factor.
|
||||
param_diag_vars = [None] * p.ndim
|
||||
|
||||
for _ in range(3):
|
||||
# Iterate 3 times, this should be enough to converge.
|
||||
@ -630,35 +656,39 @@ class NeutralGradient(Optimizer):
|
||||
# (this is after normalizing previous dimensions' variance)
|
||||
this_var = params_sq_norm.mean(dim=other_dims, keepdim=True)
|
||||
params_sq_norm = params_sq_norm / this_var
|
||||
if param_pow != 1.0:
|
||||
params_sq_partnorm = params_sq_partnorm / (this_var ** param_pow)
|
||||
|
||||
this_var = this_var.reshape(size)
|
||||
this_scale = (this_var ** (param_pow * 0.5)).reshape(size)
|
||||
proj = state[f"proj_{dim}"]
|
||||
|
||||
if proj.ndim == 1:
|
||||
proj *= this_scale
|
||||
if param_diag_vars[dim] == None:
|
||||
param_diag_vars[dim] = this_var.reshape(size)
|
||||
else:
|
||||
# scale the rows; dim 0 of `proj` corresponds to the
|
||||
# diagonalized space.
|
||||
proj *= this_scale.unsqueeze(-1)
|
||||
param_diag_vars[dim] *= this_var.reshape(size)
|
||||
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
param_diag_var = param_diag_vars[dim]
|
||||
param_diag_var = self._smooth_param_diag_var(param_diag_var,
|
||||
param_pow,
|
||||
param_rel_eps,
|
||||
param_rel_max)
|
||||
param_scale = param_diag_var ** 0.5
|
||||
proj = state[f"proj_{dim}"]
|
||||
|
||||
if proj.ndim == 1:
|
||||
proj *= param_scale
|
||||
else:
|
||||
# scale the rows; dim 0 of `proj` corresponds to the
|
||||
# diagonalized space.
|
||||
proj *= param_scale.unsqueeze(-1)
|
||||
# the scalar factors on the various dimensions' projections are a little
|
||||
# arbitrary, and their product is arbitrary too. we normalize the scalar
|
||||
# factor outside this function, see calar_exp_avg_sq.
|
||||
|
||||
if param_pow != 1.0:
|
||||
# need to get the overall scale correct, as if we had param_pow == 1.0
|
||||
scale = (params_sq_partnorm.mean() ** 0.5)
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
else:
|
||||
state[f"proj_{dim}"] *= scale
|
||||
break
|
||||
|
||||
# Reset the squared-gradient stats because we have changed the basis.
|
||||
state["exp_avg_sq"].fill_(0.0)
|
||||
if param_pow != 1.0:
|
||||
state["scalar_exp_avg_sq"].fill_(0.0)
|
||||
state["scalar_exp_avg_sq"].fill_(0.0)
|
||||
|
||||
|
||||
|
||||
@ -726,8 +756,8 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
param_cov = torch.matmul(p.t(), p) / num_outer_products
|
||||
|
||||
self._randomize_lowrank_cov(param_cov, num_outer_products,
|
||||
p.shape)
|
||||
#self._randomize_lowrank_cov(param_cov, num_outer_products,
|
||||
# p.shape)
|
||||
|
||||
return param_cov
|
||||
|
||||
@ -794,8 +824,31 @@ class NeutralGradient(Optimizer):
|
||||
(proj.t() if forward else proj)).transpose(-1, dim)
|
||||
|
||||
|
||||
def _estimate_proj(self, grad_cov: Tensor, param_cov: Tensor,
|
||||
param_pow: float = 1.0) -> Tensor:
|
||||
def _smooth_param_diag_var(self,
|
||||
param_diag: Tensor,
|
||||
param_pow: float,
|
||||
param_rel_eps: float,
|
||||
param_rel_max: float) -> Tensor:
|
||||
"""
|
||||
Applies the smoothing formula to the eigenvalues of the parameter covariance
|
||||
tensor.
|
||||
(Actually when we use this function in _get_param_cov, they won't exactly
|
||||
be the eigenvalues because we diagonalize relative to the grad_cov,
|
||||
but they will be parameter covariances in certain directions).
|
||||
"""
|
||||
# normalize so mean = 1.0
|
||||
param_diag = param_diag / (param_diag.mean() + 1.0e-20)
|
||||
# use 1/(1/x + 1/y) to apply soft-max of param_diag with param_rel_max.
|
||||
# use param_rel_eps here to prevent division by zero.
|
||||
param_diag = 1. / (1. / (param_diag + param_rel_eps) + 1. / param_rel_max)
|
||||
return param_diag ** param_pow
|
||||
|
||||
def _estimate_proj(self,
|
||||
grad_cov: Tensor,
|
||||
param_cov: Tensor,
|
||||
param_pow: float = 1.0,
|
||||
param_rel_eps: float = 0.0,
|
||||
param_rel_max: float = 1.0e+10) -> Tensor:
|
||||
"""
|
||||
Return a symmetric positive definite matrix P such that
|
||||
(P grad_cov P^T == param_cov^{param_pow}),
|
||||
@ -839,8 +892,11 @@ class NeutralGradient(Optimizer):
|
||||
# P = Q^T (Q G Q^T)^{-0.5} Q.
|
||||
|
||||
# because C is symmetric, C == U S U^T, we can ignore V.
|
||||
# S_sqrt is S.sqrt() in the normal case where param_pow == 1.0
|
||||
S_sqrt = S ** (0.5 * param_pow)
|
||||
# S_sqrt is S.sqrt() in the limit where param_pow == 1.0,
|
||||
# param_rel_eps=0, param_rel_max=inf
|
||||
S_smoothed = self._smooth_param_diag_var(S, param_pow,
|
||||
param_rel_eps, param_rel_max)
|
||||
S_sqrt = S_smoothed ** 0.5
|
||||
Q = (U * S_sqrt).t()
|
||||
|
||||
X = torch.matmul(Q, torch.matmul(G, Q.t()))
|
||||
@ -856,23 +912,24 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
P = torch.matmul(Y, Y.t())
|
||||
|
||||
if random.random() < 0.025:
|
||||
|
||||
# TEMP:
|
||||
_,s,_ = P.svd()
|
||||
print(f"Min,max eig of P: {s.min()},{s.max()}")
|
||||
|
||||
if random.random() < 1.0: #0.025:
|
||||
# TODO: remove this testing code.
|
||||
assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric.
|
||||
try:
|
||||
P = 0.5 * (P + P.t())
|
||||
_,s,_ = P.svd()
|
||||
print(f"Min,max eig of P: {s.min()},{s.max()}")
|
||||
except:
|
||||
pass
|
||||
# testing... note, this is only true modulo "eps"
|
||||
C_check = torch.matmul(torch.matmul(P, G), P)
|
||||
# C_check should equal C
|
||||
C_diff = C_check - C
|
||||
C_smoothed = torch.matmul(Q.t(), Q)
|
||||
# C_check should equal C_smoothed
|
||||
C_diff = C_check - C_smoothed
|
||||
# Roundoff can cause significant differences, so use a fairly large
|
||||
# threshold of 0.001. We may increase this later or even remove the check.
|
||||
if not C_diff.abs().mean() < 0.01 * C.diag().mean():
|
||||
print(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C.diag().mean()}")
|
||||
|
||||
if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean():
|
||||
print(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
|
||||
|
||||
return P
|
||||
|
||||
@ -899,7 +956,7 @@ class NeutralGradient(Optimizer):
|
||||
return
|
||||
|
||||
|
||||
_, beta2 = group["betas"]
|
||||
_, beta2, _ = group["betas"]
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
@ -1602,9 +1659,11 @@ def _test_eve_cain():
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
|
||||
elif iter == 2: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=10,
|
||||
estimate_period=500, stats_steps=100)
|
||||
estimate_period=500, stats_steps=100,
|
||||
lr_for_speedup=0.02)
|
||||
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=1000,
|
||||
estimate_period=500, stats_steps=100)
|
||||
estimate_period=500, stats_steps=100,
|
||||
lr_for_speedup=0.02)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user