mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 03:52:18 +00:00
Change to exp_5, 1/sqrt(t) component.
This commit is contained in:
parent
56a88badd1
commit
d0e5b9b8a5
@ -1028,39 +1028,16 @@ class Gloam(object):
|
||||
We define 't = s / warm_step', i.e. t is the step s, normalized so that it
|
||||
is 1.0 at warm_step. Our formula for the learning rate as a function of
|
||||
t is:
|
||||
rate = max_lrate * (t <= 1.0 ? t :
|
||||
sqrt((2 + alpha) / (1 + t + alpha t^2)))
|
||||
where alpha is chosen so that the 't' and 'alpha t^2' terms are identical
|
||||
at t == knee_factor (this means alpha = 1.0/knee_factor). So the
|
||||
learning rate increases linearly from t=00 to t=1, and decreases
|
||||
after that. You can see
|
||||
that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1,
|
||||
which is why the line and the curve meet at that point.
|
||||
base_rate = max_lrate * (t <= 1.0 ? t : t ** -0.5)
|
||||
epoch_rate = [starts at 1.0 but from first_decrease_epoch, start decreasing it
|
||||
by a factor of decay_per_epoch]
|
||||
rate = base_rate * epoch_rate
|
||||
|
||||
On the denominator of that ratio, the "t" term makes it decrease a
|
||||
bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term
|
||||
makes it decrease a bit like 1/t for t > warm_step; and the "1"
|
||||
term makes it decrease a bit slower than 1/sqrt(t) when t is quite
|
||||
close to 1.0 (so we linger a little, near the maximum learning rate).
|
||||
|
||||
This learning rate schedule ultimately decreases more aggressively
|
||||
than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we
|
||||
feel this will work better in conjunction with Madam, is that Madam
|
||||
keeps the norms of the parameters approximately constant throughout
|
||||
training; whereas with Noam, if there is no weight decay, these
|
||||
norms tend to increase as training progresses (although rather
|
||||
unevenly across different parameter tensors).
|
||||
As the norms of the parameters increase, the relative changes
|
||||
in parameters get smaller (the step sizes don't change because
|
||||
Adam normalizes the gradient magnitudes; they'd get smaller otherwise).
|
||||
So Noam doesn't have to decrease the learning rate too aggressively
|
||||
because even with a fixed learning rate, the effective learning rate
|
||||
would be decreasing (again, this only applies without weight decay).
|
||||
"""
|
||||
if step is None:
|
||||
step = self._step
|
||||
t = step / self._warm_step # floating point division.. t is the normalized step.
|
||||
base_rate = self._max_lrate * (t if t <= 1.0 else 1.0)
|
||||
base_rate = self._max_lrate * (t if t <= 1.0 else t ** -0.5)
|
||||
epoch_rate = self._decay_per_epoch ** max(0, self._epoch + 1 - self._first_decrease_epoch)
|
||||
return base_rate * epoch_rate
|
||||
|
||||
|
@ -134,7 +134,9 @@ def get_params() -> AttributeDict:
|
||||
{
|
||||
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
|
||||
# exp_4, vs. exp_3, is using the Gloam optimizer with
|
||||
"exp_dir": Path("conformer_lm/exp_4"),
|
||||
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
|
||||
# as well as the exponential part.
|
||||
"exp_dir": Path("conformer_lm/exp_5"),
|
||||
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
||||
"num_tokens": 5000,
|
||||
"blank_sym": 0,
|
||||
@ -526,7 +528,7 @@ def run(rank, world_size, args):
|
||||
optimizer = Gloam(
|
||||
model.parameters(),
|
||||
max_lrate=params.max_lrate,
|
||||
first_decrease_epoch=2,
|
||||
first_decrease_epoch=1,
|
||||
decay_per_epoch=0.85
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user