Source code for olympus.datasets.split.balanced_classes

from collections import OrderedDict
from dataclasses import dataclass

import numpy


[docs]def balanced_random_indices(method, classes, n_points, seed, split_ratio=0.1, **kwargs): assert n_points % len(classes) == 0, "n_points is not a multiple of number of classes" n_points_per_class = n_points // len(classes) assert n_points_per_class <= len(classes[0]), "n_points greater than nb of points available" n_test_per_class = int(numpy.ceil(n_points_per_class * split_ratio)) n_valid_per_class = n_test_per_class n_train_per_class = n_points_per_class - n_test_per_class - n_valid_per_class assert n_train_per_class + n_valid_per_class + n_test_per_class == n_points_per_class rng = numpy.random.RandomState(int(seed)) sampled_indices = Split(train=[], valid=[], test=[]) for indices in classes: class_sampled_indices = method( rng, indices, n_train_per_class, n_valid_per_class, n_test_per_class, **kwargs) for set_name in sampled_indices.keys(): sampled_indices[set_name].extend(class_sampled_indices[set_name]) # Make sure they are not grouped by class for set_name in sampled_indices.keys(): rng.shuffle(sampled_indices[set_name]) sampled_indices[set_name] = numpy.array(sampled_indices[set_name]) return sampled_indices
[docs]@dataclass class Split: """Returns 3 Splits of the main data set. The splits are indices of the samples inside the data set""" train: numpy.array valid: numpy.array test: numpy.array
[docs] def items(self): return ('train', self.train), ('valid', self.valid), ('test', self.test)
def __getitem__(self, item): return getattr(self, item) def __setitem__(self, key, value): return setattr(self, key, value)
[docs] def keys(self): return 'train', 'valid', 'test'
[docs] def values(self): return self.train, self.valid, self.test