44"""
55from collections import deque
66from concurrent .futures import ThreadPoolExecutor
7+ from multiprocessing import Manager
8+ import threading
79from pydatastructs .utils .misc_util import (
810 _comp , raise_if_backend_is_not_python , Backend , AdjacencyListGraphNode )
911from pydatastructs .miscellaneous_data_structures import (
@@ -1407,7 +1409,7 @@ def maximum_matching(graph: Graph, algorithm: str, **kwargs) -> set:
14071409 >>> graph.add_edge('v_2', 'v_3')
14081410 >>> graph.add_edge('v_4', 'v_1')
14091411 >>> maximum_matching(graph, 'hopcroft_karp', make_undirected=True)
1410- >>> {('v_3 ', 'v_2 '), ('v_1 ', 'v_4 ')}
1412+ >>> {('v_1 ', 'v_4 '), ('v_3 ', 'v_2 ')}
14111413
14121414 References
14131415 ==========
@@ -1431,6 +1433,7 @@ def maximum_matching(graph: Graph, algorithm: str, **kwargs) -> set:
14311433 return getattr (algorithms , func )(graph )
14321434
14331435def _maximum_matching_hopcroft_karp_parallel (graph : Graph , num_threads : int ) -> set :
1436+
14341437 U = set ()
14351438 V = set ()
14361439 bipartiteness , coloring = bipartite_coloring (graph )
@@ -1444,20 +1447,22 @@ def _maximum_matching_hopcroft_karp_parallel(graph: Graph, num_threads: int) ->
14441447 else :
14451448 V .add (node )
14461449
1447-
1448- pair_U = {u : None for u in U }
1449- pair_V = {v : None for v in V }
1450- dist = {}
1450+ manager = Manager ()
1451+ pair_U = manager . dict ( {u : None for u in U })
1452+ pair_V = manager . dict ( {v : None for v in V })
1453+ lock = threading . RLock ()
14511454
14521455 def bfs ():
14531456 queue = Queue ()
1457+ dist = {}
14541458 for u in U :
14551459 if pair_U [u ] is None :
14561460 dist [u ] = 0
14571461 queue .append (u )
14581462 else :
14591463 dist [u ] = float ('inf' )
14601464 dist [None ] = float ('inf' )
1465+
14611466 while queue :
14621467 u = queue .popleft ()
14631468 if dist [u ] < dist [None ]:
@@ -1470,36 +1475,77 @@ def bfs():
14701475 elif dist .get (alt , float ('inf' )) == float ('inf' ):
14711476 dist [alt ] = dist [u ] + 1
14721477 queue .append (alt )
1473- return dist .get (None , float ('inf' )) != float ('inf' )
14741478
1475- def dfs (u ):
1479+ return dist , dist .get (None , float ('inf' )) != float ('inf' )
1480+
1481+ def dfs_worker (u , dist , local_pair_U , local_pair_V , thread_results ):
1482+ if dfs (u , dist , local_pair_U , local_pair_V ) and u in local_pair_U and local_pair_U [u ] is not None :
1483+ thread_results .append ((u , local_pair_U [u ]))
1484+ return True
1485+ return False
1486+
1487+ def dfs (u , dist , local_pair_U , local_pair_V ):
14761488 if u is None :
14771489 return True
1490+
14781491 for v in graph .neighbors (u ):
1479- if v .name in pair_V :
1480- alt = pair_V [v .name ]
1492+ if v .name in local_pair_V :
1493+ alt = local_pair_V [v .name ]
14811494 if alt is None :
1482- pair_V [v .name ] = u
1483- pair_U [u ] = v .name
1495+ local_pair_V [v .name ] = u
1496+ local_pair_U [u ] = v .name
14841497 return True
14851498 elif dist .get (alt , float ('inf' )) == dist .get (u , float ('inf' )) + 1 :
1486- if dfs (alt ):
1487- pair_V [v .name ] = u
1488- pair_U [u ] = v .name
1499+ if dfs (alt , dist , local_pair_U , local_pair_V ):
1500+ local_pair_V [v .name ] = u
1501+ local_pair_U [u ] = v .name
14891502 return True
1503+
14901504 dist [u ] = float ('inf' )
14911505 return False
14921506
14931507 matching = set ()
14941508
1495- while bfs ():
1496- unmatched_nodes = [u for u in U if pair_U [u ] is None ]
1509+ while True :
1510+ dist , has_path = bfs ()
1511+ if not has_path :
1512+ break
14971513
1498- with ThreadPoolExecutor (max_workers = num_threads ) as Executor :
1499- results = Executor .map (dfs , unmatched_nodes )
1514+ unmatched = [u for u in U if pair_U [u ] is None ]
1515+ if not unmatched :
1516+ break
1517+
1518+ batch_size = max (1 , len (unmatched ) // num_threads )
1519+ batches = [unmatched [i :i + batch_size ] for i in range (0 , len (unmatched ), batch_size )]
1520+
1521+ for batch in batches :
1522+ all_results = []
1523+
1524+ with ThreadPoolExecutor (max_workers = num_threads ) as executor :
1525+ futures = []
1526+ for u in batch :
1527+ local_pair_U = dict (pair_U )
1528+ local_pair_V = dict (pair_V )
1529+ thread_results = []
15001530
1501- for u , success in zip (unmatched_nodes , results ):
1502- if success and pair_U [u ] is not None :
1531+ futures .append (executor .submit (
1532+ dfs_worker , u , dist .copy (), local_pair_U , local_pair_V , thread_results
1533+ ))
1534+
1535+ for future in futures :
1536+ future .result ()
1537+
1538+ with lock :
1539+ for u in batch :
1540+ if pair_U [u ] is None :
1541+ result = dfs (u , dist .copy (), pair_U , pair_V )
1542+ if result and pair_U [u ] is not None :
1543+ matching .add ((u , pair_U [u ]))
1544+
1545+ with lock :
1546+ matching = set ()
1547+ for u in U :
1548+ if pair_U [u ] is not None :
15031549 matching .add ((u , pair_U [u ]))
15041550
15051551 return matching
@@ -1548,7 +1594,7 @@ def maximum_matching_parallel(graph: Graph, algorithm: str, num_threads: int, **
15481594 >>> graph.add_bidirectional_edge('v_2', 'v_3')
15491595 >>> graph.add_bidirectional_edge('v_4', 'v_1')
15501596 >>> maximum_matching_parallel(graph, 'hopcroft_karp', 1, make_undirected=True)
1551- >>> {('v_3 ', 'v_2 '), ('v_1 ', 'v_4 ')}
1597+ >>> {('v_1 ', 'v_4 '), ('v_3 ', 'v_2 ')}
15521598
15531599 References
15541600 ==========
0 commit comments