From 454a0cd72bcb4fc6430eb48d803618e9804637e1 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:19:58 +0800 Subject: [PATCH] Add File --- .../optimization/learning_schedules_fastai.py | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 tools/train_utils/optimization/learning_schedules_fastai.py diff --git a/tools/train_utils/optimization/learning_schedules_fastai.py b/tools/train_utils/optimization/learning_schedules_fastai.py new file mode 100644 index 0000000..15f7d23 --- /dev/null +++ b/tools/train_utils/optimization/learning_schedules_fastai.py @@ -0,0 +1,162 @@ +# This file is modified from https://github.com/traveller59/second.pytorch + +import math +from functools import partial + +import numpy as np +import torch.optim.lr_scheduler as lr_sched + +from .fastai_optim import OptimWrapper + + +class LRSchedulerStep(object): + def __init__(self, fai_optimizer: OptimWrapper, total_step, lr_phases, + mom_phases): + # if not isinstance(fai_optimizer, OptimWrapper): + # raise TypeError('{} is not a fastai OptimWrapper'.format( + # type(fai_optimizer).__name__)) + self.optimizer = fai_optimizer + self.total_step = total_step + self.lr_phases = [] + + for i, (start, lambda_func) in enumerate(lr_phases): + if len(self.lr_phases) != 0: + assert self.lr_phases[-1][0] < start + if isinstance(lambda_func, str): + lambda_func = eval(lambda_func) + if i < len(lr_phases) - 1: + self.lr_phases.append((int(start * total_step), int(lr_phases[i + 1][0] * total_step), lambda_func)) + else: + self.lr_phases.append((int(start * total_step), total_step, lambda_func)) + assert self.lr_phases[0][0] == 0 + self.mom_phases = [] + for i, (start, lambda_func) in enumerate(mom_phases): + if len(self.mom_phases) != 0: + assert self.mom_phases[-1][0] < start + if isinstance(lambda_func, str): + lambda_func = eval(lambda_func) + if i < len(mom_phases) - 1: + self.mom_phases.append((int(start * total_step), int(mom_phases[i + 1][0] * total_step), lambda_func)) + else: + self.mom_phases.append((int(start * total_step), total_step, lambda_func)) + assert self.mom_phases[0][0] == 0 + + def step(self, step, epoch=None): + for start, end, func in self.lr_phases: + if step >= start: + self.optimizer.lr = func((step - start) / (end - start)) + for start, end, func in self.mom_phases: + if step >= start: + self.optimizer.mom = func((step - start) / (end - start)) + + +def annealing_cos(start, end, pct): + # print(pct, start, end) + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = np.cos(np.pi * pct) + 1 + return end + (start - end) / 2 * cos_out + + +class OneCycle(LRSchedulerStep): + def __init__(self, fai_optimizer, total_step, lr_max, moms, div_factor, + pct_start): + self.lr_max = lr_max + self.moms = moms + self.div_factor = div_factor + self.pct_start = pct_start + a1 = int(total_step * self.pct_start) + a2 = total_step - a1 + low_lr = self.lr_max / self.div_factor + lr_phases = ((0, partial(annealing_cos, low_lr, self.lr_max)), + (self.pct_start, + partial(annealing_cos, self.lr_max, low_lr / 1e4))) + mom_phases = ((0, partial(annealing_cos, *self.moms)), + (self.pct_start, partial(annealing_cos, + *self.moms[::-1]))) + fai_optimizer.lr, fai_optimizer.mom = low_lr, self.moms[0] + super().__init__(fai_optimizer, total_step, lr_phases, mom_phases) + + +class CosineWarmupLR(lr_sched._LRScheduler): + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): + self.T_max = T_max + self.eta_min = eta_min + super(CosineWarmupLR, self).__init__(optimizer, last_epoch) + + def get_lr(self, epoch=None): + return [self.eta_min + (base_lr - self.eta_min) * + (1 - math.cos(math.pi * self.last_epoch / self.T_max)) / 2 + for base_lr in self.base_lrs] + + +def linear_warmup(end, lr_max, pct): + k = (1 - pct / end) * (1 - 0.33333333) + warmup_lr = lr_max * (1 - k) + return warmup_lr + + +class CosineAnnealing(LRSchedulerStep): + def __init__(self, fai_optimizer, total_step, total_epoch, lr_max, moms, pct_start, warmup_iter): + self.lr_max = lr_max + self.moms = moms + self.pct_start = pct_start + + mom_phases = ((0, partial(annealing_cos, *self.moms)), + (self.pct_start, partial(annealing_cos, + *self.moms[::-1]))) + fai_optimizer.lr, fai_optimizer.mom = lr_max, self.moms[0] + + self.optimizer = fai_optimizer + self.total_step = total_step + self.warmup_iter = warmup_iter + self.total_epoch = total_epoch + + self.mom_phases = [] + for i, (start, lambda_func) in enumerate(mom_phases): + if len(self.mom_phases) != 0: + assert self.mom_phases[-1][0] < start + if isinstance(lambda_func, str): + lambda_func = eval(lambda_func) + if i < len(mom_phases) - 1: + self.mom_phases.append((int(start * total_step), int(mom_phases[i + 1][0] * total_step), lambda_func)) + else: + self.mom_phases.append((int(start * total_step), total_step, lambda_func)) + assert self.mom_phases[0][0] == 0 + + def step(self, step, epoch): + # update lr + if step < self.warmup_iter: + self.optimizer.lr = linear_warmup(self.warmup_iter, self.lr_max, step) + else: + target_lr = self.lr_max * 0.001 + cos_lr = annealing_cos(self.lr_max, target_lr, epoch / self.total_epoch) + self.optimizer.lr = cos_lr + # update mom + for start, end, func in self.mom_phases: + if step >= start: + self.optimizer.mom = func((step - start) / (end - start)) + + +class FakeOptim: + def __init__(self): + self.lr = 0 + self.mom = 0 + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + opt = FakeOptim() # 3e-3, wd=0.4, div_factor=10 + schd = OneCycle(opt, 100, 3e-3, (0.95, 0.85), 10.0, 0.1) + + lrs = [] + moms = [] + for i in range(100): + schd.step(i) + lrs.append(opt.lr) + moms.append(opt.mom) + plt.plot(lrs) + # plt.plot(moms) + plt.show() + plt.plot(moms) + plt.show()