1010import pandas as pd
1111from botocore .config import Config
1212
13- from awswrangler import _data_types , _utils
13+ from awswrangler import _data_types , _utils , exceptions
1414
1515_logger : logging .Logger = logging .getLogger (__name__ )
1616
@@ -27,61 +27,111 @@ def _df2list(df: pd.DataFrame) -> List[List[Any]]:
2727 return parameters
2828
2929
30+ def _format_timestamp (timestamp : Union [int , datetime ]) -> str :
31+ if isinstance (timestamp , int ):
32+ return str (round (timestamp / 1_000_000 ))
33+ if isinstance (timestamp , datetime ):
34+ return str (round (timestamp .timestamp () * 1_000 ))
35+ raise exceptions .InvalidArgumentType ("`time_col` must be of type timestamp." )
36+
37+
3038def _format_measure (measure_name : str , measure_value : Any , measure_type : str ) -> Dict [str , str ]:
3139 return {
3240 "Name" : measure_name ,
33- "Value" : str ( round ( measure_value . timestamp () * 1_000 ) if measure_type == "TIMESTAMP" else measure_value ),
41+ "Value" : _format_timestamp ( measure_value ) if measure_type == "TIMESTAMP" else str ( measure_value ),
3442 "Type" : measure_type ,
3543 }
3644
3745
46+ def _sanitize_common_attributes (
47+ common_attributes : Optional [Dict [str , Any ]],
48+ version : int ,
49+ measure_name : Optional [str ],
50+ ) -> Dict [str , Any ]:
51+ common_attributes = {} if not common_attributes else common_attributes
52+ # Values in common_attributes take precedence
53+ common_attributes .setdefault ("Version" , version )
54+
55+ if "Time" not in common_attributes :
56+ # TimeUnit is MILLISECONDS by default for Timestream writes
57+ # But if a time_col is supplied (i.e. Time is not in common_attributes)
58+ # then TimeUnit must be set to MILLISECONDS explicitly
59+ common_attributes ["TimeUnit" ] = "MILLISECONDS"
60+
61+ if "MeasureValue" in common_attributes and "MeasureValueType" not in common_attributes :
62+ raise exceptions .InvalidArgumentCombination (
63+ "MeasureValueType must be supplied alongside MeasureValue in common_attributes."
64+ )
65+
66+ if measure_name :
67+ common_attributes .setdefault ("MeasureName" , measure_name )
68+ elif "MeasureName" not in common_attributes :
69+ raise exceptions .InvalidArgumentCombination (
70+ "MeasureName must be supplied with the `measure_name` argument or in common_attributes."
71+ )
72+ return common_attributes
73+
74+
3875def _write_batch (
76+ timestream_client : boto3 .client ,
3977 database : str ,
4078 table : str ,
41- cols_names : List [str ],
42- measure_cols_names : List [str ],
79+ common_attributes : Dict [str , Any ],
80+ cols_names : List [Optional [str ]],
81+ measure_cols : List [Optional [str ]],
4382 measure_types : List [str ],
44- version : int ,
83+ dimension_cols : List [ Optional [ str ]] ,
4584 batch : List [Any ],
46- timestream_client : boto3 .client ,
47- measure_name : Optional [str ] = None ,
4885) -> List [Dict [str , str ]]:
49- try :
50- time_loc = 0
51- measure_cols_loc = 1
52- dimensions_cols_loc = 1 + len (measure_cols_names )
53- records : List [Dict [str , Any ]] = []
54- for rec in batch :
55- record : Dict [str , Any ] = {
56- "Dimensions" : [
57- {"Name" : name , "DimensionValueType" : "VARCHAR" , "Value" : str (value )}
58- for name , value in zip (cols_names [dimensions_cols_loc :], rec [dimensions_cols_loc :])
59- ],
60- "Time" : str (round (rec [time_loc ].timestamp () * 1_000 )),
61- "TimeUnit" : "MILLISECONDS" ,
62- "Version" : version ,
63- }
64- if len (measure_cols_names ) == 1 :
65- measure_value = rec [measure_cols_loc ]
66- if pd .isnull (measure_value ):
67- continue
68- record ["MeasureName" ] = measure_name if measure_name else measure_cols_names [0 ]
69- record ["MeasureValueType" ] = measure_types [0 ]
70- record ["MeasureValue" ] = str (measure_value )
71- else :
72- record ["MeasureName" ] = measure_name if measure_name else measure_cols_names [0 ]
73- record ["MeasureValueType" ] = "MULTI"
74- record ["MeasureValues" ] = [
75- _format_measure (measure_name , measure_value , measure_value_type )
76- for measure_name , measure_value , measure_value_type in zip (
77- measure_cols_names , rec [measure_cols_loc :dimensions_cols_loc ], measure_types
78- )
79- if not pd .isnull (measure_value )
80- ]
81- if len (record ["MeasureValues" ]) == 0 :
82- continue
86+ records : List [Dict [str , Any ]] = []
87+ scalar = bool (len (measure_cols ) == 1 and "MeasureValues" not in common_attributes )
88+ time_loc = 0
89+ measure_cols_loc = 1 if cols_names [0 ] else 0
90+ dimensions_cols_loc = 1 if len (measure_cols ) == 1 else 1 + len (measure_cols )
91+ if all (cols_names ):
92+ # Time and Measures are supplied in the data frame
93+ dimensions_cols_loc = 1 + len (measure_cols )
94+ elif all (v is None for v in cols_names [:2 ]):
95+ # Time and Measures are supplied in common_attributes
96+ dimensions_cols_loc = 0
97+
98+ for row in batch :
99+ record : Dict [str , Any ] = {}
100+ if "Time" not in common_attributes :
101+ record ["Time" ] = _format_timestamp (row [time_loc ])
102+ if scalar and "MeasureValue" not in common_attributes :
103+ measure_value = row [measure_cols_loc ]
104+ if pd .isnull (measure_value ):
105+ continue
106+ record ["MeasureValue" ] = str (measure_value )
107+ elif not scalar and "MeasureValues" not in common_attributes :
108+ record ["MeasureValues" ] = [
109+ _format_measure (measure_name , measure_value , measure_value_type ) # type: ignore[arg-type]
110+ for measure_name , measure_value , measure_value_type in zip (
111+ measure_cols , row [measure_cols_loc :dimensions_cols_loc ], measure_types
112+ )
113+ if not pd .isnull (measure_value )
114+ ]
115+ if len (record ["MeasureValues" ]) == 0 :
116+ continue
117+ if "MeasureValueType" not in common_attributes :
118+ record ["MeasureValueType" ] = measure_types [0 ] if scalar else "MULTI"
119+ # Dimensions can be specified in both common_attributes and the data frame
120+ dimensions = (
121+ [
122+ {"Name" : name , "DimensionValueType" : "VARCHAR" , "Value" : str (value )}
123+ for name , value in zip (dimension_cols , row [dimensions_cols_loc :])
124+ ]
125+ if all (dimension_cols )
126+ else []
127+ )
128+ if dimensions :
129+ record ["Dimensions" ] = dimensions
130+ if record :
83131 records .append (record )
84- if len (records ) > 0 :
132+
133+ try :
134+ if records :
85135 _utils .try_it (
86136 f = timestream_client .write_records ,
87137 ex = (
@@ -91,6 +141,7 @@ def _write_batch(
91141 max_num_tries = 5 ,
92142 DatabaseName = database ,
93143 TableName = table ,
144+ CommonAttributes = common_attributes ,
94145 Records = records ,
95146 )
96147 except timestream_client .exceptions .RejectedRecordsException as ex :
@@ -192,12 +243,13 @@ def write(
192243 df : pd .DataFrame ,
193244 database : str ,
194245 table : str ,
195- time_col : str ,
196- measure_col : Union [str , List [str ]],
197- dimensions_cols : List [str ],
246+ time_col : Optional [ str ] = None ,
247+ measure_col : Union [str , List [Optional [ str ]], None ] = None ,
248+ dimensions_cols : Optional [ List [Optional [ str ]]] = None ,
198249 version : int = 1 ,
199250 num_threads : int = 32 ,
200251 measure_name : Optional [str ] = None ,
252+ common_attributes : Optional [Dict [str , Any ]] = None ,
201253 boto3_session : Optional [boto3 .Session ] = None ,
202254) -> List [Dict [str , str ]]:
203255 """Store a Pandas DataFrame into a Amazon Timestream table.
@@ -206,6 +258,16 @@ def write(
206258 this function will not throw a Python exception.
207259 Instead it will return the rejection information.
208260
261+ Note
262+ ----
263+ Values in `common_attributes` take precedence over all other arguments and data frame values.
264+ Dimension attributes are merged with attributes in record objects.
265+ Example: common_attributes = {"Dimensions": {"Name": "device_id", "Value": "12345"}, "MeasureValueType": "DOUBLE"}.
266+
267+ Note
268+ ----
269+ If the `time_col` column is supplied it must be of type timestamp. `TimeUnit` is set to MILLISECONDS by default.
270+
209271 Parameters
210272 ----------
211273 df: pandas.DataFrame
@@ -214,18 +276,21 @@ def write(
214276 Amazon Timestream database name.
215277 table : str
216278 Amazon Timestream table name.
217- time_col : str
279+ time_col : Optional[ str]
218280 DataFrame column name to be used as time. MUST be a timestamp column.
219- measure_col : Union[str, List[str]]
281+ measure_col : Union[str, List[str], None ]
220282 DataFrame column name(s) to be used as measure.
221- dimensions_cols : List[str]
283+ dimensions_cols : Optional[ List[str] ]
222284 List of DataFrame column names to be used as dimensions.
223285 version : int
224286 Version number used for upserts.
225287 Documentation https://docs.aws.amazon.com/timestream/latest/developerguide/API_WriteRecords.html.
226288 measure_name : Optional[str]
227289 Name that represents the data attribute of the time series.
228290 Overrides ``measure_col`` if specified.
291+ common_attributes : Optional[Dict[str, Any]]
292+ Dictionary of attributes that is shared across all records in the request.
293+ Using common attributes can optimize the cost of writes by reducing the size of request payloads.
229294 num_threads : str
230295 Number of thread to be used for concurrent writing.
231296 boto3_session : boto3.Session(), optional
@@ -279,30 +344,50 @@ def write(
279344 session = boto3_session ,
280345 botocore_config = Config (read_timeout = 20 , max_pool_connections = 5000 , retries = {"max_attempts" : 10 }),
281346 )
282- measure_cols_names = measure_col if isinstance (measure_col , list ) else [measure_col ]
283- _logger .debug ("measure_cols_names: %s" , measure_cols_names )
284347
285- measure_types : List [str ] = [
286- _data_types .timestream_type_from_pandas (df [[measure_col_name ]]) for measure_col_name in measure_cols_names
348+ measure_cols = measure_col if isinstance (measure_col , list ) else [measure_col ]
349+ measure_types = [
350+ _data_types .timestream_type_from_pandas (df [[measure_col_name ]])
351+ for measure_col_name in measure_cols
352+ if measure_col_name
287353 ]
288- _logger .debug ("measure_types: %s" , measure_types )
289- cols_names : List [str ] = [time_col ] + measure_cols_names + dimensions_cols
290- _logger .debug ("cols_names: %s" , cols_names )
291- batches : List [List [Any ]] = _utils .chunkify (lst = _df2list (df = df [cols_names ]), max_length = 100 )
292- _logger .debug ("len(batches): %s" , len (batches ))
354+ dimensions_cols = dimensions_cols if dimensions_cols else [dimensions_cols ] # type: ignore[list-item]
355+ cols_names : List [Optional [str ]] = [time_col ] + measure_cols + dimensions_cols
356+ measure_name = measure_name if measure_name else measure_cols [0 ]
357+ common_attributes = _sanitize_common_attributes (common_attributes , version , measure_name )
358+
359+ _logger .debug (
360+ "common_attributes: %s\n , cols_names: %s\n , measure_types: %s" ,
361+ common_attributes ,
362+ cols_names ,
363+ measure_types ,
364+ )
365+
366+ # User can supply arguments in one of two ways:
367+ # 1. With the `common_attributes` dictionary which takes precedence
368+ # 2. With data frame columns
369+ # However, the data frame cannot be completely empty.
370+ # So if all values in `cols_names` are None, an exception is raised.
371+ if any (cols_names ):
372+ batches : List [List [Any ]] = _utils .chunkify (lst = _df2list (df = df [[c for c in cols_names if c ]]), max_length = 100 )
373+ else :
374+ raise exceptions .InvalidArgumentCombination (
375+ "At least one of `time_col`, `measure_col` or `dimensions_cols` must be specified."
376+ )
377+
293378 with concurrent .futures .ThreadPoolExecutor (max_workers = num_threads ) as executor :
294379 res : List [List [Any ]] = list (
295380 executor .map (
296381 _write_batch ,
382+ itertools .repeat (timestream_client ),
297383 itertools .repeat (database ),
298384 itertools .repeat (table ),
385+ itertools .repeat (common_attributes ),
299386 itertools .repeat (cols_names ),
300- itertools .repeat (measure_cols_names ),
387+ itertools .repeat (measure_cols ),
301388 itertools .repeat (measure_types ),
302- itertools .repeat (version ),
389+ itertools .repeat (dimensions_cols ),
303390 batches ,
304- itertools .repeat (timestream_client ),
305- itertools .repeat (measure_name ),
306391 )
307392 )
308393 return [item for sublist in res for item in sublist ]
0 commit comments