from dataclasses import dataclass, field
from datetime import datetime
import torch
from torch.utils.data import DataLoader
from olympus.observers.observer import Metric
from olympus.utils import error
from olympus.utils.stat import StatStream
from olympus.utils.cuda import Stream, stream
import math
[docs]class NotFittedError(Exception):
pass
[docs]def detach(f):
if isinstance(f, torch.Tensor):
return f.detach()
return f
[docs]def item(f):
if isinstance(f, torch.Tensor):
return f.item()
return f
[docs]@dataclass
class Accuracy(Metric):
loader: DataLoader = None
accuracies: list = field(default_factory=list)
losses: list = field(default_factory=list)
name: str = 'validation'
eval_time: StatStream = field(default_factory=lambda: StatStream(drop_first_obs=0))
total_time: int = 0
metric_stream: Stream = field(default_factory=Stream)
frequency_new_epoch: int = 1
frequency_new_batch: int = 0
[docs] def state_dict(self):
return dict(accuracies=self.accuracies, losses=self.losses)
[docs] def load_state_dict(self, state_dict):
self.accuracies = state_dict['accuracies']
self.losses = state_dict['losses']
[docs] def on_new_trial(self, task, step, parameters, uid):
if not hasattr(task, 'accuracy'):
raise AttributeError('Task need the accuracy method to compute the AccuracyMetric')
[docs] def compute_accuracy(self, task):
start = datetime.utcnow()
losses = []
accs = []
count = len(self.loader)
with stream(self.metric_stream):
with torch.no_grad():
total = 0
for data, target in self.loader:
accuracy, loss = task.accuracy(data, target)
batch_size = target.shape[0]
total += batch_size
accs.append(detach(accuracy))
losses.append(detach(loss)*batch_size)
acc = math.fsum([item(a) for a in accs]) / total
loss_acc = math.fsum([item(l) for l in losses]) / total
end = datetime.utcnow()
eval_time = (end - start).total_seconds()
loss = (loss_acc / count)
return eval_time, acc, loss
[docs] def get_accuracy(self, task, epoch, context):
# I would like to make this completely async
# but I do not think I can do it easily
# Good enough for now
eval_time, acc, loss = self.compute_accuracy(task)
self.eval_time += eval_time
self.accuracies.append(acc)
self.losses.append(loss)
[docs] def on_end_epoch(self, task, epoch, context):
self.get_accuracy(task, epoch, context)
[docs] def on_end_train(self, task, step=None):
self.get_accuracy(task, step, None)
[docs] def on_start_train(self, task, step=None):
try:
self.get_accuracy(task, step, None)
except NotFittedError:
# error('')
pass
[docs] def value(self):
if not self.accuracies:
return {}
return {
f'{self.name}_accuracy': self.accuracies[-1],
f'{self.name}_error_rate': 1 - self.accuracies[-1],
f'{self.name}_loss': self.losses[-1],
f'{self.name}_time': self.eval_time.avg
}
[docs]@dataclass
class OnlineTrainAccuracy(Metric):
"""Reuse precomputed loss and prediction to get accuracy
because the model is updated in between each batch, this does not return the true accuracy on the training set,
"""
accuracies: list = field(default_factory=list)
losses: list = field(default_factory=list)
accumulator: int = 0
loss: int = 0
count: int = 0
frequency_end_epoch: int = 1
frequency_end_batch: int = 1
[docs] def state_dict(self):
return dict(
accuracies=self.accuracies,
losses=self.losses,
accumulator=self.accumulator,
loss=self.loss,
count=self.count
)
[docs] def load_state_dict(self, state_dict):
self.accuracies = state_dict['accuracies']
self.losses = state_dict['losses']
self.accumulator = state_dict['accumulator']
self.loss = state_dict['loss']
self.count = state_dict['count']
[docs] def on_end_batch(self, task, step, input, context):
_, targets, *_ = input
predictions = context.get('predictions')
# compute accuracy for the current batch
if predictions is not None:
_, predicted = torch.max(predictions, 1)
target = input[1].to(device=task.device)
loss = task.criterion(predictions, target).item()
acc = (predicted == target).sum().item() / target.size(0)
self.accumulator += acc
self.loss += loss
self.count += 1
[docs] def on_end_epoch(self, task, epoch, context):
if self.count > 0:
# new epoch
self.accuracies.append(self.accumulator / self.count)
self.losses.append(self.loss / self.count)
self.accumulator = 0
self.loss = 0
self.count = 0
[docs] def on_end_train(self, task, step=None):
if self.count > 0:
self.on_new_epoch(task, None, None)
[docs] def value(self):
if not self.accuracies:
return {}
return {
'online_train_accuracy': self.accuracies[-1],
'online_train_loss': self.losses[-1]
}
[docs]@dataclass
class AUC(Metric):
loader: DataLoader = None
aucs: list = field(default_factory=list)
pccs: list = field(default_factory=list)
name: str = 'validation'
eval_time: StatStream = field(default_factory=lambda: StatStream(drop_first_obs=0))
total_time: int = 0
metric_stream: Stream = field(default_factory=Stream)
frequency_new_epoch: int = 1
frequency_new_batch: int = 0
[docs] def state_dict(self):
return dict(aucs=self.aucs, pccs=self.pccs)
[docs] def load_state_dict(self, state_dict):
self.aucs = state_dict['aucs']
self.pccs = state_dict['pccs']
[docs] def compute_auc(self, task):
start = datetime.utcnow()
data = self.loader[0]
#auc, pcc = task.auc(data[0][0], data[1])
auc, pcc = task.auc(data[0], data[1])
end = datetime.utcnow()
eval_time = (end - start).total_seconds()
return eval_time, auc, pcc
[docs] def get_auc(self, task, epoch, context):
# I would like to make this completely async
# but I do not think I can do it easily
# Good enough for now
eval_time, auc, pcc = self.compute_auc(task)
self.eval_time += eval_time
self.aucs.append(auc)
self.pccs.append(pcc)
[docs] def on_end_epoch(self, task, epoch, context):
self.get_auc(task, epoch, context)
[docs] def on_end_train(self, task, step=None):
self.get_auc(task, step, None)
[docs] def on_start_train(self, task, step=None):
try:
self.get_auc(task, step, None)
except NotFittedError:
# error('')
pass
[docs] def value(self):
if not self.aucs:
return {}
return {
f'{self.name}_auc': self.aucs[-1],
f'{self.name}_aac': 1 - self.aucs[-1], # Area above the curve... ¯\_(ツ)_/¯
f'{self.name}_pcc': self.pccs[-1],
f'{self.name}_time': self.eval_time.avg
}