import torch
import torch.nn as nn
import torchvision.transforms as transforms
from olympus.transforms import Denormalize
[docs]def imagenet_preprocess(img):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
preprocessor = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
normalize
])
if isinstance(img, list):
return torch.stack([preprocessor(i) for i in img])
return torch.stack([preprocessor(img)])
[docs]def imagenet_postprocessor(image):
denorm = Denormalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def renorm(img):
img -= img.min()
img /= img.max()
return img
post = transforms.Compose([
denorm,
renorm,
transforms.ToPILImage(),
# transforms.Grayscale(),
])
if len(image.shape) == 4:
results = []
for i in range(image.shape[0]):
img = image[i]
results.append(post(img))
else:
return post(image)
return results
[docs]class GuidedBackprop:
"""
TODO: find original paper
Parameters
----------
model:
Pytorch model
activation:
Type of the activation layer to use, defaults to ReLU
preprocessor: Callable[[List[PILImage]], Tensor]
use to apply preprocessing to images
postprocesoor: Callable[[Tensor], List[PilImage]]
used to reconstruct the image from a tensor
References
----------
.. [1] J. T. Springenberg, A. Dosovitskiy, T. Brox, and M. Riedmiller.
"Striving for Simplicity: The All Convolutional Net"
https://arxiv.org/abs/1412.6806
.. [2] K. Simonyan, A. Vedaldi, A. Zisserman.
"Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps"
https://arxiv.org/abs/1312.6034
.. [3] https://arxiv.org/pdf/1810.03292v1.pdf
Examples
--------
>>> import torchvision.models as models
>>> from torchvision import transforms
>>> from PIL import Image
>>> path = 'docs/_static/images/cat.jpg'
>>> img = Image.open(path)
>>> model = models.alexnet(pretrained=True)
>>> guided = GuidedBackprop(model)
>>> _ = guided([img], [285])
>>> for i, grad in enumerate(guided.negative_saliency()):
... img = imagenet_postprocessor(grad)
... img.save(f'negative_saliency_{i}.jpg')
.. image:: ../../../docs/_static/images/cat.jpg
:width: 45 %
.. image:: ../../../docs/_static/images/negative_saliency.jpg
:width: 45 %
"""
def __init__(self, model, activation=nn.ReLU, preprocessor=imagenet_preprocess, postprocessor=imagenet_postprocessor):
self.gradients = None
self.inputs = None
self.activation_stack = []
self.model = model
self.activation = activation
self._register_hooks(model, activation)
self.preprocessor = preprocessor
self.postprocessor = postprocessor
def _register_hooks(self, model, activation=nn.ReLU):
layers = list([module for module in model.modules() if type(module) != nn.Sequential])
# first layer is self
skip = int(isinstance(layers[0], type(model)))
layers = layers[skip:]
first_layer = layers[0]
first_layer.register_backward_hook(self.fetch_gradient)
# Hook to activation
for module in layers:
if isinstance(module, activation):
module.register_backward_hook(self.activation_backward)
module.register_forward_hook(self.activation_forward)
return self
[docs] def fetch_gradient(self, module, grad_in, grad_out):
"""Fetch last gradient or gradient of the first layer"""
self.gradients = grad_in[0].detach()
[docs] def activation_forward(self, module, ten_in, ten_out):
"""Forward hook to the activation layer"""
self.activation_stack.append(ten_out.detach())
[docs] def activation_backward(self, module, grad_in, grad_out):
"""Backward hook to the activation layer"""
output = self.activation_stack.pop()
output[output > 0] = 1
new_grad = output * torch.clamp(grad_in[0], min=0.0)
return new_grad,
[docs] def positive_saliency(self):
"""Returns positive gradients"""
return [grad.clamp(min=0) / grad.max() for grad in self.gradients]
[docs] def negative_saliency(self):
"""Returns negative gradients"""
return [(-grad).clamp(min=0) / - grad.min() for grad in self.gradients]
def __call__(self, images, classes=None):
"""
Parameters
----------
images: List
list of images used to generate saliency map
classes: Optional[List]
class index of each images, if not provided will default to class = 0.
`ImageNet class to index map <https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a>`_
"""
self.input = self.preprocessor(images)
self.input.requires_grad = True
out = self.model(self.input)
self.model.zero_grad()
# if no classes are provided use a dummy one
if classes is None:
classes = [0 for _ in range(len(images))]
# Fabricate gradient to back-propagate
gradient = torch.zeros_like(out, dtype=torch.float32)
for i, cls in enumerate(classes):
gradient[i, cls] = 1
out.backward(gradient=gradient)
return out