Source code for olympus.datasets.fashionmnist

from filelock import FileLock
import torch
from torchvision import datasets, transforms

from olympus.datasets.dataset import AllDataset
from olympus.utils import option


[docs]class FashionMNIST(AllDataset): """Fashion-MNIST, is a dataset comprising of 28x28 grayscale images of 70,000 fashion products from 10 categories, with 7,000 images per category. The training set has 60,000 images and the test set has 10,000 images. Fashion-MNIST is intended to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms, as it shares the same image size, data format and the structure of training and testing splits. More on `arxiv <https://arxiv.org/abs/1708.07747>`_. The full specification can be found at `here <https://github.com/zalandoresearch/fashion-mnist>`_. See also :class:`.BalancedEMNIST` and :class:`.MNIST` Attributes ---------- classes: List[int] Return the mapping between samples index and their class input_shape: (28, 28) Size of a sample stored in this dataset output_shape: (10,) The classes are (T-shirt, Trouser, Pullover, Dress, Coat, Sandals, Shirt, Sneaker, Bag, Ankle Boot) train_size: 50000 Size of the train dataset valid_size: 10000 Size of the validation dataset test_size: 10000 Size of the test dataset References ---------- .. [1] Han Xiao, Kashif Rasul, Roland Vollgraf. "Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms" Aug 2017 """ def __init__(self, data_path): with FileLock('FashionMNIST.lock', timeout=option('download.lock.timeout', 4 * 60, type=int)): train_dataset = datasets.FashionMNIST( data_path, train=True, download=True, transform=transforms.ToTensor() ) with FileLock('FashionMNIST.lock', timeout=option('download.lock.timeout', 4 * 60, type=int)): test_dataset = datasets.FashionMNIST( data_path, train=False, download=True, transform=transforms.ToTensor() ) super(FashionMNIST, self).__init__( torch.utils.data.ConcatDataset([train_dataset, test_dataset]), test_size=len(test_dataset) )
[docs] @staticmethod def categories(): return set(['classification'])
builders = { 'fashion_mnist': FashionMNIST}