Source code for olympus.datasets.split.split

import copy

import numpy

from olympus.datasets.split.balanced_classes import balanced_random_indices, Split


[docs]def split_random_indices(rng, indices, n_train, n_valid, n_test, index): indices = numpy.array(copy.deepcopy(indices)) rng.shuffle(indices) n_points = n_train + n_valid + n_test start = index * n_points if start + n_points > len(indices): raise ValueError( 'Cannot have index `{}` for dataset of size `{}`'.format( index, len(indices))) train_indices = indices[start:start + n_train] valid_indices = indices[start + n_train:start + n_train + n_valid] test_indices = indices[start + n_train + n_valid:start + n_train + n_valid + n_test] return Split(train=train_indices, valid=valid_indices, test=test_indices)
[docs]def split(datasets, data_size, seed, ratio, index): return balanced_random_indices( method=split_random_indices, classes=datasets.classes, n_points=data_size, seed=seed, index=index)