Source code for olympus.reinforcement.parallel

# This code is from openai baseline, with slight modifications
# https://github.com/openai/baselines/tree/master/baselines/common/vec_env

from ctypes import Structure, c_double, c_int
from multiprocessing import Process, Pipe, Manager
from multiprocessing.sharedctypes import Value
import time

import numpy as np
import torch

from olympus.utils import debug
from olympus.utils.signals import SignalHandler


[docs]class WorkerStat(Structure): _fields_ = [ ('step', c_int), ('reset', c_int), ('close', c_int), ('done', c_int), ('get_spaces', c_int), ('step_time', c_double), ('reset_time', c_double), ('duplicates', c_int) # Number of duplicate states generated ] # if high this is a huge RED FLAG
[docs]def worker(remote, parent_remote, env_factory, stat: WorkerStat, unique_set): parent_remote.close() cmd, data = remote.recv() assert cmd == 'init', 'first message should be about environment init' env = env_factory(*data) def check_for_duplicates(state): """Check if the state is a duplicate""" key = hash(state.tostring()) if key not in unique_set: unique_set[key] = 1 else: stat.duplicates += 1 while True: cmd, data = remote.recv() if cmd == 'step': stat.step += 1 s = time.time() ob, reward, done, info = env.step(data) stat.step_time += time.time() - s if done: stat.done += 1 stat.reset += 1 s = time.time() ob = env.reset() stat.reset_time += time.time() - s check_for_duplicates(ob) remote.send((ob, reward, int(done), info)) elif cmd == 'reset': stat.reset += 1 s = time.time() ob = env.reset() stat.reset_time += time.time() - s check_for_duplicates(ob) remote.send(ob) elif cmd == 'reset_task': stat.reset += 1 s = time.time() ob = env.reset_task() stat.reset_time += time.time() - s check_for_duplicates(ob) remote.send(ob) elif cmd == 'close': stat.close += 1 debug('closing env') break elif cmd == 'get_spaces': stat.get_spaces += 1 remote.send((env.observation_space, env.action_space)) else: raise NotImplementedError debug('cleaning up') env.close() remote.close()
[docs]class VectorStat(Structure): _fields_ = [ ('step', c_int), ('step_time', c_double) ]
class _ParallelEnvironmentCleaner(SignalHandler): def __init__(self, penv): super(_ParallelEnvironmentCleaner, self).__init__() self.env = penv def sigterm(self, signum, frame): return self.env.close() def sigint(self, signum, frame): return self.env.close() def atexit(self): return self.env.close() # Create a manager for everybody # Because it opens socket and does not close them _manager = None # TODO: Manage Seeding HERE
[docs]class ParallelEnvironment: """A group of environment that are computed in parallel""" def __init__(self, num_workers, transforms, env_factory, *env_args): global _manager if _manager is None: _manager = Manager() self.manager: Manager = _manager self.unique_set = self.manager.dict() self.worker_stat = Value(WorkerStat, 0, 0, 0, 0, 0, 0) self.vector_stat = VectorStat() self.remotes, work_remotes = zip(*[Pipe() for _ in range(num_workers)]) self.ps = [] for (work_remote, remote) in zip(work_remotes, self.remotes): self.ps.append( Process(target=worker, args=(work_remote, remote, env_factory, self.worker_stat, self.unique_set)) ) for p in self.ps: p.start() for remote in work_remotes: remote.close() for remote in self.remotes: remote.send(('init', env_args)) self.waiting = False self.closed = False self.transforms = transforms if self.transforms is None: self.transforms = lambda x: x self.observation_space, self.action_space = self._get_spaces() # do not hang in case of error self._cleaner = _ParallelEnvironmentCleaner(self) self.dtype = torch.float def _get_spaces(self): self.remotes[0].send(('get_spaces', None)) observation_space, action_space = self.remotes[0].recv() observation_space = self._to_shape(observation_space) action_space = self._to_shape(action_space) input = torch.randn((1,) + observation_space) observation_space = self.transforms(input).shape[1:] return observation_space, action_space def _to_shape(self, space): if isinstance(space, (tuple, torch.Size)): return space import gym if isinstance(space, gym.spaces.Discrete): return space.n, return space.shape @property def state_space(self): return self.observation_space
[docs] def step_async(self, actions): for remote, action in zip(self.remotes, actions): remote.send(('step', action)) self.waiting = True
[docs] def step_wait(self): results = [remote.recv() for remote in self.remotes] self.waiting = False obs, rewards, dones, infos = zip(*results) return self.transforms(self._convert(np.stack(obs))),\ self._convert(np.stack(rewards)),\ self._convert(np.stack(dones)), infos
[docs] def step(self, actions): assert self.waiting is not True, 'should call step_wait after step_async!' self.vector_stat.step += 1 s = time.time() self.step_async(actions) r = self.step_wait() self.vector_stat.step_time += time.time() - s return r
[docs] def report(self): stat = self.worker_stat avg = self.vector_stat.step_time / self.vector_stat.step stats = { 'worker_step': stat.step, 'worker_done': stat.done, 'worker_reset': stat.reset, 'worker_step_time': stat.step_time, 'worker_reset_time': stat.reset_time, 'worker_duplicates': stat.duplicates, 'step_time_avg': avg, 'unique_states': len(self.unique_set), 'item/sec': len(self.remotes) / avg } return { 'rl_loader': stats }
[docs] def close(self): if self.closed: return if self.waiting: for remote in self.remotes: remote.recv() for remote in self.remotes: remote.send(('close', None)) for p in self.ps: p.join() p.close() for remote in self.remotes: remote.close() self.closed = True
def _convert(self, x): return torch.from_numpy( np.stack(x) ).to(dtype=self.dtype)
[docs] def render(self, mode='human'): pass
[docs] def reset(self): for remote in self.remotes: remote.send(('reset', None)) return self.transforms(self._convert(np.stack([remote.recv() for remote in self.remotes])))
[docs] def reset_task(self): for remote in self.remotes: remote.send(('reset_task', None)) return self.transforms(self._convert(np.stack([remote.recv() for remote in self.remotes])))
[docs] def get_spaces(self): return self.observation_space, self.action_space
def __len__(self): return len(self.remotes)