Source code for olympus.datasets.sampling

import copy
import time

import numpy

from torch.utils.data.sampler import Sampler


class _BaseSampler(Sampler):
    """Standalone Sampler that does not use a global PRNG.
    This makes the results more reproducible as no outside call can modify the state of this sampler.
    """
    def __init__(self, seed):
        if seed is None:
            seed = time.time()
        self.seed = seed

        self.rand_engine = numpy.random.RandomState(seed)

    def state_dict(self):
        return dict(rng_state=self.rand_engine.get_state())

    def load_state_dict(self, state_dict):
        self.rand_engine.set_state(state_dict['rng_state'])


[docs]class RandomSamplerWithoutReplacement(_BaseSampler): """Permute the list of items""" def __init__(self, data_source, seed): super(RandomSamplerWithoutReplacement, self).__init__(seed) self.data_source = data_source self.draws = list(range(len(self.data_source))) def _generate_draws(self): draws = copy.deepcopy(self.draws) self.rand_engine.shuffle(draws) return draws def __iter__(self): return iter(self._generate_draws()) def __len__(self): return len(self.data_source)
[docs]class RandomSamplerWithReplacement(_BaseSampler): def __init__(self, data_source, num_samples, seed): super(RandomSamplerWithReplacement, self).__init__(seed) self.data_source = data_source self.num_samples = num_samples def __iter__(self): draws = [self.rand_engine.randrange(start=0, stop=len(self)) for _ in self.num_samples] return iter(draws) def __len__(self): return len(self.data_source)
[docs]class SequentialSampler(Sampler): """Pytorch like Sampler""" def __init__(self, data_source, seed=None, replacement=False, num_samples=None): self.seed = seed if num_samples is None: num_samples = len(data_source) self.num_samples = num_samples def __iter__(self): return iter(range(self.num_samples)) def __len__(self): return self.num_samples
[docs] def state_dict(self): return {}
[docs] def load_state_dict(self, state_dict): pass
[docs]class RandomSampler(Sampler): """Pytorch like Sampler""" def __init__(self, data_source, seed=None, replacement=False, num_samples=None): self.sampler = None self.seed = seed if num_samples is None: num_samples = len(data_source) if replacement: self.sampler = RandomSamplerWithReplacement(data_source, num_samples, seed) else: self.sampler = RandomSamplerWithoutReplacement(data_source, seed) def __iter__(self): return iter(self.sampler) def __len__(self): return len(self.sampler)
[docs] def state_dict(self): return self.sampler.state_dict()
[docs] def load_state_dict(self, state_dict): self.sampler.load_state_dict(state_dict)