Skip to content

Commit 9caa1e2

Browse files
committed
Resolve conflicts with main
2 parents 589c114 + fb1d228 commit 9caa1e2

File tree

2 files changed

+314
-70
lines changed

2 files changed

+314
-70
lines changed

awswrangler/timestream.py

Lines changed: 156 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas as pd
1010
from botocore.config import Config
1111

12-
from awswrangler import _data_types, _utils
12+
from awswrangler import _data_types, _utils, exceptions
1313
from awswrangler._distributed import engine
1414
from awswrangler._threading import _get_executor, _ThreadPoolExecutor
1515
from awswrangler.distributed.ray import ray_get
@@ -34,75 +34,126 @@ def _df2list(df: pd.DataFrame) -> List[List[Any]]:
3434
return parameters
3535

3636

37+
def _format_timestamp(timestamp: Union[int, datetime]) -> str:
38+
if isinstance(timestamp, int):
39+
return str(round(timestamp / 1_000_000))
40+
if isinstance(timestamp, datetime):
41+
return str(round(timestamp.timestamp() * 1_000))
42+
raise exceptions.InvalidArgumentType("`time_col` must be of type timestamp.")
43+
44+
3745
def _format_measure(measure_name: str, measure_value: Any, measure_type: str) -> Dict[str, str]:
3846
return {
3947
"Name": measure_name,
40-
"Value": str(round(measure_value.timestamp() * 1_000) if measure_type == "TIMESTAMP" else measure_value),
48+
"Value": _format_timestamp(measure_value) if measure_type == "TIMESTAMP" else str(measure_value),
4149
"Type": measure_type,
4250
}
4351

4452

53+
def _sanitize_common_attributes(
54+
common_attributes: Optional[Dict[str, Any]],
55+
version: int,
56+
measure_name: Optional[str],
57+
) -> Dict[str, Any]:
58+
common_attributes = {} if not common_attributes else common_attributes
59+
# Values in common_attributes take precedence
60+
common_attributes.setdefault("Version", version)
61+
62+
if "Time" not in common_attributes:
63+
# TimeUnit is MILLISECONDS by default for Timestream writes
64+
# But if a time_col is supplied (i.e. Time is not in common_attributes)
65+
# then TimeUnit must be set to MILLISECONDS explicitly
66+
common_attributes["TimeUnit"] = "MILLISECONDS"
67+
68+
if "MeasureValue" in common_attributes and "MeasureValueType" not in common_attributes:
69+
raise exceptions.InvalidArgumentCombination(
70+
"MeasureValueType must be supplied alongside MeasureValue in common_attributes."
71+
)
72+
73+
if measure_name:
74+
common_attributes.setdefault("MeasureName", measure_name)
75+
elif "MeasureName" not in common_attributes:
76+
raise exceptions.InvalidArgumentCombination(
77+
"MeasureName must be supplied with the `measure_name` argument or in common_attributes."
78+
)
79+
return common_attributes
80+
81+
4582
@engine.dispatch_on_engine
4683
def _write_batch(
4784
timestream_client: Optional["TimestreamWriteClient"],
4885
database: str,
4986
table: str,
50-
cols_names: List[str],
51-
measure_cols_names: List[str],
87+
common_attributes: Dict[str, Any],
88+
cols_names: List[Optional[str]],
89+
measure_cols: List[Optional[str]],
5290
measure_types: List[str],
53-
version: int,
91+
dimensions_cols: List[Optional[str]],
5492
batch: List[Any],
55-
measure_name: Optional[str] = None,
5693
) -> List[Dict[str, str]]:
57-
timestream_client = timestream_client if timestream_client else _utils.client(service_name="timestream-write")
58-
try:
59-
time_loc = 0
60-
measure_cols_loc = 1
61-
dimensions_cols_loc = 1 + len(measure_cols_names)
62-
records: List[Dict[str, Any]] = []
63-
for rec in batch:
64-
record: Dict[str, Any] = {
65-
"Dimensions": [
66-
{"Name": name, "DimensionValueType": "VARCHAR", "Value": str(value)}
67-
for name, value in zip(cols_names[dimensions_cols_loc:], rec[dimensions_cols_loc:])
68-
],
69-
"Time": str(round(rec[time_loc].timestamp() * 1_000)),
70-
"TimeUnit": "MILLISECONDS",
71-
"Version": version,
72-
}
73-
if len(measure_cols_names) == 1:
74-
measure_value = rec[measure_cols_loc]
75-
if pd.isnull(measure_value):
76-
continue
77-
record["MeasureName"] = measure_name if measure_name else measure_cols_names[0]
78-
record["MeasureValueType"] = measure_types[0]
79-
record["MeasureValue"] = str(measure_value)
80-
else:
81-
record["MeasureName"] = measure_name if measure_name else measure_cols_names[0]
82-
record["MeasureValueType"] = "MULTI"
83-
record["MeasureValues"] = [
84-
_format_measure(measure_name, measure_value, measure_value_type)
85-
for measure_name, measure_value, measure_value_type in zip(
86-
measure_cols_names, rec[measure_cols_loc:dimensions_cols_loc], measure_types
87-
)
88-
if not pd.isnull(measure_value)
89-
]
90-
if len(record["MeasureValues"]) == 0:
91-
continue
94+
client_timestream = timestream_client if timestream_client else _utils.client(service_name="timestream-write")
95+
records: List[Dict[str, Any]] = []
96+
scalar = bool(len(measure_cols) == 1 and "MeasureValues" not in common_attributes)
97+
time_loc = 0
98+
measure_cols_loc = 1 if cols_names[0] else 0
99+
dimensions_cols_loc = 1 if len(measure_cols) == 1 else 1 + len(measure_cols)
100+
if all(cols_names):
101+
# Time and Measures are supplied in the data frame
102+
dimensions_cols_loc = 1 + len(measure_cols)
103+
elif all(v is None for v in cols_names[:2]):
104+
# Time and Measures are supplied in common_attributes
105+
dimensions_cols_loc = 0
106+
107+
for row in batch:
108+
record: Dict[str, Any] = {}
109+
if "Time" not in common_attributes:
110+
record["Time"] = _format_timestamp(row[time_loc])
111+
if scalar and "MeasureValue" not in common_attributes:
112+
measure_value = row[measure_cols_loc]
113+
if pd.isnull(measure_value):
114+
continue
115+
record["MeasureValue"] = str(measure_value)
116+
elif not scalar and "MeasureValues" not in common_attributes:
117+
record["MeasureValues"] = [
118+
_format_measure(measure_name, measure_value, measure_value_type) # type: ignore[arg-type]
119+
for measure_name, measure_value, measure_value_type in zip(
120+
measure_cols, row[measure_cols_loc:dimensions_cols_loc], measure_types
121+
)
122+
if not pd.isnull(measure_value)
123+
]
124+
if len(record["MeasureValues"]) == 0:
125+
continue
126+
if "MeasureValueType" not in common_attributes:
127+
record["MeasureValueType"] = measure_types[0] if scalar else "MULTI"
128+
# Dimensions can be specified in both common_attributes and the data frame
129+
dimensions = (
130+
[
131+
{"Name": name, "DimensionValueType": "VARCHAR", "Value": str(value)}
132+
for name, value in zip(dimensions_cols, row[dimensions_cols_loc:])
133+
]
134+
if all(dimensions_cols)
135+
else []
136+
)
137+
if dimensions:
138+
record["Dimensions"] = dimensions
139+
if record:
92140
records.append(record)
93-
if len(records) > 0:
141+
142+
try:
143+
if records:
94144
_utils.try_it(
95-
f=timestream_client.write_records,
145+
f=client_timestream.write_records,
96146
ex=(
97-
timestream_client.exceptions.ThrottlingException,
98-
timestream_client.exceptions.InternalServerException,
147+
client_timestream.exceptions.ThrottlingException,
148+
client_timestream.exceptions.InternalServerException,
99149
),
100150
max_num_tries=5,
101151
DatabaseName=database,
102152
TableName=table,
153+
CommonAttributes=common_attributes,
103154
Records=records,
104155
)
105-
except timestream_client.exceptions.RejectedRecordsException as ex:
156+
except client_timestream.exceptions.RejectedRecordsException as ex:
106157
return cast(List[Dict[str, str]], ex.response["RejectedRecords"])
107158
return []
108159

