import logging
import numpy
import torch.nn as nn
from olympus.utils import info
[docs]class LeNet(nn.Module):
"""
`Paper <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Attributes
----------
input_size: (1, 28, 28), (3, 32, 32), (3, 64, 64)
Supported input sizes
References
----------
.. [1] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner.
"Gradient-based learning applied to document recognition."
Proceedings of the IEEE, 86(11):2278-2324, November 1998.
"""
def __init__(self, input_size, num_classes):
super(LeNet, self).__init__()
if not isinstance(num_classes, int):
num_classes = numpy.product(num_classes)
n_channels = input_size[0]
if tuple(input_size) == (1, 28, 28):
info('Using LeNet architecture for MNIST')
self.conv1 = nn.Conv2d(n_channels, 20, 5, 1)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(50 * 4 * 4, 500)
self.fc2 = nn.Linear(500, num_classes)
elif tuple(input_size) == (3, 32, 32):
info('Using LeNet architecture for CIFAR10/100')
self.conv1 = nn.Conv2d(n_channels, 20, 5, 1)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(50 * 5 * 5, 500)
self.fc2 = nn.Linear(500, num_classes)
elif tuple(input_size) == (3, 64, 64):
info('Using LeNet architecture for TinyImageNet')
self.conv1 = nn.Conv2d(n_channels, 20, 5, 1)
self.pool1 = nn.MaxPool2d(3, 3)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.pool2 = nn.MaxPool2d(3, 3)
self.fc1 = nn.Linear(50 * 5 * 5, 500)
self.fc2 = nn.Linear(500, num_classes)
else:
raise ValueError(
'There is no LeNet architecture for an input size {}'.format(input_size))
[docs] def forward(self, x):
out = nn.functional.relu(self.conv1(x))
out = self.pool1(out)
out = nn.functional.relu(self.conv2(out))
out = self.pool2(out)
out = out.view(out.size(0), -1)
out = nn.functional.relu(self.fc1(out))
out = self.fc2(out)
return out
[docs]def build(input_size, output_size):
return LeNet(input_size, output_size)
builders = {'lenet': build}