Source code for olympus.models.inits

from olympus.utils.factory import fetch_factories

from olympus.utils import set_seeds as init_seed, warning, HyperParameters

from torch.nn import Module
from torch.random import fork_rng

registered_initialization = fetch_factories('olympus.models.inits', __file__)


[docs]def known_initialization(): return registered_initialization.keys()
[docs]def register_initialization(name, factory, override=False): global registered_initialization if name in registered_initialization: warning(f'{name} was already registered, use override=True to ignore') if not override: return registered_initialization[name] = factory
[docs]class RegisteredInitNotFound(Exception): pass
[docs]def get_initializers_space(): space = {} for k, initializer in registered_initialization.items(): space[k] = initializer.get_space() return dict(initializer=space)
[docs]class Initializer: """Lazy Initializer""" def __init__(self, name, seed=0, **kwargs): self.name = name self.hyper_parameters = HyperParameters(space={}) self.seed = seed self._initializer = None self.initializer_ctor = registered_initialization.get(name) if self.initializer_ctor is None: raise RegisteredInitNotFound(name) if hasattr(self.initializer_ctor, 'get_space'): self.hyper_parameters.space = self.initializer_ctor.get_space() self.hyper_parameters.add_parameters(**kwargs)
[docs] def get_space(self): """Return the dimension space of each parameters""" return self.hyper_parameters.missing_parameters()
[docs] def get_current_space(self): """Get currently defined parameter space""" return self.hyper_parameters.parameters(strict=False)
[docs] def init(self, override=False, **kwargs): if self._initializer and not override: warning('Initializer is already set, use override=True to force re initialization') return self self.hyper_parameters.add_parameters(**kwargs) self._initializer = self.initializer_ctor(**self.hyper_parameters.parameters(strict=True)) return self
@property def initializer(self): if not self._initializer: self.init() return self._initializer def __call__(self, model): with fork_rng(enabled=True): init_seed(self.seed) return self.initializer(model)
[docs]def initialize_weights(model, name=None, seed=0, half=False, **kwargs): """Initialize the model weights using a particular method Parameters ---------- model: Module name: str Name of the initializer seed: int seed to use for the PRNGs Returns ------- The initialized model """ # TODO: remove dependency to global PRNG # At the moment we simply fork the PRNG to prevent affecting later calls with fork_rng(enabled=True): init_seed(seed) method_builder = registered_initialization.get(name) if not method_builder: raise RegisteredInitNotFound(name) method = method_builder(**kwargs) return method(model)