Source code for olympus.utils

import random
import time
from typing import Callable, Optional, TypeVar
from urllib.parse import urlparse

import numpy

import torch

from olympus.utils.options import option, set_option
from olympus.utils.chrono import Chrono
from olympus.utils.functional import select, flatten
from olympus.utils.arguments import parse_args, show_hyperparameter_space, required, get_parameters, drop_empty_key
from olympus.utils.log import warning, info, debug, error, critical, exception, set_verbose_level, set_log_level

A = TypeVar('A')
R = TypeVar('R')


[docs]class MissingArgument(Exception): pass
[docs]def fetch_device(): """Set the default device to CPU if cuda is not available""" default = 'cpu' if torch.cuda.is_available(): default = 'cuda' return torch.device(option('device.type', default))
[docs]def show_dict(dictionary, indent=0): print(' ' * indent + '-' * 80) for k, v in dictionary.items(): print(f'{k:>30}: {v}') print(' ' * indent + '-' * 80)
[docs]class TimeThrottler: """Limit how often the function `fun` is called in seconds""" def __init__(self, fun: Callable[[A], R], every=10): self.fun = fun self.last_time: float = 0 self.every: float = every def __call__(self, *args, **kwargs) -> Optional[R]: now = time.time() elapsed = now - self.last_time if elapsed > self.every: self.last_time = now return self.fun(*args, **kwargs) return None
[docs]def parse_uri_options(options): if not options: return dict() opt = dict() for item in options.split('&'): k, v = item.split('=') opt[k] = v return opt
[docs]def parse_uri(uri): parsed = urlparse(uri) netloc = parsed.netloc arguments = { 'scheme': parsed.scheme, 'path': parsed.path, 'query': parse_uri_options(parsed.query), 'fragment': parsed.fragment, 'params': parsed.params } if netloc: usr_pwd_add_port = netloc.split('@') if len(usr_pwd_add_port) == 2: usr_pwd = usr_pwd_add_port[0].split(':') if len(usr_pwd) == 2: arguments['password'] = usr_pwd[1] arguments['username'] = usr_pwd[0] add_port = usr_pwd_add_port[-1].split(':') if len(add_port) == 2: arguments['port'] = add_port[1] arguments['address'] = add_port[0] return arguments
[docs]def get_storage(uri, objective=None): """Shorten the storage config from orion that is super long an super confusing <storage_type>:<database>:<file or address> legacy:pickleddb:my_data.pkl legacy:mongodb://user@pass:192.168.0.0:8989 """ storage_type, storage_uri = uri.split(':', maxsplit=1) arguments = parse_uri(storage_uri) database = arguments.get('scheme', 'pickleddb') database_resource = arguments.get('path', arguments.get('address')) if storage_type == 'legacy': # TODO: make it work for mongodb return { 'type': storage_type, 'database': { 'type': database, 'host': database_resource, } } if storage_type == 'track': return { 'type': 'track', 'uri': f'{storage_uri}?objective={objective}' }
[docs]def get_value(item): if isinstance(item, torch.Tensor): return item.item() return item
[docs]def find_batch_size(model, shape, low, high, dtype=torch.float32): """Find the highest batch size that can fit in memory using binary search""" low = (low // 8) * 8 high = (1 + high // 8) * 8 batches = list(range(low, high, 8)) a = 0 b = len(batches) mid = a + (b - a) // 2 while b != a + 1: mid = a + (b - a) // 2 try: batch_size = batches[mid] tensor = torch.randn((batch_size,) + shape, dtype=dtype) model(tensor) # ran successfully a = mid # ran out of memory except RuntimeError as e: if 'out of memory' in str(e): b = mid else: raise e return batches[mid]
[docs]class CircularDependencies(Exception): pass
[docs]class LazyCall: """Save the call parameters of a function for it can be invoked at a later date""" def __init__(self, fun, *args, **kwargs): self.fun = fun self.args = args self.kwargs = kwargs self.obj = None self.is_processing = False def __call__(self, *args, **kwargs): self.invoke() return self.obj(*args, **kwargs)
[docs] def add_arguments(self, *args, **kwargs): self.args = self.args + args self.kwargs.update(kwargs)
[docs] def invoke(self, **kwargs): if self.obj is None: self.is_processing = True self.obj = self.fun(*self.args, **self.kwargs, **kwargs) self.is_processing = False return self.obj return self.obj
def __getattr__(self, item): if self.obj is None and self.is_processing: raise CircularDependencies('Circular dependencies') self.invoke() return getattr(self.obj, item)
[docs] def was_invoked(self): return self.obj is not None
[docs]class MissingParameters(Exception): pass
[docs]class WrongParameter(Exception): pass
[docs]class HyperParameters: """Keeps track of mandatory hyper parameters Parameters ---------- space: Dict[str, Space] A dictionary defining each parameters and their respective space/dim kwargs: A dictionary of defined hyper parameters """ def __init__(self, space, **kwargs): self.space = space self.check_correct_parameters(kwargs) self.current_parameters = kwargs
[docs] def check_correct_parameters(self, kwargs): for k, v in kwargs.items(): if k not in self.space: raise WrongParameter(f'{k} is not a valid parameter, pick from: {self.space.keys()}')
[docs] def missing_parameters(self): missing = {} for k, v in self.space.items(): if k not in self.current_parameters: missing[k] = v return missing
[docs] def add_parameters(self, **kwargs): self.check_correct_parameters(kwargs) self.current_parameters.update(kwargs)
[docs] def parameters(self, strict=False): if strict: missing = self.missing_parameters() if missing: raise MissingParameters('Parameters are missing: {}'.format(', '.join(missing.keys()))) return self.current_parameters
SEEDS = {}
[docs]def new_seed(**kwargs): """Global seed management""" global SEEDS import random assert len(kwargs) == 1, 'Only single seed can be registered' automatic_seeding = option('seeding.random', default=False, type=bool) for name, value in kwargs.items(): # do not change the seed if it was already set if name in SEEDS: return SEEDS[name] elif not automatic_seeding: SEEDS[name] = value else: val = random.getrandbits(64) SEEDS[name] = val kwargs[name] = val return kwargs.popitem()[1]
[docs]def get_seeds(): return SEEDS
[docs]def set_seeds(seed): if torch.cuda.is_available(): torch.backends.cudnn.benchmark = False torch.cuda.manual_seed_all(new_seed(torch_cuda=seed)) # torch.backends.cudnn.deterministic = True random.seed(new_seed(python_rand=seed)) numpy.random.seed(new_seed(numpy=seed)) torch.manual_seed(new_seed(torch_cpu=seed))