mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Some code improvements...
This commit is contained in:
parent
f9bb745cf3
commit
94e48ff07e
@ -49,8 +49,8 @@ class NeutralGradient(Optimizer):
|
||||
param_eps: An epsilon on the rms value of the parameter, such that when the parameter
|
||||
gets smaller than this we'll start using a fixed learning rate, not
|
||||
decreasing with the parameter norm.
|
||||
rel_eps: An epsilon that limits the condition number of gradient and parameter
|
||||
covariance matrices
|
||||
rel_eps: An epsilon that limits the condition number of gradient covariance matrices
|
||||
param_rel_eps: An epsilon that limits the condition number of parameter covariance matrices
|
||||
param_max: To prevent parameter tensors getting too large, we will clip elements to
|
||||
-param_max..param_max.
|
||||
max_size: We will only use the full-covariance (not-diagonal) update for
|
||||
@ -67,11 +67,13 @@ class NeutralGradient(Optimizer):
|
||||
lr=1e-2,
|
||||
betas=(0.9, 0.98, 0.98),
|
||||
scale_speed=0.1,
|
||||
eps=1e-8,
|
||||
grad_eps=1e-8,
|
||||
param_eps=1.0e-05,
|
||||
rel_eps=1.0e-10,
|
||||
grad_rel_eps=1.0e-10,
|
||||
param_rel_eps=1.0e-04,
|
||||
param_max=10.0,
|
||||
min_diag_smooth=0.2,
|
||||
grad_diag_smooth=0.2,
|
||||
param_diag_smooth=0.4,
|
||||
max_size=1023,
|
||||
stats_period=1,
|
||||
estimate_period=50,
|
||||
@ -79,8 +81,8 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= grad_eps:
|
||||
raise ValueError("Invalid grad_eps value: {}".format(grad_eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 0: {}".format(betas[0])
|
||||
@ -108,11 +110,13 @@ class NeutralGradient(Optimizer):
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
scale_speed=scale_speed,
|
||||
grad_eps=grad_eps,
|
||||
param_eps=param_eps,
|
||||
rel_eps=rel_eps,
|
||||
min_diag_smooth=min_diag_smooth,
|
||||
grad_rel_eps=grad_rel_eps,
|
||||
param_rel_eps=param_rel_eps,
|
||||
grad_diag_smooth=grad_diag_smooth,
|
||||
param_diag_smooth=param_diag_smooth,
|
||||
param_max=param_max,
|
||||
max_size=max_size,
|
||||
stats_period=stats_period,
|
||||
@ -142,9 +146,11 @@ class NeutralGradient(Optimizer):
|
||||
lr = group["lr"]
|
||||
scale_speed = group["scale_speed"]
|
||||
param_eps = group["param_eps"]
|
||||
eps = group["eps"]
|
||||
rel_eps = group["rel_eps"]
|
||||
min_diag_smooth = group["min_diag_smooth"]
|
||||
grad_eps = group["grad_eps"]
|
||||
grad_rel_eps = group["grad_rel_eps"]
|
||||
param_rel_eps = group["param_rel_eps"]
|
||||
grad_diag_smooth = group["grad_diag_smooth"]
|
||||
param_diag_smooth = group["param_diag_smooth"]
|
||||
param_max = group["param_max"]
|
||||
max_size = group["max_size"]
|
||||
stats_period = group["stats_period"]
|
||||
@ -229,7 +235,7 @@ class NeutralGradient(Optimizer):
|
||||
# anything here. May fix this at some point.
|
||||
scale_bias_correction2 = 1 - beta2 ** (step + 1)
|
||||
|
||||
scale_denom = (scale_exp_avg_sq.sqrt()).add_(eps)
|
||||
scale_denom = (scale_exp_avg_sq.sqrt()).add_(grad_eps)
|
||||
|
||||
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1)
|
||||
|
||||
@ -253,7 +259,7 @@ class NeutralGradient(Optimizer):
|
||||
if step % 10 == 0:
|
||||
param_rms.copy_((p**2).mean().sqrt())
|
||||
|
||||
grad_rms = (exp_avg_sq.sqrt()).add_(eps)
|
||||
grad_rms = (exp_avg_sq.sqrt()).add_(grad_eps)
|
||||
this_delta = grad / grad_rms
|
||||
alpha = (-(1-beta1) * lr * bias_correction2) * param_rms.clamp(min=param_eps)
|
||||
delta.add_(this_delta, alpha=alpha)
|
||||
@ -261,16 +267,16 @@ class NeutralGradient(Optimizer):
|
||||
# The full update.
|
||||
# First, if needed, accumulate per-dimension stats from the grad
|
||||
if step % stats_period == 0:
|
||||
self._accumulate_per_dim_stats(grad, state, beta3, eps)
|
||||
self._accumulate_per_dim_stats(grad, state, beta3, grad_eps)
|
||||
|
||||
if step % estimate_period == 0 or step in [25, 50, 200, 400]:
|
||||
self._estimate(p, state, beta3, max_size,
|
||||
stats_period, estimate_period,
|
||||
eps, param_eps,
|
||||
rel_eps, min_diag_smooth)
|
||||
grad_eps, param_eps,
|
||||
grad_rel_eps, param_rel_eps,
|
||||
grad_diag_smooth, param_diag_smooth)
|
||||
|
||||
# TEMP!! Override the setting inside _estimate.
|
||||
state["ref_exp_avg_sq"][:] = (exp_avg_sq/bias_correction2 + eps*eps)
|
||||
state["ref_exp_avg_sq"][:] = (exp_avg_sq/bias_correction2 + grad_eps*grad_eps)
|
||||
|
||||
|
||||
ref_exp_avg_sq = state["ref_exp_avg_sq"] # computed in self._estimate()
|
||||
@ -281,7 +287,8 @@ class NeutralGradient(Optimizer):
|
||||
# (if viewed as a matrix of shape (p.numel(), p.numel())
|
||||
# Note: ref_exp_avg_sq had eps*eps added to it when we created it.
|
||||
|
||||
grad_scale = (ref_exp_avg_sq / (exp_avg_sq/bias_correction2 + eps*eps)) ** 0.25
|
||||
grad_scale = (ref_exp_avg_sq / (exp_avg_sq/bias_correction2 +
|
||||
grad_eps*grad_eps)) ** 0.25
|
||||
|
||||
#if random.random() < 0.02:
|
||||
# print(f"grad_scale mean = {grad_scale.mean().item():.43}, shape = {p.shape}")
|
||||
@ -368,14 +375,15 @@ class NeutralGradient(Optimizer):
|
||||
max_size: int,
|
||||
stats_period: int,
|
||||
estimate_period: int,
|
||||
eps: float,
|
||||
grad_eps: float,
|
||||
param_eps: float,
|
||||
rel_eps: float,
|
||||
min_diag_smooth: float
|
||||
) -> Tensor:
|
||||
grad_rel_eps: float,
|
||||
param_rel_eps: float,
|
||||
grad_diag_smooth: float,
|
||||
param_diag_smooth: float
|
||||
) -> None:
|
||||
"""
|
||||
Estimate the per-dimension projections proj_0, proj_1 and so on, and
|
||||
ref_exp_avg_sq.
|
||||
Estimate the per-dimension projections proj_0, proj_1 and so on
|
||||
"""
|
||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||
ref_exp_avg_sq = torch.ones([1] * p.ndim, **kwargs)
|
||||
@ -384,27 +392,9 @@ class NeutralGradient(Optimizer):
|
||||
step = state["step"]
|
||||
assert step % stats_period == 0
|
||||
norm_step = step // stats_period
|
||||
param_rel_eps = 1.0e-04
|
||||
|
||||
scale_change = True
|
||||
|
||||
def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim,
|
||||
scale_change):
|
||||
"""
|
||||
Updates ref_exp_avg_sq, taking into account `grad_cov_smoothed` which is
|
||||
a tensor of shape (p.shape[dim],) containing the smoothed diagonal covariance
|
||||
of grads for dimension `dim`.
|
||||
Returns new ref_exp_avg_sq
|
||||
"""
|
||||
if not scale_change:
|
||||
# We only want to "count" the overall scale of the gradients' variance
|
||||
# (i.e. mean variance) once, so for all dims other than 0 we normalize
|
||||
# this out.
|
||||
grad_cov_smoothed = grad_cov_smoothed / grad_cov_smoothed.mean()
|
||||
for _ in range(dim + 1, ndim):
|
||||
grad_cov_smoothed = grad_cov_smoothed.unsqueeze(-1)
|
||||
return ref_exp_avg_sq * grad_cov_smoothed
|
||||
|
||||
for dim in range(ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
@ -421,20 +411,8 @@ class NeutralGradient(Optimizer):
|
||||
param_var_smoothed.add_(param_eps*param_eps + param_rel_eps * param_var_smoothed.mean())
|
||||
if bias_correction3 < 0.9999:
|
||||
grad_cov = grad_cov / bias_correction3
|
||||
grad_var_smoothed = grad_cov + (eps*eps + rel_eps * grad_cov.mean())
|
||||
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
|
||||
grad_var_smoothed,
|
||||
dim, scale_change)
|
||||
|
||||
if True:
|
||||
def check_close(a, b):
|
||||
assert ((a-b).abs().sum() < 0.01 * (a.abs()+b.abs()).sum())
|
||||
param_cov_smoothed = self._estimate_and_smooth_param_cov(p, dim,
|
||||
param_eps,
|
||||
rel_eps=param_rel_eps,
|
||||
min_diag_smooth=min_diag_smooth)
|
||||
check_close(param_var_smoothed, param_cov_smoothed.diag())
|
||||
|
||||
grad_var_smoothed = grad_cov + (grad_eps*grad_eps +
|
||||
grad_rel_eps * grad_cov.mean())
|
||||
|
||||
if scale_change:
|
||||
scale_change = False # only use the scale change once
|
||||
@ -449,16 +427,13 @@ class NeutralGradient(Optimizer):
|
||||
else:
|
||||
param_cov_smoothed = self._estimate_and_smooth_param_cov(p, dim,
|
||||
param_eps,
|
||||
rel_eps=param_rel_eps,
|
||||
min_diag_smooth=min_diag_smooth)
|
||||
param_rel_eps=param_rel_eps,
|
||||
param_diag_smooth=param_diag_smooth)
|
||||
grad_cov_smoothed = self._smooth_grad_cov(p, grad_cov,
|
||||
eps, norm_step+100000, # TEMp
|
||||
grad_eps, norm_step,
|
||||
beta3,
|
||||
rel_eps=rel_eps,
|
||||
min_diag_smooth=min_diag_smooth)
|
||||
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
|
||||
grad_cov_smoothed.diag(),
|
||||
dim, scale_change)
|
||||
grad_rel_eps=grad_rel_eps,
|
||||
grad_diag_smooth=grad_diag_smooth)
|
||||
|
||||
if scale_change:
|
||||
scale_change = False # only use the scale change once
|
||||
@ -473,7 +448,6 @@ class NeutralGradient(Optimizer):
|
||||
proj[:] = self._estimate_proj(grad_cov_smoothed,
|
||||
param_cov_smoothed)
|
||||
|
||||
state["ref_exp_avg_sq"][:] = ref_exp_avg_sq + eps*eps
|
||||
|
||||
def _get_this_beta3(self, beta3: float, numel: int, size: int):
|
||||
"""
|
||||
@ -530,8 +504,8 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
def _estimate_and_smooth_param_cov(self, p: Tensor, dim: int,
|
||||
param_eps: float,
|
||||
rel_eps: float = 1.0e-10,
|
||||
min_diag_smooth: float = 0.2) -> Tensor:
|
||||
param_rel_eps: float = 1.0e-10,
|
||||
param_diag_smooth: float = 0.4) -> Tensor:
|
||||
"""
|
||||
Compute a smoothed version of a covariance matrix for one dimension of
|
||||
a parameter tensor, estimated by treating the other
|
||||
@ -541,7 +515,7 @@ class NeutralGradient(Optimizer):
|
||||
dim: The dimenion that we want the covariances for.
|
||||
param_eps: A small epsilon value that represents the minimum root-mean-square
|
||||
parameter value that we'll estimate.
|
||||
rel_eps: An epsilon value that limits the condition number of the resulting
|
||||
param_rel_eps: An epsilon value that limits the condition number of the resulting
|
||||
matrix. Caution: this applies to the variance, not the rms value,
|
||||
so should be quite small.
|
||||
|
||||
@ -556,10 +530,12 @@ class NeutralGradient(Optimizer):
|
||||
# later we may be able to find a more principled formula for this.
|
||||
#if random.random() < 0.2:
|
||||
# print(f"param diag_smooth = {diag_smooth}, shape={p.shape}")
|
||||
#diag_smooth = min_diag_smooth
|
||||
diag_smooth = 0.4
|
||||
|
||||
diag_smooth = max(param_diag_smooth,
|
||||
size / (size + num_outer_products))
|
||||
|
||||
diag = param_cov.diag()
|
||||
extra_diag = (diag * diag_smooth) + (diag.mean() * rel_eps +
|
||||
extra_diag = (diag * diag_smooth) + (diag.mean() * param_rel_eps +
|
||||
param_eps * param_eps)
|
||||
param_cov.mul_(1-diag_smooth).add_(extra_diag.diag())
|
||||
return param_cov
|
||||
@ -567,11 +543,11 @@ class NeutralGradient(Optimizer):
|
||||
def _smooth_grad_cov(self,
|
||||
p: Tensor,
|
||||
grad_cov: Tensor,
|
||||
eps: float,
|
||||
grad_eps: float,
|
||||
norm_step: int,
|
||||
beta3: float,
|
||||
rel_eps: float = 1.0e-10,
|
||||
min_diag_smooth: float = 0.2) -> Tensor:
|
||||
grad_rel_eps: float = 1.0e-10,
|
||||
grad_diag_smooth: float = 0.2) -> Tensor:
|
||||
"""
|
||||
Compute a smoothed version of a covariance matrix for one dimension of
|
||||
a gradient tensor. The covariance stats are supplied; this function
|
||||
@ -581,17 +557,17 @@ class NeutralGradient(Optimizer):
|
||||
p: The parameter tensor from which we are estimating the covariance
|
||||
of the gradient (over one of its dimensions); only needed for the shape
|
||||
grad_cov: The moving-average covariance statistics, to be smoothed
|
||||
eps: An epsilon value that represents a root-mean-square gradient
|
||||
grad_eps: An epsilon value that represents a root-mean-square gradient
|
||||
magnitude, used to avoid zero values
|
||||
norm_step: The step count that reflects how many times we updated
|
||||
grad_cov, we assume we have updated it (norm_step + 1) times; this
|
||||
is just step divided by stats_perio.
|
||||
beta3: The user-supplied beta value for decaying the gradient covariance
|
||||
stats
|
||||
rel_eps: An epsilon value that limits the condition number of the resulting
|
||||
grad_rel_eps: An epsilon value that limits the condition number of the resulting
|
||||
matrix. Caution: this applies to the variance, not the rms value,
|
||||
so should be quite small.
|
||||
min_diag_smooth: A minimum proportion by which we smooth the covariance
|
||||
grad_diag_smooth: A minimum proportion by which we smooth the covariance
|
||||
with the diagonal.
|
||||
|
||||
Returns:
|
||||
@ -611,14 +587,14 @@ class NeutralGradient(Optimizer):
|
||||
# individual outer products that are represented with significant weight
|
||||
# in grad_cov
|
||||
num_outer_products = rank_per_iter * num_iters_in_stats
|
||||
diag_smooth = max(min_diag_smooth,
|
||||
diag_smooth = max(grad_diag_smooth,
|
||||
size / (size + num_outer_products))
|
||||
if random.random() < 0.5:
|
||||
print(f"grad diag_smooth = {diag_smooth}, shape={p.shape}")
|
||||
|
||||
diag = grad_cov.diag()
|
||||
extra_diag = (diag * diag_smooth).add_(diag.mean() * rel_eps +
|
||||
eps * eps)
|
||||
extra_diag = (diag * diag_smooth).add_(diag.mean() * grad_rel_eps +
|
||||
grad_eps * grad_eps)
|
||||
grad_cov = (grad_cov * (1-diag_smooth)).add_(extra_diag.diag())
|
||||
return grad_cov
|
||||
|
||||
@ -1370,6 +1346,7 @@ def _test_eve_cain():
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
avg_loss = 0.0
|
||||
for epoch in range(150):
|
||||
scheduler.step_epoch()
|
||||
|
||||
@ -1383,6 +1360,7 @@ def _test_eve_cain():
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||
if n == 0 and epoch % 10 == 0:
|
||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||
@ -1392,7 +1370,7 @@ def _test_eve_cain():
|
||||
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||
print(f"Iter {iter}, epoch {epoch}, batch {n}, loss {loss.item()}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
loss.log().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
Loading…
x
Reference in New Issue
Block a user