Source code for olympus.optimizers.base

import torch.optim as optim


[docs]class OptimizerInterface(optim.Optimizer): """Base Olympus Optimizer""" def __init__(self, params): super(OptimizerInterface, self).__init__(params)
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): return super(OptimizerInterface, self).state_dict()
[docs] def load_state_dict(self, state_dict, strict=True): return super(OptimizerInterface, self).load_state_dict(state_dict)
[docs] def zero_grad(self): return super(OptimizerInterface, self).zero_grad()
[docs] def step(self, closure=None): return super(OptimizerInterface, self).step(closure)
[docs] def backward(self, loss): """This method comes from FP16 Optimizer, for consistency we add it everywhere""" loss.backward()
[docs] def add_param_group(self, param_group): return super(OptimizerInterface, self).add_param_group(param_group)
[docs] @staticmethod def get_space(): """Specifies the hyper parameters that are supported by this optimizer""" return {}
[docs] @staticmethod def defaults(): """Specifies the hyper parameters defaults""" return {}
[docs]class OptimizerAdapter(OptimizerInterface): """Wraps an existing Pytorch Optimizer into an Olympus optimizer""" def __init__(self, factory, *args, **kwargs): self.optimizer = factory(*args, **kwargs)
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): return self.optimizer.state_dict()
[docs] def load_state_dict(self, state_dict, strict=True): return self.optimizer.load_state_dict(state_dict)
[docs] def zero_grad(self): return self.optimizer.zero_grad()
[docs] def step(self, closure=None): return self.optimizer.step(closure)
[docs] def backward(self, loss): """This method comes from FP16 Optimizer, for consistency we add it everywhere""" loss.backward()
[docs] def add_param_group(self, param_group): return self.optimizer.add_param_group(param_group)
@property def state(self): return self.optimizer.state @property def param_groups(self): return self.optimizer.param_groups
[docs] @staticmethod def get_space(): """Specifies the hyper parameters that are supported by this optimizer""" raise NotImplementedError()
[docs] @staticmethod def defaults(): """Specifies the hyper parameters defaults""" raise NotImplementedError()