Source code for olympus.tasks.segmentation

import torch
import sklearn
import sklearn.metrics
import numpy as np

from torch.nn import Module, CrossEntropyLoss
from torch.nn.functional import log_softmax

from olympus.observers import ElapsedRealTime, SampleCount, ProgressView, Speed, CheckPointer
from olympus.metrics import OnlineLoss
from olympus.tasks.task import Task
from olympus.utils import select, drop_empty_key
from olympus.resuming import state_dict, load_state_dict, BadResumeGuard
from olympus.transforms import Preprocessor


[docs]class Segmentation(Task): """Train a model to recognize a range of classes Attributes ---------- classifier: Module Module taking sample image and returning the probability of each pixel belonging to a range of classes optimizer: Optimizer Optimizer taking model's parameters criterion: Module Function evaluating the quality of the model's predictions, also named cost function or loss function lr_scheduler: LRSchedule Learning Scheduler, updates the learning rates periodically dataloader: Iterator Batch sample iterator used to train the model preprocessor: Preprocessor Set of functions that transform the inputs before it is given to the model device: Acceleration device to run the task on storage: Storage Where to save checkpoints in case of failures """ def __init__(self, classifier, optimizer, lr_scheduler, dataloader, criterion, nclasses, device=None, storage=None, preprocessor=None, metrics=None): super(Segmentation, self).__init__(device=device) criterion = select(criterion, CrossEntropyLoss()) self._first_epoch = 0 self.current_epoch = 0 self.classifier = classifier self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.dataloader = dataloader self.criterion = criterion self.preprocessor = Preprocessor() # ------------------------------------------------------------------ # TODO: This should go inside user code it will remove 2 arguments self.nclasses = nclasses self.metrics.append(ElapsedRealTime().every(batch=1)) self.metrics.append(SampleCount().every(batch=1, epoch=1)) self.metrics.append(OnlineLoss()) self.metrics.append(Speed()) # All metrics must be before ProgressView and CheckPointer if metrics: for metric in metrics: self.metrics.append(metric) self.metrics.append(ProgressView(speed=self.metrics.get('Speed'))) if storage: self.metrics.append(CheckPointer(storage=storage)) # ------------------------------------------------------------------ if preprocessor is not None: self.preprocessor = preprocessor self.hyper_parameters = {} # Hyper Parameter Settings # ---------------------------------------------------------------------
[docs] def get_space(self): """Return hyper parameter space""" return drop_empty_key({ 'optimizer': self.optimizer.get_space(), 'lr_schedule': self.lr_scheduler.get_space(), 'model': self.model.get_space() })
[docs] def get_current_space(self): """Get currently defined parameter space""" return { 'optimizer': self.optimizer.get_current_space(), 'lr_schedule': self.lr_scheduler.get_current_space(), 'model': self.model.get_current_space() }
[docs] def init(self, optimizer=None, lr_schedule=None, model=None, uid=None): """ Parameters ---------- optimizer: Dict Optimizer hyper parameters!s lr_shchedule: Dict lr schedule hyper parameters model: Dict model hyper parameters trial_id: Optional[str] trial id to use for logging. When using orion usually it already created a trial for us we just need to append to it """ optimizer = select(optimizer, {}) lr_schedule = select(lr_schedule, {}) model = select(model, {}) self.classifier.init( **model ) # list of all parameters this task has parameters = self.preprocessor.parameters() parameters.append({ 'params': self.classifier.parameters()} ) # We need to set the device now so optimizer receive cuda tensors self.set_device(self.device) self.optimizer.init( parameters, override=True, **optimizer ) self.lr_scheduler.init( self.optimizer, override=True, **lr_schedule ) self.hyper_parameters = { 'optimizer': optimizer, 'lr_schedule': lr_schedule, 'model': model } # Get all hyper parameters even the one that were set manually hyperparameters = self.get_current_space() # Trial Creation and Trial resume self.metrics.new_trial(hyperparameters, uid) self.set_device(self.device)
# Training # ---------------------------------------------------------------------
[docs] def fit(self, epochs, context=None): if self.stopped: return with BadResumeGuard(self): self.classifier.to(self.device) self._start(epochs) for epoch in range(self._first_epoch, epochs): self.epoch(epoch + 1, context) if self.stopped: break self.metrics.end_train() self._first_epoch = epochs
[docs] def epoch(self, epoch, context): self.current_epoch = epoch self.metrics.new_epoch(epoch, context) iterations = len(self.dataloader) * (epoch - 1) for step, mini_batch in enumerate(self.dataloader): step += iterations self.metrics.new_batch(step, mini_batch, None) results = self.step(step, mini_batch, context) self.lr_scheduler.step(step) self.metrics.end_batch(step, mini_batch, results) self.lr_scheduler.epoch(epoch) self.metrics.end_epoch(epoch, context)
[docs] def step(self, step, input, context): self.classifier.train() self.optimizer.zero_grad() batch, target, *_ = self.preprocessor(input) batch = [x.to(device=self.device) for x in batch] predictions = self.classifier(*batch) loss = self.criterion(predictions, target.to(device=self.device)) self.optimizer.backward(loss) self.optimizer.step() results = { # to compute online loss 'loss': loss.detach(), } return results
# ---------------------------------------------------------------------
[docs] def eval_loss(self, batch): self.model.eval() with torch.no_grad(): batch, target = batch batch = [x.to(device=self.device) for x in batch] predictions = self.classifier(*batch) loss = self.criterion(predictions, target.to(device=self.device)) self.model.train() return loss.detach()
[docs] def predict_scores(self, batch): with torch.no_grad(): data = [x.to(device=self.device) for x in batch] return self.classifier(*data)
[docs] def predict_log_probabilities(self, batch): return log_softmax(self.predict_scores(batch), dim=1)
[docs] def predict(self, batch, target=None): scores = self.predict_scores(batch) _, predicted = torch.max(scores, 1) loss = None if target is not None: loss = self.criterion(scores, target.to(device=self.device)) return predicted, loss
[docs] def confusion_matrix(self, batch, target): self.model.eval() with torch.no_grad(): predicted, loss = self.predict(batch, target) idx = target != 255 target = target[idx] predicted = predicted[idx] target, predicted = target.cpu().numpy(), predicted.cpu().numpy() conf_mtx = sklearn.metrics.confusion_matrix(target, predicted, labels=np.arange(self.nclasses)) loss = loss.detach().item() self.model.train() return conf_mtx, loss
[docs] def load_state_dict(self, state, strict=True): load_state_dict(self, state, strict, force_default=True) self._first_epoch = state['epoch'] self.current_epoch = state['epoch']
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): state = state_dict(self, destination, prefix, keep_vars, force_default=True) state['epoch'] = self.current_epoch return state
[docs] def parameters(self): return self.classifier.parameters()
@property def model(self) -> Module: return self.classifier @model.setter def model(self, model): self.classifier = model