import torch
from torch.nn import Module, CrossEntropyLoss
from torch.nn.functional import log_softmax
from olympus.observers import (
ElapsedRealTime,
SampleCount,
ProgressView,
Speed,
CheckPointer,
)
from olympus.metrics import OnlineTrainAccuracy
from olympus.tasks.task import Task
from olympus.utils import select, drop_empty_key
from olympus.resuming import state_dict, load_state_dict, BadResumeGuard
from olympus.transforms import Preprocessor
[docs]class Classification(Task):
"""Train a model to recognize a range of classes
Attributes
----------
classifier: Module
Module taking sample data and returning the probability of the sample belonging to a range of classes
optimizer: Optimizer
Optimizer taking model's parameters
criterion: Module
Function evaluating the quality of the model's predictions, also named cost function or loss function
lr_scheduler: LRSchedule
Learning Scheduler, updates the learning rates periodically
dataloader: Iterator
Batch sample iterator used to train the model
preprocessor: Preprocessor
Set of functions that transform the inputs before it is given to the model
device:
Acceleration device to run the task on
storage: Storage
Where to save checkpoints in case of failures
"""
def __init__(
self,
classifier,
optimizer,
lr_scheduler,
dataloader,
criterion=None,
device=None,
storage=None,
preprocessor=None,
metrics=None,
):
super(Classification, self).__init__(device=device)
criterion = select(criterion, CrossEntropyLoss())
self._first_epoch = 0
self.current_epoch = 0
self.classifier = classifier
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.dataloader = dataloader
self.criterion = criterion
self.preprocessor = Preprocessor()
# ------------------------------------------------------------------
self.metrics.append(ElapsedRealTime().every(batch=1))
self.metrics.append(SampleCount().every(batch=1, epoch=1))
self.metrics.append(OnlineTrainAccuracy())
self.metrics.append(Speed())
# All metrics must be before ProgressView and CheckPointer
if metrics:
for metric in metrics:
self.metrics.append(metric)
self.metrics.append(ProgressView(self.metrics.get("Speed")))
if storage:
self.metrics.append(CheckPointer(storage=storage))
# ------------------------------------------------------------------
if preprocessor is not None:
self.preprocessor = preprocessor
self.hyper_parameters = {}
# Hyper Parameter Settings
# ---------------------------------------------------------------------
[docs] def get_space(self):
"""Return hyper parameter space"""
return drop_empty_key(
{
"optimizer": self.optimizer.get_space(),
"lr_schedule": self.lr_scheduler.get_space(),
"model": self.model.get_space(),
}
)
[docs] def get_current_space(self):
"""Get currently defined parameter space"""
return {
"optimizer": self.optimizer.get_current_space(),
"lr_schedule": self.lr_scheduler.get_current_space(),
"model": self.model.get_current_space(),
}
[docs] def init(self, optimizer=None, lr_schedule=None, model=None, uid=None):
"""
Parameters
----------
optimizer: Dict
Optimizer hyper parameters!s
lr_schedule: Dict
lr schedule hyper parameters
model: Dict
model hyper parameters
uid: Optional[str]
trial id to use for logging.
When using orion usually it already created a trial for us we just need to append to it
"""
optimizer = select(optimizer, {})
lr_schedule = select(lr_schedule, {})
model = select(model, {})
self.classifier.init(**model)
# list of all parameters this task has
parameters = self.preprocessor.parameters()
parameters.append({"params": self.classifier.parameters()})
# We need to set the device now so optimizer receive cuda tensors
self.set_device(self.device)
self.optimizer.init(params=parameters, override=True, **optimizer)
self.lr_scheduler.init(self.optimizer, override=True, **lr_schedule)
self.hyper_parameters = {
"optimizer": optimizer,
"lr_schedule": lr_schedule,
"model": model,
}
# Get all hyper parameters even the one that were set manually
hyperparameters = self.get_current_space()
# Trial Creation and Trial resume
self.metrics.new_trial(hyperparameters, uid)
self.set_device(self.device)
# Training
# ---------------------------------------------------------------------
[docs] def fit(self, epochs, context=None):
if self.stopped:
return
with BadResumeGuard(self):
self.classifier.to(self.device)
self._start(epochs)
for epoch in range(self._first_epoch, epochs):
self.epoch(epoch + 1, context)
if self.stopped:
break
self.metrics.end_train()
self._first_epoch = epochs
[docs] def epoch(self, epoch, context):
self.current_epoch = epoch
self.metrics.new_epoch(epoch, context)
# iterations = len(self.dataloader) * (epoch - 1)
for step, mini_batch in enumerate(self.dataloader):
# why is this there
# step += iterations
self.metrics.new_batch(step, mini_batch, None)
results = self.step(step, mini_batch, context)
self.lr_scheduler.step(step)
self.metrics.end_batch(step, mini_batch, results)
self.lr_scheduler.epoch(epoch, self._get_validation_accuracy)
self.metrics.end_epoch(epoch, context)
[docs] def step(self, step, input, context):
self.classifier.train()
self.optimizer.zero_grad()
batch, target, *_ = self.preprocessor(input)
batch = [x.to(device=self.device) for x in batch]
predictions = self.classifier(*batch)
loss = self.criterion(predictions, target.to(device=self.device))
self.optimizer.backward(loss)
self.optimizer.step()
results = {
# to compute online loss
"loss": loss.detach(),
# to compute only accuracy
"predictions": predictions.detach(),
}
return results
# ---------------------------------------------------------------------
def _get_validation_accuracy(self, x):
return self.metrics.value().get("validation_accuracy", None)
[docs] def eval_loss(self, batch):
self.model.eval()
with torch.no_grad():
batch, target = batch
batch = [x.to(device=self.device) for x in batch]
predictions = self.classifier(*batch)
loss = self.criterion(predictions, target.to(device=self.device))
self.model.train()
return loss.detach()
[docs] def predict_scores(self, batch):
with torch.no_grad():
data = [x.to(device=self.device) for x in batch]
return self.classifier(*data)
[docs] def predict_log_probabilities(self, batch):
return log_softmax(self.predict_scores(batch), dim=1)
[docs] def predict(self, batch, target=None):
scores = self.predict_scores(batch)
_, predicted = torch.max(scores, 1)
loss = None
if target is not None:
loss = self.criterion(scores, target.to(device=self.device))
return predicted, loss
[docs] def accuracy(self, batch, target):
self.model.eval()
with torch.no_grad():
predicted, loss = self.predict(batch, target)
acc = (predicted == target.to(device=self.device)).sum()
self.model.train()
return acc.float(), loss
[docs] def load_state_dict(self, state, strict=True):
load_state_dict(self, state, strict, force_default=True)
self._first_epoch = state["epoch"]
self.current_epoch = state["epoch"]
[docs] def state_dict(self, destination=None, prefix="", keep_vars=False):
state = state_dict(self, destination, prefix, keep_vars, force_default=True)
state["epoch"] = self.current_epoch
return state
[docs] def parameters(self):
return self.classifier.parameters()
@property
def model(self) -> Module:
return self.classifier
@model.setter
def model(self, model):
self.classifier = model