@@ -113,12 +164,12 @@ def _write_df(
113164
executor: _ThreadPoolExecutor,
114165
database: str,
115166
table: str,
116-
cols_names: List[str],
117-
measure_cols_names: List[str],
167+
common_attributes: Dict[str, Any],
168+
cols_names: List[Optional[str]],
169+
measure_cols: List[Optional[str]],
118170
measure_types: List[str],
119-
version: int,
171+
dimensions_cols: List[Optional[str]],
120172
boto3_session: Optional[boto3.Session],
121-
measure_name: Optional[str] = None,
122173
) -> List[Dict[str, str]]:
123174
timestream_client = _utils.client(
124175
service_name="timestream-write",
@@ -132,12 +183,12 @@ def _write_df(
132183
timestream_client,
133184
itertools.repeat(database),
134185
itertools.repeat(table),
186+
itertools.repeat(common_attributes),
135187
itertools.repeat(cols_names),
136-
itertools.repeat(measure_cols_names),
188+
itertools.repeat(measure_cols),
137189
itertools.repeat(measure_types),
138-
itertools.repeat(version),
190+
itertools.repeat(dimensions_cols),
139191
batches,
140-
itertools.repeat(measure_name),
141192
)
142193

143194

@@ -241,12 +292,13 @@ def write(
241292
df: pd.DataFrame,
242293
database: str,
243294
table: str,
244-
time_col: str,
245-
measure_col: Union[str, List[str]],
246-
dimensions_cols: List[str],
295+
time_col: Optional[str] = None,
296+
measure_col: Union[str, List[Optional[str]], None] = None,
297+
dimensions_cols: Optional[List[Optional[str]]] = None,
247298
version: int = 1,
248299
use_threads: Union[bool, int] = True,
249300
measure_name: Optional[str] = None,
301+
common_attributes: Optional[Dict[str, Any]] = None,
250302
boto3_session: Optional[boto3.Session] = None,
251303
) -> List[Dict[str, str]]:
252304
"""Store a Pandas DataFrame into a Amazon Timestream table.
@@ -259,6 +311,16 @@ def write(
259311
this function will not throw a Python exception.
260312
Instead it will return the rejection information.
261313
314+
Note
315+
----
316+
Values in `common_attributes` take precedence over all other arguments and data frame values.
317+
Dimension attributes are merged with attributes in record objects.
318+
Example: common_attributes = {"Dimensions": {"Name": "device_id", "Value": "12345"}, "MeasureValueType": "DOUBLE"}.
319+
320+
Note
321+
----
322+
If the `time_col` column is supplied it must be of type timestamp. `TimeUnit` is set to MILLISECONDS by default.
323+
262324
Parameters
263325
----------
264326
df : pandas.DataFrame
@@ -267,11 +329,11 @@ def write(
267329
Amazon Timestream database name.
268330
table : str
269331
Amazon Timestream table name.
270-
time_col : str
332+
time_col : Optional[str]
271333
DataFrame column name to be used as time. MUST be a timestamp column.
272-
measure_col : Union[str, List[str]]
334+
measure_col : Union[str, List[str], None]
273335
DataFrame column name(s) to be used as measure.
274-
dimensions_cols : List[str]
336+
dimensions_cols : Optional[List[str]]
275337
List of DataFrame column names to be used as dimensions.
276338
version : int
277339
Version number used for upserts.
@@ -283,6 +345,9 @@ def write(
283345
measure_name : Optional[str]
284346
Name that represents the data attribute of the time series.
285347
Overrides ``measure_col`` if specified.
348+
common_attributes : Optional[Dict[str, Any]]
349+
Dictionary of attributes that is shared across all records in the request.
350+
Using common attributes can optimize the cost of writes by reducing the size of request payloads.
286351
boto3_session : boto3.Session(), optional
287352
Boto3 Session. The default boto3 Session will be used if boto3_session receive None.
288353
@@ -329,13 +394,35 @@ def write(
329394
>>> ]
330395
331396
"""
332-
measure_cols_names = measure_col if isinstance(measure_col, list) else [measure_col]
333-
_logger.debug("measure_cols_names: %s", measure_cols_names)
334-
measure_types: List[str] = _data_types.timestream_type_from_pandas(df.loc[:, measure_cols_names])
335-
_logger.debug("measure_types: %s", measure_types)
336-
cols_names: List[str] = [time_col] + measure_cols_names + dimensions_cols
337-
_logger.debug("cols_names: %s", cols_names)
338-
dfs = _utils.split_pandas_frame(df.loc[:, cols_names], _utils.ensure_cpu_count(use_threads=use_threads))
397+
measure_cols = measure_col if isinstance(measure_col, list) else [measure_col]
398+
measure_types: List[str] = (
399+
_data_types.timestream_type_from_pandas(df.loc[:, measure_cols]) if all(measure_cols) else []
400+
)
401+
dimensions_cols = dimensions_cols if dimensions_cols else [dimensions_cols] # type: ignore[list-item]
402+
cols_names: List[Optional[str]] = [time_col] + measure_cols + dimensions_cols
403+
measure_name = measure_name if measure_name else measure_cols[0]
404+
common_attributes = _sanitize_common_attributes(common_attributes, version, measure_name)
405+
406+
_logger.debug(
407+
"common_attributes: %s\n, cols_names: %s\n, measure_types: %s",
408+
common_attributes,
409+
cols_names,
410+
measure_types,
411+
)
412+
413+
# User can supply arguments in one of two ways:
414+
# 1. With the `common_attributes` dictionary which takes precedence
415+
# 2. With data frame columns
416+
# However, the data frame cannot be completely empty.
417+
# So if all values in `cols_names` are None, an exception is raised.
418+
if any(cols_names):
419+
dfs = _utils.split_pandas_frame(
420+
df.loc[:, [c for c in cols_names if c]], _utils.ensure_cpu_count(use_threads=use_threads)
421+
)
422+
else:
423+
raise exceptions.InvalidArgumentCombination(
424+
"At least one of `time_col`, `measure_col` or `dimensions_cols` must be specified."
425+
)
339426
_logger.debug("len(dfs): %s", len(dfs))
340427

341428
executor = _get_executor(use_threads=use_threads)
@@ -348,12 +435,12 @@ def write(
348435
executor=executor,
349436
database=database,
350437
table=table,
438+
common_attributes=common_attributes,
351439
cols_names=cols_names,
352-
measure_cols_names=measure_cols_names,
440+
measure_cols=measure_cols,
353441
measure_types=measure_types,
354-
version=version,
442+
dimensions_cols=dimensions_cols,
355443
boto3_session=boto3_session,
356-
measure_name=measure_name,
357444
)
358445
for df in dfs
359446
]

0 commit comments

Comments
 (0)