33from __future__ import annotations
44
55import json
6+ import operator
67import re
78import warnings
89from abc import ABC , ABCMeta
@@ -34,7 +35,6 @@ class DbtConnectionParam(NamedTuple):
3435 name : str
3536 store_override_name : Optional [str ] = None
3637 default : Optional [Any ] = None
37- depends_on : Callable [[Connection ], bool ] = lambda x : True
3838
3939 @property
4040 def override_name (self ):
@@ -50,6 +50,88 @@ def override_name(self):
5050 return self .store_override_name
5151
5252
53+ class ResolverCondition (NamedTuple ):
54+ """Condition for resolving connection parameters based on extra_dejson.
55+
56+ Attributes:
57+ condition_key: The key in `extra_dejson` to check.
58+ comparison_operator: A function to compare the actual value
59+ with the expected value.
60+ expected: The expected value for the condition to be satisfied.
61+ """
62+
63+ condition_key : str
64+ comparison_operator : Callable [[Any , Any ], bool ]
65+ expected : Any
66+
67+
68+ class ResolverResult (NamedTuple ):
69+ """Result of resolving a connection parameter.
70+
71+ Attributes:
72+ override_name: The name to override the parameter with, if applicable.
73+ default: The default value to use if no value is found.
74+ """
75+
76+ override_name : Optional [str ]
77+ default : Optional [Any ]
78+
79+
80+ def make_extra_dejson_resolver (
81+ * conditions : tuple [ResolverCondition , ResolverResult ],
82+ default : ResolverResult = ResolverResult (None , None ),
83+ ) -> Callable [[Connection ], ResolverResult ]:
84+ """Creates a resolver function for override names and defaults.
85+
86+ Args:
87+ *conditions: A sequence of conditions and their corresponding results.
88+ default: The default result if no condition is met.
89+
90+ Returns:
91+ A function that takes a `Connection` object and returns
92+ the appropriate `ResolverResult`.
93+ """
94+
95+ def extra_dejson_resolver (conn : Connection ) -> ResolverResult :
96+ for (
97+ condition_key ,
98+ comparison_operator ,
99+ expected ,
100+ ), resolver_result in conditions :
101+ if comparison_operator (conn .extra_dejson .get (condition_key ), expected ):
102+ return resolver_result
103+ return default
104+
105+ return extra_dejson_resolver
106+
107+
108+ class DbtConnectionConditionParam (NamedTuple ):
109+ """Connection parameter with dynamic override name and default value.
110+
111+ Attributes:
112+ name: The original name of the parameter.
113+ resolver: A function that resolves the parameter
114+ name and default value based on the connection's `extra_dejson`.
115+ """
116+
117+ name : str
118+ resolver : Callable [[Connection ], ResolverResult ]
119+
120+ def resolve (self , connection : Connection ) -> ResolverResult :
121+ """Resolves the override name and default value for this parameter.
122+
123+ Args:
124+ connection: The Airflow connection object.
125+
126+ Returns:
127+ The resolved override name and default value.
128+ """
129+ override_name , default = self .resolver (connection )
130+ if override_name is None :
131+ return ResolverResult (self .name , default )
132+ return ResolverResult (override_name , default )
133+
134+
53135class DbtConnectionHookMeta (ABCMeta ):
54136 """A hook metaclass to collect all subclasses of DbtConnectionHook."""
55137
@@ -78,15 +160,17 @@ class DbtConnectionHook(BaseHook, ABC, metaclass=DbtConnectionHookMeta):
78160 hook_name = "dbt Hook"
79161 airflow_conn_types : tuple [str , ...] = ()
80162
81- conn_params : list [Union [DbtConnectionParam , str ]] = [
163+ conn_params : list [Union [DbtConnectionParam , DbtConnectionConditionParam , str ]] = [
82164 DbtConnectionParam ("conn_type" , "type" ),
83165 "host" ,
84166 "schema" ,
85167 "login" ,
86168 "password" ,
87169 "port" ,
88170 ]
89- conn_extra_params : list [Union [DbtConnectionParam , str ]] = []
171+ conn_extra_params : list [
172+ Union [DbtConnectionParam , DbtConnectionConditionParam , str ]
173+ ] = []
90174
91175 def __init__ (
92176 self ,
@@ -139,10 +223,11 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
139223 dbt_details = {"type" : self .conn_type }
140224 for param in self .conn_params :
141225 if isinstance (param , DbtConnectionParam ):
142- if not param .depends_on (conn ):
143- continue
144226 key = param .override_name
145227 value = getattr (conn , param .name , param .default )
228+ elif isinstance (param , DbtConnectionConditionParam ):
229+ key , default = param .resolve (conn )
230+ value = getattr (conn , param .name , default )
146231 else :
147232 key = param
148233 value = getattr (conn , key , None )
@@ -159,10 +244,11 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
159244
160245 for param in self .conn_extra_params :
161246 if isinstance (param , DbtConnectionParam ):
162- if not param .depends_on (conn ):
163- continue
164247 key = param .override_name
165248 value = extra .get (param .name , param .default )
249+ elif isinstance (param , DbtConnectionConditionParam ):
250+ key , default = param .resolve (conn )
251+ value = extra .get (param .name , default )
166252 else :
167253 key = param
168254 value = extra .get (key , None )
@@ -220,7 +306,8 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
220306 conn = copy (conn )
221307 extra_dejson = conn .extra_dejson
222308 options = extra_dejson .pop ("options" )
223- # This is to pass options (e.g. `-c search_path=myschema`) to dbt in the required form
309+ # This is to pass options (e.g. `-c search_path=myschema`) to dbt
310+ # in the required form
224311 for k , v in re .findall (r"-c (\w+)=(.*)$" , options ):
225312 extra_dejson [k ] = v
226313 conn .extra = json .dumps (extra_dejson )
@@ -235,11 +322,14 @@ class DbtRedshiftHook(DbtPostgresHook):
235322 airflow_conn_types = (conn_type ,)
236323
237324 conn_extra_params = DbtPostgresHook .conn_extra_params + [
238- "method" ,
239- DbtConnectionParam (
325+ DbtConnectionConditionParam (
240326 "method" ,
241- default = "iam" ,
242- depends_on = lambda x : x .extra_dejson .get ("iam_profile" ) is not None ,
327+ resolver = make_extra_dejson_resolver (
328+ (
329+ ResolverCondition ("iam_profile" , operator .is_not , None ),
330+ ResolverResult (None , "iam" ),
331+ )
332+ ),
243333 ),
244334 "cluster_id" ,
245335 "iam_profile" ,
@@ -262,40 +352,33 @@ class DbtSnowflakeHook(DbtConnectionHook):
262352 conn_params = [
263353 "host" ,
264354 "schema" ,
265- DbtConnectionParam (
266- "login" ,
267- "user" ,
268- depends_on = lambda x : x .extra_dejson .get ("authenticator" , "" ) != "oauth" ,
269- ),
270- DbtConnectionParam (
355+ DbtConnectionConditionParam (
271356 "login" ,
272- "oauth_client_id" ,
273- depends_on = lambda x : x .extra_dejson .get ("authenticator" , "" ) == "oauth" ,
274- ),
275- DbtConnectionParam (
276- "password" ,
277- depends_on = lambda x : not any (
357+ resolver = make_extra_dejson_resolver (
278358 (
279- * (
280- k in x .extra_dejson
281- for k in ("private_key_file" , "private_key_content" )
282- ),
283- x .extra_dejson .get ("authenticator" , "" ) == "oauth" ,
359+ ResolverCondition ("authenticator" , operator .eq , "oauth" ),
360+ ResolverResult ("oauth_client_id" , None ),
284361 ),
362+ default = ResolverResult ("user" , None ),
285363 ),
286364 ),
287- DbtConnectionParam (
365+ DbtConnectionConditionParam (
288366 "password" ,
289- "private_key_passphrase" ,
290- depends_on = lambda x : any (
291- k in x .extra_dejson for k in ("private_key_file" , "private_key_content" )
367+ resolver = make_extra_dejson_resolver (
368+ (
369+ ResolverCondition ("authenticator" , operator .eq , "oauth" ),
370+ ResolverResult ("oauth_client_secret" , None ),
371+ ),
372+ (
373+ ResolverCondition ("private_key_file" , operator .is_not , None ),
374+ ResolverResult ("private_key_passphrase" , None ),
375+ ),
376+ (
377+ ResolverCondition ("private_key_content" , operator .is_not , None ),
378+ ResolverResult ("private_key_passphrase" , None ),
379+ ),
292380 ),
293381 ),
294- DbtConnectionParam (
295- "password" ,
296- "oauth_client_secret" ,
297- depends_on = lambda x : x .extra_dejson .get ("authenticator" , "" ) == "oauth" ,
298- ),
299382 ]
300383 conn_extra_params = [
301384 "warehouse" ,
@@ -327,20 +410,22 @@ class DbtBigQueryHook(DbtConnectionHook):
327410 ]
328411 conn_extra_params = [
329412 DbtConnectionParam ("method" , default = "oauth" ),
330- DbtConnectionParam (
413+ DbtConnectionConditionParam (
331414 "method" ,
332- default = "oauth-secrets" ,
333- depends_on = lambda x : x .extra_dejson .get ("refresh_token" ) is not None ,
334- ),
335- DbtConnectionParam (
336- "method" ,
337- default = "service-account-json" ,
338- depends_on = lambda x : x .extra_dejson .get ("keyfile_dict" ) is not None ,
339- ),
340- DbtConnectionParam (
341- "method" ,
342- default = "service-account" ,
343- depends_on = lambda x : x .extra_dejson .get ("key_path" ) is not None ,
415+ resolver = make_extra_dejson_resolver (
416+ (
417+ ResolverCondition ("refresh_token" , operator .is_not , None ),
418+ ResolverResult (None , "oauth-secrets" ),
419+ ),
420+ (
421+ ResolverCondition ("keyfile_dict" , operator .is_not , None ),
422+ ResolverResult (None , "service-account-json" ),
423+ ),
424+ (
425+ ResolverCondition ("key_path" , operator .is_not , None ),
426+ ResolverResult (None , "service-account" ),
427+ ),
428+ ),
344429 ),
345430 DbtConnectionParam ("key_path" , "keyfile" ),
346431 DbtConnectionParam ("keyfile_dict" , "keyfile_json" ),
@@ -366,14 +451,14 @@ class DbtSparkHook(DbtConnectionHook):
366451 "port" ,
367452 "schema" ,
368453 DbtConnectionParam ("login" , "user" ),
369- DbtConnectionParam (
370- "password" ,
371- depends_on = lambda x : x .extra_dejson .get ("method" , "" ) == "thrift" ,
372- ),
373- DbtConnectionParam (
454+ DbtConnectionConditionParam (
374455 "password" ,
375- "token" ,
376- depends_on = lambda x : x .extra_dejson .get ("method" , "" ) != "thrift" ,
456+ resolver = make_extra_dejson_resolver (
457+ (
458+ ResolverCondition ("method" , operator .ne , "thrift" ),
459+ ResolverResult ("token" , None ),
460+ ),
461+ ),
377462 ),
378463 ]
379464 conn_extra_params = []
0 commit comments