Skip to content

Commit 627ea84

Browse files
committed
feat(connection): support redshift serverless
1 parent 93141e6 commit 627ea84

File tree

7 files changed

+503
-118
lines changed

7 files changed

+503
-118
lines changed

redshift_connector/iam_helper.py

Lines changed: 94 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
5151
Helper function to handle IAM connection properties and ensure required parameters are specified.
5252
Parameters
5353
"""
54+
import pkg_resources
55+
from packaging.version import Version
56+
5457
if info is None:
5558
raise InterfaceError("Invalid connection property setting. info must be specified")
5659

5760
# Check for IAM keys and AuthProfile first
5861
if info.auth_profile is not None:
59-
import pkg_resources
60-
from packaging.version import Version
61-
6262
if Version(pkg_resources.get_distribution("boto3").version) < Version("1.17.111"):
6363
raise pkg_resources.VersionConflict(
6464
"boto3 >= 1.17.111 required for authentication via Amazon Redshift authentication profile. "
@@ -83,6 +83,18 @@ def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
8383
)
8484
info.put_all(resp)
8585

86+
if info.is_serverless_host and info.iam:
87+
raise ProgrammingError("This feature is not yet available")
88+
# if Version(pkg_resources.get_distribution("boto3").version) <= Version("1.20.22"):
89+
# raise pkg_resources.VersionConflict(
90+
# "boto3 >= XXX required for authentication with Amazon Redshift serverless. "
91+
# "Please upgrade the installed version of boto3 to use this functionality."
92+
# )
93+
94+
if info.is_serverless_host:
95+
info.set_account_id_from_host()
96+
info.set_region_from_host()
97+
8698
# IAM requires an SSL connection to work.
8799
# Make sure that is set to SSL level VERIFY_CA or higher.
88100
if info.ssl is True:
@@ -109,8 +121,10 @@ def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
109121
"AWS credentials, Amazon Redshift authentication profile, or AWS profile"
110122
)
111123
elif info.iam is True:
124+
_logger.debug("boto3 version: {}".format(Version(pkg_resources.get_distribution("boto3").version)))
125+
_logger.debug("botocore version: {}".format(Version(pkg_resources.get_distribution("botocore").version)))
112126

113-
if info.cluster_identifier is None:
127+
if info.cluster_identifier is None and not info.is_serverless_host:
114128
raise InterfaceError(
115129
"Invalid connection property setting. cluster_identifier must be provided when IAM is enabled"
116130
)
@@ -174,7 +188,7 @@ def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
174188
info.put("db_groups", [group.lower() for group in info.db_groups])
175189

176190
if info.iam is True:
177-
if info.cluster_identifier is None:
191+
if info.cluster_identifier is None and not info.is_serverless_host:
178192
raise InterfaceError(
179193
"Invalid connection property setting. cluster_identifier must be provided when IAM is enabled"
180194
)
@@ -305,20 +319,42 @@ def set_iam_credentials(info: RedshiftProperty) -> None:
305319
IamHelper.set_cluster_credentials(provider, info)
306320

307321
@staticmethod
308-
def get_credentials_cache_key(info: RedshiftProperty):
322+
def get_credentials_cache_key(
323+
info: RedshiftProperty, cred_provider: typing.Union[SamlCredentialsProvider, AWSCredentialsProvider]
324+
):
309325
db_groups: str = ""
310326

311327
if len(info.db_groups) > 0:
312328
info.put("db_groups", sorted(info.db_groups))
313329
db_groups = ",".join(info.db_groups)
314330

331+
cred_key: str = ""
332+
333+
if cred_provider:
334+
cred_key = str(cred_provider.get_cache_key())
335+
315336
return ";".join(
316-
(
317-
typing.cast(str, info.db_user),
318-
info.db_name,
319-
db_groups,
320-
typing.cast(str, info.cluster_identifier),
321-
str(info.auto_create),
337+
filter(
338+
None,
339+
(
340+
cred_key,
341+
typing.cast(str, info.db_user if info.db_user else info.user_name),
342+
info.db_name,
343+
db_groups,
344+
typing.cast(str, info.account_id if info.is_serverless_host else info.cluster_identifier),
345+
str(info.auto_create),
346+
str(info.duration),
347+
# v2 api parameters
348+
info.preferred_role,
349+
info.web_identity_token,
350+
info.role_arn,
351+
info.role_session_name,
352+
# providers
353+
info.profile,
354+
info.access_key_id,
355+
info.secret_access_key,
356+
info.session_token,
357+
),
322358
)
323359
)
324360

@@ -339,6 +375,9 @@ def set_cluster_credentials(
339375
] = cred_provider.get_credentials()
340376
session_credentials: typing.Dict[str, str] = credentials_holder.get_session_credentials()
341377

378+
redshift_client: str = "redshift-serverless" if info.is_serverless_host else "redshift"
379+
_logger.debug("boto3.client(service_name={}) being used for IAM auth".format(redshift_client))
380+
342381
for opt_key, opt_val in (("region_name", info.region), ("endpoint_url", info.endpoint_url)):
343382
if opt_val is not None:
344383
session_credentials[opt_key] = opt_val
@@ -348,21 +387,28 @@ def set_cluster_credentials(
348387
cached_session: boto3.Session = typing.cast(
349388
ABCAWSCredentialsHolder, credentials_holder
350389
).get_boto_session()
351-
client = cached_session.client(service_name="redshift", region_name=info.region)
390+
client = cached_session.client(service_name=redshift_client, region_name=info.region)
352391
else:
353-
client = boto3.client(service_name="redshift", **session_credentials)
392+
client = boto3.client(service_name=redshift_client, **session_credentials)
354393

355394
if info.host is None or info.host == "" or info.port is None or info.port == "":
356-
response = client.describe_clusters(ClusterIdentifier=info.cluster_identifier)
395+
response: dict
357396

358-
info.put("host", response["Clusters"][0]["Endpoint"]["Address"])
359-
info.put("port", response["Clusters"][0]["Endpoint"]["Port"])
397+
if info.is_serverless_host:
398+
response = client.describe_configuration()
399+
info.put("host", response["endpoint"]["address"])
400+
info.put("port", response["endpoint"]["port"])
401+
else:
402+
response = client.describe_clusters(ClusterIdentifier=info.cluster_identifier)
403+
info.put("host", response["Clusters"][0]["Endpoint"]["Address"])
404+
info.put("port", response["Clusters"][0]["Endpoint"]["Port"])
360405

361406
cred: typing.Optional[typing.Dict[str, typing.Union[str, datetime.datetime]]] = None
362407

363408
if info.iam_disable_cache is False:
409+
_logger.debug("iam_disable_cache=False")
364410
# temporary credentials are cached by redshift_connector and will be used if they have not expired
365-
cache_key: str = IamHelper.get_credentials_cache_key(info)
411+
cache_key: str = IamHelper.get_credentials_cache_key(info, cred_provider)
366412
cred = IamHelper.credentials_cache.get(cache_key, None)
367413

368414
_logger.debug(
@@ -375,26 +421,42 @@ def set_cluster_credentials(
375421
if cred is None or typing.cast(datetime.datetime, cred["Expiration"]) < datetime.datetime.now(tz=tzutc()):
376422
# retries will occur by default ref:
377423
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#legacy-retry-mode
378-
cred = typing.cast(
379-
typing.Dict[str, typing.Union[str, datetime.datetime]],
380-
client.get_cluster_credentials(
381-
DbUser=info.db_user,
382-
DbName=info.db_name,
383-
DbGroups=info.db_groups,
384-
ClusterIdentifier=info.cluster_identifier,
385-
AutoCreate=info.auto_create,
386-
),
387-
)
424+
_logger.debug("Credentials expired or not found...requesting from boto")
425+
if info.is_serverless_host:
426+
cred = typing.cast(
427+
typing.Dict[str, typing.Union[str, datetime.datetime]],
428+
client.get_credentials(
429+
dbName=info.db_name,
430+
),
431+
)
432+
# re-map expiration for compatibility with redshift credential response
433+
cred["Expiration"] = cred["expiration"]
434+
del cred["expiration"]
435+
else:
436+
cred = typing.cast(
437+
typing.Dict[str, typing.Union[str, datetime.datetime]],
438+
client.get_cluster_credentials(
439+
DbUser=info.db_user,
440+
DbName=info.db_name,
441+
DbGroups=info.db_groups,
442+
ClusterIdentifier=info.cluster_identifier,
443+
AutoCreate=info.auto_create,
444+
),
445+
)
388446

389447
if info.iam_disable_cache is False:
390448
IamHelper.credentials_cache[cache_key] = typing.cast(
391449
typing.Dict[str, typing.Union[str, datetime.datetime]], cred
392450
)
451+
# redshift-serverless api json response payload slightly differs
452+
if info.is_serverless_host:
453+
info.put("user_name", typing.cast(str, cred["dbUser"]))
454+
info.put("password", typing.cast(str, cred["dbPassword"]))
455+
else:
456+
info.put("user_name", typing.cast(str, cred["DbUser"]))
457+
info.put("password", typing.cast(str, cred["DbPassword"]))
393458

394-
info.put("user_name", typing.cast(str, cred["DbUser"]))
395-
info.put("password", typing.cast(str, cred["DbPassword"]))
396-
397-
_logger.debug("Using temporary aws credentials with expiration: {}".format(cred["Expiration"]))
459+
_logger.debug("Using temporary aws credentials with expiration: {}".format(cred.get("Expiration")))
398460

399461
except botocore.exceptions.ClientError as e:
400462
_logger.error("ClientError: %s", e)

redshift_connector/redshift_property.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import typing
22

33
from redshift_connector.config import DEFAULT_PROTOCOL_VERSION
4+
from redshift_connector.error import ProgrammingError
5+
6+
SERVERLESS_HOST_PATTERN: str = r"(.+)\.(.+).redshift-serverless(-dev)?\.amazonaws\.com(.)*"
47

58

69
class RedshiftProperty:
@@ -105,13 +108,17 @@ def __init__(self: "RedshiftProperty", **kwargs):
105108
# The user name.
106109
self.user_name: str = ""
107110
self.web_identity_token: typing.Optional[str] = None
111+
# The AWS Account Id
112+
self.account_id: typing.Optional[str] = None
108113

109114
else:
110115
for k, v in kwargs.items():
111116
setattr(self, k, v)
112117

113118
def __str__(self: "RedshiftProperty") -> str:
114-
return str(self.__dict__)
119+
rp = self.__dict__
120+
rp["is_serverless_host"] = self.is_serverless_host
121+
return str(rp)
115122

116123
def put_all(self, other):
117124
"""
@@ -128,3 +135,37 @@ def put(self: "RedshiftProperty", key: str, value: typing.Any):
128135
"""
129136
if value is not None:
130137
setattr(self, key, value)
138+
139+
@property
140+
def is_serverless_host(self: "RedshiftProperty") -> bool:
141+
"""
142+
If the host indicate Redshift serverless will be used for connection.
143+
"""
144+
if not self.host:
145+
return False
146+
147+
import re
148+
149+
return bool(re.fullmatch(pattern=SERVERLESS_HOST_PATTERN, string=str(self.host)))
150+
151+
def set_account_id_from_host(self: "RedshiftProperty") -> None:
152+
"""
153+
Returns the AWS account id as parsed from the Redshift serverless endpoint.
154+
"""
155+
import re
156+
157+
m2 = re.fullmatch(pattern=SERVERLESS_HOST_PATTERN, string=self.host)
158+
159+
if m2:
160+
self.put(key="account_id", value=m2.group(1))
161+
162+
def set_region_from_host(self: "RedshiftProperty") -> None:
163+
"""
164+
Returns the AWS region as parsed from the Redshift serverless endpoint.
165+
"""
166+
import re
167+
168+
m2 = re.fullmatch(pattern=SERVERLESS_HOST_PATTERN, string=self.host)
169+
170+
if m2:
171+
self.put(key="region", value=m2.group(2))

test/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,37 @@ def perf_db_kwargs() -> typing.Dict[str, typing.Union[str, bool]]:
7676
return db_connect
7777

7878

79+
@pytest.fixture(scope="class")
80+
def serverless_native_db_kwargs() -> typing.Dict[str, str]:
81+
db_connect = {
82+
"database": conf.get("redshift-serverless", "database", fallback="mock_database"),
83+
"host": conf.get(
84+
"redshift-serverless", "host", fallback="012345678901.us-east-2.redshift-serverless.amazonaws.com"
85+
),
86+
"user": conf.get("redshift-serverless", "user", fallback="mock_user"),
87+
"password": conf.get("redshift-serverless", "password", fallback="mock_password"),
88+
}
89+
90+
return db_connect
91+
92+
93+
@pytest.fixture(scope="class")
94+
def serverless_iam_db_kwargs() -> typing.Dict[str, typing.Union[str, bool]]:
95+
db_connect = {
96+
"database": conf.get("redshift-serverless", "database", fallback="mock_database"),
97+
"iam": conf.getboolean("redshift-serverless", "iam", fallback=True),
98+
"access_key_id": conf.get("redshift-serverless", "access_key_id", fallback="mock_access_key_id"),
99+
"secret_access_key": conf.get("redshift-serverless", "secret_access_key", fallback="mock_secret_access_key"),
100+
"session_token": conf.get("redshift-serverless", "session_token", fallback="mock_session_token"),
101+
"region": conf.get("redshift-serverless", "region", fallback="mock_region"),
102+
"host": conf.get(
103+
"redshift-serverless", "host", fallback="012345678901.us-east-2.redshift-serverless.amazonaws.com"
104+
),
105+
}
106+
107+
return db_connect # type: ignore
108+
109+
79110
@pytest.fixture(scope="class")
80111
def okta_idp() -> typing.Dict[str, typing.Union[str, bool, int]]:
81112
db_connect = {
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
3+
import redshift_connector
4+
5+
"""
6+
These functional tests ensure connections to Redshift serverless can be established when
7+
using various authentication methods.
8+
9+
Please note the pre-requisites were documented while this feature is under public preview,
10+
and are subject to change.
11+
12+
Pre-requisites:
13+
1) Redshift serverless configuration
14+
2) EC2 instance configured for accessing Redshift serverless (i.e. in compatible VPC, subnet)
15+
3) Perform a sanity check using psql to ensure Redshift serverless connection can be established
16+
3) EC2 instance has Python installed
17+
4) Clone redshift_connector on EC2 instance and install
18+
19+
How to use:
20+
1) Populate config.ini with the Redshift serverless endpoint and user authentication information
21+
2) Run this file with pytest
22+
"""
23+
24+
25+
def test_native_auth(serverless_native_db_kwargs):
26+
with redshift_connector.connect(**serverless_native_db_kwargs):
27+
pass
28+
29+
30+
def test_iam_auth(serverless_iam_db_kwargs):
31+
with redshift_connector.connect(**serverless_iam_db_kwargs):
32+
pass
33+
34+
35+
def test_idp_auth(okta_idp):
36+
okta_idp["host"] = "my_redshift_serverless_endpoint"
37+
38+
with redshift_connector.connect(**okta_idp):
39+
pass

0 commit comments

Comments
 (0)