2024-10-06 19:07:07 +08:00

142 lines
3.8 KiB
Python

import math
from bisect import bisect_right
from torch.optim.lr_scheduler import _LRScheduler
class WarmupLrScheduler(_LRScheduler):
def __init__(
self,
optimizer,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.warmup_epoch = warmup_epoch
self.warmup_ratio = warmup_ratio
self.warmup = warmup
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
ratio = self.get_lr_ratio()
lrs = [ratio * lr for lr in self.base_lrs]
return lrs
def get_lr_ratio(self):
if self.last_epoch < self.warmup_epoch:
ratio = self.get_warmup_ratio()
else:
ratio = self.get_main_ratio()
return ratio
def get_main_ratio(self):
raise NotImplementedError
def get_warmup_ratio(self):
assert self.warmup in ("linear", "exp")
alpha = self.last_epoch / self.warmup_epoch
if self.warmup == "linear":
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
elif self.warmup == "exp":
ratio = self.warmup_ratio ** (1.0 - alpha)
return ratio
class WarmupPolyLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
power,
max_epoch,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.power = power
self.max_epoch = max_epoch
super(WarmupPolyLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
real_max_epoch = self.max_epoch - self.warmup_epoch
alpha = real_epoch / real_max_epoch
ratio = (1 - alpha) ** self.power
return ratio
class WarmupExpLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
gamma,
interval=1,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.gamma = gamma
self.interval = interval
super(WarmupExpLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** (real_epoch // self.interval)
return ratio
class WarmupCosineLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
max_epoch,
eta_ratio=0,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.eta_ratio = eta_ratio
self.max_epoch = max_epoch
super(WarmupCosineLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_max_epoch = self.max_epoch - self.warmup_epoch
return (
self.eta_ratio
+ (1 - self.eta_ratio)
* (1 + math.cos(math.pi * self.last_epoch / real_max_epoch))
/ 2
)
class WarmupStepLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
milestones: list,
gamma=0.1,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.milestones = milestones
self.gamma = gamma
super(WarmupStepLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** bisect_right(self.milestones, real_epoch)
return ratio