From 1078e4878cca664ea2d48994110f75945484b0a2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Sep 2021 14:19:01 +0800 Subject: [PATCH] Add 1/sqrt(t) factor to gloam --- egs/librispeech/ASR/conformer_ctc/madam.py | 35 ++++------------------ egs/librispeech/ASR/conformer_ctc/train.py | 4 +-- 2 files changed, 8 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/madam.py b/egs/librispeech/ASR/conformer_ctc/madam.py index a480a4aba..1cf0d322f 100644 --- a/egs/librispeech/ASR/conformer_ctc/madam.py +++ b/egs/librispeech/ASR/conformer_ctc/madam.py @@ -960,7 +960,7 @@ class Gloam(object): Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups - warm_step: number of warmup steps before the learning rate starts to decay + warm_step: number of warmup steps before the learning rate starts to decrease (it increases until this point). max_lrate: The learning rate at its maximum, on step `warm_step` first_decay_epoch: The epoch number on which to start decreasing the @@ -1028,39 +1028,16 @@ class Gloam(object): We define 't = s / warm_step', i.e. t is the step s, normalized so that it is 1.0 at warm_step. Our formula for the learning rate as a function of t is: - rate = max_lrate * (t <= 1.0 ? t : - sqrt((2 + alpha) / (1 + t + alpha t^2))) - where alpha is chosen so that the 't' and 'alpha t^2' terms are identical - at t == knee_factor (this means alpha = 1.0/knee_factor). So the - learning rate increases linearly from t=00 to t=1, and decreases - after that. You can see - that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1, - which is why the line and the curve meet at that point. + base_rate = max_lrate * (t <= 1.0 ? t : t ** -0.5) + epoch_rate = [starts at 1.0 but from first_decay_epoch, start decreasing it + by a factor of decay_per_epoch] + rate = base_rate * epoch_rate - On the denominator of that ratio, the "t" term makes it decrease a - bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term - makes it decrease a bit like 1/t for t > warm_step; and the "1" - term makes it decrease a bit slower than 1/sqrt(t) when t is quite - close to 1.0 (so we linger a little, near the maximum learning rate). - - This learning rate schedule ultimately decreases more aggressively - than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we - feel this will work better in conjunction with Madam, is that Madam - keeps the norms of the parameters approximately constant throughout - training; whereas with Noam, if there is no weight decay, these - norms tend to increase as training progresses (although rather - unevenly across different parameter tensors). - As the norms of the parameters increase, the relative changes - in parameters get smaller (the step sizes don't change because - Adam normalizes the gradient magnitudes; they'd get smaller otherwise). - So Noam doesn't have to decrease the learning rate too aggressively - because even with a fixed learning rate, the effective learning rate - would be decreasing (again, this only applies without weight decay). """ if step is None: step = self._step t = step / self._warm_step # floating point division.. t is the normalized step. - base_rate = self._max_lrate * (t if t <= 1.0 else 1.0) + base_rate = self._max_lrate * (t if t <= 1.0 else t ** -0.5) epoch_rate = self._decay_per_epoch ** max(0, self._epoch + 1 - self._first_decay_epoch) return base_rate * epoch_rate diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 918ce9c0d..4afe23215 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -150,7 +150,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp_gloam_2e-4_0.85"), + "exp_dir": Path("conformer_ctc/exp_gloam_5e-4_0.85"), "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "subsampling_factor": 4, @@ -173,7 +173,7 @@ def get_params() -> AttributeDict: "is_espnet_structure": True, "mmi_loss": False, "use_feat_batchnorm": True, - "max_lrate": 2.0e-04, + "max_lrate": 5.0e-04, "first_decay_epoch": 1, "decay_per_epoch": 0.85, "warm_step": 40000,