Source code for olympus.tasks.detection

import torch
from torch.nn import Module

from olympus.utils import select, drop_empty_key
from olympus.tasks.task import Task
from olympus.metrics import OnlineLoss
from olympus.resuming import state_dict, load_state_dict, BadResumeGuard
from olympus.observers import ProgressView, Speed, ElapsedRealTime, CheckPointer, SampleCount


[docs]class ObjectDetection(Task): def __init__(self, detector, optimizer, lr_scheduler, dataloader, criterion=None, device=None, storage=None): super(ObjectDetection, self).__init__(device=device) self._first_epoch = 0 self.current_epoch = 0 self.detector = detector self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.dataloader = dataloader self.criterion = criterion self.storage = storage self.metrics.append(ElapsedRealTime().every(batch=1)) self.metrics.append(SampleCount().every(batch=1, epoch=1)) speed = Speed() self.metrics.append(speed) self.metrics.append(ProgressView(speed)) self.metrics.append(OnlineLoss()) if storage: self.metrics.append(CheckPointer(storage=storage)) # Hyper Parameter Settings # ---------------------------------------------------------------------
[docs] def get_space(self): """Return hyper parameter space""" return drop_empty_key({ 'optimizer': self.optimizer.get_space(), 'lr_schedule': self.lr_scheduler.get_space(), 'model': self.model.get_space() })
[docs] def init(self, optimizer=None, lr_schedule=None, model=None, uid=None): optimizer = select(optimizer, {}) lr_schedule = select(lr_schedule, {}) model = select(model, {}) self.detector.init( **model ) self.set_device(self.device) self.optimizer.init( self.detector.parameters(), override=True, **optimizer ) self.lr_scheduler.init( self.optimizer, override=True, **lr_schedule ) parameters = {} parameters.update(optimizer) parameters.update(lr_schedule) parameters.update(model) # Trial Creation and Trial resume self.metrics.new_trial(parameters, uid) self.set_device(self.device)
# Training # --------------------------------------------------------------------
[docs] def fit(self, epochs, context=None): if self.stopped: return with BadResumeGuard(self): self._start(epochs) for epoch in range(self._first_epoch, epochs): self.epoch(epoch + 1, context) if self.stopped: break self.report(pprint=True, print_fun=print) self.metrics.end_train()
[docs] def epoch(self, epoch, context): self._fix() self.current_epoch = epoch self.metrics.new_epoch(epoch, context) for step, batch in enumerate(self.dataloader): self.metrics.new_batch(step, batch) results = self.step(step, batch, context) self.lr_scheduler.step(step) self.metrics.end_batch(step, batch, results) self.lr_scheduler.epoch(epoch, lambda x: self.metrics.value()['validation_loss']) self.metrics.end_epoch(epoch, context)
[docs] def step(self, step, input, context): images, targets = input images = list(image[0].to(self.device) for image in images) targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets] self.model.train() self.optimizer.zero_grad() loss_dict = self.model(images, targets) loss = self.criterion(loss_dict) loss.backward() self.optimizer.step() results = { # to compute online loss 'loss': loss.detach() } return results
# ---------------------------------------------------------------------
[docs] def load_state_dict(self, state, strict=True): load_state_dict(self, state, strict, force_default=True) self._first_epoch = state['epoch'] self.current_epoch = state['epoch']
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): state = state_dict(self, destination, prefix, keep_vars, force_default=True) state['epoch'] = self.current_epoch return state
[docs] def eval_loss(self, batch): # Will be fixed in the next-next torchvision release self.model.train() with torch.no_grad(): images, targets = batch images = list(image[0].to(self.device) for image in images) targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets] loss_dict = self.model(images, targets) loss = self.criterion(loss_dict) self.model.train() # do not use item() in the loop it forces cuda to sync if hasattr(loss, 'detach'): return loss.detach() return torch.Tensor(loss)
@property def model(self) -> Module: return self.detector @model.setter def model(self, model): self.detector = model
[docs] def parameters(self): return self.detector.parameters()