from dataclasses import dataclass, field
from olympus.utils import warning
[docs]@dataclass
class Observer:
"""Metrics are observers that receives events periodically
Attributes
----------
frequency_epoch: int
Controls how often `on_new_epoch` is called, 0 disables it
frequency_batch: int
Controls how often `on_new_batch` is called, 0 disables it
frequency_trial: int
Controls how often `on_new_trial` is called, 0 disables it
priority: int
Controls which metric is called first
"""
frequency_new_epoch: int = field(default=0)
frequency_new_batch: int = field(default=0)
frequency_new_trial: int = field(default=0)
priority: int = field(default=0)
[docs] def on_new_epoch(self, task, epoch, context):
"""Called at the end of an epoch, before a new epoch starts"""
pass
[docs] def on_new_batch(self, task, step, input=None, context=None):
"""Called after a batch has been processed"""
pass
[docs] def on_new_trial(self, task, step, parameters, uid):
"""Called after a trial has been processed"""
pass
[docs] def on_start_train(self, task, step=None):
"""Called on ce the training starts
Notes
-----
You should not rely on this function to initialize your metric as it will
not be called if the training is resumed from a previous state
"""
pass
[docs] def on_end_train(self, task, step=None):
"""Called at the end of training after the last epoch"""
pass
[docs] def value(self):
"""Return the key values that metrics computes"""
return dict()
[docs] def every(self, *args, epoch=None, batch=None):
"""Define how often this metric should be called"""
assert len(args) == 0
if epoch is not None:
self.frequency_epoch = epoch
if batch is not None:
self.frequency_batch = batch
return self
[docs] def state_dict(self):
"""Return a state dictionary used to checkpointing and resuming"""
warning(f'This metric {type(self)} does not support resuming')
return {}
[docs] def load_state_dict(self, state_dict):
"""Load a state dictionary to resume a previous training"""
warning(f'This metric {type(self)} does not support resuming')
Metric = Observer