Skip to content

Commit 8534dbf

Browse files
authored
fix: Thread-safety improvements (#2186)
* Thread-safety - add locks to engine/memf * Use reentrant locks * [skip ci] Types
1 parent 06ac57e commit 8534dbf

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

awswrangler/_distributed.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# pylint: disable=import-outside-toplevel
44

55
import importlib.util
6+
import threading
67
from collections import defaultdict
78
from enum import Enum, unique
89
from functools import wraps
@@ -37,6 +38,7 @@ class Engine:
3738
_engine: Optional[EngineEnum] = None
3839
_initialized_engine: Optional[EngineEnum] = None
3940
_registry: Dict[EngineLiteral, Dict[str, Callable[..., Any]]] = defaultdict(dict)
41+
_lock: threading.RLock = threading.RLock()
4042

4143
@classmethod
4244
def get_installed(cls) -> EngineEnum:
@@ -64,27 +66,31 @@ def get(cls) -> EngineEnum:
6466
str
6567
The distribution engine configured.
6668
"""
67-
return cls._engine if cls._engine else cls.get_installed()
69+
with cls._lock:
70+
return cls._engine if cls._engine else cls.get_installed()
6871

6972
@classmethod
7073
def set(cls, name: EngineLiteral) -> None:
7174
"""Set the distribution engine."""
72-
cls._engine = EngineEnum._member_map_[ # type: ignore[assignment] # pylint: disable=protected-access,no-member
73-
name.upper()
74-
]
75+
with cls._lock:
76+
cls._engine = EngineEnum._member_map_[ # type: ignore[assignment] # pylint: disable=protected-access,no-member
77+
name.upper()
78+
]
7579

7680
@classmethod
7781
def dispatch_func(cls, source_func: FunctionType, value: Optional[EngineLiteral] = None) -> FunctionType:
7882
"""Dispatch a func based on value or the distribution engine and the source function."""
7983
try:
80-
return cls._registry[value or cls.get().value][source_func.__name__] # type: ignore[return-value]
84+
with cls._lock:
85+
return cls._registry[value or cls.get().value][source_func.__name__] # type: ignore[return-value]
8186
except KeyError:
8287
return getattr(source_func, "_source_func", source_func)
8388

8489
@classmethod
8590
def register_func(cls, source_func: Callable[..., Any], destination_func: Callable[..., Any]) -> Callable[..., Any]:
8691
"""Register a func based on the distribution engine and source function."""
87-
cls._registry[cls.get().value][source_func.__name__] = destination_func
92+
with cls._lock:
93+
cls._registry[cls.get().value][source_func.__name__] = destination_func
8894
return destination_func
8995

9096
@classmethod
@@ -102,38 +108,42 @@ def wrapper(*args: Any, **kw: Dict[str, Any]) -> Any:
102108
@classmethod
103109
def register(cls, name: Optional[EngineLiteral] = None) -> None:
104110
"""Register the distribution engine dispatch methods."""
105-
engine_name = cast(EngineLiteral, name or cls.get_installed().value)
106-
cls.set(engine_name)
107-
cls._registry.clear()
111+
with cls._lock:
112+
engine_name = cast(EngineLiteral, name or cls.get_installed().value)
113+
cls.set(engine_name)
114+
cls._registry.clear()
108115

109-
if engine_name == EngineEnum.RAY.value:
110-
from awswrangler.distributed.ray._register import register_ray
116+
if engine_name == EngineEnum.RAY.value:
117+
from awswrangler.distributed.ray._register import register_ray
111118

112-
register_ray()
119+
register_ray()
113120

114121
@classmethod
115122
def initialize(cls, name: Optional[EngineLiteral] = None) -> None:
116123
"""Initialize the distribution engine."""
117-
engine_name = cast(EngineLiteral, name or cls.get_installed().value)
118-
if engine_name == EngineEnum.RAY.value:
119-
from awswrangler.distributed.ray import initialize_ray
124+
with cls._lock:
125+
engine_name = cast(EngineLiteral, name or cls.get_installed().value)
126+
if engine_name == EngineEnum.RAY.value:
127+
from awswrangler.distributed.ray import initialize_ray
120128

121-
initialize_ray()
122-
cls.register(engine_name)
123-
cls._initialized_engine = cls.get()
129+
initialize_ray()
130+
cls.register(engine_name)
131+
cls._initialized_engine = cls.get()
124132

125133
@classmethod
126134
def is_initialized(cls, name: Optional[EngineLiteral] = None) -> bool:
127135
"""Check if the distribution engine is initialized."""
128-
engine_name = cast(EngineLiteral, name or cls.get_installed().value)
136+
with cls._lock:
137+
engine_name = cast(EngineLiteral, name or cls.get_installed().value)
129138

130-
return False if not cls._initialized_engine else cls._initialized_engine.value == engine_name
139+
return False if not cls._initialized_engine else cls._initialized_engine.value == engine_name
131140

132141

133142
class MemoryFormat:
134143
"""Memory format configuration class."""
135144

136145
_enum: Optional[MemoryFormatEnum] = None
146+
_lock: threading.RLock = threading.RLock()
137147

138148
@classmethod
139149
def get_installed(cls) -> MemoryFormatEnum:
@@ -161,14 +171,16 @@ def get(cls) -> MemoryFormatEnum:
161171
Enum
162172
The memory format configured.
163173
"""
164-
return cls._enum if cls._enum else cls.get_installed()
174+
with cls._lock:
175+
return cls._enum if cls._enum else cls.get_installed()
165176

166177
@classmethod
167178
def set(cls, name: EngineLiteral) -> None:
168179
"""Set the memory format."""
169-
cls._enum = MemoryFormatEnum._member_map_[name.upper()] # type: ignore[assignment] # pylint: disable=protected-access,no-member
180+
with cls._lock:
181+
cls._enum = MemoryFormatEnum._member_map_[name.upper()] # type: ignore[assignment] # pylint: disable=protected-access,no-member
170182

171-
_reload()
183+
_reload()
172184

173185

174186
def _reload() -> None:

0 commit comments

Comments
 (0)