Source code for olympus.models

import torch
import torch.nn as nn

from olympus.models.module import Module
from olympus.models.inits import Initializer, known_initialization, get_initializers_space

from olympus.utils import MissingArgument, warning, LazyCall, HyperParameters
from olympus.utils.factory import fetch_factories
from olympus.utils.fp16 import network_to_half

registered_models = fetch_factories('olympus.models', __file__)


[docs]def known_models(): return registered_models.keys()
[docs]def register_model(name, factory, override=False): global registered_models if name in registered_models: warning(f'{name} was already registered, use override=True to ignore') if not override: return registered_models[name] = factory
[docs]class RegisteredModelNotFound(Exception): pass
[docs]def try_convert(x, device, dtype): if hasattr(x, 'to'): return x.to(device=device, dtype=dtype) return x
default_init = Initializer('glorot_uniform', seed=0, gain=1.0) # TODO: Make Model distributed here
[docs]class Model(nn.Module): """Olympus standardized Model interface Parameters ---------- name: str Name of a registered model half: bool Convert the network to half/fp16 model: Model Custom model to use, mutually exclusive with :param name Examples -------- Model wrappers that provide a wide range of utility built-in. Can instantiate common model directly >>> model = Model('resnet18', input_size=(1, 28, 28), output_size=(10,)) Handles mixed precision conversion for you >>> model = Model('resnet18', input_size=(1, 28, 28), output_size=(10,), half=True) Handles weight initialization >>> model = Model('resnet18', input_size=(1, 28, 28), output_size=(10,), weight_init='glorot_uniform') Supports your custom model >>> class MyModel(nn.Module): ... def __init__(self, input_size, output_size): ... self.main = nn.Linear(input_size[0], output_size[0]) ... ... def forward(self, x): ... return self.main(x) >>> >>> model = Model(model=MyModel, input_size=(1, 28, 28), output_size=(10,)) Raises ------ RegisteredModelNotFound when using a name of an known model MissingArgument: if name nor model were not set """ _dtype = torch.float32 _device = torch.device('cpu') def __init__(self, name=None, *, half=False, model=None, input_size=None, output_size=None, weight_init=default_init, **kwargs): super(Model, self).__init__() # Save all the args that ware passed down so we can instantiate it again in standalone self.replay_args = dict( name=name, half=half, model=model, input_size=input_size, output_size=output_size, weight_init=weight_init, kwargs=kwargs) self.transform = lambda x: try_convert(x, self.device, self.dtype) self.half = half self._model = None # Track defined hyper parameters self.hyper_parameters = HyperParameters(space=dict()) # If init is set then we can add its hyper parameters self.weight_init = weight_init if weight_init is not None: if isinstance(weight_init, str): self.weight_init = Initializer(weight_init) # replace weight init by its own hyper parameters space = self.weight_init.get_space() if space: self.hyper_parameters.space.update(dict(initializer=space)) # Make a Lazy Model that will be initialized once all the hyper parameters are set if model: if hasattr(model, 'get_space'): self.hyper_parameters.space.update(model.get_space()) if isinstance(model, type): self.model_builder = LazyCall( model, input_size=input_size, output_size=output_size) else: self.model_builder = LazyCall(lambda *args, **kwargs: model) elif name: # load an olympus model model_fun = registered_models.get(name) if not model_fun: raise RegisteredModelNotFound(name) self.model_builder = LazyCall( model_fun, input_size=input_size, output_size=output_size) if hasattr(model_fun, 'get_space'): self.hyper_parameters.space.update(model_fun.get_space()) else: raise MissingArgument('Model or Name need to be set') # Any Additional parameters set Hyper parameters self.other_params = self.hyper_parameters.add_parameters(strict=False, **kwargs) @property def dtype(self): return self._dtype @property def device(self): return self._device
[docs] def get_space(self): """Return hyper parameter space""" return self.hyper_parameters.missing_parameters()
[docs] def get_current_space(self): """Get currently defined parameter space""" return self.hyper_parameters.parameters(strict=False)
[docs] def init(self, override=False, **model_hyperparams): others = self.hyper_parameters.add_parameters(strict=False, **model_hyperparams) self.other_params.update(others) params = self.hyper_parameters.parameters(strict=True) initializer = params.pop('initializer', {}) if isinstance(initializer, dict): self.weight_init.init(**initializer) self._model = self.model_builder.invoke(**self.other_params, **params) self.weight_init(self._model) if self.half: self._model = network_to_half(self._model) # Register module so we can use all the parent methods self.add_module('_model', self._model) self.replay_args['kwargs'].update(model_hyperparams) return self
@property def model(self): if not self._model: self.init() return self._model
[docs] def forward(self, *input, **kwargs): return self.model(self.transform(input[0]), *input[1:], **kwargs)
def __call__(self, *args, **kwargs): return super(Model, self).__call__(*args, **kwargs)
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): destination = { 'model': self.model.state_dict(None, prefix, keep_vars), 'half': self.half, 'replay': self.replay_args, 'types': { 'model': type(self.model) } } return destination
[docs] @staticmethod def from_state(state): kwargs = state.get('replay') kwargs.update(kwargs.pop('kwargs', dict())) m = Model(**kwargs) m.init() m.load_state_dict(state) return m
[docs] def load_state_dict(self, state_dict, strict=True): self.half = state_dict['half'] self.model.load_state_dict(state_dict['model'], strict=strict)
[docs] def parameters(self, recurse: bool = True): return self.model.parameters(recurse)
[docs] def to(self, *args, **kwargs): self._device, self._dtype, *_ = torch._C._nn._parse_to(*args, **kwargs) super(Model, self).to(*args, **kwargs) return self
[docs] def act(self, *args, **kwargs): return self.model.act(*args, **kwargs)
[docs] def critic(self, *args, **kwargs): return self.model.critic(*args, **kwargs)