Source code for olympus.observers

import json

from olympus.observers.observer import Observer
from olympus.observers.progress import *
from olympus.observers.checkpointer import CheckPointer
from olympus.observers.msgtracker import metric_logger
from olympus.utils import warning, error


new_trial   = 'new_trial'       # HPO created a trial
start_train = 'start_train'     # Train is starting from scratch
resume_train= 'resume_train'    # Train is being resumed
new_epoch   = 'new_epoch'       # New epoch is starting
end_epoch   = 'end_epoch'
new_batch   = 'new_batch'       # New Batch is starting
end_batch   = 'end_batch'
end_train   = 'end_train'       # Train has finished


[docs]class ObserverList: """MetricList relays the Event to the Metrics/Observers""" def __init__(self, *args, task=None, name=None): self.name = name self._metrics_mapping = dict() self.metrics = list() for arg in args: self.append(arg) self.batch_id: int = 0 self.trial_id: int = 0 self._epoch: int = 0 self._previous_step = 0 self.task = task
[docs] @staticmethod def should_run(metric, name, step): if step is None: warning(f'step is none; cannot run (metric: {metric}) with (event: {name})') return False frequency = getattr(metric, f'frequency_{name}', 1) if frequency > 0: return step % frequency == 0 return False
[docs] def broadcast_event(self, event_name, task, step, *args, **kwargs): for m in self.metrics: if ObserverList.should_run(m, event_name, step): fun = getattr(m, f'on_{event_name}', None) if fun is not None: try: fun(task, step, *args, **kwargs) except TypeError: error(f'(metric: {m}) (event: {event_name})') raise
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): """Save all the children states""" return [m.state_dict() for m in self.metrics]
[docs] def load_state_dict(self, state_dict, strict=True): """Resume all children metrics using a state_dict""" for m, m_state_dict in zip(self.metrics, state_dict): m.load_state_dict(m_state_dict)
def __getitem__(self, item): v = self.get(item) if v is None: raise RuntimeError('Not found') def __setitem__(self, key, value): self.append(m=value, key=key)
[docs] def get(self, key, default=None): """Retrieve a metric from its key Parameters ---------- key: Union[str, int] default: any default object returned if not found """ if isinstance(key, int): return self.metrics[key] if isinstance(key, str): return self._metrics_mapping.get(key, default) return default
[docs] def append(self, m: Observer, key=None): """Insert a new metric to compute Parameters ---------- m: Metric new metric to insert key: Optional[str] optional key used to retrieve the metric by default the type name will be used as key """ # Use name attribute as key if hasattr(m, 'name') and not key: key = m.name # Use type name as key elif not key: key = type(m).__name__ # only insert if there are no conflicts if key not in self._metrics_mapping: self._metrics_mapping[key] = m self.metrics.append(m) self.metrics.sort(key=lambda met: met.priority, reverse=True)
[docs] def new_epoch(self, epoch, context=None): """Broadcast a `new_epoch` event to all metrics""" self.broadcast_event('new_epoch', self.task, epoch, context)
[docs] def new_batch(self, step, input=None, context=None): """Broadcast a `new_batch` event to all metrics""" self.broadcast_event('new_batch', self.task, step, input, context)
[docs] def end_epoch(self, epoch, context=None): """Broadcast a `new_epoch` event to all metrics""" self.broadcast_event('end_epoch', self.task, epoch, context)
[docs] def end_batch(self, step, input=None, context=None): """Broadcast a `new_batch` event to all metrics""" self.broadcast_event('end_batch', self.task, step, input, context)
[docs] def new_trial(self, parameters, uid): """Broadcast a `new_trial` event""" self.broadcast_event('new_trial', self.task, 0, parameters, uid)
[docs] def start_train(self): """Broadcast a `start` event to all metrics""" self.broadcast_event('start_train', self.task, 0)
[docs] def resume_train(self, start_epoch): """Broadcast a `resume` event to all metrics""" self.broadcast_event('resume_train', self.task, start_epoch)
[docs] def end_train(self): """Broadcast a `finish` event to all metrics""" self.broadcast_event('end_train', self.task, 0)
[docs] def value(self): """Returns a dictionary of all computed metrics""" metrics = {} for metric in self.metrics: for key, value in metric.value().items(): if self.name: key = f'{self.name}_{key}' metrics[key] = value return metrics
[docs] def report(self, pprint=True, print_fun=print): """Pretty prints all the metrics""" metrics = self.value() if pprint: print_fun(json.dumps(metrics, indent=2)) return metrics
MetricList = ObserverList