mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Version that is passing the tests
This commit is contained in:
parent
e597b68160
commit
c6f62dfc08
@ -43,12 +43,14 @@ class NeutralGradient(Optimizer):
|
|||||||
at 0.03 and decreases over time, i.e. much higher than other common
|
at 0.03 and decreases over time, i.e. much higher than other common
|
||||||
optimizers.
|
optimizers.
|
||||||
beta: Momentum constants for regular momentum, and stats of gradients.
|
beta: Momentum constants for regular momentum, and stats of gradients.
|
||||||
eps: An epsilon to prevent division by zero
|
|
||||||
scale_speed: This scales the learning rate for purposes of learning the overall scale
|
scale_speed: This scales the learning rate for purposes of learning the overall scale
|
||||||
of each tensor
|
of each tensor
|
||||||
rms_eps: An epsilon on the rms value of the parameter, such that when the parameter
|
eps: An epsilon to prevent division by zero
|
||||||
|
param_eps: An epsilon on the rms value of the parameter, such that when the parameter
|
||||||
gets smaller than this we'll start using a fixed learning rate, not
|
gets smaller than this we'll start using a fixed learning rate, not
|
||||||
decreasing with the parameter norm.
|
decreasing with the parameter norm.
|
||||||
|
cond_eps: An epsilon that limits the condition number of gradient and parameter
|
||||||
|
covariance matrices
|
||||||
param_max: To prevent parameter tensors getting too large, we will clip elements to
|
param_max: To prevent parameter tensors getting too large, we will clip elements to
|
||||||
-param_max..param_max.
|
-param_max..param_max.
|
||||||
max_size: We will only use the full-covariance (not-diagonal) update for
|
max_size: We will only use the full-covariance (not-diagonal) update for
|
||||||
@ -64,9 +66,10 @@ class NeutralGradient(Optimizer):
|
|||||||
params,
|
params,
|
||||||
lr=1e-2,
|
lr=1e-2,
|
||||||
betas=(0.9, 0.98, 0.99),
|
betas=(0.9, 0.98, 0.99),
|
||||||
eps=1e-8,
|
|
||||||
scale_speed=0.05,
|
scale_speed=0.05,
|
||||||
rms_eps=1.0e-05,
|
eps=1e-8,
|
||||||
|
param_eps=1.0e-05,
|
||||||
|
cond_eps=1.0e-08,
|
||||||
param_max=10.0,
|
param_max=10.0,
|
||||||
max_size=1023,
|
max_size=1023,
|
||||||
stats_period=1,
|
stats_period=1,
|
||||||
@ -91,8 +94,8 @@ class NeutralGradient(Optimizer):
|
|||||||
)
|
)
|
||||||
if not 0 < scale_speed < 1.0:
|
if not 0 < scale_speed < 1.0:
|
||||||
raise ValueError("Invalid scale_speed value: {}".format(scale_speed))
|
raise ValueError("Invalid scale_speed value: {}".format(scale_speed))
|
||||||
if not 0.0 < rms_eps < 0.1:
|
if not 0.0 < param_eps < 0.1:
|
||||||
raise ValueError("Invalid rms_eps value: {}".format(rms_eps))
|
raise ValueError("Invalid param_eps value: {}".format(param_eps))
|
||||||
if not 0.1 < param_max <= 10000.0:
|
if not 0.1 < param_max <= 10000.0:
|
||||||
raise ValueError("Invalid param_max value: {}".format(param_max))
|
raise ValueError("Invalid param_max value: {}".format(param_max))
|
||||||
if not stats_period > 0:
|
if not stats_period > 0:
|
||||||
@ -106,7 +109,8 @@ class NeutralGradient(Optimizer):
|
|||||||
betas=betas,
|
betas=betas,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
scale_speed=0.05,
|
scale_speed=0.05,
|
||||||
rms_eps=rms_eps,
|
param_eps=param_eps,
|
||||||
|
cond_eps=cond_eps,
|
||||||
param_max=param_max,
|
param_max=param_max,
|
||||||
max_size=max_size,
|
max_size=max_size,
|
||||||
stats_period=stats_period,
|
stats_period=stats_period,
|
||||||
@ -135,7 +139,7 @@ class NeutralGradient(Optimizer):
|
|||||||
beta1, beta2, beta3 = group["betas"]
|
beta1, beta2, beta3 = group["betas"]
|
||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
scale_speed = group["scale_speed"]
|
scale_speed = group["scale_speed"]
|
||||||
rms_eps = group["rms_eps"]
|
param_eps = group["param_eps"]
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
param_max = group["param_max"]
|
param_max = group["param_max"]
|
||||||
max_size = group["max_size"]
|
max_size = group["max_size"]
|
||||||
@ -163,15 +167,18 @@ class NeutralGradient(Optimizer):
|
|||||||
p, memory_format=torch.preserve_format
|
p, memory_format=torch.preserve_format
|
||||||
)
|
)
|
||||||
# Exponential moving average of squared gradient value
|
# Exponential moving average of squared gradient value
|
||||||
# after all multiplications. This is used to help
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
# determine the scalar factor in the update size. It is
|
|
||||||
# resert after every `estimate_period` minibatches.
|
|
||||||
|
|
||||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||||
|
|
||||||
if p.numel() > 1:
|
if p.numel() > 1:
|
||||||
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
|
||||||
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias
|
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias
|
||||||
|
if not is_one_axis:
|
||||||
|
state["ref_exp_avg_sq"] = torch.ones_like(p)
|
||||||
|
else:
|
||||||
|
state["param_rms"] = torch.zeros((), **kwargs)
|
||||||
|
|
||||||
# we learn the scale on the parameters in a way that
|
# we learn the scale on the parameters in a way that
|
||||||
# also emulates Eve. scale_exp_avg_sq is the
|
# also emulates Eve. scale_exp_avg_sq is the
|
||||||
@ -179,14 +186,6 @@ class NeutralGradient(Optimizer):
|
|||||||
# scalar scale.
|
# scalar scale.
|
||||||
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||||
|
|
||||||
state["param_rms"] = torch.zeros((), **kwargs)
|
|
||||||
|
|
||||||
# TODO: may have to add some special stuff here if
|
|
||||||
if is_one_axis:
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
|
||||||
else:
|
|
||||||
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
|
||||||
|
|
||||||
for dim in range(p.ndim):
|
for dim in range(p.ndim):
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
if size == 1 or size == p.numel():
|
if size == 1 or size == p.numel():
|
||||||
@ -195,15 +194,12 @@ class NeutralGradient(Optimizer):
|
|||||||
elif size > max_size:
|
elif size > max_size:
|
||||||
# diagonal only...
|
# diagonal only...
|
||||||
state[f"grad_cov_{dim}"] = torch.zeros(size, **kwargs)
|
state[f"grad_cov_{dim}"] = torch.zeros(size, **kwargs)
|
||||||
state[f"proj_{dim}"] = torch.zeros(size, **kwargs)
|
state[f"proj_{dim}"] = torch.ones(size, **kwargs)
|
||||||
else:
|
else:
|
||||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size,
|
state[f"grad_cov_{dim}"] = torch.zeros(size, size,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
state[f"proj_{dim}"] = torch.zeros(size, size,
|
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
|
||||||
**kwargs)
|
|
||||||
else:
|
|
||||||
# p.numel() == 1: scalar parameter
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
@ -240,44 +236,62 @@ class NeutralGradient(Optimizer):
|
|||||||
|
|
||||||
delta.add_(p, alpha=scale_delta)
|
delta.add_(p, alpha=scale_delta)
|
||||||
|
|
||||||
param_rms = state["param_rms"]
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
if step % 10 == 0:
|
# Decay the second moment running average coefficient
|
||||||
param_rms.copy_((p**2).mean().sqrt())
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||||
|
|
||||||
|
bias_correction2 = 1 - beta2 ** (step + 1)
|
||||||
|
|
||||||
if is_one_axis:
|
if is_one_axis:
|
||||||
exp_avg_sq = state["exp_avg_sq"]
|
param_rms = state["param_rms"]
|
||||||
# Decay the second moment running average coefficient
|
if step % 10 == 0:
|
||||||
exp_avg_sq.mul_(beta3).addcmul_(grad, grad, value=1 - beta3)
|
param_rms.copy_((p**2).mean().sqrt())
|
||||||
|
|
||||||
bias_correction2 = 1 - beta3 ** (step + 1)
|
|
||||||
grad_rms = (exp_avg_sq.sqrt()).add_(eps)
|
grad_rms = (exp_avg_sq.sqrt()).add_(eps)
|
||||||
this_delta = grad / grad_rms
|
this_delta = grad / grad_rms
|
||||||
|
alpha = (-(1-beta1) * lr * bias_correction2) * param_rms.clamp(min=param_eps)
|
||||||
alpha = (-(1-beta1) * lr * bias_correction2) * param_rms.clamp(min=rms_eps)
|
|
||||||
|
|
||||||
delta.add_(this_delta, alpha=alpha)
|
delta.add_(this_delta, alpha=alpha)
|
||||||
# below, will do: delta.add_(this_delta, alpha=alpha)
|
|
||||||
else:
|
else:
|
||||||
# The full update.
|
# The full update.
|
||||||
conditioned_grad = self._precondition_grad(p, grad, state, beta3, max_size,
|
# First, if needed, accumulate per-dimension stats from the grad
|
||||||
stats_period, estimate_period,
|
if step % stats_period == 0:
|
||||||
eps, rms_eps)
|
self._accumulate_per_dim_stats(grad, state, beta3, eps)
|
||||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
|
||||||
grad_sq = (grad**2).mean()
|
|
||||||
conditioned_grad_sq = (conditioned_grad**2).mean()
|
|
||||||
prod = (grad*conditioned_grad).mean()
|
|
||||||
cos_angle = prod / (grad_sq * conditioned_grad_sq).sqrt()
|
|
||||||
if random.random() < 0.001 or cos_angle < 0.0075:
|
|
||||||
print(f"cos_angle = {cos_angle}, shape={grad.shape}")
|
|
||||||
|
|
||||||
#assert grad_sq - grad_sq == 0
|
if step % estimate_period == 0 or step in [50, 200, 400]:
|
||||||
#assert conditioned_grad_sq - conditioned_grad_sq == 0
|
self._estimate(p, state, beta3, max_size,
|
||||||
scalar_exp_avg_sq.mul_(beta2).add_(grad_sq, alpha=(1-beta2))
|
stats_period, estimate_period,
|
||||||
bias_correction2 = 1 - beta2 ** (step + 1)
|
eps, param_eps)
|
||||||
avg_grad_sq = scalar_exp_avg_sq / bias_correction2
|
|
||||||
scale = param_rms.clamp(min=rms_eps) * (grad_sq / ((avg_grad_sq * conditioned_grad_sq) + 1.0e-20)).sqrt()
|
ref_exp_avg_sq = state["ref_exp_avg_sq"] # computed in self._estimate()
|
||||||
alpha = -lr * (1-beta1) * scale
|
|
||||||
delta.add_(conditioned_grad, alpha=alpha)
|
# do ** 0.25, not ** 0.5, because we divide this into two factors:
|
||||||
|
# one to be applied before the grad preconditioning, and one after.
|
||||||
|
# This keeps the overall transformation on the gradient symmetric
|
||||||
|
# (if viewed as a matrix of shape (p.numel(), p.numel())
|
||||||
|
# Note: ref_exp_avg_sq had eps*eps added to it when we created it.
|
||||||
|
|
||||||
|
grad_scale = (ref_exp_avg_sq / (exp_avg_sq/bias_correction2 + eps*eps)) ** 0.25
|
||||||
|
|
||||||
|
#print(f"grad_scale mean = {grad_scale.mean()}, shape = {p.shape}")
|
||||||
|
|
||||||
|
cur_grad = grad * grad_scale
|
||||||
|
cur_grad = self._precondition_grad(cur_grad, state)
|
||||||
|
cur_grad *= grad_scale
|
||||||
|
|
||||||
|
if True: # testing
|
||||||
|
# in principle, the cur_grad is supposed to have the same rms as params, on average.
|
||||||
|
cur_grad_rms = (cur_grad**2).mean().sqrt()
|
||||||
|
param_rms = (p**2).mean().sqrt()
|
||||||
|
#print(f"cur_grad_rms={cur_grad_rms}, param_rms={param_rms}")
|
||||||
|
|
||||||
|
if random.random() < 0.1:
|
||||||
|
prod = (grad*cur_grad).mean()
|
||||||
|
cos_angle = prod / ((grad**2).mean() * (cur_grad**2).mean()).sqrt()
|
||||||
|
if random.random() < 0.01 or cos_angle < 0.0075:
|
||||||
|
print(f"cos_angle = {cos_angle}, shape={grad.shape}")
|
||||||
|
|
||||||
|
alpha = -lr * (1-beta1)
|
||||||
|
delta.add_(cur_grad, alpha=alpha)
|
||||||
else:
|
else:
|
||||||
# p.numel() == 1
|
# p.numel() == 1
|
||||||
# Decay the second moment running average coefficient
|
# Decay the second moment running average coefficient
|
||||||
@ -301,18 +315,135 @@ class NeutralGradient(Optimizer):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def _precondition_grad(self,
|
def _accumulate_per_dim_stats(self,
|
||||||
p: Tensor,
|
grad: Tensor,
|
||||||
grad: Tensor,
|
state: dict,
|
||||||
state: dict,
|
beta3: float,
|
||||||
beta3: float,
|
eps: float) -> None:
|
||||||
max_size: int,
|
"""
|
||||||
stats_period: int,
|
Accumulate some stats for each dimension of the tensor, about the covariance of gradients.
|
||||||
estimate_period: int,
|
"""
|
||||||
eps: float,
|
ndim = grad.ndim
|
||||||
rms_eps: float
|
for dim in range(ndim):
|
||||||
|
size = grad.shape[dim]
|
||||||
|
if size == 1:
|
||||||
|
continue
|
||||||
|
grad_cov = state[f"grad_cov_{dim}"]
|
||||||
|
if grad_cov.ndim == 1:
|
||||||
|
# We are treating this dimension diagonally because it is too big.
|
||||||
|
other_dims = [ i for i in range(ndim) if i != dim]
|
||||||
|
grad_cov.mul_(beta3).add_((grad**2).mean(dim=other_dims))
|
||||||
|
else:
|
||||||
|
# full-covariance stats.
|
||||||
|
this_beta3 = self._get_this_beta3(beta3, grad.numel(), size)
|
||||||
|
grad_cov.mul_(this_beta3).add_(self._get_cov(grad, dim),
|
||||||
|
alpha=(1.0-this_beta3))
|
||||||
|
|
||||||
|
def _estimate(self,
|
||||||
|
p: Tensor,
|
||||||
|
state: dict,
|
||||||
|
beta3: float,
|
||||||
|
max_size: int,
|
||||||
|
stats_period: int,
|
||||||
|
estimate_period: int,
|
||||||
|
eps: float,
|
||||||
|
param_eps: float
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
Estimate the per-dimension projections proj_0, proj_1 and so on, and
|
||||||
|
ref_exp_avg_sq.
|
||||||
|
"""
|
||||||
|
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||||
|
ref_exp_avg_sq = torch.ones([1] * p.ndim, **kwargs)
|
||||||
|
|
||||||
|
ndim = p.ndim
|
||||||
|
step = state["step"]
|
||||||
|
assert step % stats_period == 0
|
||||||
|
norm_step = step // stats_period
|
||||||
|
|
||||||
|
def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim):
|
||||||
|
"""
|
||||||
|
Updates ref_exp_avg_sq, taking into account `grad_cov_smoothed` which is
|
||||||
|
a tensor of shape (p.shape[dim],) containing the smoothed diagonal covariance
|
||||||
|
of grads for dimension `dim`.
|
||||||
|
Returns new ref_exp_avg_sq
|
||||||
|
"""
|
||||||
|
if dim > 0:
|
||||||
|
# We only want to "count" the overall scale of the gradients' variance
|
||||||
|
# (i.e. mean variance) once, so for all dims other than 0 we normalize
|
||||||
|
# this out.
|
||||||
|
grad_cov_smoothed = grad_cov_smoothed / grad_cov_smoothed.mean()
|
||||||
|
for _ in range(dim + 1, ndim):
|
||||||
|
grad_cov_smoothed = grad_cov_smoothed.unsqueeze(-1)
|
||||||
|
return ref_exp_avg_sq * grad_cov_smoothed
|
||||||
|
|
||||||
|
for dim in range(ndim):
|
||||||
|
size = p.shape[dim]
|
||||||
|
if size == 1:
|
||||||
|
continue
|
||||||
|
grad_cov = state[f"grad_cov_{dim}"]
|
||||||
|
proj = state[f"proj_{dim}"]
|
||||||
|
|
||||||
|
if grad_cov.ndim == 1:
|
||||||
|
# This dim is treated diagonally, because it is too large.
|
||||||
|
other_dims = [ i for i in range(ndim) if i != dim]
|
||||||
|
bias_correction3 = 1 - beta3 ** (norm_step + 1)
|
||||||
|
# _smoothed means we have the eps terms.
|
||||||
|
param_var_smoothed = (p**2).mean(dim=other_dims) + param_eps*param_eps
|
||||||
|
if bias_correction3 < 0.9999:
|
||||||
|
grad_cov = grad_cov / bias_correction3
|
||||||
|
grad_var_smoothed = grad_cov + eps*eps
|
||||||
|
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
|
||||||
|
grad_var_smoothed,
|
||||||
|
dim)
|
||||||
|
if dim != 0:
|
||||||
|
# If dim != 0, make the grad_var_smoothed and
|
||||||
|
# param_var_smoothed have the same mean, because when
|
||||||
|
# preconditioning gradients, we only want to count the
|
||||||
|
# overall change in scale once.
|
||||||
|
grad_var_smoothed *= (param_var_smoothed.sum() / grad_var_smoothed.sum())
|
||||||
|
proj[:] = (param_var_smoothed / grad_var_smoothed).sqrt()
|
||||||
|
else:
|
||||||
|
param_cov_smoothed = self._estimate_and_smooth_param_cov(p, dim,
|
||||||
|
param_eps)
|
||||||
|
grad_cov_smoothed = self._smooth_grad_cov(p, grad_cov,
|
||||||
|
eps, norm_step,
|
||||||
|
beta3)
|
||||||
|
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
|
||||||
|
grad_cov_smoothed.diag(),
|
||||||
|
dim)
|
||||||
|
|
||||||
|
|
||||||
|
if dim != 0:
|
||||||
|
# If dim != 0, make the traces of grad_cov_smoothed and
|
||||||
|
# param_cov_smoothed be the same, because when
|
||||||
|
# preconditioning gradients, we only want to count the
|
||||||
|
# overall change in scale once.
|
||||||
|
grad_cov_smoothed *= (param_cov_smoothed.diag().sum() /
|
||||||
|
grad_cov_smoothed.diag().sum())
|
||||||
|
|
||||||
|
proj[:] = self._estimate_proj(grad_cov_smoothed,
|
||||||
|
param_cov_smoothed)
|
||||||
|
|
||||||
|
state["ref_exp_avg_sq"][:] = ref_exp_avg_sq
|
||||||
|
|
||||||
|
def _get_this_beta3(self, beta3: float, numel: int, size: int):
|
||||||
|
"""
|
||||||
|
Get a discounting value beta3 (e.g. 0.99) that is used in computing moving
|
||||||
|
averages of stats. The idea is that the stats are accumulated from a tensor
|
||||||
|
x that has `numel` elements and `size == x.shape[dim]`, and we want the covariance
|
||||||
|
of shape (dim, dim). We may need to make beta3 closer to 1 if this means
|
||||||
|
we'll be getting only very rank-deficient stats each time.
|
||||||
|
"""
|
||||||
|
rank_per_iter = numel // size # maximum rank of each iteration's covaraince
|
||||||
|
safety_factor = 4.0 # should be > 1.0
|
||||||
|
grad_num_steps_needed = safety_factor * size / rank_per_iter
|
||||||
|
return max(beta3, 1 - 1. / grad_num_steps_needed)
|
||||||
|
|
||||||
|
def _precondition_grad(self,
|
||||||
|
grad: Tensor,
|
||||||
|
state: dict) -> Tensor:
|
||||||
|
"""
|
||||||
Multiply the grad by a positive-semidefinite matrix for each dimension, to
|
Multiply the grad by a positive-semidefinite matrix for each dimension, to
|
||||||
try to make its covariance the same as that of the parameters (up to a
|
try to make its covariance the same as that of the parameters (up to a
|
||||||
scalar constant).
|
scalar constant).
|
||||||
@ -329,67 +460,110 @@ class NeutralGradient(Optimizer):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
cur_grad = grad
|
cur_grad = grad
|
||||||
|
ndim = grad.ndim
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
for dim in range(p.ndim):
|
for dim in range(ndim):
|
||||||
size = p.shape[dim]
|
size = grad.shape[dim]
|
||||||
if size == 1:
|
if size == 1:
|
||||||
continue
|
continue
|
||||||
grad_cov = state[f"grad_cov_{dim}"]
|
|
||||||
proj = state[f"proj_{dim}"]
|
proj = state[f"proj_{dim}"]
|
||||||
|
|
||||||
if size > max_size:
|
if proj.ndim == 1:
|
||||||
# We are treating this dimension diagonally because it is too big.
|
# We are treating this dimension diagonally because it is too big.
|
||||||
if step % stats_period == 0 or step < 100:
|
proj = proj.reshape(size, *([1] * (ndim-1-dim)))
|
||||||
# for diagonal dims, both re-estimate stats and re-estimate
|
|
||||||
# transform every `stats_period`, as the re-estimation is
|
|
||||||
# easy.
|
|
||||||
other_dims = [ i for i in range(p.ndim) if i != dim]
|
|
||||||
grad_cov.mul_(beta3).add_((grad**2).mean(dim=other_dims))
|
|
||||||
# estimate param_var, the variance of the parameters.
|
|
||||||
param_var = (p**2).mean(dim=other_dims)
|
|
||||||
|
|
||||||
factor = ((param_var + rms_eps*rms_eps) / (grad_cov + eps*eps)).sqrt()
|
|
||||||
# normalize these factors to stop preconditioned-grad mean from getting
|
|
||||||
# very large or small
|
|
||||||
factor = factor / factor.mean()
|
|
||||||
proj[:] = factor
|
|
||||||
proj = proj.reshape(size, *([1] * (p.ndim-1-dim)))
|
|
||||||
cur_grad = cur_grad * proj
|
cur_grad = cur_grad * proj
|
||||||
else:
|
else:
|
||||||
# This dimension gets the full-covariance treatment.
|
|
||||||
grad_num_points = grad.numel() // size
|
|
||||||
grad_num_steps_needed = 4 * size / grad_num_points
|
|
||||||
|
|
||||||
if step % stats_period == 0 or step < 100:
|
|
||||||
# normally use beta3 for the decay of the decaying-average
|
|
||||||
# gradient stats, but we may use a value closer to 1 if we
|
|
||||||
# have to estimate a high-dimensional gradient from very
|
|
||||||
# low-rank covariance contributions.
|
|
||||||
this_beta3 = max(beta3, 1 - 1. / grad_num_steps_needed)
|
|
||||||
grad_cov.mul_(this_beta3).add_(self._get_cov(grad, dim),
|
|
||||||
alpha=(1.0-this_beta3))
|
|
||||||
if step % estimate_period == 0 or (step < estimate_period and step in [10,40]):
|
|
||||||
# Estimate the parameter covariance on this dimension,
|
|
||||||
# treating the other dimensions of the parameter as we would
|
|
||||||
# frames. Smooth with the diagonal because there are not
|
|
||||||
# many points to estimate the covariance from.
|
|
||||||
param_cov = self._get_cov(p, dim)
|
|
||||||
|
|
||||||
min_diag = 0.2
|
|
||||||
if True: # Smooth parameter
|
|
||||||
num_points = p.numel() // size
|
|
||||||
# later we may be able to find a more principled formula for this.
|
|
||||||
diag_smooth = max(min_diag, dim / (dim + num_points))
|
|
||||||
diag = param_cov.diag()
|
|
||||||
param_cov.mul_(1-diag_smooth).add_((diag*diag_smooth+ rms_eps*rms_eps).diag())
|
|
||||||
if True:
|
|
||||||
diag_smooth = min_diag # This is likely far from optimal!
|
|
||||||
diag = grad_cov.diag()
|
|
||||||
grad_cov = grad_cov * (1-diag_smooth) + (diag*diag_smooth + eps*eps).diag()
|
|
||||||
proj[:] = self._estimate_proj(grad_cov, param_cov)
|
|
||||||
cur_grad = self._multiply_on_dim(cur_grad, proj, dim)
|
cur_grad = self._multiply_on_dim(cur_grad, proj, dim)
|
||||||
return cur_grad
|
return cur_grad
|
||||||
|
|
||||||
|
def _estimate_and_smooth_param_cov(self, p: Tensor, dim: int,
|
||||||
|
param_eps: float,
|
||||||
|
cond_eps: float = 1.0e-10,
|
||||||
|
min_diag_smooth: float = 0.2) -> Tensor:
|
||||||
|
"""
|
||||||
|
Compute a smoothed version of a covariance matrix for one dimension of
|
||||||
|
a parameter tensor, estimated by treating the other
|
||||||
|
dimensions as a batch index.
|
||||||
|
|
||||||
|
p: The parameter tensor from which we are estimating the covariance
|
||||||
|
dim: The dimenion that we want the covariances for.
|
||||||
|
param_eps: A small epsilon value that represents the minimum root-mean-square
|
||||||
|
parameter value that we'll estimate.
|
||||||
|
cond_eps: An epsilon value that limits the condition number of the resulting
|
||||||
|
matrix. Caution: this applies to the variance, not the rms value,
|
||||||
|
so should be quite small.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A non-singular covariance matrix of shape (p.shape[dim], p.shape[dim])
|
||||||
|
"""
|
||||||
|
p = p.transpose(dim, -1)
|
||||||
|
p = p.reshape(-1, p.shape[-1])
|
||||||
|
(num_outer_products, size) = p.shape
|
||||||
|
param_cov = torch.matmul(p.t(), p) / p.shape[0]
|
||||||
|
|
||||||
|
# later we may be able to find a more principled formula for this.
|
||||||
|
diag_smooth = max(min_diag_smooth, size / (size + num_outer_products))
|
||||||
|
diag = param_cov.diag()
|
||||||
|
extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps +
|
||||||
|
param_eps * param_eps)
|
||||||
|
param_cov.mul_(1-diag_smooth).add_(extra_diag.diag())
|
||||||
|
return param_cov
|
||||||
|
|
||||||
|
def _smooth_grad_cov(self,
|
||||||
|
p: Tensor,
|
||||||
|
grad_cov: Tensor,
|
||||||
|
eps: float,
|
||||||
|
norm_step: int,
|
||||||
|
beta3: float,
|
||||||
|
cond_eps: float = 1.0e-10,
|
||||||
|
min_diag_smooth: float = 0.2) -> Tensor:
|
||||||
|
"""
|
||||||
|
Compute a smoothed version of a covariance matrix for one dimension of
|
||||||
|
a gradient tensor. The covariance stats are supplied; this function
|
||||||
|
is responsible for making sure it is nonsingular and smoothed with its
|
||||||
|
diagonal in a reasonably smart way.
|
||||||
|
|
||||||
|
p: The parameter tensor from which we are estimating the covariance
|
||||||
|
of the gradient (over one of its dimensions); only needed for the shape
|
||||||
|
grad_cov: The moving-average covariance statistics, to be smoothed
|
||||||
|
eps: An epsilon value that represents a root-mean-square gradient
|
||||||
|
magnitude, used to avoid zero values
|
||||||
|
norm_step: The step count that reflects how many times we updated
|
||||||
|
grad_cov, we assume we have updated it (norm_step + 1) times; this
|
||||||
|
is just step divided by stats_perio.
|
||||||
|
beta3: The user-supplied beta value for decaying the gradient covariance
|
||||||
|
stats
|
||||||
|
cond_eps: An epsilon value that limits the condition number of the resulting
|
||||||
|
matrix. Caution: this applies to the variance, not the rms value,
|
||||||
|
so should be quite small.
|
||||||
|
min_diag_smooth: A minimum proportion by which we smooth the covariance
|
||||||
|
with the diagonal.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A non-singular covariance matrix of gradients, of shape (p.shape[dim], p.shape[dim])
|
||||||
|
"""
|
||||||
|
size = grad_cov.shape[0]
|
||||||
|
this_beta3 = self._get_this_beta3(beta3, p.numel(), size)
|
||||||
|
bias_correction3 = 1 - this_beta3 ** (norm_step + 1)
|
||||||
|
if bias_correction3 < 0.9999:
|
||||||
|
grad_cov = grad_cov / bias_correction3
|
||||||
|
|
||||||
|
rank_per_iter = p.numel() // size # maximum rank of each iteration's covaraince
|
||||||
|
# the second part of the following formula roughly represents the number of
|
||||||
|
# frames that have a "large" weight in the stats.
|
||||||
|
num_iters_in_stats = min(norm_step + 1, 1.0 / (1 - beta3))
|
||||||
|
# grad_cov_rank is generally suposed to represent the number of
|
||||||
|
# individual outer products that are represented with significant weight
|
||||||
|
# in grad_cov
|
||||||
|
num_outer_products = rank_per_iter * num_iters_in_stats
|
||||||
|
diag_smooth = max(min_diag_smooth,
|
||||||
|
size / (size + num_outer_products))
|
||||||
|
diag = grad_cov.diag()
|
||||||
|
extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps +
|
||||||
|
eps * eps)
|
||||||
|
grad_cov.mul_(1-diag_smooth).add_(extra_diag.diag())
|
||||||
|
return grad_cov
|
||||||
|
|
||||||
def _get_cov(self, x: Tensor, dim: int) -> Tensor:
|
def _get_cov(self, x: Tensor, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Computes the covariance of x over dimension `dim`, returning something
|
Computes the covariance of x over dimension `dim`, returning something
|
||||||
@ -409,6 +583,18 @@ class NeutralGradient(Optimizer):
|
|||||||
return torch.matmul(x.transpose(-1, dim),
|
return torch.matmul(x.transpose(-1, dim),
|
||||||
proj).transpose(-1, dim)
|
proj).transpose(-1, dim)
|
||||||
|
|
||||||
|
def _multiply_on_dim_diag(self, x: Tensor, proj: Tensor, dim: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Matrix-multiplies x by `proj` on x's dimension numbered `dim`
|
||||||
|
Args:
|
||||||
|
x: Tensor to multiply, of arbitrary shape
|
||||||
|
proj: Symmetric matrix to multiply x by, of shape (x.shape[dim], x.shape[dim])
|
||||||
|
"""
|
||||||
|
return torch.matmul(x.transpose(-1, dim),
|
||||||
|
proj).transpose(-1, dim)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _estimate_proj(self, grad_cov: Tensor, param_cov: Tensor) -> Tensor:
|
def _estimate_proj(self, grad_cov: Tensor, param_cov: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Return a symmetric positive definite matrix P such that
|
Return a symmetric positive definite matrix P such that
|
||||||
@ -425,11 +611,12 @@ class NeutralGradient(Optimizer):
|
|||||||
param_cov: Covariance of parameters, a symmetric-positive-definite
|
param_cov: Covariance of parameters, a symmetric-positive-definite
|
||||||
matrix [except for roundoff!] of shape (size, size)
|
matrix [except for roundoff!] of shape (size, size)
|
||||||
Returns:
|
Returns:
|
||||||
A matrix P such that (alpha P grad_cov P^T == param_cov) for some
|
A matrix P such that (alpha P grad_cov P^T == param_cov)
|
||||||
real alpha, and such that P.diag().mean() == 1.
|
|
||||||
"""
|
"""
|
||||||
grad_cov = grad_cov + grad_cov.t()
|
# symmetrize grad_cov and param_cov. The projection P only cares about the
|
||||||
C = param_cov + param_cov.t() # we will normalize away any scalar factor.
|
# relative sizes, so doubling both of them does not matter.
|
||||||
|
G = grad_cov + grad_cov.t()
|
||||||
|
C = param_cov + param_cov.t()
|
||||||
U, S, V = C.svd()
|
U, S, V = C.svd()
|
||||||
|
|
||||||
# Using G == grad_cov and C == param_cov, and not writing transpose because
|
# Using G == grad_cov and C == param_cov, and not writing transpose because
|
||||||
@ -451,7 +638,7 @@ class NeutralGradient(Optimizer):
|
|||||||
|
|
||||||
# because C is symmetric, C == U S U^T, we can ignore V.
|
# because C is symmetric, C == U S U^T, we can ignore V.
|
||||||
Q = (U * S.sqrt()).t()
|
Q = (U * S.sqrt()).t()
|
||||||
X = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
|
X = torch.matmul(Q, torch.matmul(G, Q.t()))
|
||||||
|
|
||||||
U2, S2, _ = X.svd()
|
U2, S2, _ = X.svd()
|
||||||
# Because X^{-0.5} = U2 S2^{-0.5} U2^T, we have
|
# Because X^{-0.5} = U2 S2^{-0.5} U2^T, we have
|
||||||
@ -464,22 +651,16 @@ class NeutralGradient(Optimizer):
|
|||||||
|
|
||||||
P = torch.matmul(Y, Y.t())
|
P = torch.matmul(Y, Y.t())
|
||||||
|
|
||||||
P = P * (1.0 / P.diag().mean()) # normalize size of P.
|
|
||||||
|
|
||||||
if random.random() < 0.1:
|
if random.random() < 0.1:
|
||||||
#U, S, V = P.svd()
|
|
||||||
#if random.random() < 0.1:
|
|
||||||
# print(f"Min,max eig of P: {S.min().item()},{S.max().item()}")
|
|
||||||
|
|
||||||
# TODO: remove this testing code.
|
# TODO: remove this testing code.
|
||||||
assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric.
|
assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric.
|
||||||
# testing... note, this is only true modulo "eps"
|
# testing... note, this is only true modulo "eps"
|
||||||
|
C_check = torch.matmul(torch.matmul(P, G), P)
|
||||||
C_check = torch.matmul(torch.matmul(P, grad_cov), P)
|
# C_check should equal C
|
||||||
# C_check should equal C, up to a constant. First normalize both
|
C_diff = C_check - C
|
||||||
C_diff = (C_check / C_check.diag().mean()) - (C / C.diag().mean())
|
# Roundoff can cause significant differences, so use a fairly large
|
||||||
# actually, roundoff can cause significant differences.
|
# threshold of 0.001. We may increase this later or even remove the check.
|
||||||
assert C_diff.abs().mean() < 0.001
|
assert C_diff.abs().mean() < 0.01 * C.diag().mean()
|
||||||
|
|
||||||
return P
|
return P
|
||||||
|
|
||||||
@ -1112,6 +1293,7 @@ def _test_eve_cain():
|
|||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
|
fix_random_seed(42)
|
||||||
# these input_magnitudes and output_magnitudes are to test that
|
# these input_magnitudes and output_magnitudes are to test that
|
||||||
# Abel is working as we expect and is able to adjust scales of
|
# Abel is working as we expect and is able to adjust scales of
|
||||||
# different dims differently.
|
# different dims differently.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user