33import json
44import boto3
55import time
6+ import datetime
67
78from rdl .data_sources .ChangeTrackingInfo import ChangeTrackingInfo
89from rdl .data_sources .SourceTableInfo import SourceTableInfo
1112
1213
1314class AWSLambdaDataSource (object ):
14- # 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;'
15+ # 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;role=arn:aws:iam::123456789012:role/RoleName; '
1516 CONNECTION_STRING_PREFIX = "aws-lambda://"
1617 CONNECTION_STRING_GROUP_SEPARATOR = ";"
1718 CONNECTION_STRING_KEY_VALUE_SEPARATOR = "="
1819
20+ CONNECTION_DATA_ROLE_KEY = "role"
21+ CONNECTION_DATA_FUNCTION_KEY = "function"
22+ CONNECTION_DATA_TENANT_KEY = "tenant"
23+
24+ AWS_SERVICE_LAMBDA = "lambda"
25+ AWS_SERVICE_S3 = "s3"
26+
1927 def __init__ (self , connection_string , logger = None ):
2028 self .logger = logger or logging .getLogger (__name__ )
29+
2130 if not AWSLambdaDataSource .can_handle_connection_string (connection_string ):
2231 raise ValueError (connection_string )
32+
2333 self .connection_string = connection_string
2434 self .connection_data = dict (
2535 kv .split (AWSLambdaDataSource .CONNECTION_STRING_KEY_VALUE_SEPARATOR )
@@ -29,8 +39,19 @@ def __init__(self, connection_string, logger=None):
2939 .rstrip (AWSLambdaDataSource .CONNECTION_STRING_GROUP_SEPARATOR )
3040 .split (AWSLambdaDataSource .CONNECTION_STRING_GROUP_SEPARATOR )
3141 )
32- self .aws_lambda_client = boto3 .client ("lambda" )
33- self .aws_s3_client = boto3 .client ("s3" )
42+
43+ self .aws_sts_client = boto3 .client ("sts" )
44+ role_credentials = self .__assume_role (
45+ self .connection_data [self .CONNECTION_DATA_ROLE_KEY ],
46+ f"dwp_{ self .connection_data [self .CONNECTION_DATA_TENANT_KEY ]} " ,
47+ )
48+
49+ self .aws_lambda_client = self .__get_aws_client (
50+ self .AWS_SERVICE_LAMBDA , role_credentials
51+ )
52+ self .aws_s3_client = self .__get_aws_client (
53+ self .AWS_SERVICE_S3 , role_credentials
54+ )
3455
3556 @staticmethod
3657 def can_handle_connection_string (connection_string ):
@@ -87,7 +108,7 @@ def get_table_data_frame(
87108 def __get_table_info (self , table_config , last_known_sync_version ):
88109 pay_load = {
89110 "Command" : "GetTableInfo" ,
90- "TenantId" : int (self .connection_data ["tenant" ]),
111+ "TenantId" : int (self .connection_data [self . CONNECTION_DATA_TENANT_KEY ]),
91112 "Table" : {"Schema" : table_config ["schema" ], "Name" : table_config ["name" ]},
92113 "CommandPayload" : {"LastSyncVersion" : last_known_sync_version },
93114 }
@@ -113,7 +134,7 @@ def __get_table_data(
113134 ):
114135 pay_load = {
115136 "Command" : "GetTableData" ,
116- "TenantId" : int (self .connection_data ["tenant" ]),
137+ "TenantId" : int (self .connection_data [self . CONNECTION_DATA_TENANT_KEY ]),
117138 "Table" : {"Schema" : table_config ["schema" ], "Name" : table_config ["name" ]},
118139 "CommandPayload" : {
119140 "AuditColumnNameForChangeVersion" : Providers .AuditColumnsNames .CHANGE_VERSION ,
@@ -125,7 +146,7 @@ def __get_table_data(
125146 {
126147 "Name" : col ["source_name" ],
127148 "DataType" : col ["destination" ]["type" ],
128- "IsPrimaryKey" : col ["destination" ]["primary_key" ]
149+ "IsPrimaryKey" : col ["destination" ]["primary_key" ],
129150 }
130151 for col in columns_config
131152 ],
@@ -148,41 +169,93 @@ def __get_table_data(
148169 def __get_data_frame (self , data : [[]], column_names : []):
149170 return pandas .DataFrame (data = data , columns = column_names )
150171
172+ def __assume_role (self , role_arn , session_name ):
173+ self .logger .debug (f"\n Assuming role with ARN: { role_arn } " )
174+
175+ assume_role_response = self .aws_sts_client .assume_role (
176+ RoleArn = role_arn , RoleSessionName = session_name
177+ )
178+
179+ role_credentials = assume_role_response ["Credentials" ]
180+
181+ self .role_session_expiry = role_credentials ["Expiration" ]
182+
183+ return role_credentials
184+
185+ def __get_aws_client (self , service , credentials ):
186+ return boto3 .client (
187+ service_name = service ,
188+ aws_access_key_id = credentials ["AccessKeyId" ],
189+ aws_secret_access_key = credentials ["SecretAccessKey" ],
190+ aws_session_token = credentials ["SessionToken" ],
191+ )
192+
193+ def __refresh_aws_clients_if_expired (self ):
194+ # this is due to AWS returning their expiry date in UTC
195+ current_datetime = datetime .datetime .now (datetime .timezone .utc )
196+
197+ if (
198+ current_datetime > self .role_session_expiry - datetime .timedelta (minutes = 5 )
199+ and current_datetime < self .role_session_expiry
200+ ):
201+ role_credentials = self .__assume_role (
202+ self .connection_data [self .CONNECTION_DATA_ROLE_KEY ],
203+ f"dwp_{ self .connection_data [self .CONNECTION_DATA_TENANT_KEY ]} " ,
204+ )
205+
206+ self .aws_lambda_client = self .__get_aws_client (
207+ self .AWS_SERVICE_LAMBDA , role_credentials
208+ )
209+ self .aws_s3_client = self .__get_aws_client (
210+ self .AWS_SERVICE_S3 , role_credentials
211+ )
212+
151213 def __invoke_lambda (self , pay_load ):
152214 max_attempts = Constants .MAX_AWS_LAMBDA_INVOKATION_ATTEMPTS
153215 retry_delay = Constants .AWS_LAMBDA_RETRY_DELAY_SECONDS
154216 response_payload = None
155217
156- for current_attempt in list (range (1 , max_attempts + 1 , 1 )):
218+ for current_attempt in list (range (1 , max_attempts + 1 , 1 )):
219+
220+ self .__refresh_aws_clients_if_expired ()
221+
157222 if current_attempt > 1 :
158- self .logger .debug (f"\n Delaying retry for { (current_attempt - 1 ) ^ retry_delay } seconds" )
223+ self .logger .debug (
224+ f"\n Delaying retry for { (current_attempt - 1 ) ^ retry_delay } seconds"
225+ )
159226 time .sleep ((current_attempt - 1 ) ^ retry_delay )
160227
161- self .logger .debug (f"\n Request being sent to Lambda, attempt { current_attempt } of { max_attempts } :" )
228+ self .logger .debug (
229+ f"\n Request being sent to Lambda, attempt { current_attempt } of { max_attempts } :"
230+ )
162231 self .logger .debug (pay_load )
163232
164233 lambda_response = self .aws_lambda_client .invoke (
165- FunctionName = self .connection_data ["function" ],
234+ FunctionName = self .connection_data [self . CONNECTION_DATA_FUNCTION_KEY ],
166235 InvocationType = "RequestResponse" ,
167236 LogType = "None" , # |'Tail', Set to Tail to include the execution log in the response
168237 Payload = json .dumps (pay_load ).encode (),
169238 )
170239
171240 response_status_code = int (lambda_response ["StatusCode" ])
172241 response_function_error = lambda_response .get ("FunctionError" )
173- self .logger .debug (f"\n Response received from Lambda, attempt { current_attempt } of { max_attempts } :" )
242+ self .logger .debug (
243+ f"\n Response received from Lambda, attempt { current_attempt } of { max_attempts } :"
244+ )
174245 self .logger .debug (f'Response - StatusCode = "{ response_status_code } "' )
175246 self .logger .debug (f'Response - FunctionError = "{ response_function_error } "' )
176247
177248 response_payload = json .loads (lambda_response ["Payload" ].read ())
178249
179250 if response_status_code != 200 or response_function_error :
180251 self .logger .error (
181- f' Error in response from aws lambda \ '{ self .connection_data ["function" ] } \ ' , '
182- f' attempt { current_attempt } of { max_attempts } '
252+ f" Error in response from aws lambda '{ self .connection_data [self . CONNECTION_DATA_FUNCTION_KEY ] } ', "
253+ f" attempt { current_attempt } of { max_attempts } "
183254 )
184255 self .logger .error (f"Response - Status Code = { response_status_code } " )
185- self .logger .error (f"Response - Error Function = { response_function_error } " )
256+ self .logger .error (
257+ f"Response - Error Function = { response_function_error } "
258+ )
186259 self .logger .error (f"Response - Error Details:" )
187260 # the below is risky as it may contain actual data if this line is reached in case of success
188261 # however, the same Payload field is used to return actual error details in case of failure
0 commit comments