Skip to content

Commit 4926a6e

Browse files
committed
test(IdP, JwtCredentialsProvider): JWT SSO IdP support
1 parent 2cb3ee3 commit 4926a6e

File tree

9 files changed

+342
-26
lines changed

9 files changed

+342
-26
lines changed

test/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,10 @@
1-
from test.integration import azure_browser_idp, idp_arg, okta_browser_idp, okta_idp
1+
from test.integration import (
2+
adfs_idp,
3+
azure_browser_idp,
4+
azure_idp,
5+
idp_arg,
6+
jwt_azure_v2_idp,
7+
jwt_google_idp,
8+
okta_browser_idp,
9+
okta_idp,
10+
)

test/integration/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,10 @@
1-
from .conftest import azure_browser_idp, idp_arg, okta_browser_idp, okta_idp
1+
from .conftest import (
2+
adfs_idp,
3+
azure_browser_idp,
4+
azure_idp,
5+
idp_arg,
6+
jwt_azure_v2_idp,
7+
jwt_google_idp,
8+
okta_browser_idp,
9+
okta_idp,
10+
)

test/integration/conftest.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,48 @@ def adfs_idp():
145145
return db_connect
146146

147147

148+
@pytest.fixture(scope="class")
149+
def jwt_google_idp():
150+
db_connect = {
151+
"database": conf.get("database", "database"),
152+
"host": conf.get("database", "host"),
153+
"port": conf.getint("database", "port"),
154+
"db_user": conf.get("database", "user"),
155+
"ssl": conf.getboolean("database", "ssl"),
156+
"sslmode": conf.get("database", "sslmode"),
157+
"iam": conf.getboolean("jwt-google-idp", "iam"),
158+
"user": conf.get("database", "user"),
159+
"password": conf.get("jwt-google-idp", "password"),
160+
"credentials_provider": conf.get("jwt-google-idp", "credentials_provider"),
161+
"region": conf.get("jwt-google-idp", "region"),
162+
"cluster_identifier": conf.get("jwt-google-idp", "cluster_identifier"),
163+
"web_identity_token": conf.get("jwt-google-idp", "web_identity_token"),
164+
"preferred_role": conf.get("jwt-google-idp", "preferred_role"),
165+
}
166+
return db_connect
167+
168+
169+
@pytest.fixture(scope="class")
170+
def jwt_azure_v2_idp():
171+
db_connect = {
172+
"database": conf.get("database", "database"),
173+
"host": conf.get("database", "host"),
174+
"port": conf.getint("database", "port"),
175+
"db_user": conf.get("database", "user"),
176+
"ssl": conf.getboolean("database", "ssl"),
177+
"sslmode": conf.get("database", "sslmode"),
178+
"iam": conf.getboolean("jwt-azure-v2-idp", "iam"),
179+
"user": conf.get("database", "user"),
180+
"password": conf.get("jwt-azure-v2-idp", "password"),
181+
"credentials_provider": conf.get("jwt-azure-v2-idp", "credentials_provider"),
182+
"region": conf.get("jwt-azure-v2-idp", "region"),
183+
"cluster_identifier": conf.get("jwt-azure-v2-idp", "cluster_identifier"),
184+
"web_identity_token": conf.get("jwt-azure-v2-idp", "web_identity_token"),
185+
"preferred_role": conf.get("jwt-azure-v2-idp", "preferred_role"),
186+
}
187+
return db_connect
188+
189+
148190
@pytest.fixture
149191
def con(request, db_kwargs):
150192
conn = redshift_connector.connect(**db_kwargs)

test/integration/plugin/test_credentials_providers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
conf.read(root_path + "/config.ini")
1515

1616

17-
NON_BROWSER_IDP: typing.List[str] = ["okta_idp"]
17+
NON_BROWSER_IDP: typing.List[str] = ["okta_idp", "azure_idp"]
1818
ALL_IDP: typing.List[str] = ["okta_browser_idp", "azure_browser_idp"] + NON_BROWSER_IDP
1919

2020

test/unit/helpers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .idp_helpers import make_redshift_property

