Skip to content

Commit aab8cd0

Browse files
committed
Add parallel module
1 parent 788ea92 commit aab8cd0

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

src/parallel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from src.parallel.processes_manager import TorchProcsHandler

src/parallel/processes_manager.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Callable
2+
3+
import torch.multiprocessing as mp
4+
5+
6+
class TorchProcsHandler(object):
7+
"""The TorchProcsHandler class. Utility
8+
class to handle PyTorch processe
9+
10+
"""
11+
12+
def __init__(self, n_procs: int) -> None:
13+
"""Constructor
14+
15+
Parameters
16+
----------
17+
n_procs: The number of processes to handle
18+
19+
"""
20+
self.n_procs = n_procs
21+
self.processes = []
22+
23+
def create_and_start(self, target: Callable, *args) -> None:
24+
for i in range(self.n_procs):
25+
p = mp.Process(target=target, args=args)
26+
p.start()
27+
self.processes.append(p)
28+
29+
def create_process_and_start(self, target: Callable, args) -> None:
30+
p = mp.Process(target=target, args=args)
31+
p.start()
32+
self.processes.append(p)
33+
34+
def join(self) -> None:
35+
for p in self.processes:
36+
p.join()
37+
38+
def terminate(self) -> None:
39+
for p in self.processes:
40+
p.terminate()
41+
42+
def join_and_terminate(self):
43+
self.join()
44+
self.terminate()

0 commit comments

Comments
 (0)