Some changes to algorithm; more diagnostics printing

This commit is contained in:
Daniel Povey 2022-05-15 14:12:23 +08:00
parent 67c402a369
commit 6306f24430

View File

@ -214,7 +214,8 @@ class Abel(Optimizer):
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-08).
target_rms (float, optional): target root-mean-square value of
x_norm in factorization (conceptually).
x_norm in factorization (conceptually). Actually this now just becomes
a factor in the learning rate, we may remove it at some point.
.. _Adam\: A Method for Stochastic Optimization:
@ -245,7 +246,7 @@ class Abel(Optimizer):
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 < target_rms <= 10.0:
if not 0 < target_rms <= 1.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(
lr=lr,
@ -314,7 +315,8 @@ class Abel(Optimizer):
torch.zeros(*shape, dtype=p.dtype, device=p.device)
for shape in _get_factor_shapes(list(p.shape))
]
state["factorization"] = _init_factorization(p, eps)
# we call _init_factorization inside step().
state["factorization"] = torch.zeros_like(p)
exp_avg_sq = state["exp_avg_sq"]
delta, exp_avg_sq = state["delta"], state["exp_avg_sq"]
@ -350,15 +352,32 @@ class Abel(Optimizer):
factor_grads = _get_factor_grads(p, grad)
if step < 10 or step % 10 == 1:
# do this only every 10 steps, to save time.
# update the factorization this only every 10 steps, to save time.
if step % 1000 == 0:
# Periodically refresh the factorization; this is
# out of concern that after a large number of
# multiplications, roundoff could cause it to drift
# significantly away from being a product of
# independent factors, one per dimension.
factorization[:] = _init_factorization(p, eps)
_update_factorization(p, factorization,
speed=0.1,
eps=eps)
factors_sum = None
numel = p.numel()
for g, e in zip(factor_grads, factors_exp_avg_sq):
this_numel = numel / g.numel()
update_exp_avg_sq(g, e)
this_denom = (e/bias_correction2 + eps*eps).sqrt()
# this_numel**0.5 will turn into this_numel**-0.25 when we sqrt() and
# divide by denom. This becomes a factor in the learning rate.
# The technique naturally makes us learn faster by (this_numel**0.5)
# in any of the parameter directions corresponding to rows or columns,
# and because we don't want to be to aggressive and risk instability or
# excess param noise, we reduce this to (this_numel**0.25) by dividing
# the effective LR by this_numel**0.25.
this_denom = (e*(this_numel**0.5)/bias_correction2 + eps*eps).sqrt()
assert g.shape == this_denom.shape
factor_delta = g / this_denom
factors_sum = (factor_delta if factors_sum is None
@ -371,18 +390,7 @@ class Abel(Optimizer):
# from updating the factors (before considering momentum!)
# We multiply by target_rms*target_rms here to have the same
# effect as multiplying by target_rms after the sqrt, so that later
# when we divide by denom it's like dividing the learning rate by
# target_rms. This is to simulate the effect of keeping `factorization`
# updated in such a way that x_norm had an average root-mean-square
# value of `target_rms` instead of 1.0; this would require
# `factorization` being multiplied of `1/target_rms` below when setting
# `delta`. We do it this way to save a tensor operation.
# We are very lazy about when to add `eps`, just doing it whenever is
# most convenient, since this is really not going to be critical.
denom = (exp_avg_sq*target_rms*target_rms/bias_correction2 +
denom = (exp_avg_sq/bias_correction2 +
eps*eps).sqrt()
# `grad * factorization / denom` represents the change in parameters x, only considering
@ -398,9 +406,8 @@ class Abel(Optimizer):
# rate or momentum.
this_delta = ((grad * factorization / denom) + p * factors_sum)
# compute the moving-average change in parameters, and add it to p.
delta.mul_(beta1).add_(this_delta, alpha=-lr*(1-beta1))
lr_eff = lr * (bias_correction2 ** 0.5) / target_rms
delta.mul_(beta1).add_(this_delta, alpha=-lr_eff*(1-beta1))
p.add_(delta)
@ -749,7 +756,7 @@ def _test_abel():
# Abel is working as we expect and is able to adjust scales of
# different dims differently.
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
output_magnitudes = (0.0 * torch.randn(E, dtype=dtype, device=device)).exp()
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
for iter in [0,1]:
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
@ -800,8 +807,10 @@ def _test_abel():
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
print("Un-normalized 2-output col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
print("Un-normalized 2-input col magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=0).sqrt().log()))
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=1).sqrt().log()))
print("Normalized output row magnitudes log-stddev: ", stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log()))