mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Some changes to algorithm; more diagnostics printing
This commit is contained in:
parent
67c402a369
commit
6306f24430
@ -214,7 +214,8 @@ class Abel(Optimizer):
|
|||||||
eps (float, optional): term added to the denominator to improve
|
eps (float, optional): term added to the denominator to improve
|
||||||
numerical stability (default: 1e-08).
|
numerical stability (default: 1e-08).
|
||||||
target_rms (float, optional): target root-mean-square value of
|
target_rms (float, optional): target root-mean-square value of
|
||||||
x_norm in factorization (conceptually).
|
x_norm in factorization (conceptually). Actually this now just becomes
|
||||||
|
a factor in the learning rate, we may remove it at some point.
|
||||||
|
|
||||||
|
|
||||||
.. _Adam\: A Method for Stochastic Optimization:
|
.. _Adam\: A Method for Stochastic Optimization:
|
||||||
@ -245,7 +246,7 @@ class Abel(Optimizer):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||||
)
|
)
|
||||||
if not 0 < target_rms <= 10.0:
|
if not 0 < target_rms <= 1.0:
|
||||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
@ -314,7 +315,8 @@ class Abel(Optimizer):
|
|||||||
torch.zeros(*shape, dtype=p.dtype, device=p.device)
|
torch.zeros(*shape, dtype=p.dtype, device=p.device)
|
||||||
for shape in _get_factor_shapes(list(p.shape))
|
for shape in _get_factor_shapes(list(p.shape))
|
||||||
]
|
]
|
||||||
state["factorization"] = _init_factorization(p, eps)
|
# we call _init_factorization inside step().
|
||||||
|
state["factorization"] = torch.zeros_like(p)
|
||||||
|
|
||||||
exp_avg_sq = state["exp_avg_sq"]
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
delta, exp_avg_sq = state["delta"], state["exp_avg_sq"]
|
delta, exp_avg_sq = state["delta"], state["exp_avg_sq"]
|
||||||
@ -350,15 +352,32 @@ class Abel(Optimizer):
|
|||||||
factor_grads = _get_factor_grads(p, grad)
|
factor_grads = _get_factor_grads(p, grad)
|
||||||
|
|
||||||
if step < 10 or step % 10 == 1:
|
if step < 10 or step % 10 == 1:
|
||||||
# do this only every 10 steps, to save time.
|
# update the factorization this only every 10 steps, to save time.
|
||||||
|
if step % 1000 == 0:
|
||||||
|
# Periodically refresh the factorization; this is
|
||||||
|
# out of concern that after a large number of
|
||||||
|
# multiplications, roundoff could cause it to drift
|
||||||
|
# significantly away from being a product of
|
||||||
|
# independent factors, one per dimension.
|
||||||
|
factorization[:] = _init_factorization(p, eps)
|
||||||
_update_factorization(p, factorization,
|
_update_factorization(p, factorization,
|
||||||
speed=0.1,
|
speed=0.1,
|
||||||
eps=eps)
|
eps=eps)
|
||||||
|
|
||||||
|
|
||||||
factors_sum = None
|
factors_sum = None
|
||||||
|
numel = p.numel()
|
||||||
for g, e in zip(factor_grads, factors_exp_avg_sq):
|
for g, e in zip(factor_grads, factors_exp_avg_sq):
|
||||||
|
this_numel = numel / g.numel()
|
||||||
update_exp_avg_sq(g, e)
|
update_exp_avg_sq(g, e)
|
||||||
this_denom = (e/bias_correction2 + eps*eps).sqrt()
|
# this_numel**0.5 will turn into this_numel**-0.25 when we sqrt() and
|
||||||
|
# divide by denom. This becomes a factor in the learning rate.
|
||||||
|
# The technique naturally makes us learn faster by (this_numel**0.5)
|
||||||
|
# in any of the parameter directions corresponding to rows or columns,
|
||||||
|
# and because we don't want to be to aggressive and risk instability or
|
||||||
|
# excess param noise, we reduce this to (this_numel**0.25) by dividing
|
||||||
|
# the effective LR by this_numel**0.25.
|
||||||
|
this_denom = (e*(this_numel**0.5)/bias_correction2 + eps*eps).sqrt()
|
||||||
assert g.shape == this_denom.shape
|
assert g.shape == this_denom.shape
|
||||||
factor_delta = g / this_denom
|
factor_delta = g / this_denom
|
||||||
factors_sum = (factor_delta if factors_sum is None
|
factors_sum = (factor_delta if factors_sum is None
|
||||||
@ -371,18 +390,7 @@ class Abel(Optimizer):
|
|||||||
# from updating the factors (before considering momentum!)
|
# from updating the factors (before considering momentum!)
|
||||||
|
|
||||||
|
|
||||||
# We multiply by target_rms*target_rms here to have the same
|
denom = (exp_avg_sq/bias_correction2 +
|
||||||
# effect as multiplying by target_rms after the sqrt, so that later
|
|
||||||
# when we divide by denom it's like dividing the learning rate by
|
|
||||||
# target_rms. This is to simulate the effect of keeping `factorization`
|
|
||||||
# updated in such a way that x_norm had an average root-mean-square
|
|
||||||
# value of `target_rms` instead of 1.0; this would require
|
|
||||||
# `factorization` being multiplied of `1/target_rms` below when setting
|
|
||||||
# `delta`. We do it this way to save a tensor operation.
|
|
||||||
# We are very lazy about when to add `eps`, just doing it whenever is
|
|
||||||
# most convenient, since this is really not going to be critical.
|
|
||||||
|
|
||||||
denom = (exp_avg_sq*target_rms*target_rms/bias_correction2 +
|
|
||||||
eps*eps).sqrt()
|
eps*eps).sqrt()
|
||||||
|
|
||||||
# `grad * factorization / denom` represents the change in parameters x, only considering
|
# `grad * factorization / denom` represents the change in parameters x, only considering
|
||||||
@ -398,9 +406,8 @@ class Abel(Optimizer):
|
|||||||
# rate or momentum.
|
# rate or momentum.
|
||||||
this_delta = ((grad * factorization / denom) + p * factors_sum)
|
this_delta = ((grad * factorization / denom) + p * factors_sum)
|
||||||
|
|
||||||
|
lr_eff = lr * (bias_correction2 ** 0.5) / target_rms
|
||||||
# compute the moving-average change in parameters, and add it to p.
|
delta.mul_(beta1).add_(this_delta, alpha=-lr_eff*(1-beta1))
|
||||||
delta.mul_(beta1).add_(this_delta, alpha=-lr*(1-beta1))
|
|
||||||
|
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
|
|
||||||
@ -749,7 +756,7 @@ def _test_abel():
|
|||||||
# Abel is working as we expect and is able to adjust scales of
|
# Abel is working as we expect and is able to adjust scales of
|
||||||
# different dims differently.
|
# different dims differently.
|
||||||
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||||
output_magnitudes = (0.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||||
|
|
||||||
for iter in [0,1]:
|
for iter in [0,1]:
|
||||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||||
@ -800,8 +807,10 @@ def _test_abel():
|
|||||||
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
|
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
|
||||||
|
|
||||||
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
||||||
print("Un-normalized 2-output col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
|
print("Un-normalized 2-input col magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=0).sqrt().log()))
|
||||||
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=1).sqrt().log()))
|
||||||
|
|
||||||
|
print("Normalized output row magnitudes log-stddev: ", stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log()))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user