Fix bug for NewBobScheduler

This commit is contained in:
Yifan Yang 2023-06-02 23:42:57 +08:00
parent 4f152327c6
commit c526e958e5

View File

@ -156,7 +156,6 @@ class NewBobScheduler(LRScheduler):
Args:
metric: A number for determining whether to change the lr value.
"""
factor = 1
if self.prev_metric is not None:
if self.prev_metric == 0:
improvement = 0
@ -166,14 +165,14 @@ class NewBobScheduler(LRScheduler):
) / self.prev_metric
if improvement < self.threshold:
if self.current_patient == 0:
factor = self.annealing_factor
self.base_lrs *= self.annealing_factor
self.current_patient = self.patient
else:
self.current_patient -= 1
self.prev_metric = self.current_metric
return [x * factor for x in self.base_lrs]
return self.base_lrs
def state_dict(self):
return {