77
88from __future__ import annotations
99
10+ import base64
11+ import binascii
1012import json
1113import operator
1214import re
2628from airflow .models .connection import Connection
2729
2830
31+ def try_decode_base64 (s : str ) -> str :
32+ """Attempt to decode a string from base64.
33+
34+ If the string is not valid base64, returns the original value.
35+
36+ Args:
37+ s: The string to decode.
38+
39+ Returns:
40+ The decoded string, or the original value if decoding fails.
41+ """
42+ try :
43+ s = base64 .b64decode (s , validate = True ).decode ("utf-8" )
44+ except binascii .Error :
45+ pass
46+ return s
47+
48+
2949class DbtConnectionParam (NamedTuple ):
3050 """A tuple indicating connection parameters relevant to dbt.
3151
@@ -40,6 +60,7 @@ class DbtConnectionParam(NamedTuple):
4060 name : str
4161 store_override_name : Optional [str ] = None
4262 default : Optional [Any ] = None
63+ converter : Callable [[Any ], Any ] | None = None
4364
4465 @property
4566 def override_name (self ):
@@ -230,6 +251,8 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
230251 if isinstance (param , DbtConnectionParam ):
231252 key = param .override_name
232253 value = getattr (conn , param .name , param .default )
254+ if param .converter :
255+ value = param .converter (value )
233256 elif isinstance (param , DbtConnectionConditionParam ):
234257 key , default = param .resolve (conn )
235258 value = getattr (conn , param .name , default )
@@ -399,7 +422,9 @@ class DbtSnowflakeHook(DbtConnectionHook):
399422 "database" ,
400423 DbtConnectionParam ("refresh_token" , "token" ),
401424 DbtConnectionParam ("private_key_file" , "private_key_path" ),
402- DbtConnectionParam ("private_key_content" , "private_key" ),
425+ DbtConnectionParam (
426+ "private_key_content" , "private_key" , converter = try_decode_base64
427+ ),
403428 ]
404429
405430
0 commit comments