From af0fc31c78b88f211e55b506b156478bb7031d82 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Oct 2022 15:15:43 +0800 Subject: [PATCH] Introduce warmup schedule in optimizer --- .../ASR/pruned_transducer_stateless7/optim.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 0e5808369..8f13a67c5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -646,10 +646,14 @@ class LRScheduler(object): class Eden(LRScheduler): """ Eden scheduler. - lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) + The basic formula (before warmup) is: + lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. - E.g. suggest initial-lr = 0.003 (passed to optimizer). + + E.g. suggest initial-lr = 0.04 (passed to optimizer) if used with ScaledAdam Args: optimizer: the optimizer to change the learning rates on @@ -666,11 +670,13 @@ class Eden(LRScheduler): optimizer: Optimizer, lr_batches: Union[int, float], lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, verbose: bool = False, ): super(Eden, self).__init__(optimizer, verbose) self.lr_batches = lr_batches self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches def get_lr(self): factor = ( @@ -679,7 +685,10 @@ class Eden(LRScheduler): ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) ** -0.25 ) - return [x * factor for x in self.base_lrs] + warmup_factor = (1.0 if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches)) + + return [x * factor * warmup_factor for x in self.base_lrs] def _test_eden():