|
| 1 | +import logging |
| 2 | +import pandas |
| 3 | +import json |
| 4 | +import boto3 |
| 5 | + |
| 6 | +from rdl.data_sources.ChangeTrackingInfo import ChangeTrackingInfo |
| 7 | +from rdl.data_sources.SourceTableInfo import SourceTableInfo |
| 8 | +from rdl.shared import Providers |
| 9 | +from rdl.shared.Utils import prevent_senstive_data_logging |
| 10 | + |
| 11 | + |
| 12 | +class AWSLambdaDataSource(object): |
| 13 | + # 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;' |
| 14 | + CONNECTION_STRING_PREFIX = "aws-lambda://" |
| 15 | + CONNECTION_STRING_GROUP_SEPARATOR = ";" |
| 16 | + CONNECTION_STRING_KEY_VALUE_SEPARATOR = "=" |
| 17 | + |
| 18 | + def __init__(self, connection_string, logger=None): |
| 19 | + self.logger = logger or logging.getLogger(__name__) |
| 20 | + if not AWSLambdaDataSource.can_handle_connection_string(connection_string): |
| 21 | + raise ValueError(connection_string) |
| 22 | + self.connection_string = connection_string |
| 23 | + self.connection_data = dict( |
| 24 | + kv.split(AWSLambdaDataSource.CONNECTION_STRING_KEY_VALUE_SEPARATOR) |
| 25 | + for kv in self.connection_string.lstrip( |
| 26 | + AWSLambdaDataSource.CONNECTION_STRING_PREFIX |
| 27 | + ) |
| 28 | + .rstrip(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR) |
| 29 | + .split(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR) |
| 30 | + ) |
| 31 | + self.aws_lambda_client = boto3.client("lambda") |
| 32 | + |
| 33 | + @staticmethod |
| 34 | + def can_handle_connection_string(connection_string): |
| 35 | + return connection_string.startswith( |
| 36 | + AWSLambdaDataSource.CONNECTION_STRING_PREFIX |
| 37 | + ) and len(connection_string) != len( |
| 38 | + AWSLambdaDataSource.CONNECTION_STRING_PREFIX |
| 39 | + ) |
| 40 | + |
| 41 | + @staticmethod |
| 42 | + def get_connection_string_prefix(): |
| 43 | + return AWSLambdaDataSource.CONNECTION_STRING_PREFIX |
| 44 | + |
| 45 | + def get_table_info(self, table_config, last_known_sync_version): |
| 46 | + column_names, last_sync_version, sync_version, full_refresh_required, data_changed_since_last_sync \ |
| 47 | + = self.__get_table_info(table_config, last_known_sync_version) |
| 48 | + columns_in_database = column_names |
| 49 | + change_tracking_info = ChangeTrackingInfo( |
| 50 | + last_sync_version=last_sync_version, |
| 51 | + sync_version=sync_version, |
| 52 | + force_full_load=full_refresh_required, |
| 53 | + data_changed_since_last_sync=data_changed_since_last_sync, |
| 54 | + ) |
| 55 | + source_table_info = SourceTableInfo(columns_in_database, change_tracking_info) |
| 56 | + return source_table_info |
| 57 | + |
| 58 | + @prevent_senstive_data_logging |
| 59 | + def get_table_data_frame( |
| 60 | + self, |
| 61 | + table_config, |
| 62 | + columns_config, |
| 63 | + batch_config, |
| 64 | + batch_tracker, |
| 65 | + batch_key_tracker, |
| 66 | + full_refresh, |
| 67 | + change_tracking_info, |
| 68 | + ): |
| 69 | + self.logger.debug(f"Starting read data from lambda.. : \n{None}") |
| 70 | + column_names, data = self.__get_table_data( |
| 71 | + table_config, |
| 72 | + batch_config, |
| 73 | + change_tracking_info, |
| 74 | + full_refresh, |
| 75 | + columns_config, |
| 76 | + batch_key_tracker, |
| 77 | + ) |
| 78 | + self.logger.debug(f"Finished read data from lambda.. : \n{None}") |
| 79 | + # should we log size of data extracted? |
| 80 | + data_frame = self.__get_data_frame(data, column_names) |
| 81 | + batch_tracker.extract_completed_successfully(len(data_frame)) |
| 82 | + return data_frame |
| 83 | + |
| 84 | + def __get_table_info(self, table_config, last_known_sync_version): |
| 85 | + pay_load = { |
| 86 | + "Command": "GetTableInfo", |
| 87 | + "TenantId": int(self.connection_data["tenant"]), |
| 88 | + "Table": {"Schema": table_config["schema"], "Name": table_config["name"]}, |
| 89 | + "CommandPayload": {"LastSyncVersion": last_known_sync_version}, |
| 90 | + } |
| 91 | + |
| 92 | + result = self.__invoke_lambda(pay_load) |
| 93 | + |
| 94 | + return result["ColumnNames"], \ |
| 95 | + result["LastSyncVersion"], \ |
| 96 | + result["CurrentSyncVersion"], \ |
| 97 | + result["FullRefreshRequired"], \ |
| 98 | + result["DataChangedSinceLastSync"] |
| 99 | + |
| 100 | + def __get_table_data( |
| 101 | + self, |
| 102 | + table_config, |
| 103 | + batch_config, |
| 104 | + change_tracking_info, |
| 105 | + full_refresh, |
| 106 | + columns_config, |
| 107 | + batch_key_tracker, |
| 108 | + ): |
| 109 | + pay_load = { |
| 110 | + "Command": "GetTableData", |
| 111 | + "TenantId": int(self.connection_data["tenant"]), |
| 112 | + "Table": {"Schema": table_config["schema"], "Name": table_config["name"]}, |
| 113 | + "CommandPayload": { |
| 114 | + "AuditColumnNameForChangeVersion": Providers.AuditColumnsNames.CHANGE_VERSION, |
| 115 | + "AuditColumnNameForDeletionFlag": Providers.AuditColumnsNames.IS_DELETED, |
| 116 | + "BatchSize": batch_config["size"], |
| 117 | + "LastSyncVersion": change_tracking_info.last_sync_version, |
| 118 | + "FullRefresh": full_refresh, |
| 119 | + "ColumnNames": list(map(lambda cfg: cfg['source_name'], columns_config)), |
| 120 | + "PrimaryKeyColumnNames": table_config["primary_keys"], |
| 121 | + "LastBatchPrimaryKeys": [ |
| 122 | + {"Key": k, "Value": v} for k, v in batch_key_tracker.bookmarks.items() |
| 123 | + ], |
| 124 | + }, |
| 125 | + } |
| 126 | + |
| 127 | + result = self.__invoke_lambda(pay_load) |
| 128 | + |
| 129 | + return result["ColumnNames"], result["Data"] |
| 130 | + |
| 131 | + def __get_data_frame(self, data: [[]], column_names: []): |
| 132 | + return pandas.DataFrame(data=data, columns=column_names) |
| 133 | + |
| 134 | + def __invoke_lambda(self, pay_load): |
| 135 | + self.logger.debug('\nRequest being sent to Lambda:') |
| 136 | + self.logger.debug(pay_load) |
| 137 | + |
| 138 | + lambda_response = self.aws_lambda_client.invoke( |
| 139 | + FunctionName=self.connection_data["function"], |
| 140 | + InvocationType="RequestResponse", |
| 141 | + LogType="None", # |'Tail', Set to Tail to include the execution log in the response |
| 142 | + Payload=json.dumps(pay_load).encode(), |
| 143 | + ) |
| 144 | + |
| 145 | + response_status_code = int(lambda_response['StatusCode']) |
| 146 | + response_function_error = lambda_response.get("FunctionError") |
| 147 | + self.logger.debug('\nResponse received from Lambda:') |
| 148 | + self.logger.debug(f'Response - StatusCode = "{response_status_code}"') |
| 149 | + self.logger.debug(f'Response - FunctionError = "{response_function_error}"') |
| 150 | + |
| 151 | + response_payload = json.loads(lambda_response['Payload'].read()) |
| 152 | + |
| 153 | + if response_status_code != 200 or response_function_error: |
| 154 | + self.logger.error(F'Error in response from aws lambda {self.connection_data["function"]}') |
| 155 | + self.logger.error(f'Response - Status Code = {response_status_code}') |
| 156 | + self.logger.error(f'Response - Error Function = {response_function_error}') |
| 157 | + self.logger.error(f'Response - Error Details:') |
| 158 | + # the below is risky as it may contain actual data if this line is reached in case of a successful result |
| 159 | + # however, the same Payload field is used to return actual error details in case of real errors |
| 160 | + # i.e. StatusCode is 200 (since AWS could invoke the lambda) |
| 161 | + # BUT the lambda barfed with an error and therefore the FunctionError would not be None |
| 162 | + self.logger.error(response_payload) |
| 163 | + raise Exception('Error received when invoking AWS Lambda. See logs for further details.') |
| 164 | + |
| 165 | + return response_payload |
0 commit comments