Source code for olympus.utils

import random
import signal
import time
from typing import Callable, Optional, TypeVar, Dict, NoReturn, Union
from urllib.parse import urlparse

import base64
import zlib

import numpy

import torch

from olympus.utils.options import option, set_option
from olympus.utils.chrono import Chrono
from olympus.utils.functional import select, flatten
from olympus.utils.arguments import (
    parse_args,
    show_hyperparameter_space,
    required,
    get_parameters,
    drop_empty_key,
)
from olympus.utils.log import (
    warning,
    info,
    debug,
    error,
    critical,
    exception,
    set_verbose_level,
    set_log_level,
)

try:
    import bson

    BSON_ERROR = None
except ImportError as e:
    BSON_ERROR = e

A = TypeVar("A")
R = TypeVar("R")


[docs]class MissingArgument(Exception): pass
[docs]def fetch_device(): """Set the default device to CPU if cuda is not available""" default = "cpu" if torch.cuda.is_available(): default = "cuda" return torch.device(option("device.type", default))
[docs]def show_dict(dictionary: Dict, indent: int = 0, print_fun=print) -> NoReturn: print_fun(" " * indent + "-" * 80) for k, v in dictionary.items(): print_fun(f"{k:>30}: {v}") print_fun(" " * indent + "-" * 80)
[docs]def compress_dict(state: Dict) -> Dict: """Compress a state dictionary and return a json friendly compressed state""" if BSON_ERROR: raise BSON_ERROR binary = bson.encode(state) compressed_json = base64.b64encode(zlib.compress(binary)) crc32 = zlib.crc32(binary) return dict(zlib=compressed_json, crc32=crc32)
[docs]def decompress_dict(state: Dict) -> Dict: """Decompress a state dictionary and return its json""" if BSON_ERROR: raise BSON_ERROR if "zlib" in state: binary = base64.b64decode(state["zlib"]) decompressed_bson = zlib.decompress(binary) assert zlib.crc32(decompressed_bson) == state["crc32"], "State is corrupted" return bson.decode(decompressed_bson) return state
[docs]def encode_rng_state(state): state = list(state) state[1] = state[1].tolist() return tuple(state)
[docs]def decode_rng_state(state): state = list(state) state[1] = numpy.array(state[1]) return tuple(state)
[docs]class TimeThrottler: """Limit how often the function `fun` is called in seconds Parameters ---------- fun: function to throttle every: int Time in second in between each calls callback: function called when throttled Examples -------- .. code-block:: python throttled_print = TimeThrottler(print, every=1) # Only prints 0 for i in range(0, 10): throttled_print(i) # Prints 0 to 9 for i in range(0, 10): throttled_print(i) time.sleep(1) """ def __init__(self, fun: Callable[[A], R], every=10, callback=None): self.fun = fun self.last_time: float = 0 self.every: float = every self.callback = callback def __call__(self, *args, **kwargs) -> Optional[R]: now = time.time() elapsed = now - self.last_time if elapsed > self.every: self.last_time = now return self.fun(*args, **kwargs) if self.callback: return self.callback(*args, **kwargs) return None
[docs]def parse_uri_options(options: str) -> Dict: if not options: return dict() opt = dict() for item in options.split("&"): k, v = item.split("=") opt[k] = v return opt
[docs]def parse_uri(uri: str) -> Dict: parsed = urlparse(uri) netloc = parsed.netloc arguments = { "scheme": parsed.scheme, "path": parsed.path, "query": parse_uri_options(parsed.query), "fragment": parsed.fragment, "params": parsed.params, } if netloc: usr_pwd_add_port = netloc.split("@") if len(usr_pwd_add_port) == 2: usr_pwd = usr_pwd_add_port[0].split(":") if len(usr_pwd) == 2: arguments["password"] = usr_pwd[1] arguments["username"] = usr_pwd[0] add_port = usr_pwd_add_port[-1].split(":") if len(add_port) == 2: arguments["port"] = add_port[1] arguments["address"] = add_port[0] return arguments
[docs]def get_value(item: Union[float, torch.Tensor]) -> float: if isinstance(item, torch.Tensor): return item.item() return item
[docs]def find_batch_size(model, shape, low, high, dtype=torch.float32): """Find the highest batch size that can fit in memory using binary search""" low = (low // 8) * 8 high = (1 + high // 8) * 8 batches = list(range(low, high, 8)) a = 0 b = len(batches) mid = a + (b - a) // 2 while b != a + 1: mid = a + (b - a) // 2 try: batch_size = batches[mid] tensor = torch.randn((batch_size,) + shape, dtype=dtype) model(tensor) # ran successfully a = mid # ran out of memory except RuntimeError as e: if "out of memory" in str(e): b = mid else: raise e return batches[mid]
[docs]class CircularDependencies(Exception): pass
[docs]class LazyCall: """Save the call parameters of a function for it can be invoked at a later date""" def __init__(self, fun, *args, **kwargs): self.fun = fun self.args = args self.kwargs = kwargs self.obj = None self.is_processing = False def __call__(self, *args, **kwargs): self.invoke() return self.obj(*args, **kwargs)
[docs] def add_arguments(self, *args, **kwargs): self.args = self.args + args self.kwargs.update(kwargs)
[docs] def invoke(self, **kwargs): if self.obj is None: self.is_processing = True self.obj = self.fun(*self.args, **self.kwargs, **kwargs) self.is_processing = False return self.obj return self.obj
def __getattr__(self, item): if self.obj is None and self.is_processing: raise CircularDependencies("Circular dependencies") self.invoke() return getattr(self.obj, item)
[docs] def was_invoked(self): return self.obj is not None
[docs]class MissingParameters(Exception): pass
[docs]class WrongParameter(Exception): pass
[docs]def missing_params(space, kwargs, missing): for k, v in space.items(): if not isinstance(v, dict): if k not in kwargs: missing[k] = v else: if k not in kwargs: missing[k] = v continue if k not in missing: missing[k] = {} sub_missing = missing_params(v, kwargs[k], missing[k]) # If nothing is missing pop it if len(sub_missing) == 0: missing.pop(k) return missing
[docs]def update_params(space, kwargs, params, strict=True, unknown_params=None): if unknown_params is None: unknown_params = dict() for k, v in kwargs.items(): if k not in space: unknown_params[k] = v continue if isinstance(v, dict): if k not in params: params[k] = {} ukwn = update_params(space[k], v, params[k], strict=False) if ukwn: unknown_params[k] = ukwn else: params[k] = v if strict and len(unknown_params) > 0: raise WrongParameter( f"{list(unknown_params.keys())} is not a valid parameter, pick from: {list(space.keys())}" ) return unknown_params
[docs]class HyperParameters: """Keeps track of mandatory hyper parameters Parameters ---------- space: Dict[str, Space] A dictionary defining each parameters and their respective space/dim kwargs: A dictionary of defined hyper parameters """ def __init__(self, space, **kwargs): self.space = space self.current_parameters = {} self.add_parameters(**kwargs)
[docs] def missing_parameters(self): """Returns a dictionary of missing parameters""" return missing_params(self.space, self.current_parameters, {})
[docs] def add_parameters(self, strict=True, **kwargs): """Insert a new parameter value Parameters ---------- strict: bool control if an exception is raised or not when an unknown parameter is encountered Returns ------- a dictionary of unknown parameter """ return update_params(self.space, kwargs, self.current_parameters, strict=strict)
[docs] def parameters(self, strict=False): """Returns all the parameters and checks if any are missing""" if strict: missing = self.missing_parameters() if missing: raise MissingParameters( "Parameters are missing: {}".format(", ".join(missing.keys())) ) return self.current_parameters
# Tracks global seeds SEEDS = {}
[docs]def new_seed(**kwargs): """Global seed management""" global SEEDS import random assert len(kwargs) == 1, "Only single seed can be registered" # Allow user to force seed to change seeds automatically each time the program is ran # Disabled by default automatic_seeding = option("seeding.random", default=False, type=bool) for name, value in kwargs.items(): # do not change the seed if it was already set if name in SEEDS: warning(f"Resetting a global seed for {name}") if not automatic_seeding: SEEDS[name] = value else: val = random.getrandbits(64) SEEDS[name] = val kwargs[name] = val return kwargs.popitem()[1]
[docs]def get_seeds(): """Returns a set of seed that are used by the program""" return SEEDS
[docs]def set_seeds(seed): """Set most commonly used global seeds""" if torch.cuda.is_available(): torch.backends.cudnn.benchmark = False torch.cuda.manual_seed_all(seed) # https://pytorch.org/docs/stable/notes/randomness.html # We actually tried to not set it to True but it does make # significant differences after a few epochs torch.backends.cudnn.deterministic = True random.seed(seed) numpy.random.seed(seed) torch.manual_seed(seed)
[docs]def get_rng_states(): state = dict() if torch.cuda.is_available(): state["torch_cuda"] = torch.cuda.get_rng_state_all() state["random"] = random.getstate() state["numpy"] = numpy.random.get_state() state["torch_cpu"] = torch.get_rng_state() return state
[docs]def set_rng_states(state): if torch.cuda.is_available(): torch.cuda.set_rng_state_all(state["torch_cuda"]) elif "torch_cuda" in state: raise RuntimeError("Cannot restore state without a GPU.") random.setstate(state["random"]) numpy.random.set_state(state["numpy"]) torch.set_rng_state(state["torch_cpu"])