163 lines
6.0 KiB
Python
163 lines
6.0 KiB
Python
# This file is modified from https://github.com/traveller59/second.pytorch
|
|
|
|
import math
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch.optim.lr_scheduler as lr_sched
|
|
|
|
from .fastai_optim import OptimWrapper
|
|
|
|
|
|
class LRSchedulerStep(object):
|
|
def __init__(self, fai_optimizer: OptimWrapper, total_step, lr_phases,
|
|
mom_phases):
|
|
# if not isinstance(fai_optimizer, OptimWrapper):
|
|
# raise TypeError('{} is not a fastai OptimWrapper'.format(
|
|
# type(fai_optimizer).__name__))
|
|
self.optimizer = fai_optimizer
|
|
self.total_step = total_step
|
|
self.lr_phases = []
|
|
|
|
for i, (start, lambda_func) in enumerate(lr_phases):
|
|
if len(self.lr_phases) != 0:
|
|
assert self.lr_phases[-1][0] < start
|
|
if isinstance(lambda_func, str):
|
|
lambda_func = eval(lambda_func)
|
|
if i < len(lr_phases) - 1:
|
|
self.lr_phases.append((int(start * total_step), int(lr_phases[i + 1][0] * total_step), lambda_func))
|
|
else:
|
|
self.lr_phases.append((int(start * total_step), total_step, lambda_func))
|
|
assert self.lr_phases[0][0] == 0
|
|
self.mom_phases = []
|
|
for i, (start, lambda_func) in enumerate(mom_phases):
|
|
if len(self.mom_phases) != 0:
|
|
assert self.mom_phases[-1][0] < start
|
|
if isinstance(lambda_func, str):
|
|
lambda_func = eval(lambda_func)
|
|
if i < len(mom_phases) - 1:
|
|
self.mom_phases.append((int(start * total_step), int(mom_phases[i + 1][0] * total_step), lambda_func))
|
|
else:
|
|
self.mom_phases.append((int(start * total_step), total_step, lambda_func))
|
|
assert self.mom_phases[0][0] == 0
|
|
|
|
def step(self, step, epoch=None):
|
|
for start, end, func in self.lr_phases:
|
|
if step >= start:
|
|
self.optimizer.lr = func((step - start) / (end - start))
|
|
for start, end, func in self.mom_phases:
|
|
if step >= start:
|
|
self.optimizer.mom = func((step - start) / (end - start))
|
|
|
|
|
|
def annealing_cos(start, end, pct):
|
|
# print(pct, start, end)
|
|
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
|
cos_out = np.cos(np.pi * pct) + 1
|
|
return end + (start - end) / 2 * cos_out
|
|
|
|
|
|
class OneCycle(LRSchedulerStep):
|
|
def __init__(self, fai_optimizer, total_step, lr_max, moms, div_factor,
|
|
pct_start):
|
|
self.lr_max = lr_max
|
|
self.moms = moms
|
|
self.div_factor = div_factor
|
|
self.pct_start = pct_start
|
|
a1 = int(total_step * self.pct_start)
|
|
a2 = total_step - a1
|
|
low_lr = self.lr_max / self.div_factor
|
|
lr_phases = ((0, partial(annealing_cos, low_lr, self.lr_max)),
|
|
(self.pct_start,
|
|
partial(annealing_cos, self.lr_max, low_lr / 1e4)))
|
|
mom_phases = ((0, partial(annealing_cos, *self.moms)),
|
|
(self.pct_start, partial(annealing_cos,
|
|
*self.moms[::-1])))
|
|
fai_optimizer.lr, fai_optimizer.mom = low_lr, self.moms[0]
|
|
super().__init__(fai_optimizer, total_step, lr_phases, mom_phases)
|
|
|
|
|
|
class CosineWarmupLR(lr_sched._LRScheduler):
|
|
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
|
|
self.T_max = T_max
|
|
self.eta_min = eta_min
|
|
super(CosineWarmupLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self, epoch=None):
|
|
return [self.eta_min + (base_lr - self.eta_min) *
|
|
(1 - math.cos(math.pi * self.last_epoch / self.T_max)) / 2
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
def linear_warmup(end, lr_max, pct):
|
|
k = (1 - pct / end) * (1 - 0.33333333)
|
|
warmup_lr = lr_max * (1 - k)
|
|
return warmup_lr
|
|
|
|
|
|
class CosineAnnealing(LRSchedulerStep):
|
|
def __init__(self, fai_optimizer, total_step, total_epoch, lr_max, moms, pct_start, warmup_iter):
|
|
self.lr_max = lr_max
|
|
self.moms = moms
|
|
self.pct_start = pct_start
|
|
|
|
mom_phases = ((0, partial(annealing_cos, *self.moms)),
|
|
(self.pct_start, partial(annealing_cos,
|
|
*self.moms[::-1])))
|
|
fai_optimizer.lr, fai_optimizer.mom = lr_max, self.moms[0]
|
|
|
|
self.optimizer = fai_optimizer
|
|
self.total_step = total_step
|
|
self.warmup_iter = warmup_iter
|
|
self.total_epoch = total_epoch
|
|
|
|
self.mom_phases = []
|
|
for i, (start, lambda_func) in enumerate(mom_phases):
|
|
if len(self.mom_phases) != 0:
|
|
assert self.mom_phases[-1][0] < start
|
|
if isinstance(lambda_func, str):
|
|
lambda_func = eval(lambda_func)
|
|
if i < len(mom_phases) - 1:
|
|
self.mom_phases.append((int(start * total_step), int(mom_phases[i + 1][0] * total_step), lambda_func))
|
|
else:
|
|
self.mom_phases.append((int(start * total_step), total_step, lambda_func))
|
|
assert self.mom_phases[0][0] == 0
|
|
|
|
def step(self, step, epoch):
|
|
# update lr
|
|
if step < self.warmup_iter:
|
|
self.optimizer.lr = linear_warmup(self.warmup_iter, self.lr_max, step)
|
|
else:
|
|
target_lr = self.lr_max * 0.001
|
|
cos_lr = annealing_cos(self.lr_max, target_lr, epoch / self.total_epoch)
|
|
self.optimizer.lr = cos_lr
|
|
# update mom
|
|
for start, end, func in self.mom_phases:
|
|
if step >= start:
|
|
self.optimizer.mom = func((step - start) / (end - start))
|
|
|
|
|
|
class FakeOptim:
|
|
def __init__(self):
|
|
self.lr = 0
|
|
self.mom = 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import matplotlib.pyplot as plt
|
|
|
|
opt = FakeOptim() # 3e-3, wd=0.4, div_factor=10
|
|
schd = OneCycle(opt, 100, 3e-3, (0.95, 0.85), 10.0, 0.1)
|
|
|
|
lrs = []
|
|
moms = []
|
|
for i in range(100):
|
|
schd.step(i)
|
|
lrs.append(opt.lr)
|
|
moms.append(opt.mom)
|
|
plt.plot(lrs)
|
|
# plt.plot(moms)
|
|
plt.show()
|
|
plt.plot(moms)
|
|
plt.show()
|