Source code for olympus.tasks.task

import torch
from olympus.metrics import MetricList


[docs]class Task: def __init__(self, device=None): self._device = device if device else torch.device('cpu') self._first_epoch = 0 self._metrics = MetricList(task=self) self.bad_state = False self.dataloader = None @property def device(self): return self._device @device.setter def device(self, device): self.set_device(device)
[docs] def set_device(self, device): for name in dir(self): attr = getattr(self, name) if hasattr(attr, 'to'): try: setattr(self, name, attr.to(device=device)) except: print(f'Cant set attribute on {name} {attr}') raise self._device = device
[docs] def eval_loss(self, batch): """This is used to compute validation and test loss""" raise NotImplementedError()
def _start(self, epochs): progress = self.metrics.get('ProgressView') if progress: # in case of a resume progress.epoch = self._first_epoch progress.max_epoch = epochs progress.max_step = len(self.dataloader) if not self.resumed(): self.metrics.start_train() else: self.metrics.resume_train(self._first_epoch)
[docs] def fit(self, epoch, context=None): """Execute a single batch Parameters ---------- epoch: int current step in the training process context: dict Optional Context Notes ----- You should wrap whatever code you have here inside a `BadResumeGuard` to prevent users from resuming a failed task that can have a bad states To resume a task, you need to create a clean one with the same hyper parameters. It will pickup automatically where at its last checkpoint """ raise NotImplementedError()
@property def metrics(self): return self._metrics
[docs] def report(self, pprint=True, print_fun=print): m = self.metrics if m: return self.metrics.report(pprint, print_fun)
[docs] def summary(self): print(GenerateSummary().task_summary(self))
[docs] def get_space(self, **fidelities): """Return missing hyper parameters that need to be set using `init`""" raise NotImplementedError()
[docs] def init(self, **kwargs): """Used to initialize the hyperparameters is any""" raise NotImplementedError()
[docs] def resumed(self): return self._first_epoch > 0
[docs] def load_state_dict(self, state, strict=True): """Try to load a previous unfinished state to resume Notes ----- You should wrap whatever code you have here inside a `BadResumeGuard` to prevent users from resuming a failed task that can have a bad states To resume a task, you need to create a clean one with the same hyper parameters. It will pickup automatically where at its last checkpoint """ raise NotImplementedError()
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): """Save a state the task can go back to if an error occur""" raise NotImplementedError()
def _fix(self): # ----------------------------- # RL Creates a lot of small torch.tensor # They need to be GCed so pytorch can reuse that memory import gc # Only GC the most recent gen because that where the small tensors are gc.collect(2)
# -----------------------------
[docs]class GenerateSummary: dispatch = { 'Model': lambda model: model.model, 'Optimizer': lambda optimizer: optimizer.optimizer, 'DataLoader': lambda data: data.dataset, 'TransformedSubset': lambda data: data.dataset, 'MetricList': lambda metrics: metrics.metrics, 'LRSchedule': lambda schedule: schedule.lr_scheduler } _rename = { '_metrics': 'metrics', '_device': 'device', '_first_epoch': 'first_epoch' } def __init__(self): self.output = []
[docs] def print(self, msg='', end='\n'): self.output.append(f'{msg}{end}')
[docs] def is_nested(self, name): return name in GenerateSummary.dispatch
[docs] def retrieve_nested(self, name, obj): return GenerateSummary.dispatch.get(name, lambda x: x)(obj)
[docs] def rename(self, name): return GenerateSummary._rename.get(name, name)
[docs] def get_name(self, attr, obj, type_name, depth=0): self.print(f'{" " * depth} {self.rename(attr)}: ', end='') if not self.is_nested(type_name): if type_name == 'device': self.print(str(obj)) elif type_name == 'list': self.print() for item in obj: self.print(f'{" " * (depth + 1)} - {type(item).__name__}') else: self.print(type_name) else: self.print() nested = self.retrieve_nested(type_name, obj) nested_type = type(nested).__name__ self.get_name(type_name, nested, nested_type, depth + 1)
[docs] def task_summary(self, obj): self.output = [] self.print('=' * 80) self.print(type(obj).__name__) self.print('-' * 80) for attr, value in obj.__dict__.items(): type_name = type(value).__name__ self.get_name(attr, value, type_name) self.print('=' * 80) return ''.join(self.output)