from dataclasses import dataclass, field
from datetime import datetime
import torch
from torch.utils.data import DataLoader
from olympus.observers.observer import Metric
from olympus.utils.stat import StatStream
from olympus.utils.cuda import Stream, stream
[docs]@dataclass
class Accuracy(Metric):
loader: DataLoader = None
accuracies: list = field(default_factory=list)
losses: list = field(default_factory=list)
name: str = 'validation'
eval_time: StatStream = field(default_factory=lambda: StatStream(drop_first_obs=0))
total_time: int = 0
metric_stream: Stream = field(default_factory=Stream)
frequency_new_epoch: int = 1
frequency_new_batch: int = 0
[docs] def state_dict(self):
return dict(accuracies=self.accuracies, losses=self.losses)
[docs] def load_state_dict(self, state_dict):
self.accuracies = state_dict['accuracies']
self.losses = state_dict['losses']
[docs] def compute_accuracy(self, task):
start = datetime.utcnow()
losses = []
accs = []
count = len(self.loader)
with stream(self.metric_stream):
with torch.no_grad():
for data, target in self.loader:
accuracy, loss = task.accuracy(data, target)
accs.append(accuracy.detach())
losses.append(loss.detach())
acc = sum([a.item() for a in accs])
loss_acc = sum([l.item() for l in losses])
end = datetime.utcnow()
eval_time = (end - start).total_seconds()
acc = (acc / count)
loss = (loss_acc / count)
return eval_time, acc, loss
[docs] def on_new_epoch(self, task, epoch, context):
# I would like to make this completely async
# but I do not think I can do it easily
# Good enough for now
eval_time, acc, loss = self.compute_accuracy(task)
self.eval_time += eval_time
self.accuracies.append(acc)
self.losses.append(loss)
[docs] def on_start_train(self, task, step=None):
self.on_new_epoch(task, step, None)
[docs] def on_end_train(self, task, step=None):
self.on_new_epoch(task, step, None)
[docs] def value(self):
if not self.accuracies:
return {}
return {
f'{self.name}_accuracy': self.accuracies[-1],
f'{self.name}_loss': self.losses[-1],
f'{self.name}_time': self.eval_time.avg
}
[docs]@dataclass
class OnlineTrainAccuracy(Metric):
"""Reuse precomputed loss and prediction to get accuracy
because the model is updated in between each batch, this does not return the true accuracy on the training set,
"""
accuracies: list = field(default_factory=list)
losses: list = field(default_factory=list)
accumulator: int = 0
loss: int = 0
count: int = 0
frequency_end_epoch: int = 1
frequency_end_batch: int = 1
[docs] def state_dict(self):
return dict(
accuracies=self.accuracies,
losses=self.losses,
accumulator=self.accumulator,
loss=self.loss,
count=self.count
)
[docs] def load_state_dict(self, state_dict):
self.accuracies = state_dict['accuracies']
self.losses = state_dict['losses']
self.accumulator = state_dict['accumulator']
self.loss = state_dict['loss']
self.count = state_dict['count']
[docs] def on_end_batch(self, task, step, input, context):
_, targets = input
predictions = context.get('predictions')
# compute accuracy for the current batch
if predictions is not None:
_, predicted = torch.max(predictions, 1)
target = input[1].to(device=task.device)
loss = task.criterion(predictions, target).item()
acc = (predicted == target).sum().item() / target.size(0)
self.accumulator += acc
self.loss += loss
self.count += 1
[docs] def on_end_epoch(self, task, epoch, context):
if self.count > 0:
# new epoch
self.accuracies.append(self.accumulator / self.count)
self.losses.append(self.loss / self.count)
self.accumulator = 0
self.loss = 0
self.count = 0
[docs] def on_end_train(self, task, step=None):
if self.count > 0:
self.on_new_epoch(task, None, None)
[docs] def value(self):
if not self.accuracies:
return {}
return {
'online_train_accuracy': self.accuracies[-1],
'online_train_loss': self.losses[-1]
}