import random
import signal
import time
from typing import Callable, Optional, TypeVar, Dict, NoReturn, Union
from urllib.parse import urlparse
import base64
import zlib
import numpy
import torch
from olympus.utils.options import option, set_option
from olympus.utils.chrono import Chrono
from olympus.utils.functional import select, flatten
from olympus.utils.arguments import (
parse_args,
show_hyperparameter_space,
required,
get_parameters,
drop_empty_key,
)
from olympus.utils.log import (
warning,
info,
debug,
error,
critical,
exception,
set_verbose_level,
set_log_level,
)
try:
import bson
BSON_ERROR = None
except ImportError as e:
BSON_ERROR = e
A = TypeVar("A")
R = TypeVar("R")
[docs]class MissingArgument(Exception):
pass
[docs]def fetch_device():
"""Set the default device to CPU if cuda is not available"""
default = "cpu"
if torch.cuda.is_available():
default = "cuda"
return torch.device(option("device.type", default))
[docs]def show_dict(dictionary: Dict, indent: int = 0, print_fun=print) -> NoReturn:
print_fun(" " * indent + "-" * 80)
for k, v in dictionary.items():
print_fun(f"{k:>30}: {v}")
print_fun(" " * indent + "-" * 80)
[docs]def compress_dict(state: Dict) -> Dict:
"""Compress a state dictionary and return a json friendly compressed state"""
if BSON_ERROR:
raise BSON_ERROR
binary = bson.encode(state)
compressed_json = base64.b64encode(zlib.compress(binary))
crc32 = zlib.crc32(binary)
return dict(zlib=compressed_json, crc32=crc32)
[docs]def decompress_dict(state: Dict) -> Dict:
"""Decompress a state dictionary and return its json"""
if BSON_ERROR:
raise BSON_ERROR
if "zlib" in state:
binary = base64.b64decode(state["zlib"])
decompressed_bson = zlib.decompress(binary)
assert zlib.crc32(decompressed_bson) == state["crc32"], "State is corrupted"
return bson.decode(decompressed_bson)
return state
[docs]def encode_rng_state(state):
state = list(state)
state[1] = state[1].tolist()
return tuple(state)
[docs]def decode_rng_state(state):
state = list(state)
state[1] = numpy.array(state[1])
return tuple(state)
[docs]class TimeThrottler:
"""Limit how often the function `fun` is called in seconds
Parameters
----------
fun:
function to throttle
every: int
Time in second in between each calls
callback:
function called when throttled
Examples
--------
.. code-block:: python
throttled_print = TimeThrottler(print, every=1)
# Only prints 0
for i in range(0, 10):
throttled_print(i)
# Prints 0 to 9
for i in range(0, 10):
throttled_print(i)
time.sleep(1)
"""
def __init__(self, fun: Callable[[A], R], every=10, callback=None):
self.fun = fun
self.last_time: float = 0
self.every: float = every
self.callback = callback
def __call__(self, *args, **kwargs) -> Optional[R]:
now = time.time()
elapsed = now - self.last_time
if elapsed > self.every:
self.last_time = now
return self.fun(*args, **kwargs)
if self.callback:
return self.callback(*args, **kwargs)
return None
[docs]def parse_uri_options(options: str) -> Dict:
if not options:
return dict()
opt = dict()
for item in options.split("&"):
k, v = item.split("=")
opt[k] = v
return opt
[docs]def parse_uri(uri: str) -> Dict:
parsed = urlparse(uri)
netloc = parsed.netloc
arguments = {
"scheme": parsed.scheme,
"path": parsed.path,
"query": parse_uri_options(parsed.query),
"fragment": parsed.fragment,
"params": parsed.params,
}
if netloc:
usr_pwd_add_port = netloc.split("@")
if len(usr_pwd_add_port) == 2:
usr_pwd = usr_pwd_add_port[0].split(":")
if len(usr_pwd) == 2:
arguments["password"] = usr_pwd[1]
arguments["username"] = usr_pwd[0]
add_port = usr_pwd_add_port[-1].split(":")
if len(add_port) == 2:
arguments["port"] = add_port[1]
arguments["address"] = add_port[0]
return arguments
[docs]def get_value(item: Union[float, torch.Tensor]) -> float:
if isinstance(item, torch.Tensor):
return item.item()
return item
[docs]def find_batch_size(model, shape, low, high, dtype=torch.float32):
"""Find the highest batch size that can fit in memory using binary search"""
low = (low // 8) * 8
high = (1 + high // 8) * 8
batches = list(range(low, high, 8))
a = 0
b = len(batches)
mid = a + (b - a) // 2
while b != a + 1:
mid = a + (b - a) // 2
try:
batch_size = batches[mid]
tensor = torch.randn((batch_size,) + shape, dtype=dtype)
model(tensor)
# ran successfully
a = mid
# ran out of memory
except RuntimeError as e:
if "out of memory" in str(e):
b = mid
else:
raise e
return batches[mid]
[docs]class CircularDependencies(Exception):
pass
[docs]class LazyCall:
"""Save the call parameters of a function for it can be invoked at a later date"""
def __init__(self, fun, *args, **kwargs):
self.fun = fun
self.args = args
self.kwargs = kwargs
self.obj = None
self.is_processing = False
def __call__(self, *args, **kwargs):
self.invoke()
return self.obj(*args, **kwargs)
[docs] def add_arguments(self, *args, **kwargs):
self.args = self.args + args
self.kwargs.update(kwargs)
[docs] def invoke(self, **kwargs):
if self.obj is None:
self.is_processing = True
self.obj = self.fun(*self.args, **self.kwargs, **kwargs)
self.is_processing = False
return self.obj
return self.obj
def __getattr__(self, item):
if self.obj is None and self.is_processing:
raise CircularDependencies("Circular dependencies")
self.invoke()
return getattr(self.obj, item)
[docs] def was_invoked(self):
return self.obj is not None
[docs]class MissingParameters(Exception):
pass
[docs]class WrongParameter(Exception):
pass
[docs]def missing_params(space, kwargs, missing):
for k, v in space.items():
if not isinstance(v, dict):
if k not in kwargs:
missing[k] = v
else:
if k not in kwargs:
missing[k] = v
continue
if k not in missing:
missing[k] = {}
sub_missing = missing_params(v, kwargs[k], missing[k])
# If nothing is missing pop it
if len(sub_missing) == 0:
missing.pop(k)
return missing
[docs]def update_params(space, kwargs, params, strict=True, unknown_params=None):
if unknown_params is None:
unknown_params = dict()
for k, v in kwargs.items():
if k not in space:
unknown_params[k] = v
continue
if isinstance(v, dict):
if k not in params:
params[k] = {}
ukwn = update_params(space[k], v, params[k], strict=False)
if ukwn:
unknown_params[k] = ukwn
else:
params[k] = v
if strict and len(unknown_params) > 0:
raise WrongParameter(
f"{list(unknown_params.keys())} is not a valid parameter, pick from: {list(space.keys())}"
)
return unknown_params
[docs]class HyperParameters:
"""Keeps track of mandatory hyper parameters
Parameters
----------
space: Dict[str, Space]
A dictionary defining each parameters and their respective space/dim
kwargs:
A dictionary of defined hyper parameters
"""
def __init__(self, space, **kwargs):
self.space = space
self.current_parameters = {}
self.add_parameters(**kwargs)
[docs] def missing_parameters(self):
"""Returns a dictionary of missing parameters"""
return missing_params(self.space, self.current_parameters, {})
[docs] def add_parameters(self, strict=True, **kwargs):
"""Insert a new parameter value
Parameters
----------
strict: bool
control if an exception is raised or not when an unknown parameter is encountered
Returns
-------
a dictionary of unknown parameter
"""
return update_params(self.space, kwargs, self.current_parameters, strict=strict)
[docs] def parameters(self, strict=False):
"""Returns all the parameters and checks if any are missing"""
if strict:
missing = self.missing_parameters()
if missing:
raise MissingParameters(
"Parameters are missing: {}".format(", ".join(missing.keys()))
)
return self.current_parameters
# Tracks global seeds
SEEDS = {}
[docs]def new_seed(**kwargs):
"""Global seed management"""
global SEEDS
import random
assert len(kwargs) == 1, "Only single seed can be registered"
# Allow user to force seed to change seeds automatically each time the program is ran
# Disabled by default
automatic_seeding = option("seeding.random", default=False, type=bool)
for name, value in kwargs.items():
# do not change the seed if it was already set
if name in SEEDS:
warning(f"Resetting a global seed for {name}")
if not automatic_seeding:
SEEDS[name] = value
else:
val = random.getrandbits(64)
SEEDS[name] = val
kwargs[name] = val
return kwargs.popitem()[1]
[docs]def get_seeds():
"""Returns a set of seed that are used by the program"""
return SEEDS
[docs]def set_seeds(seed):
"""Set most commonly used global seeds"""
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed_all(seed)
# https://pytorch.org/docs/stable/notes/randomness.html
# We actually tried to not set it to True but it does make
# significant differences after a few epochs
torch.backends.cudnn.deterministic = True
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
[docs]def get_rng_states():
state = dict()
if torch.cuda.is_available():
state["torch_cuda"] = torch.cuda.get_rng_state_all()
state["random"] = random.getstate()
state["numpy"] = numpy.random.get_state()
state["torch_cpu"] = torch.get_rng_state()
return state
[docs]def set_rng_states(state):
if torch.cuda.is_available():
torch.cuda.set_rng_state_all(state["torch_cuda"])
elif "torch_cuda" in state:
raise RuntimeError("Cannot restore state without a GPU.")
random.setstate(state["random"])
numpy.random.set_state(state["numpy"])
torch.set_rng_state(state["torch_cpu"])