99import pandas as pd
1010from botocore .config import Config
1111
12- from awswrangler import _data_types , _utils
12+ from awswrangler import _data_types , _utils , exceptions
1313from awswrangler ._distributed import engine
1414from awswrangler ._threading import _get_executor , _ThreadPoolExecutor
1515from awswrangler .distributed .ray import ray_get
@@ -34,75 +34,126 @@ def _df2list(df: pd.DataFrame) -> List[List[Any]]:
3434 return parameters
3535
3636
37+ def _format_timestamp (timestamp : Union [int , datetime ]) -> str :
38+ if isinstance (timestamp , int ):
39+ return str (round (timestamp / 1_000_000 ))
40+ if isinstance (timestamp , datetime ):
41+ return str (round (timestamp .timestamp () * 1_000 ))
42+ raise exceptions .InvalidArgumentType ("`time_col` must be of type timestamp." )
43+
44+
3745def _format_measure (measure_name : str , measure_value : Any , measure_type : str ) -> Dict [str , str ]:
3846 return {
3947 "Name" : measure_name ,
40- "Value" : str ( round ( measure_value . timestamp () * 1_000 ) if measure_type == "TIMESTAMP" else measure_value ),
48+ "Value" : _format_timestamp ( measure_value ) if measure_type == "TIMESTAMP" else str ( measure_value ),
4149 "Type" : measure_type ,
4250 }
4351
4452
53+ def _sanitize_common_attributes (
54+ common_attributes : Optional [Dict [str , Any ]],
55+ version : int ,
56+ measure_name : Optional [str ],
57+ ) -> Dict [str , Any ]:
58+ common_attributes = {} if not common_attributes else common_attributes
59+ # Values in common_attributes take precedence
60+ common_attributes .setdefault ("Version" , version )
61+
62+ if "Time" not in common_attributes :
63+ # TimeUnit is MILLISECONDS by default for Timestream writes
64+ # But if a time_col is supplied (i.e. Time is not in common_attributes)
65+ # then TimeUnit must be set to MILLISECONDS explicitly
66+ common_attributes ["TimeUnit" ] = "MILLISECONDS"
67+
68+ if "MeasureValue" in common_attributes and "MeasureValueType" not in common_attributes :
69+ raise exceptions .InvalidArgumentCombination (
70+ "MeasureValueType must be supplied alongside MeasureValue in common_attributes."
71+ )
72+
73+ if measure_name :
74+ common_attributes .setdefault ("MeasureName" , measure_name )
75+ elif "MeasureName" not in common_attributes :
76+ raise exceptions .InvalidArgumentCombination (
77+ "MeasureName must be supplied with the `measure_name` argument or in common_attributes."
78+ )
79+ return common_attributes
80+
81+
4582@engine .dispatch_on_engine
4683def _write_batch (
4784 timestream_client : Optional ["TimestreamWriteClient" ],
4885 database : str ,
4986 table : str ,
50- cols_names : List [str ],
51- measure_cols_names : List [str ],
87+ common_attributes : Dict [str , Any ],
88+ cols_names : List [Optional [str ]],
89+ measure_cols : List [Optional [str ]],
5290 measure_types : List [str ],
53- version : int ,
91+ dimensions_cols : List [ Optional [ str ]] ,
5492 batch : List [Any ],
55- measure_name : Optional [str ] = None ,
5693) -> List [Dict [str , str ]]:
57- timestream_client = timestream_client if timestream_client else _utils .client (service_name = "timestream-write" )
58- try :
59- time_loc = 0
60- measure_cols_loc = 1
61- dimensions_cols_loc = 1 + len (measure_cols_names )
62- records : List [Dict [str , Any ]] = []
63- for rec in batch :
64- record : Dict [str , Any ] = {
65- "Dimensions" : [
66- {"Name" : name , "DimensionValueType" : "VARCHAR" , "Value" : str (value )}
67- for name , value in zip (cols_names [dimensions_cols_loc :], rec [dimensions_cols_loc :])
68- ],
69- "Time" : str (round (rec [time_loc ].timestamp () * 1_000 )),
70- "TimeUnit" : "MILLISECONDS" ,
71- "Version" : version ,
72- }
73- if len (measure_cols_names ) == 1 :
74- measure_value = rec [measure_cols_loc ]
75- if pd .isnull (measure_value ):
76- continue
77- record ["MeasureName" ] = measure_name if measure_name else measure_cols_names [0 ]
78- record ["MeasureValueType" ] = measure_types [0 ]
79- record ["MeasureValue" ] = str (measure_value )
80- else :
81- record ["MeasureName" ] = measure_name if measure_name else measure_cols_names [0 ]
82- record ["MeasureValueType" ] = "MULTI"
83- record ["MeasureValues" ] = [
84- _format_measure (measure_name , measure_value , measure_value_type )
85- for measure_name , measure_value , measure_value_type in zip (
86- measure_cols_names , rec [measure_cols_loc :dimensions_cols_loc ], measure_types
87- )
88- if not pd .isnull (measure_value )
89- ]
90- if len (record ["MeasureValues" ]) == 0 :
91- continue
94+ client_timestream = timestream_client if timestream_client else _utils .client (service_name = "timestream-write" )
95+ records : List [Dict [str , Any ]] = []
96+ scalar = bool (len (measure_cols ) == 1 and "MeasureValues" not in common_attributes )
97+ time_loc = 0
98+ measure_cols_loc = 1 if cols_names [0 ] else 0
99+ dimensions_cols_loc = 1 if len (measure_cols ) == 1 else 1 + len (measure_cols )
100+ if all (cols_names ):
101+ # Time and Measures are supplied in the data frame
102+ dimensions_cols_loc = 1 + len (measure_cols )
103+ elif all (v is None for v in cols_names [:2 ]):
104+ # Time and Measures are supplied in common_attributes
105+ dimensions_cols_loc = 0
106+
107+ for row in batch :
108+ record : Dict [str , Any ] = {}
109+ if "Time" not in common_attributes :
110+ record ["Time" ] = _format_timestamp (row [time_loc ])
111+ if scalar and "MeasureValue" not in common_attributes :
112+ measure_value = row [measure_cols_loc ]
113+ if pd .isnull (measure_value ):
114+ continue
115+ record ["MeasureValue" ] = str (measure_value )
116+ elif not scalar and "MeasureValues" not in common_attributes :
117+ record ["MeasureValues" ] = [
118+ _format_measure (measure_name , measure_value , measure_value_type ) # type: ignore[arg-type]
119+ for measure_name , measure_value , measure_value_type in zip (
120+ measure_cols , row [measure_cols_loc :dimensions_cols_loc ], measure_types
121+ )
122+ if not pd .isnull (measure_value )
123+ ]
124+ if len (record ["MeasureValues" ]) == 0 :
125+ continue
126+ if "MeasureValueType" not in common_attributes :
127+ record ["MeasureValueType" ] = measure_types [0 ] if scalar else "MULTI"
128+ # Dimensions can be specified in both common_attributes and the data frame
129+ dimensions = (
130+ [
131+ {"Name" : name , "DimensionValueType" : "VARCHAR" , "Value" : str (value )}
132+ for name , value in zip (dimensions_cols , row [dimensions_cols_loc :])
133+ ]
134+ if all (dimensions_cols )
135+ else []
136+ )
137+ if dimensions :
138+ record ["Dimensions" ] = dimensions
139+ if record :
92140 records .append (record )
93- if len (records ) > 0 :
141+
142+ try :
143+ if records :
94144 _utils .try_it (
95- f = timestream_client .write_records ,
145+ f = client_timestream .write_records ,
96146 ex = (
97- timestream_client .exceptions .ThrottlingException ,
98- timestream_client .exceptions .InternalServerException ,
147+ client_timestream .exceptions .ThrottlingException ,
148+ client_timestream .exceptions .InternalServerException ,
99149 ),
100150 max_num_tries = 5 ,
101151 DatabaseName = database ,
102152 TableName = table ,
153+ CommonAttributes = common_attributes ,
103154 Records = records ,
104155 )
105- except timestream_client .exceptions .RejectedRecordsException as ex :
156+ except client_timestream .exceptions .RejectedRecordsException as ex :
106157 return cast (List [Dict [str , str ]], ex .response ["RejectedRecords" ])
107158 return []
108159
@@ -113,12 +164,12 @@ def _write_df(
113164 executor : _ThreadPoolExecutor ,
114165 database : str ,
115166 table : str ,
116- cols_names : List [str ],
117- measure_cols_names : List [str ],
167+ common_attributes : Dict [str , Any ],
168+ cols_names : List [Optional [str ]],
169+ measure_cols : List [Optional [str ]],
118170 measure_types : List [str ],
119- version : int ,
171+ dimensions_cols : List [ Optional [ str ]] ,
120172 boto3_session : Optional [boto3 .Session ],
121- measure_name : Optional [str ] = None ,
122173) -> List [Dict [str , str ]]:
123174 timestream_client = _utils .client (
124175 service_name = "timestream-write" ,
@@ -132,12 +183,12 @@ def _write_df(
132183 timestream_client ,
133184 itertools .repeat (database ),
134185 itertools .repeat (table ),
186+ itertools .repeat (common_attributes ),
135187 itertools .repeat (cols_names ),
136- itertools .repeat (measure_cols_names ),
188+ itertools .repeat (measure_cols ),
137189 itertools .repeat (measure_types ),
138- itertools .repeat (version ),
190+ itertools .repeat (dimensions_cols ),
139191 batches ,
140- itertools .repeat (measure_name ),
141192 )
142193
143194
@@ -241,12 +292,13 @@ def write(
241292 df : pd .DataFrame ,
242293 database : str ,
243294 table : str ,
244- time_col : str ,
245- measure_col : Union [str , List [str ]],
246- dimensions_cols : List [str ],
295+ time_col : Optional [ str ] = None ,
296+ measure_col : Union [str , List [Optional [ str ]], None ] = None ,
297+ dimensions_cols : Optional [ List [Optional [ str ]]] = None ,
247298 version : int = 1 ,
248299 use_threads : Union [bool , int ] = True ,
249300 measure_name : Optional [str ] = None ,
301+ common_attributes : Optional [Dict [str , Any ]] = None ,
250302 boto3_session : Optional [boto3 .Session ] = None ,
251303) -> List [Dict [str , str ]]:
252304 """Store a Pandas DataFrame into a Amazon Timestream table.
@@ -259,6 +311,16 @@ def write(
259311 this function will not throw a Python exception.
260312 Instead it will return the rejection information.
261313
314+ Note
315+ ----
316+ Values in `common_attributes` take precedence over all other arguments and data frame values.
317+ Dimension attributes are merged with attributes in record objects.
318+ Example: common_attributes = {"Dimensions": {"Name": "device_id", "Value": "12345"}, "MeasureValueType": "DOUBLE"}.
319+
320+ Note
321+ ----
322+ If the `time_col` column is supplied it must be of type timestamp. `TimeUnit` is set to MILLISECONDS by default.
323+
262324 Parameters
263325 ----------
264326 df : pandas.DataFrame
@@ -267,11 +329,11 @@ def write(
267329 Amazon Timestream database name.
268330 table : str
269331 Amazon Timestream table name.
270- time_col : str
332+ time_col : Optional[ str]
271333 DataFrame column name to be used as time. MUST be a timestamp column.
272- measure_col : Union[str, List[str]]
334+ measure_col : Union[str, List[str], None ]
273335 DataFrame column name(s) to be used as measure.
274- dimensions_cols : List[str]
336+ dimensions_cols : Optional[ List[str] ]
275337 List of DataFrame column names to be used as dimensions.
276338 version : int
277339 Version number used for upserts.
@@ -283,6 +345,9 @@ def write(
283345 measure_name : Optional[str]
284346 Name that represents the data attribute of the time series.
285347 Overrides ``measure_col`` if specified.
348+ common_attributes : Optional[Dict[str, Any]]
349+ Dictionary of attributes that is shared across all records in the request.
350+ Using common attributes can optimize the cost of writes by reducing the size of request payloads.
286351 boto3_session : boto3.Session(), optional
287352 Boto3 Session. The default boto3 Session will be used if boto3_session receive None.
288353
@@ -329,13 +394,35 @@ def write(
329394 >>> ]
330395
331396 """
332- measure_cols_names = measure_col if isinstance (measure_col , list ) else [measure_col ]
333- _logger .debug ("measure_cols_names: %s" , measure_cols_names )
334- measure_types : List [str ] = _data_types .timestream_type_from_pandas (df .loc [:, measure_cols_names ])
335- _logger .debug ("measure_types: %s" , measure_types )
336- cols_names : List [str ] = [time_col ] + measure_cols_names + dimensions_cols
337- _logger .debug ("cols_names: %s" , cols_names )
338- dfs = _utils .split_pandas_frame (df .loc [:, cols_names ], _utils .ensure_cpu_count (use_threads = use_threads ))
397+ measure_cols = measure_col if isinstance (measure_col , list ) else [measure_col ]
398+ measure_types : List [str ] = (
399+ _data_types .timestream_type_from_pandas (df .loc [:, measure_cols ]) if all (measure_cols ) else []
400+ )
401+ dimensions_cols = dimensions_cols if dimensions_cols else [dimensions_cols ] # type: ignore[list-item]
402+ cols_names : List [Optional [str ]] = [time_col ] + measure_cols + dimensions_cols
403+ measure_name = measure_name if measure_name else measure_cols [0 ]
404+ common_attributes = _sanitize_common_attributes (common_attributes , version , measure_name )
405+
406+ _logger .debug (
407+ "common_attributes: %s\n , cols_names: %s\n , measure_types: %s" ,
408+ common_attributes ,
409+ cols_names ,
410+ measure_types ,
411+ )
412+
413+ # User can supply arguments in one of two ways:
414+ # 1. With the `common_attributes` dictionary which takes precedence
415+ # 2. With data frame columns
416+ # However, the data frame cannot be completely empty.
417+ # So if all values in `cols_names` are None, an exception is raised.
418+ if any (cols_names ):
419+ dfs = _utils .split_pandas_frame (
420+ df .loc [:, [c for c in cols_names if c ]], _utils .ensure_cpu_count (use_threads = use_threads )
421+ )
422+ else :
423+ raise exceptions .InvalidArgumentCombination (
424+ "At least one of `time_col`, `measure_col` or `dimensions_cols` must be specified."
425+ )
339426 _logger .debug ("len(dfs): %s" , len (dfs ))
340427
341428 executor = _get_executor (use_threads = use_threads )
@@ -348,12 +435,12 @@ def write(
348435 executor = executor ,
349436 database = database ,
350437 table = table ,
438+ common_attributes = common_attributes ,
351439 cols_names = cols_names ,
352- measure_cols_names = measure_cols_names ,
440+ measure_cols = measure_cols ,
353441 measure_types = measure_types ,
354- version = version ,
442+ dimensions_cols = dimensions_cols ,
355443 boto3_session = boto3_session ,
356- measure_name = measure_name ,
357444 )
358445 for df in dfs
359446 ]
0 commit comments