Source code for olympus.tasks.task

import torch
from olympus.metrics import MetricList


[docs]class Task: def __init__(self, device=None): self._device = device if device else torch.device("cpu") self._first_epoch = 0 self._metrics = MetricList(task=self) self.bad_state = False self.dataloader = None self.stopped = False @property def events(self): return self._metrics @property def device(self): return self._device @device.setter def device(self, device): self.set_device(device)
[docs] def set_device(self, device): for name in dir(self): attr = getattr(self, name) if hasattr(attr, "to"): try: setattr(self, name, attr.to(device=device)) except: print(f"Cant set attribute on {name} {attr}") raise self._device = device
[docs] def eval_loss(self, batch): """This is used to compute validation and test loss""" raise NotImplementedError()
def _start(self, epochs): progress = self.metrics.get("ProgressView") if progress: # in case of a resume progress.epoch = self._first_epoch progress.set_max_epochs(epochs) if not self.resumed(): self.metrics.start_train() else: self.metrics.resume_train(self._first_epoch)
[docs] def fit(self, epoch, context=None): """Execute a single batch Parameters ---------- epoch: int current step in the training process context: dict Optional Context Notes ----- You should wrap whatever code you have here inside a `BadResumeGuard` to prevent users from resuming a failed task that can have a bad states To resume a task, you need to create a clean one with the same hyper parameters. It will pickup automatically where at its last checkpoint """ raise NotImplementedError()
@property def metrics(self): return self._metrics
[docs] def report(self, pprint=True, print_fun=print): m = self.metrics if m: return self.metrics.report(pprint, print_fun)
[docs] def summary(self): print(GenerateSummary().task_summary(self))
[docs] def get_space(self): """Return missing hyper parameters that need to be set using `init`""" raise NotImplementedError()
[docs] def init(self, **kwargs): """Used to initialize the hyperparameters is any""" raise NotImplementedError()
[docs] def resumed(self): return self._first_epoch > 0
[docs] def load_state_dict(self, state, strict=True): """Try to load a previous unfinished state to resume Notes ----- You should wrap whatever code you have here inside a `BadResumeGuard` to prevent users from resuming a failed task that can have a bad states To resume a task, you need to create a clean one with the same hyper parameters. It will pickup automatically where at its last checkpoint """ raise NotImplementedError()
[docs] def state_dict(self, destination=None, prefix="", keep_vars=False): """Save a state the task can go back to if an error occur""" raise NotImplementedError()
def _fix(self): # ----------------------------- # RL Creates a lot of small torch.tensor # They need to be GCed so pytorch can reuse that memory import gc # Only GC the most recent gen because that where the small tensors are gc.collect(2)
# -----------------------------
[docs]class GenerateSummary: dispatch = { "Model": lambda model: model.model, "Optimizer": lambda optimizer: optimizer.optimizer, "DataLoader": lambda data: data.dataset, "TransformedSubset": lambda data: data.dataset, "MetricList": lambda metrics: metrics.metrics, "LRSchedule": lambda schedule: schedule.lr_scheduler, } _rename = { "_metrics": "metrics", "_device": "device", "_first_epoch": "first_epoch", } def __init__(self): self.output = []
[docs] def print(self, msg="", end="\n"): self.output.append(f"{msg}{end}")
[docs] def is_nested(self, name): return name in GenerateSummary.dispatch
[docs] def retrieve_nested(self, name, obj): return GenerateSummary.dispatch.get(name, lambda x: x)(obj)
[docs] def rename(self, name): return GenerateSummary._rename.get(name, name)
[docs] def get_name(self, attr, obj, type_name, depth=0): self.print(f'{" " * depth} {self.rename(attr)}: ', end="") if not self.is_nested(type_name): if type_name == "device": self.print(str(obj)) elif type_name == "list": self.print() for item in obj: self.print(f'{" " * (depth + 1)} - {type(item).__name__}') else: self.print(type_name) else: self.print() nested = self.retrieve_nested(type_name, obj) nested_type = type(nested).__name__ self.get_name(type_name, nested, nested_type, depth + 1)
[docs] def task_summary(self, obj): self.output = [] self.print("=" * 80) self.print(type(obj).__name__) self.print("-" * 80) for attr, value in obj.__dict__.items(): type_name = type(value).__name__ self.get_name(attr, value, type_name) self.print("=" * 80) return "".join(self.output)