test/unit/helpers/idp_helpers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from redshift_connector import RedshiftProperty
2+
from redshift_connector.config import ClientProtocolVersion
3+
4+
5+
def make_redshift_property() -> RedshiftProperty:
6+
rp: RedshiftProperty = RedshiftProperty()
7+
rp.user_name = "mario@luigi.com"
8+
rp.password = "bowser"
9+
rp.db_name = "dev"
10+
rp.cluster_identifier = "something"
11+
rp.idp_host = "8000"
12+
rp.duration = 100
13+
rp.preferred_role = "analyst"
14+
rp.sslInsecure = False
15+
rp.db_user = "primary"
16+
rp.db_groups = ["employees"]
17+
rp.force_lowercase = True
18+
rp.auto_create = False
19+
rp.region = "us-west-1"
20+
rp.principal = "arn:aws:iam::123456789012:user/Development/product_1234/*"
21+
rp.client_protocol_version = ClientProtocolVersion.BASE_SERVER
22+
return rp

test/unit/plugin/test_credentials_providers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
import configparser
22
import os
33
import typing
4-
from test import azure_browser_idp, idp_arg, okta_browser_idp, okta_idp
4+
from test import (
5+
adfs_idp,
6+
azure_browser_idp,
7+
azure_idp,
8+
idp_arg,
9+
jwt_azure_v2_idp,
10+
jwt_google_idp,
11+
okta_browser_idp,
12+
okta_idp,
13+
)
514

615
import pytest # type: ignore
716

@@ -12,9 +21,14 @@
1221
conf.read(root_path + "/config.ini")
1322

1423

15-
NON_BROWSER_IDP: typing.List[str] = ["okta_idp"]
24+
NON_BROWSER_IDP: typing.List[str] = ["okta_idp", "azure_idp", "adfs_idp"]
1625

17-
ALL_IDP: typing.List[str] = ["okta_browser_idp", "azure_browser_idp"] + NON_BROWSER_IDP
26+
ALL_IDP: typing.List[str] = [
27+
"okta_browser_idp",
28+
"azure_browser_idp",
29+
"jwt_google_idp",
30+
"jwt_azure_v2_idp",
31+
] + NON_BROWSER_IDP
1832

1933

