33# pylint: disable=import-outside-toplevel
44
55import importlib .util
6+ import threading
67from collections import defaultdict
78from enum import Enum , unique
89from functools import wraps
@@ -37,6 +38,7 @@ class Engine:
3738 _engine : Optional [EngineEnum ] = None
3839 _initialized_engine : Optional [EngineEnum ] = None
3940 _registry : Dict [EngineLiteral , Dict [str , Callable [..., Any ]]] = defaultdict (dict )
41+ _lock : threading .RLock = threading .RLock ()
4042
4143 @classmethod
4244 def get_installed (cls ) -> EngineEnum :
@@ -64,27 +66,31 @@ def get(cls) -> EngineEnum:
6466 str
6567 The distribution engine configured.
6668 """
67- return cls ._engine if cls ._engine else cls .get_installed ()
69+ with cls ._lock :
70+ return cls ._engine if cls ._engine else cls .get_installed ()
6871
6972 @classmethod
7073 def set (cls , name : EngineLiteral ) -> None :
7174 """Set the distribution engine."""
72- cls ._engine = EngineEnum ._member_map_ [ # type: ignore[assignment] # pylint: disable=protected-access,no-member
73- name .upper ()
74- ]
75+ with cls ._lock :
76+ cls ._engine = EngineEnum ._member_map_ [ # type: ignore[assignment] # pylint: disable=protected-access,no-member
77+ name .upper ()
78+ ]
7579
7680 @classmethod
7781 def dispatch_func (cls , source_func : FunctionType , value : Optional [EngineLiteral ] = None ) -> FunctionType :
7882 """Dispatch a func based on value or the distribution engine and the source function."""
7983 try :
80- return cls ._registry [value or cls .get ().value ][source_func .__name__ ] # type: ignore[return-value]
84+ with cls ._lock :
85+ return cls ._registry [value or cls .get ().value ][source_func .__name__ ] # type: ignore[return-value]
8186 except KeyError :
8287 return getattr (source_func , "_source_func" , source_func )
8388
8489 @classmethod
8590 def register_func (cls , source_func : Callable [..., Any ], destination_func : Callable [..., Any ]) -> Callable [..., Any ]:
8691 """Register a func based on the distribution engine and source function."""
87- cls ._registry [cls .get ().value ][source_func .__name__ ] = destination_func
92+ with cls ._lock :
93+ cls ._registry [cls .get ().value ][source_func .__name__ ] = destination_func
8894 return destination_func
8995
9096 @classmethod
@@ -102,38 +108,42 @@ def wrapper(*args: Any, **kw: Dict[str, Any]) -> Any:
102108 @classmethod
103109 def register (cls , name : Optional [EngineLiteral ] = None ) -> None :
104110 """Register the distribution engine dispatch methods."""
105- engine_name = cast (EngineLiteral , name or cls .get_installed ().value )
106- cls .set (engine_name )
107- cls ._registry .clear ()
111+ with cls ._lock :
112+ engine_name = cast (EngineLiteral , name or cls .get_installed ().value )
113+ cls .set (engine_name )
114+ cls ._registry .clear ()
108115
109- if engine_name == EngineEnum .RAY .value :
110- from awswrangler .distributed .ray ._register import register_ray
116+ if engine_name == EngineEnum .RAY .value :
117+ from awswrangler .distributed .ray ._register import register_ray
111118
112- register_ray ()
119+ register_ray ()
113120
114121 @classmethod
115122 def initialize (cls , name : Optional [EngineLiteral ] = None ) -> None :
116123 """Initialize the distribution engine."""
117- engine_name = cast (EngineLiteral , name or cls .get_installed ().value )
118- if engine_name == EngineEnum .RAY .value :
119- from awswrangler .distributed .ray import initialize_ray
124+ with cls ._lock :
125+ engine_name = cast (EngineLiteral , name or cls .get_installed ().value )
126+ if engine_name == EngineEnum .RAY .value :
127+ from awswrangler .distributed .ray import initialize_ray
120128
121- initialize_ray ()
122- cls .register (engine_name )
123- cls ._initialized_engine = cls .get ()
129+ initialize_ray ()
130+ cls .register (engine_name )
131+ cls ._initialized_engine = cls .get ()
124132
125133 @classmethod
126134 def is_initialized (cls , name : Optional [EngineLiteral ] = None ) -> bool :
127135 """Check if the distribution engine is initialized."""
128- engine_name = cast (EngineLiteral , name or cls .get_installed ().value )
136+ with cls ._lock :
137+ engine_name = cast (EngineLiteral , name or cls .get_installed ().value )
129138
130- return False if not cls ._initialized_engine else cls ._initialized_engine .value == engine_name
139+ return False if not cls ._initialized_engine else cls ._initialized_engine .value == engine_name
131140
132141
133142class MemoryFormat :
134143 """Memory format configuration class."""
135144
136145 _enum : Optional [MemoryFormatEnum ] = None
146+ _lock : threading .RLock = threading .RLock ()
137147
138148 @classmethod
139149 def get_installed (cls ) -> MemoryFormatEnum :
@@ -161,14 +171,16 @@ def get(cls) -> MemoryFormatEnum:
161171 Enum
162172 The memory format configured.
163173 """
164- return cls ._enum if cls ._enum else cls .get_installed ()
174+ with cls ._lock :
175+ return cls ._enum if cls ._enum else cls .get_installed ()
165176
166177 @classmethod
167178 def set (cls , name : EngineLiteral ) -> None :
168179 """Set the memory format."""
169- cls ._enum = MemoryFormatEnum ._member_map_ [name .upper ()] # type: ignore[assignment] # pylint: disable=protected-access,no-member
180+ with cls ._lock :
181+ cls ._enum = MemoryFormatEnum ._member_map_ [name .upper ()] # type: ignore[assignment] # pylint: disable=protected-access,no-member
170182
171- _reload ()
183+ _reload ()
172184
173185
174186def _reload () -> None :
0 commit comments