from olympus.utils import warning
from olympus.utils.factory import fetch_factories
registered_environment = fetch_factories('olympus.reinforcement', __file__)
[docs]def known_environments(*category_filters, include_unknown=False):
"""List known environments"""
if not category_filters:
return registered_environment.keys()
matching = []
for filter in category_filters:
for name, factory in registered_environment.items():
if hasattr(factory, 'categories'):
if filter in factory.categories():
matching.append(name)
# we don't know if it matches because it does not have the categories method
elif include_unknown:
matching.append(name)
return matching
[docs]def register_environment(name, factory, override=False):
"""Register a new environment backend"""
global registered_environment
if name in registered_environment:
warning(f'{name} was already registered, use override=True to ignore')
if not override:
return
registered_environment[name] = factory
[docs]class RegisteredEnvironmentNotFound(Exception):
pass
[docs]class Environment:
"""Generic Reinforcement Learning Environment, can run multiple simulation in parallel"""
def __init__(self, env_name, transforms=None, rand_seed=None, train_size=1024, valid_size=128, test_size=128, parallel_env=8,
num_thread=4,
distribution_mode='easy'):
env_ctor = registered_environment.get(env_name)
if env_ctor is None:
raise RegisteredEnvironmentNotFound(env_name)
self.env = env_ctor(
transforms=transforms,
rand_seed=rand_seed,
train_size=train_size,
valid_size=valid_size,
test_size=test_size,
parallel_env=parallel_env,
num_thread=num_thread,
distribution_mode=distribution_mode)
[docs] def close(self):
return self.env.close()
@property
def state_space(self):
return self.input_size
@property
def action_space(self):
return self.target_size
@property
def input_size(self):
"""Return the size of the samples"""
return self.env.input_size
@property
def target_size(self):
"""Return the size of the target"""
return self.env.action_space
@property
def train(self):
return self.env.train
@property
def valid(self):
return self.env.valid
@property
def test(self):
return self.env.test
[docs] def categories(self):
"""Dataset tags so we can filter what we want depending on the task"""
return self.env.categories
[docs] def state_dict(self):
return self.env.state_dict()
[docs] def load_state_dict(self, data):
return self.env.load_state_dict(data)
[docs] def sample_action(self):
return self.env.sample_action()
[docs] def max(self):
return self.env.train