2034
@pytest.mark.parametrize("idp_arg", ALL_IDP, indirect=True)
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import base64
2+
import typing
3+
from test.unit.helpers import make_redshift_property
4+
from unittest.mock import MagicMock, patch
5+
6+
import pytest # type: ignore
7+
8+
from redshift_connector import RedshiftProperty
9+
from redshift_connector.credentials_holder import CredentialsHolder
10+
from redshift_connector.error import InterfaceError
11+
from redshift_connector.plugin import (
12+
BasicJwtCredentialsProvider,
13+
JwtCredentialsProvider,
14+
)
15+
16+
17+
@patch.multiple(JwtCredentialsProvider, __abstractmethods__=set())
18+
def make_jwtcredentialsprovider() -> JwtCredentialsProvider:
19+
return JwtCredentialsProvider() # type: ignore
20+
21+
22+
def test_make_jwtcredentialsprovider():
23+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
24+
assert hasattr(jwtcp, "role_arn")
25+
assert jwtcp.role_arn is None
26+
assert hasattr(jwtcp, "duration")
27+
assert jwtcp.duration is None
28+
assert hasattr(jwtcp, "db_groups_filter")
29+
assert jwtcp.db_groups_filter is None
30+
31+
32+
def test_jwtcredentialsprovider_add_parameter():
33+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
34+
rp: RedshiftProperty = make_redshift_property()
35+
36+
_wit: str = "hooplah"
37+
_duration: int = 1234
38+
_role: str = "my_role"
39+
_session: str = "my_session"
40+
41+
rp.role_arn = _role
42+
rp.role_session_name = _session
43+
rp.duration = _duration
44+
rp.web_identity_token = _wit
45+
46+
jwtcp.add_parameter(rp)
47+
assert jwtcp.jwt == _wit
48+
assert jwtcp.duration == _duration
49+
assert jwtcp.role_arn == _role
50+
assert jwtcp.role_session_name == _session
51+
52+
53+
cache_key_vals: typing.List[typing.Tuple] = [("a", "b", "c", "d"), ()]
54+
55+
56+
@pytest.mark.parametrize("_input", cache_key_vals)
57+
def test_get_cache_key_no_attributes(_input):
58+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
59+
if len(_input) == 4:
60+
jwtcp.role_arn = _input[0]
61+
jwtcp.jwt = _input[1]
62+
jwtcp.role_session_name = _input[2]
63+
jwtcp.duration = _input[3]
64+
assert jwtcp.get_cache_key() == "".join(_input)
65+
else:
66+
assert jwtcp.get_cache_key() == "NoneNone{}None".format(JwtCredentialsProvider.DEFAULT_ROLE_SESSION_NAME)
67+
68+
69+
@pytest.mark.parametrize("param", [JwtCredentialsProvider.KEY_ROLE_ARN, JwtCredentialsProvider.KEY_WEB_IDENTITY_TOKEN])
70+
@pytest.mark.parametrize("invalid_val", [None, ""])
71+
def test_check_required_parameters_missing_raises_exception(param, invalid_val):
72+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
73+
valid_val: str = "hello world!"
74+
75+
if param == JwtCredentialsProvider.KEY_ROLE_ARN:
76+
jwtcp.role_arn = invalid_val
77+
jwtcp.jwt = valid_val
78+
elif param == JwtCredentialsProvider.KEY_WEB_IDENTITY_TOKEN:
79+
jwtcp.role_arn = valid_val
80+
jwtcp.jwt = invalid_val
81+
else:
82+
raise pytest.PytestWarning("Invalid arg supplied for param: {}".format(param))
83+
84+
with pytest.raises(InterfaceError, match="Missing required property: {}".format(param)):
85+
jwtcp.check_required_parameters()
86+
87+
88+
def make_jwt(
89+
v1: str, v2: str, v3: typing.Optional[str]
90+
) -> typing.Tuple[typing.Optional[typing.List[typing.Union[bytes, str]]], str]:
91+
input_val: str = ""
92+
93+
for _input in (v1, v2, v3):
94+
if _input is None:
95+
continue
96+
encoded_input: str = str(base64.b64encode(_input.encode("ascii")), "ascii")
97+
input_val += "{}\\.".format(encoded_input)
98+
99+
exp_result: typing.Optional[typing.List[typing.Union[str, bytes]]] = None
100+
101+
if all((v1, v2, v3)):
102+
exp_result = [
103+
v1.encode("ascii"),
104+
v2.encode("ascii"),
105+
str(base64.b64encode(v3.encode("ascii")), "ascii"), # type: ignore
106+
]
107+
108+
return exp_result, input_val[:-2]
109+
110+
111+
@pytest.mark.parametrize(
112+
"_input",
113+
[None, make_jwt("abc", "def", "ghi"), make_jwt("hithere", "hello", "goodbye"), make_jwt("hello", "world", None)],
114+
)
115+
def test_decode_jwt(_input):
116+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
117+
if _input is None:
118+
assert jwtcp.decode_jwt(_input) is None
119+
else:
120+
exp_result, jwt = _input
121+
assert jwtcp.decode_jwt(jwt) == exp_result
122+
123+
124+
@pytest.mark.parametrize("_input", ["get_saml_assertion", "do_verify_ssl_cert", "get_form_action", "read_metadata"])
125+
def test_get_saml_assertion_not_implemented(_input):
126+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
127+
method_to_call = jwtcp.__getattribute__(_input)
128+
129+
with pytest.raises(NotImplementedError):
130+
try:
131+
method_to_call()
132+
except TypeError:
133+
method_to_call("trash")
134+
135+
136+
def test_get_credentials_handles_exception(mocker):
137+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
138+
mock_error_msg: str = "bad robot"
139+
with patch("redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.refresh") as buggy_refresh:
140+
buggy_refresh.side_effect = Exception(mock_error_msg)
141+
142+
with pytest.raises(InterfaceError, match=mock_error_msg):
143+
jwtcp.get_credentials()
144+
145+
146+
def test_get_credentials_returns_credentials(mocker):
147+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
148+
mock_cache_key = MagicMock()
149+
mock_cred_provider = MagicMock()
150+
151+
def mock_set_cache(key, val):
152+
jwtcp.cache[key] = val
153+
154+
mocker.patch(
155+
"redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.get_cache_key",
156+
return_value=mock_cache_key,
157+
)
158+
with patch("redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.refresh") as mock_refresh:
159+
mock_refresh.side_effect = mock_set_cache(mock_cache_key, mock_cred_provider)
160+
assert jwtcp.get_credentials() == mock_cred_provider
161+
162+
163+
def test_get_credentials_none_found_raises_exception(mocker):
164+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
165+
166+
mocker.patch(
167+
"redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.get_cache_key",
168+
return_value=MagicMock(),
169+
)
170+
mocker.patch("redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.refresh")
171+
with pytest.raises(InterfaceError, match="Unable to load AWS credentials from IDP"):
172+
jwtcp.get_credentials()
173+
174+
175+
def test_refresh_no_jwt():
176+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
177+
178+
with pytest.raises(InterfaceError, match="no jwt provided"):
179+
jwtcp.refresh()
180+
181+
182+
def test_refresh_passes_jwt_to_boto3(mocker):
183+
mocked_botocore_client = MagicMock()
184+
185+
mocked_processed_jwt: str = "processed value"
186+
mocker.patch("boto3.client", return_value=mocked_botocore_client)
187+
mocker.patch(
188+
"redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.process_jwt",
189+
return_value=mocked_processed_jwt,
190+
)
191+
mocker.patch(
192+
"redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.decode_jwt", return_value=None
193+
)
194+
195+
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
196+
mocked_orig_jwt: str = "initial value"
197+
mocked_role_arn: str = "my_role"
198+
jwtcp.jwt = mocked_orig_jwt
199+
jwtcp.role_arn = mocked_role_arn
200+
process_jwt_spy = mocker.spy(jwtcp, "process_jwt")
201+
decode_jwt_spy = mocker.spy(jwtcp, "decode_jwt")
202+
boto_spy = mocker.spy(mocked_botocore_client, "assume_role_with_web_identity")
203+
204+
jwtcp.refresh()
205+
assert process_jwt_spy.called is True
206+
assert process_jwt_spy.call_count == 1
207+
assert process_jwt_spy.call_args[0][0] == mocked_orig_jwt
208+
209+
assert decode_jwt_spy.called is True
210+
assert decode_jwt_spy.call_count == 1
211+
assert decode_jwt_spy.call_args[0][0] == mocked_orig_jwt
212+
213+
assert boto_spy.called is True
214+
assert boto_spy.call_count == 1
215+
assert boto_spy.call_args[1]["RoleArn"] == mocked_role_arn
216+
assert boto_spy.call_args[1]["RoleSessionName"] == JwtCredentialsProvider.DEFAULT_ROLE_SESSION_NAME
217+
assert boto_spy.call_args[1]["WebIdentityToken"] == mocked_processed_jwt
218+
219+
assert len(jwtcp.cache) == 1
220+
assert jwtcp.get_cache_key() in jwtcp.cache
221+
assert isinstance(jwtcp.cache[jwtcp.get_cache_key()], CredentialsHolder)
222+
223+
224+
def test_basic_jwt_credential_provider(mocker):
225+
bjwtcp: BasicJwtCredentialsProvider = BasicJwtCredentialsProvider()
226+
bjwtcp.jwt = "hi"
227+
bjwtcp.role_arn = "buttered bread role"
228+
229+
checker_spy = mocker.spy(bjwtcp, "check_required_parameters")
230+
assert bjwtcp.process_jwt(bjwtcp.jwt) == bjwtcp.jwt
231+
assert checker_spy.called is True
232+
assert checker_spy.call_count == 1

0 commit comments

Comments
 (0)