import torch
import numpy as np
from olympus.reinforcement.replay import ReplayVector
from olympus.reinforcement.utils import SavedAction
from olympus.utils.cuda import Stream, stream
[docs]def to_nchw(states):
return states.permute(0, 3, 1, 2)
[docs]class RLTorchIterator:
"""Iterates through environment states
Parameters
----------
actor: Union[nn.Module, Callable]
Returns the action that should be taken
critic: Union[nn.Module, Callable]
Returns the value of the current state
max_step: Optional[int]
If unspecified the Iterator is infinite, else we stop after max_steps
no_grad: bool
Whether or not the actor and the critic should have their grad computed
Returns
-------
A dictionary representing the transition from one state to anoter
state: Tensor[NCHW, dtype=uint8]
State of the game before the action is taken for images (size: (num_parallel, 3, H, W))
new_state: Tensor[NCHW, dtype=uint8]
State of the game after the action is taken for images (size: (num_parallel, 3, H, W))
action: Tensor[num_parallel, dtype=int]
Return the action taken for each parallel simulation
log_prob: Tensor[num_parallel, dtype=float]
entropy: Tensor[num_parallel, dtype=float]
critic: Tensor[num_parallel, dtype=float]
reward: Tensor[num_parallel, dtype=float]
done: Tensor[num_parallel, dtype=bool]
info: List[dict] size: num_parallel
"""
def __init__(self, environment, actor, critic, device=None, max_step=None, no_grad=False):
self.step = 0
self.max_step = max_step
self.env = environment
self.actor = actor
self.critic = critic
self.actor_stream = Stream()
self.critic_stream = Stream()
self.grad_ctx = torch.enable_grad
self.dtype = torch.float
self.completed_simulations = 0
self.device = device
if self.device is None:
self.device = torch.device('cpu')
if no_grad:
self.grad_ctx = torch.no_grad
self.state = self._convert(self.env.reset()).to(device=self.device)
[docs] def to(self, device):
self.device = device
self.state = self.state.to(device=self.device)
return self
[docs] def close(self):
return self.env.close()
def __iter__(self):
return self
def __len__(self):
return 1
def _convert(self, x):
if x is None:
return None
return torch.from_numpy(
np.stack(x)
).to(dtype=self.dtype, device=self.device)
def __next__(self):
if self.max_step is not None and self.step >= self.max_step:
raise StopIteration
with self.grad_ctx():
with stream(self.actor_stream):
action, log_prob, entropy = self.actor(self.state)
critic = None
if self.critic:
with stream(self.critic_stream):
critic = self.critic(self.state)
self.actor_stream.synchronize()
state, rew, done, info = self.env.step(action.cpu().numpy())
state = self._convert(state).to(device=self.device)
self.critic_stream.synchronize()
transition = {
'state' : self.state, # Tensor[NCHW, uint8]
'new_state': state, # Tensor[NCHW, uint8]
'action' : action, # List[N, int]
'log_prob' : log_prob.squeeze(1), # List[N, int]
'entropy' : entropy, # List[N, int]
'critic' : critic, # List[N, int]
'reward' : self._convert(rew), # List[N, Float]
'done' : self._convert(done), # List[N, Bool]
'info' : info # List[N, dict]
}
self.completed_simulations += transition.get('done').sum()
self.step += 1
self.state = state
return transition
[docs]class ReplayVectorIterator:
"""Aggregate Transition into a vector to be used for later"""
def __init__(self, iterator: RLTorchIterator, num_steps):
self.iterator = iterator
self.num_steps = num_steps
[docs] def to(self, device):
self.iterator.to(device)
return self
def __iter__(self):
return self
def __len__(self):
return len(self.iterator)
def __next__(self):
replay = ReplayVector()
for step_idx, transition in enumerate(self.iterator):
replay.append(SavedAction(
action=transition.get('action'),
reward=transition.get('reward'),
log_prob=transition.get('log_prob'),
entropy=transition.get('entropy'),
critic=transition.get('critic'),
mask=(1 - transition.get('done')),
state=transition.get('state'),
next_state=transition.get('new_state'),
info=transition.get('info')
))
if step_idx + 1 >= self.num_steps:
break
return replay
[docs] def close(self):
return self.iterator.close()
@property
def state(self):
"""Return the latest state"""
return self.iterator.state
@property
def completed_simulations(self):
"""Number of completed simulations since start"""
return self.iterator.completed_simulations
[docs]def simple_replay_vector(num_steps):
def _replay(iterator):
return ReplayVectorIterator(iterator, num_steps)
return _replay
[docs]class RLDataLoader:
"""
Parameters
----------
dataset_environment:
Generic Reinforcement Learning environment
replay:
Replay Vector iterator constructor
transform:
Transform to apply to each simulation state
"""
def __init__(self, dataset_environment, actor, critic, replay=None):
self.dataset = dataset_environment
self.replay = replay
self.actor = actor
self.critic = critic
self.device = torch.device('cpu')
if self.replay is None:
self.replay = simple_replay_vector(1)
[docs] def train(self, no_grad=False):
return self.replay(RLTorchIterator(
self.dataset.train,
self.actor,
self.critic,
device=self.device,
no_grad=no_grad))
[docs] def valid(self):
return self.replay(RLTorchIterator(
self.dataset.valid,
self.actor,
self.critic,
device=self.device,
no_grad=True))
[docs] def test(self):
return self.replay(RLTorchIterator(
self.dataset.test,
self.actor,
self.critic,
device=self.device,
no_grad=True))
[docs] def shutdown(self):
self.dataset.close()
[docs] def close(self):
self.shutdown()
if __name__ == '__main__':
from olympus.reinforcement.procgenenv import ProcgenEnvironment
env = ProcgenEnvironment('coinrun', parallel_env=4)
def dummy_actor(*args, **kwargs):
return env.sample_action(), [0, 0, 0, 0], [0, 0, 0, 0]
loader = RLDataLoader(
env,
replay=simple_replay_vector(num_steps=2),
actor=dummy_actor,
critic=lambda x: [0, 0, 0, 0]
)
train_set = loader.train()
for step, i in enumerate(train_set):
for k, v in i.to_dict().items():
if isinstance(v, torch.Tensor):
print(f'{k:>30} :', v.shape, v.dtype)
else:
print(f'{k:>30} :', v)
break