Use exp_avg_sq as ref_exp_avg_sq

This commit is contained in:
Daniel Povey 2022-06-15 14:03:26 +08:00
parent 860322bf30
commit f9bb745cf3

View File

@ -49,7 +49,7 @@ class NeutralGradient(Optimizer):
param_eps: An epsilon on the rms value of the parameter, such that when the parameter param_eps: An epsilon on the rms value of the parameter, such that when the parameter
gets smaller than this we'll start using a fixed learning rate, not gets smaller than this we'll start using a fixed learning rate, not
decreasing with the parameter norm. decreasing with the parameter norm.
cond_eps: An epsilon that limits the condition number of gradient and parameter rel_eps: An epsilon that limits the condition number of gradient and parameter
covariance matrices covariance matrices
param_max: To prevent parameter tensors getting too large, we will clip elements to param_max: To prevent parameter tensors getting too large, we will clip elements to
-param_max..param_max. -param_max..param_max.
@ -69,12 +69,12 @@ class NeutralGradient(Optimizer):
scale_speed=0.1, scale_speed=0.1,
eps=1e-8, eps=1e-8,
param_eps=1.0e-05, param_eps=1.0e-05,
cond_eps=1.0e-10, rel_eps=1.0e-10,
param_max=10.0, param_max=10.0,
min_diag_smooth=1.0, min_diag_smooth=0.2,
max_size=1023, max_size=1023,
stats_period=1, stats_period=1,
estimate_period=200, estimate_period=50,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
@ -111,7 +111,7 @@ class NeutralGradient(Optimizer):
eps=eps, eps=eps,
scale_speed=scale_speed, scale_speed=scale_speed,
param_eps=param_eps, param_eps=param_eps,
cond_eps=cond_eps, rel_eps=rel_eps,
min_diag_smooth=min_diag_smooth, min_diag_smooth=min_diag_smooth,
param_max=param_max, param_max=param_max,
max_size=max_size, max_size=max_size,
@ -143,7 +143,7 @@ class NeutralGradient(Optimizer):
scale_speed = group["scale_speed"] scale_speed = group["scale_speed"]
param_eps = group["param_eps"] param_eps = group["param_eps"]
eps = group["eps"] eps = group["eps"]
cond_eps = group["cond_eps"] rel_eps = group["rel_eps"]
min_diag_smooth = group["min_diag_smooth"] min_diag_smooth = group["min_diag_smooth"]
param_max = group["param_max"] param_max = group["param_max"]
max_size = group["max_size"] max_size = group["max_size"]
@ -267,11 +267,11 @@ class NeutralGradient(Optimizer):
self._estimate(p, state, beta3, max_size, self._estimate(p, state, beta3, max_size,
stats_period, estimate_period, stats_period, estimate_period,
eps, param_eps, eps, param_eps,
cond_eps, min_diag_smooth) rel_eps, min_diag_smooth)
# TEMP!! Override the setting inside _estimate. # TEMP!! Override the setting inside _estimate.
#state["ref_exp_avg_sq"][:] = ((exp_avg_sq/bias_correction2 + eps*eps) * state["ref_exp_avg_sq"][:] = (exp_avg_sq/bias_correction2 + eps*eps)
# state["ref_exp_avg_sq"]).sqrt()
ref_exp_avg_sq = state["ref_exp_avg_sq"] # computed in self._estimate() ref_exp_avg_sq = state["ref_exp_avg_sq"] # computed in self._estimate()
@ -288,7 +288,7 @@ class NeutralGradient(Optimizer):
cur_grad = grad cur_grad = grad
cur_grad = cur_grad * grad_scale cur_grad = cur_grad * grad_scale
cur_grad = self._precondition_grad(cur_grad, state) cur_grad = self._multiply_grad(cur_grad, state)
cur_grad *= grad_scale cur_grad *= grad_scale
if random.random() < 0.004: if random.random() < 0.004:
@ -348,11 +348,16 @@ class NeutralGradient(Optimizer):
grad_cov = state[f"grad_cov_{dim}"] grad_cov = state[f"grad_cov_{dim}"]
if grad_cov.ndim == 1: if grad_cov.ndim == 1:
# We are treating this dimension diagonally because it is too big. # We are treating this dimension diagonally because it is too big.
# note: other_dims is nonempty because we know ndim != 1 when we get
# here. dim=[] to torch mean() does not work as you would expect.
other_dims = [ i for i in range(ndim) if i != dim] other_dims = [ i for i in range(ndim) if i != dim]
grad_cov.mul_(beta3).add_((grad**2).mean(dim=other_dims)) grad_cov.mul_(beta3).add_((grad**2).mean(dim=other_dims),
alpha=(1.0-beta3))
else: else:
# full-covariance stats. # full-covariance stats.
this_beta3 = self._get_this_beta3(beta3, grad.numel(), size) this_beta3 = self._get_this_beta3(beta3, grad.numel(), size)
if this_beta3 != beta3 and random.random() < 0.01:
print("this_beta3=", this_beta3)
grad_cov.mul_(this_beta3).add_(self._get_cov(grad, dim), grad_cov.mul_(this_beta3).add_(self._get_cov(grad, dim),
alpha=(1.0-this_beta3)) alpha=(1.0-this_beta3))
@ -365,7 +370,7 @@ class NeutralGradient(Optimizer):
estimate_period: int, estimate_period: int,
eps: float, eps: float,
param_eps: float, param_eps: float,
cond_eps: float, rel_eps: float,
min_diag_smooth: float min_diag_smooth: float
) -> Tensor: ) -> Tensor:
""" """
@ -379,6 +384,7 @@ class NeutralGradient(Optimizer):
step = state["step"] step = state["step"]
assert step % stats_period == 0 assert step % stats_period == 0
norm_step = step // stats_period norm_step = step // stats_period
param_rel_eps = 1.0e-04
scale_change = True scale_change = True
@ -411,13 +417,25 @@ class NeutralGradient(Optimizer):
other_dims = [ i for i in range(ndim) if i != dim] other_dims = [ i for i in range(ndim) if i != dim]
bias_correction3 = 1 - beta3 ** (norm_step + 1) bias_correction3 = 1 - beta3 ** (norm_step + 1)
# _smoothed means we have the eps terms. # _smoothed means we have the eps terms.
param_var_smoothed = (p**2).mean(dim=other_dims) + param_eps*param_eps param_var_smoothed = (p**2).mean(dim=other_dims)
param_var_smoothed.add_(param_eps*param_eps + param_rel_eps * param_var_smoothed.mean())
if bias_correction3 < 0.9999: if bias_correction3 < 0.9999:
grad_cov = grad_cov / bias_correction3 grad_cov = grad_cov / bias_correction3
grad_var_smoothed = grad_cov + eps*eps grad_var_smoothed = grad_cov + (eps*eps + rel_eps * grad_cov.mean())
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq, ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
grad_var_smoothed, grad_var_smoothed,
dim, scale_change) dim, scale_change)
if True:
def check_close(a, b):
assert ((a-b).abs().sum() < 0.01 * (a.abs()+b.abs()).sum())
param_cov_smoothed = self._estimate_and_smooth_param_cov(p, dim,
param_eps,
rel_eps=param_rel_eps,
min_diag_smooth=min_diag_smooth)
check_close(param_var_smoothed, param_cov_smoothed.diag())
if scale_change: if scale_change:
scale_change = False # only use the scale change once scale_change = False # only use the scale change once
else: else:
@ -427,15 +445,16 @@ class NeutralGradient(Optimizer):
# to count the overall change in scale once. # to count the overall change in scale once.
grad_var_smoothed *= (param_var_smoothed.sum() / grad_var_smoothed.sum()) grad_var_smoothed *= (param_var_smoothed.sum() / grad_var_smoothed.sum())
proj[:] = (param_var_smoothed / grad_var_smoothed).sqrt() proj[:] = (param_var_smoothed / grad_var_smoothed).sqrt()
else: else:
param_cov_smoothed = self._estimate_and_smooth_param_cov(p, dim, param_cov_smoothed = self._estimate_and_smooth_param_cov(p, dim,
param_eps, param_eps,
cond_eps=1.0e-04, rel_eps=param_rel_eps,
min_diag_smooth=min_diag_smooth) min_diag_smooth=min_diag_smooth)
grad_cov_smoothed = self._smooth_grad_cov(p, grad_cov, grad_cov_smoothed = self._smooth_grad_cov(p, grad_cov,
eps, norm_step, eps, norm_step+100000, # TEMp
beta3, beta3,
cond_eps=cond_eps, rel_eps=rel_eps,
min_diag_smooth=min_diag_smooth) min_diag_smooth=min_diag_smooth)
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq, ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
grad_cov_smoothed.diag(), grad_cov_smoothed.diag(),
@ -473,9 +492,9 @@ class NeutralGradient(Optimizer):
return ans return ans
def _precondition_grad(self, def _multiply_grad(self,
grad: Tensor, grad: Tensor,
state: dict) -> Tensor: state: dict) -> Tensor:
""" """
Multiply the grad by a positive-semidefinite matrix for each dimension, to Multiply the grad by a positive-semidefinite matrix for each dimension, to
try to make its covariance the same as that of the parameters (up to a try to make its covariance the same as that of the parameters (up to a
@ -511,7 +530,7 @@ class NeutralGradient(Optimizer):
def _estimate_and_smooth_param_cov(self, p: Tensor, dim: int, def _estimate_and_smooth_param_cov(self, p: Tensor, dim: int,
param_eps: float, param_eps: float,
cond_eps: float = 1.0e-10, rel_eps: float = 1.0e-10,
min_diag_smooth: float = 0.2) -> Tensor: min_diag_smooth: float = 0.2) -> Tensor:
""" """
Compute a smoothed version of a covariance matrix for one dimension of Compute a smoothed version of a covariance matrix for one dimension of
@ -522,7 +541,7 @@ class NeutralGradient(Optimizer):
dim: The dimenion that we want the covariances for. dim: The dimenion that we want the covariances for.
param_eps: A small epsilon value that represents the minimum root-mean-square param_eps: A small epsilon value that represents the minimum root-mean-square
parameter value that we'll estimate. parameter value that we'll estimate.
cond_eps: An epsilon value that limits the condition number of the resulting rel_eps: An epsilon value that limits the condition number of the resulting
matrix. Caution: this applies to the variance, not the rms value, matrix. Caution: this applies to the variance, not the rms value,
so should be quite small. so should be quite small.
@ -540,7 +559,7 @@ class NeutralGradient(Optimizer):
#diag_smooth = min_diag_smooth #diag_smooth = min_diag_smooth
diag_smooth = 0.4 diag_smooth = 0.4
diag = param_cov.diag() diag = param_cov.diag()
extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps + extra_diag = (diag * diag_smooth) + (diag.mean() * rel_eps +
param_eps * param_eps) param_eps * param_eps)
param_cov.mul_(1-diag_smooth).add_(extra_diag.diag()) param_cov.mul_(1-diag_smooth).add_(extra_diag.diag())
return param_cov return param_cov
@ -551,7 +570,7 @@ class NeutralGradient(Optimizer):
eps: float, eps: float,
norm_step: int, norm_step: int,
beta3: float, beta3: float,
cond_eps: float = 1.0e-10, rel_eps: float = 1.0e-10,
min_diag_smooth: float = 0.2) -> Tensor: min_diag_smooth: float = 0.2) -> Tensor:
""" """
Compute a smoothed version of a covariance matrix for one dimension of Compute a smoothed version of a covariance matrix for one dimension of
@ -569,7 +588,7 @@ class NeutralGradient(Optimizer):
is just step divided by stats_perio. is just step divided by stats_perio.
beta3: The user-supplied beta value for decaying the gradient covariance beta3: The user-supplied beta value for decaying the gradient covariance
stats stats
cond_eps: An epsilon value that limits the condition number of the resulting rel_eps: An epsilon value that limits the condition number of the resulting
matrix. Caution: this applies to the variance, not the rms value, matrix. Caution: this applies to the variance, not the rms value,
so should be quite small. so should be quite small.
min_diag_smooth: A minimum proportion by which we smooth the covariance min_diag_smooth: A minimum proportion by which we smooth the covariance
@ -584,7 +603,7 @@ class NeutralGradient(Optimizer):
if bias_correction3 < 0.9999: if bias_correction3 < 0.9999:
grad_cov = grad_cov / bias_correction3 grad_cov = grad_cov / bias_correction3
rank_per_iter = p.numel() // size # maximum rank of each iteration's covaraince rank_per_iter = p.numel() // size # maximum rank of each iteration's covariance
# the second part of the following formula roughly represents the number of # the second part of the following formula roughly represents the number of
# frames that have a "large" weight in the stats. # frames that have a "large" weight in the stats.
num_iters_in_stats = min(norm_step + 1, 1.0 / (1 - beta3)) num_iters_in_stats = min(norm_step + 1, 1.0 / (1 - beta3))
@ -598,8 +617,8 @@ class NeutralGradient(Optimizer):
print(f"grad diag_smooth = {diag_smooth}, shape={p.shape}") print(f"grad diag_smooth = {diag_smooth}, shape={p.shape}")
diag = grad_cov.diag() diag = grad_cov.diag()
extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps + extra_diag = (diag * diag_smooth).add_(diag.mean() * rel_eps +
eps * eps) eps * eps)
grad_cov = (grad_cov * (1-diag_smooth)).add_(extra_diag.diag()) grad_cov = (grad_cov * (1-diag_smooth)).add_(extra_diag.diag())
return grad_cov return grad_cov