from dataclasses import dataclass, field
from datetime import datetime
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
from olympus.observers.observer import Metric
from olympus.utils.stat import StatStream
[docs]@dataclass
class ClassifierAdversary(Metric):
"""Simple Adversary Generator from `arxiv <https://arxiv.org/pdf/1412.6572.pdf.>`
Measure how robust a network is from adversary attacks
.. math::
adversary(image) = image + epsilon * sign(grad(cost(theta, image, t), image)
An adversary takes as input an image and returns a modified image that
will try to induce an error on the classifier.
Attributes
----------
epislon: float = 0.25 (for mnist) 0.07 (for ImageNet)
Epsilon corresponds to the magnitude of the smallest bit of an image encoding converted to real number
References
----------
.. [1] Ian J. Goodfellow, Jonathon Shlens, Christian Szegedy.
"Explaining and Harnessing Adversarial Examples", 20 Dec 2014
"""
epsilon: float = 0.25
accuracies: list = field(default_factory=list)
losses: list = field(default_factory=list)
distortions: list = field(default_factory=list)
distortion = 0
loss = 0
accumulator = 0
count = 0
loader: DataLoader = None
time = StatStream(drop_first_obs=1)
[docs] def on_end_epoch(self, task, epoch, context):
if self.loader:
accuracy = 0
total_loss = 0
for data, target in self.loader:
acc, loss = self.adversarial(task, data, target)
accuracy += acc.item()
total_loss += loss.item()
accuracy /= len(self.loader)
total_loss /= len(self.loader)
self.accuracies.append(accuracy)
self.losses.append(total_loss)
else:
self.accuracies.append(self.accumulator / self.count)
self.losses.append(self.loss / self.count)
self.distortions.append(self.distortion / self.count)
self.count = 0
self.distortion = 0
self.accumulator = 0
self.loss = 0
[docs] def adversarial(self, task, batch, target):
original_images = Variable(batch, requires_grad=True)
original_images.grad = None
# freeze model
for param in task.model.parameters():
param.requires_grad = False
param.grad = None
probabilities = task.model(original_images.to(device=task.device))
loss = task.criterion(probabilities, target.to(device=task.device))
loss.backward()
pertubation = self.epsilon * torch.sign(original_images.grad)
self.distortion += (pertubation.std() / original_images.std()).item()
adversarial_images = batch + pertubation
for param in task.model.parameters():
param.requires_grad = True
acc, loss = task.accuracy(adversarial_images, target)
return acc, loss
[docs] def on_end_batch(self, task, step, input, context):
# make the examples
batch, target = input
start = datetime.utcnow()
acc, loss = self.adversarial(task, batch, target)
self.time += (datetime.utcnow() - start).total_seconds
self.loss += loss.item()
self.accumulator += acc.item()
self.count += 1
[docs] def on_end_train(self, task, step=None):
if self.count > 0:
self.on_new_epoch(task, step, None)
[docs] def value(self):
results = {
'adversary_accuracy': self.accuracies[-1],
'adversary_loss': self.losses[-1],
'adversary_distortion': self.distortions[-1]
}
if self.time.count > 0:
results['adversary_time'] = self.time.avg
return results
[docs] def state_dict(self):
return dict(
accuracies=self.accuracies,
losses=self.losses,
distortions=self.distortions
)
[docs] def load_state_dict(self, state_dict):
self.accuracies = state_dict['accuracies']
self.losses = state_dict['losses']
self.losses = state_dict['distortions']