|
| 1 | +"""Module multiprocess_env. Specifies |
| 2 | +a vectorsized environment where each instance |
| 3 | +of the environment is run independently |
| 4 | +
|
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from typing import TypeVar, Callable, Any |
| 9 | +import torch.multiprocessing as mp |
| 10 | + |
| 11 | +from src.spaces import TimeStep, VectorTimeStep |
| 12 | +from src.parallel import TorchProcsHandler |
| 13 | + |
| 14 | + |
| 15 | +ActionVector = TypeVar('ActionVector') |
| 16 | + |
| 17 | + |
| 18 | +class MultiprocessEnv(object): |
| 19 | + |
| 20 | + def __init__(self, env_builder: Callable, env_args: dict, n_workers: int): |
| 21 | + self.env_builder = env_builder |
| 22 | + self.env_args = env_args |
| 23 | + self.n_workers = n_workers |
| 24 | + self.workers = TorchProcsHandler(n_procs=n_workers) |
| 25 | + self.pipes = [mp.Pipe() for _ in range(self.n_workers)] |
| 26 | + |
| 27 | + def make(self): |
| 28 | + """Create the workers |
| 29 | +
|
| 30 | + Returns |
| 31 | + ------- |
| 32 | +
|
| 33 | + """ |
| 34 | + |
| 35 | + for w in range(self.n_workers): |
| 36 | + self.workers.create_process_and_start(target=self.work, args=(w, self.env_builder, |
| 37 | + self.env_args, |
| 38 | + self.pipes[w][1])) |
| 39 | + |
| 40 | + def work(self, rank, env_builder: Callable, env_args: dict, pipe_end) -> None: |
| 41 | + """The worker function |
| 42 | +
|
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + rank: The rank of the worker |
| 46 | + env_builder: The callable that builds the worker environment |
| 47 | + env_args: The callable arguments |
| 48 | + worker_end |
| 49 | +
|
| 50 | + Returns |
| 51 | + ------- |
| 52 | + None |
| 53 | + """ |
| 54 | + |
| 55 | + # create the environment |
| 56 | + env = env_builder(env_args) |
| 57 | + while True: |
| 58 | + |
| 59 | + # receive new cmd from the manager |
| 60 | + # in order to exceute it |
| 61 | + cmd, kwargs = pipe_end.recv() |
| 62 | + |
| 63 | + if cmd == 'reset': |
| 64 | + pipe_end.send(env.reset(**kwargs)) |
| 65 | + elif cmd == 'step': |
| 66 | + pipe_end.send(env.step(**kwargs)) |
| 67 | + elif cmd == '_past_limit': |
| 68 | + pipe_end.send(env._elapsed_steps >= env._max_episode_steps) |
| 69 | + else: |
| 70 | + # including close command |
| 71 | + env.close(**kwargs) |
| 72 | + del env |
| 73 | + pipe_end.close() |
| 74 | + break |
| 75 | + |
| 76 | + def reset(self) -> TimeStep: |
| 77 | + pass |
| 78 | + |
| 79 | + def step(self, actions: ActionVector) -> VectorTimeStep: |
| 80 | + |
| 81 | + assert len(actions) == self.n_workers |
| 82 | + |
| 83 | + # send the messages to the workers |
| 84 | + [self._send_msg(('step', {'action': actions[rank]}), rank) for rank in range(self.n_workers)] |
| 85 | + |
| 86 | + time_step = VectorTimeStep() |
| 87 | + # collect the results from all processes |
| 88 | + #results = [] |
| 89 | + |
| 90 | + for rank in range(self.n_workers): |
| 91 | + parent_end, _ = self.pipes[rank] |
| 92 | + process_time_step = parent_end.recv() |
| 93 | + time_step.append(process_time_step) |
| 94 | + """ |
| 95 | + o, r, d, i = parent_end.recv() |
| 96 | + results.append((o, |
| 97 | + np.array(r, dtype=np.float), |
| 98 | + np.array(d, dtype=np.float), |
| 99 | + i)) |
| 100 | + return [np.vstack(block) for block in np.array(results).T] |
| 101 | + """ |
| 102 | + return time_step |
| 103 | + |
| 104 | + def _close(self, **kwargs): |
| 105 | + self._broadcast_msg(('close', kwargs)) |
| 106 | + |
| 107 | + def _send_msg(self, msg: Any, rank: int): |
| 108 | + """Send the message to the process with the |
| 109 | + given rank |
| 110 | +
|
| 111 | + Parameters |
| 112 | + ---------- |
| 113 | + msg: The message to send |
| 114 | + rank: The rank of the proces to send the message |
| 115 | +
|
| 116 | + Returns |
| 117 | + ------- |
| 118 | +
|
| 119 | + """ |
| 120 | + parent_end, _ = self.pipes[rank] |
| 121 | + parent_end.send(msg) |
| 122 | + |
| 123 | + def _broadcast_msg(self, msg): |
| 124 | + [parent_end.send(msg) for parent_end, _ in self.pipes] |
0 commit comments