Skip to content

Commit 904a76e

Browse files
committed
Add multiprocess environment
1 parent aab8cd0 commit 904a76e

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

src/spaces/multiprocess_env.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Module multiprocess_env. Specifies
2+
a vectorsized environment where each instance
3+
of the environment is run independently
4+
5+
"""
6+
7+
import numpy as np
8+
from typing import TypeVar, Callable, Any
9+
import torch.multiprocessing as mp
10+
11+
from src.spaces import TimeStep, VectorTimeStep
12+
from src.parallel import TorchProcsHandler
13+
14+
15+
ActionVector = TypeVar('ActionVector')
16+
17+
18+
class MultiprocessEnv(object):
19+
20+
def __init__(self, env_builder: Callable, env_args: dict, n_workers: int):
21+
self.env_builder = env_builder
22+
self.env_args = env_args
23+
self.n_workers = n_workers
24+
self.workers = TorchProcsHandler(n_procs=n_workers)
25+
self.pipes = [mp.Pipe() for _ in range(self.n_workers)]
26+
27+
def make(self):
28+
"""Create the workers
29+
30+
Returns
31+
-------
32+
33+
"""
34+
35+
for w in range(self.n_workers):
36+
self.workers.create_process_and_start(target=self.work, args=(w, self.env_builder,
37+
self.env_args,
38+
self.pipes[w][1]))
39+
40+
def work(self, rank, env_builder: Callable, env_args: dict, pipe_end) -> None:
41+
"""The worker function
42+
43+
Parameters
44+
----------
45+
rank: The rank of the worker
46+
env_builder: The callable that builds the worker environment
47+
env_args: The callable arguments
48+
worker_end
49+
50+
Returns
51+
-------
52+
None
53+
"""
54+
55+
# create the environment
56+
env = env_builder(env_args)
57+
while True:
58+
59+
# receive new cmd from the manager
60+
# in order to exceute it
61+
cmd, kwargs = pipe_end.recv()
62+
63+
if cmd == 'reset':
64+
pipe_end.send(env.reset(**kwargs))
65+
elif cmd == 'step':
66+
pipe_end.send(env.step(**kwargs))
67+
elif cmd == '_past_limit':
68+
pipe_end.send(env._elapsed_steps >= env._max_episode_steps)
69+
else:
70+
# including close command
71+
env.close(**kwargs)
72+
del env
73+
pipe_end.close()
74+
break
75+
76+
def reset(self) -> TimeStep:
77+
pass
78+
79+
def step(self, actions: ActionVector) -> VectorTimeStep:
80+
81+
assert len(actions) == self.n_workers
82+
83+
# send the messages to the workers
84+
[self._send_msg(('step', {'action': actions[rank]}), rank) for rank in range(self.n_workers)]
85+
86+
time_step = VectorTimeStep()
87+
# collect the results from all processes
88+
#results = []
89+
90+
for rank in range(self.n_workers):
91+
parent_end, _ = self.pipes[rank]
92+
process_time_step = parent_end.recv()
93+
time_step.append(process_time_step)
94+
"""
95+
o, r, d, i = parent_end.recv()
96+
results.append((o,
97+
np.array(r, dtype=np.float),
98+
np.array(d, dtype=np.float),
99+
i))
100+
return [np.vstack(block) for block in np.array(results).T]
101+
"""
102+
return time_step
103+
104+
def _close(self, **kwargs):
105+
self._broadcast_msg(('close', kwargs))
106+
107+
def _send_msg(self, msg: Any, rank: int):
108+
"""Send the message to the process with the
109+
given rank
110+
111+
Parameters
112+
----------
113+
msg: The message to send
114+
rank: The rank of the proces to send the message
115+
116+
Returns
117+
-------
118+
119+
"""
120+
parent_end, _ = self.pipes[rank]
121+
parent_end.send(msg)
122+
123+
def _broadcast_msg(self, msg):
124+
[parent_end.send(msg) for parent_end, _ in self.pipes]

tests/test_multiprocess_env.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import unittest
2+
import pytest
3+
4+
from src.spaces import MultiprocessEnv
5+
6+
7+
class TestMultiprocessEnv(unittest.TestCase):
8+
9+
def test_make(self):
10+
pass
11+
12+
13+
if __name__ == '__main__':
14+
unittest.main()

0 commit comments

Comments
 (0)