Custom ScheduleΒΆ

import torch

from olympus.optimizers import Optimizer
from olympus.models import Model
from olympus.optimizers.schedules import LRSchedule
from olympus.optimizers.schedules.base import LRScheduleInterface


class MyExponentialLR(LRScheduleInterface):
    def __init__(self, optimizer, gamma):
        super(MyExponentialLR, self).__init__(optimizer)
        self.gamma = gamma

    def state_dict(self):
        state = super(MyExponentialLR, self).state_dict()
        state['gamma'] = self.gamma
        return state

    def load_state_dict(self, state_dict):
        self.gamma = state_dict.pop('gamma')
        return super(MyExponentialLR, self).load_state_dict(state_dict)

    def epoch(self, epoch, metrics=None):
        self._step_count += 1

        if epoch is None:
            epoch = self.last_epoch + 1

        self.last_epoch = epoch
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

    def get_lr(self):
        return [base_lr * self.gamma ** self.last_epoch
                for base_lr in self.base_lrs]

    @staticmethod
    def get_space():
        return {'gamma': 'loguniform(0.97, 1)'}


if __name__ == '__main__':
    model = Model(
        'logreg',
        input_size=(290,),
        output_size=(10,)
    )

    optimizer = Optimizer('sgd', params=model.parameters())

    # If you use an hyper parameter optimizer, it will generate this for you
    optimizer.init(lr=1e-4, momentum=0.02, weight_decay=1e-3)

    schedule = LRSchedule(schedule=MyExponentialLR)
    schedule.init(optimizer=optimizer, gamma=0.97)

    optimizer.zero_grad()

    input = torch.randn((10, 290))
    output = model(input)
    loss = output.sum()
    loss.backward()

    optimizer.step()

    print(optimizer.param_groups[0]['lr'])
    schedule.epoch(1)
    print(optimizer.param_groups[0]['lr'])