mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Some changes to algorithm; more diagnostics printing
This commit is contained in:
parent
67c402a369
commit
6306f24430
@ -214,7 +214,8 @@ class Abel(Optimizer):
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-08).
|
||||
target_rms (float, optional): target root-mean-square value of
|
||||
x_norm in factorization (conceptually).
|
||||
x_norm in factorization (conceptually). Actually this now just becomes
|
||||
a factor in the learning rate, we may remove it at some point.
|
||||
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
@ -245,7 +246,7 @@ class Abel(Optimizer):
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||
)
|
||||
if not 0 < target_rms <= 10.0:
|
||||
if not 0 < target_rms <= 1.0:
|
||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
@ -314,7 +315,8 @@ class Abel(Optimizer):
|
||||
torch.zeros(*shape, dtype=p.dtype, device=p.device)
|
||||
for shape in _get_factor_shapes(list(p.shape))
|
||||
]
|
||||
state["factorization"] = _init_factorization(p, eps)
|
||||
# we call _init_factorization inside step().
|
||||
state["factorization"] = torch.zeros_like(p)
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
delta, exp_avg_sq = state["delta"], state["exp_avg_sq"]
|
||||
@ -350,15 +352,32 @@ class Abel(Optimizer):
|
||||
factor_grads = _get_factor_grads(p, grad)
|
||||
|
||||
if step < 10 or step % 10 == 1:
|
||||
# do this only every 10 steps, to save time.
|
||||
# update the factorization this only every 10 steps, to save time.
|
||||
if step % 1000 == 0:
|
||||
# Periodically refresh the factorization; this is
|
||||
# out of concern that after a large number of
|
||||
# multiplications, roundoff could cause it to drift
|
||||
# significantly away from being a product of
|
||||
# independent factors, one per dimension.
|
||||
factorization[:] = _init_factorization(p, eps)
|
||||
_update_factorization(p, factorization,
|
||||
speed=0.1,
|
||||
eps=eps)
|
||||
|
||||
|
||||
factors_sum = None
|
||||
numel = p.numel()
|
||||
for g, e in zip(factor_grads, factors_exp_avg_sq):
|
||||
this_numel = numel / g.numel()
|
||||
update_exp_avg_sq(g, e)
|
||||
this_denom = (e/bias_correction2 + eps*eps).sqrt()
|
||||
# this_numel**0.5 will turn into this_numel**-0.25 when we sqrt() and
|
||||
# divide by denom. This becomes a factor in the learning rate.
|
||||
# The technique naturally makes us learn faster by (this_numel**0.5)
|
||||
# in any of the parameter directions corresponding to rows or columns,
|
||||
# and because we don't want to be to aggressive and risk instability or
|
||||
# excess param noise, we reduce this to (this_numel**0.25) by dividing
|
||||
# the effective LR by this_numel**0.25.
|
||||
this_denom = (e*(this_numel**0.5)/bias_correction2 + eps*eps).sqrt()
|
||||
assert g.shape == this_denom.shape
|
||||
factor_delta = g / this_denom
|
||||
factors_sum = (factor_delta if factors_sum is None
|
||||
@ -371,18 +390,7 @@ class Abel(Optimizer):
|
||||
# from updating the factors (before considering momentum!)
|
||||
|
||||
|
||||
# We multiply by target_rms*target_rms here to have the same
|
||||
# effect as multiplying by target_rms after the sqrt, so that later
|
||||
# when we divide by denom it's like dividing the learning rate by
|
||||
# target_rms. This is to simulate the effect of keeping `factorization`
|
||||
# updated in such a way that x_norm had an average root-mean-square
|
||||
# value of `target_rms` instead of 1.0; this would require
|
||||
# `factorization` being multiplied of `1/target_rms` below when setting
|
||||
# `delta`. We do it this way to save a tensor operation.
|
||||
# We are very lazy about when to add `eps`, just doing it whenever is
|
||||
# most convenient, since this is really not going to be critical.
|
||||
|
||||
denom = (exp_avg_sq*target_rms*target_rms/bias_correction2 +
|
||||
denom = (exp_avg_sq/bias_correction2 +
|
||||
eps*eps).sqrt()
|
||||
|
||||
# `grad * factorization / denom` represents the change in parameters x, only considering
|
||||
@ -398,9 +406,8 @@ class Abel(Optimizer):
|
||||
# rate or momentum.
|
||||
this_delta = ((grad * factorization / denom) + p * factors_sum)
|
||||
|
||||
|
||||
# compute the moving-average change in parameters, and add it to p.
|
||||
delta.mul_(beta1).add_(this_delta, alpha=-lr*(1-beta1))
|
||||
lr_eff = lr * (bias_correction2 ** 0.5) / target_rms
|
||||
delta.mul_(beta1).add_(this_delta, alpha=-lr_eff*(1-beta1))
|
||||
|
||||
p.add_(delta)
|
||||
|
||||
@ -749,7 +756,7 @@ def _test_abel():
|
||||
# Abel is working as we expect and is able to adjust scales of
|
||||
# different dims differently.
|
||||
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
output_magnitudes = (0.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
|
||||
for iter in [0,1]:
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
@ -800,8 +807,10 @@ def _test_abel():
|
||||
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
|
||||
|
||||
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
||||
print("Un-normalized 2-output col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
|
||||
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
||||
print("Un-normalized 2-input col magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=0).sqrt().log()))
|
||||
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=1).sqrt().log()))
|
||||
|
||||
print("Normalized output row magnitudes log-stddev: ", stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log()))
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user