Source code for olympus.optimizers.schedules

from olympus.utils import MissingArgument, warning, HyperParameters
from olympus.utils.factory import fetch_factories


registered_schedules = fetch_factories('olympus.optimizers.schedules', __file__)


[docs]class RegisteredLRSchedulerNotFound(Exception): pass
[docs]class UninitializedLRScheduler(Exception): pass
[docs]def known_schedule(): return registered_schedules.keys()
[docs]def register_schedule(name, factory, override=False): global registered_schedules if name in registered_schedules: warning(f'{name} was already registered, use override=True to ignore') if not override: return registered_schedules[name] = factory
[docs]class LRSchedule: """Lazy LRSchedule that allows you to first fetch the supported parameters using ``get_space`` and then initialize the underlying schedule using ``init_optimizer`` Parameters ---------- name: str Name of a registered schedule schedule: LRSchedule Custom schedule, mutually exclusive with :param name Examples -------- .. code-block:: python from olympus.optimizers import Optimizer optimizer = Optimizer('sgd') schedule = LRSchedule('exponential') schedule.get_space() # {'gamma': 'loguniform(0.97, 1)'} schedule.init(optimizer, gamma=0.97) Raises ------ RegisteredLRSchedulerNotFound when using a name of an known schedule MissingArgument: if name nor schedule were not set """ def __init__(self, name=None, *, schedule=None, optimizer=None, **kwargs): self._schedule = None self._schedule_builder = None self._optimizer = optimizer self.hyper_parameters = HyperParameters(space={}) if schedule: if isinstance(schedule, type): self._schedule_builder = schedule if hasattr(schedule, 'get_space'): self.hyper_parameters.space = schedule.get_space() else: self._schedule = schedule if hasattr(self._schedule, 'get_space'): self.hyper_parameters.space = self._schedule.get_space() elif name: # load an olympus model builder = registered_schedules.get(name) if not builder: raise RegisteredLRSchedulerNotFound(name) self._schedule_builder = builder if hasattr(self._schedule_builder, 'get_space'): self.hyper_parameters.space = self._schedule_builder.get_space() else: raise MissingArgument('None or name needs to be set') self.hyper_parameters.add_parameters(**kwargs)
[docs] def init(self, optimizer=None, override=False, **kwargs): """Initialize the LR schedule with the given hyper parameters""" if self._schedule: warning('LRSchedule is already set, use override=True to force re initialization') if not override: return self._schedule if optimizer is None: optimizer = self._optimizer if optimizer is None: raise MissingArgument('Missing optimizer argument!') self.hyper_parameters.add_parameters(**kwargs) self._schedule = self._schedule_builder( optimizer, **self.hyper_parameters.parameters(strict=True)) return self
[docs] def get_space(self): """Return the missing hyper parameters required to initialize the LR schedule""" if self._schedule: warning('LRSchedule is already set') return self.hyper_parameters.missing_parameters()
[docs] def get_current_space(self): """Get currently defined parameter space""" return self.hyper_parameters.parameters(strict=False)
@property def defaults(self): """Return default hyper parameters""" return self._schedule_builder.defaults() @property def lr_scheduler(self): if not self._schedule: self.init() return self._schedule
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): return self.lr_scheduler.state_dict()
[docs] def load_state_dict(self, state_dict, strict=True): self.lr_scheduler.load_state_dict(state_dict)
[docs] def epoch(self, epoch, metrics=None): """Called after every epoch to update LR""" self.lr_scheduler.epoch(epoch, metrics)
[docs] def step(self, step, metrics=None): """Called every step/batch to update LR""" self.lr_scheduler.step(step, metrics)
[docs] def get_lr(self): return self.lr_scheduler.get_lr()