Source code for olympus.datasets.mnist

import functools

from filelock import FileLock
import torch
from torchvision import datasets, transforms

from olympus.datasets.dataset import AllDataset
from olympus.transforms import minimize
from olympus.utils import option


[docs]class MNIST(AllDataset): """The MNIST database (Modified National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems. The database is also widely used for training and testing in the field of machine learning. More on `wikipedia <https://en.wikipedia.org/wiki/MNIST_database>`_. The full specification can be found at `here <http://yann.lecun.com/exdb/mnist/>`_. See also :class:`.BalancedEMNIST` and :class:`.FashionMNIST` Attributes ---------- classes: List[int] Return the mapping between samples index and their class input_shape: (28, 28) Size of a sample returned after transformation target_shape: (10,) The classes are numbers from 0 to 9 train_size: 50000 Size of the train dataset valid_size: 10000 Size of the validation dataset test_size: 10000 Size of the test dataset References ---------- .. [1] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to document recognition." Proceedings of the IEEE, 86(11):2278-2324, November 1998. """ def __init__(self, data_path, mini=False, train_size=None, valid_size=None, test_size=None, input_shape=None, target_shape=None, **kwargs): transformations = [ transforms.Normalize((0.1307,), (0.3081,)) ] if mini: transformations.insert(0, minimize(7)) transform = transforms.Compose(transformations) with FileLock('mnist.lock', timeout=option('download.lock.timeout', 4 * 60, type=int)): train_dataset = datasets.MNIST( data_path, train=True, download=True, transform=transforms.ToTensor() ) with FileLock('mnist.lock', timeout=option('download.lock.timeout', 4 * 60, type=int)): test_dataset = datasets.MNIST( data_path, train=False, download=True, transform=transforms.ToTensor() ) if test_size is None: test_size = len(test_dataset) super(MNIST, self).__init__( torch.utils.data.ConcatDataset([train_dataset, test_dataset]), test_size=test_size, train_size=train_size, valid_size=valid_size, transforms=transform, input_shape=input_shape, target_shape=target_shape )
[docs] @staticmethod def categories(): return set(['classification'])
builders = { 'mnist': MNIST, 'mini-mnist': functools.partial(MNIST, mini=True), # Technically this should be in the sampler logic but we want to test the usual split method 'test-mnist': functools.partial(MNIST, train_size=128, valid_size=64, test_size=64, input_shape=(28, 28), target_shape=(10,)), }