Add 1/sqrt(t) factor to gloam

This commit is contained in:
Daniel Povey 2021-09-09 14:19:01 +08:00
parent c810e67342
commit 1078e4878c
2 changed files with 8 additions and 31 deletions

View File

@ -960,7 +960,7 @@ class Gloam(object):
Args: Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups params (iterable): iterable of parameters to optimize or dicts defining parameter groups
warm_step: number of warmup steps before the learning rate starts to decay warm_step: number of warmup steps before the learning rate starts to decrease
(it increases until this point). (it increases until this point).
max_lrate: The learning rate at its maximum, on step `warm_step` max_lrate: The learning rate at its maximum, on step `warm_step`
first_decay_epoch: The epoch number on which to start decreasing the first_decay_epoch: The epoch number on which to start decreasing the
@ -1028,39 +1028,16 @@ class Gloam(object):
We define 't = s / warm_step', i.e. t is the step s, normalized so that it We define 't = s / warm_step', i.e. t is the step s, normalized so that it
is 1.0 at warm_step. Our formula for the learning rate as a function of is 1.0 at warm_step. Our formula for the learning rate as a function of
t is: t is:
rate = max_lrate * (t <= 1.0 ? t : base_rate = max_lrate * (t <= 1.0 ? t : t ** -0.5)
sqrt((2 + alpha) / (1 + t + alpha t^2))) epoch_rate = [starts at 1.0 but from first_decay_epoch, start decreasing it
where alpha is chosen so that the 't' and 'alpha t^2' terms are identical by a factor of decay_per_epoch]
at t == knee_factor (this means alpha = 1.0/knee_factor). So the rate = base_rate * epoch_rate
learning rate increases linearly from t=00 to t=1, and decreases
after that. You can see
that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1,
which is why the line and the curve meet at that point.
On the denominator of that ratio, the "t" term makes it decrease a
bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term
makes it decrease a bit like 1/t for t > warm_step; and the "1"
term makes it decrease a bit slower than 1/sqrt(t) when t is quite
close to 1.0 (so we linger a little, near the maximum learning rate).
This learning rate schedule ultimately decreases more aggressively
than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we
feel this will work better in conjunction with Madam, is that Madam
keeps the norms of the parameters approximately constant throughout
training; whereas with Noam, if there is no weight decay, these
norms tend to increase as training progresses (although rather
unevenly across different parameter tensors).
As the norms of the parameters increase, the relative changes
in parameters get smaller (the step sizes don't change because
Adam normalizes the gradient magnitudes; they'd get smaller otherwise).
So Noam doesn't have to decrease the learning rate too aggressively
because even with a fixed learning rate, the effective learning rate
would be decreasing (again, this only applies without weight decay).
""" """
if step is None: if step is None:
step = self._step step = self._step
t = step / self._warm_step # floating point division.. t is the normalized step. t = step / self._warm_step # floating point division.. t is the normalized step.
base_rate = self._max_lrate * (t if t <= 1.0 else 1.0) base_rate = self._max_lrate * (t if t <= 1.0 else t ** -0.5)
epoch_rate = self._decay_per_epoch ** max(0, self._epoch + 1 - self._first_decay_epoch) epoch_rate = self._decay_per_epoch ** max(0, self._epoch + 1 - self._first_decay_epoch)
return base_rate * epoch_rate return base_rate * epoch_rate

View File

@ -150,7 +150,7 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp_gloam_2e-4_0.85"), "exp_dir": Path("conformer_ctc/exp_gloam_5e-4_0.85"),
"lang_dir": Path("data/lang_bpe"), "lang_dir": Path("data/lang_bpe"),
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
@ -173,7 +173,7 @@ def get_params() -> AttributeDict:
"is_espnet_structure": True, "is_espnet_structure": True,
"mmi_loss": False, "mmi_loss": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"max_lrate": 2.0e-04, "max_lrate": 5.0e-04,
"first_decay_epoch": 1, "first_decay_epoch": 1,
"decay_per_epoch": 0.85, "decay_per_epoch": 0.85,
"warm_step": 40000, "warm_step": 40000,