import numpy
from olympus.utils.log import info
from olympus.datasets.split.balanced_classes import balanced_random_indices, Split
[docs]def bootstrap_random_indices(rng, indices, n_train, n_valid, n_test):
indices = set(indices)
train_indices = rng.choice(list(indices), size=n_train, replace=True)
indices -= set(train_indices)
valid_indices = rng.choice(list(indices), size=n_valid, replace=True)
indices -= set(valid_indices)
test_indices = rng.choice(list(indices), size=n_test, replace=True)
indices -= set(test_indices)
return Split(train=train_indices, valid=valid_indices, test=test_indices)
[docs]def split(datasets, data_size, seed, ratio, index, balanced=True):
if balanced:
info('Using balanced bootstrap')
return balanced_random_indices(
method=bootstrap_random_indices,
classes=datasets.classes,
n_points=data_size, seed=seed,
split_ratio=ratio)
else:
info('Using unbalanced bootstrap')
n_points = len(datasets)
n_test = int(numpy.ceil(n_points * ratio))
n_valid = n_test
n_train = n_points - n_test - n_valid
rng = numpy.random.RandomState(int(seed))
return bootstrap_random_indices(rng, range(n_points), n_train, n_valid, n_test)