|
1 | 1 | import typing as t |
2 | | -from datetime import datetime, timedelta |
| 2 | +from datetime import datetime |
3 | 3 |
|
4 | | -import requests |
5 | 4 |
|
6 | | - |
7 | | -class OAuth2ClientCredentialsAuth(requests.auth.AuthBase): |
8 | | - """ |
9 | | - This implements the OAuth2 ClientCredentials Grant authentication flow. |
10 | | - https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 |
| 5 | +class AuthProviderBase: |
11 | 6 | """ |
| 7 | + Base class for auth providers. |
12 | 8 |
|
13 | | - def __init__( |
14 | | - self, |
15 | | - client_id: str, |
16 | | - client_secret: str, |
17 | | - token_url: str, |
18 | | - scopes: list[str] | None = None, |
19 | | - verify_ssl: str | bool | None = None, |
20 | | - ): |
21 | | - self._token_server_auth = requests.auth.HTTPBasicAuth(client_id, client_secret) |
22 | | - self._token_url = token_url |
23 | | - self._scopes = scopes |
24 | | - self._verify_ssl = verify_ssl |
| 9 | + This abstract base class will analyze the authentication proposals of the openapi specs. |
| 10 | + Different authentication schemes can be implemented in subclasses. |
| 11 | + """ |
25 | 12 |
|
26 | | - self._access_token: str | None = None |
27 | | - self._expire_at: datetime | None = None |
| 13 | + def __init__(self) -> None: |
| 14 | + self._oauth2_token: str | None = None |
| 15 | + self._oauth2_expires: datetime = datetime.now() |
| 16 | + |
| 17 | + def can_complete_http_basic(self) -> bool: |
| 18 | + return False |
| 19 | + |
| 20 | + def can_complete_mutualTLS(self) -> bool: |
| 21 | + return False |
| 22 | + |
| 23 | + def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool: |
| 24 | + return False |
| 25 | + |
| 26 | + def can_complete_scheme(self, scheme: dict[str, t.Any], scopes: list[str]) -> bool: |
| 27 | + if scheme["type"] == "http": |
| 28 | + if scheme["scheme"] == "basic": |
| 29 | + return self.can_complete_http_basic() |
| 30 | + elif scheme["type"] == "mutualTLS": |
| 31 | + return self.can_complete_mutualTLS() |
| 32 | + elif scheme["type"] == "oauth2": |
| 33 | + for flow_name, flow in scheme["flows"].items(): |
| 34 | + if ( |
| 35 | + flow_name == "clientCredentials" |
| 36 | + and self.can_complete_oauth2_client_credentials(flow["scopes"]) |
| 37 | + ): |
| 38 | + return True |
| 39 | + return False |
| 40 | + |
| 41 | + def can_complete( |
| 42 | + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] |
| 43 | + ) -> bool: |
| 44 | + for name, scopes in proposal.items(): |
| 45 | + scheme = security_schemes.get(name) |
| 46 | + if scheme is None or not self.can_complete_scheme(scheme, scopes): |
| 47 | + return False |
| 48 | + # This covers the case where `[]` allows for no auth at all. |
| 49 | + return True |
| 50 | + |
| 51 | + async def auth_success_hook( |
| 52 | + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] |
| 53 | + ) -> None: |
| 54 | + pass |
| 55 | + |
| 56 | + async def auth_failure_hook( |
| 57 | + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] |
| 58 | + ) -> None: |
| 59 | + pass |
| 60 | + |
| 61 | + async def http_basic_credentials(self) -> tuple[bytes, bytes]: |
| 62 | + raise NotImplementedError() |
| 63 | + |
| 64 | + async def oauth2_client_credentials(self) -> tuple[bytes, bytes]: |
| 65 | + raise NotImplementedError() |
| 66 | + |
| 67 | + def tls_credentials(self) -> tuple[str, str | None]: |
| 68 | + raise NotImplementedError() |
| 69 | + |
| 70 | + |
| 71 | +class BasicAuthProvider(AuthProviderBase): |
| 72 | + """ |
| 73 | + AuthProvider providing basic auth with fixed `username`, `password`. |
| 74 | + """ |
28 | 75 |
|
29 | | - def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: |
30 | | - if self._expire_at is None or self._expire_at < datetime.now(): |
31 | | - self._retrieve_token() |
| 76 | + def __init__(self, username: t.AnyStr, password: t.AnyStr): |
| 77 | + super().__init__() |
| 78 | + self.username: bytes = username.encode("latin1") if isinstance(username, str) else username |
| 79 | + self.password: bytes = password.encode("latin1") if isinstance(password, str) else password |
32 | 80 |
|
33 | | - assert self._access_token is not None |
| 81 | + def can_complete_http_basic(self) -> bool: |
| 82 | + return True |
34 | 83 |
|
35 | | - request.headers["Authorization"] = f"Bearer {self._access_token}" |
| 84 | + async def http_basic_credentials(self) -> tuple[bytes, bytes]: |
| 85 | + return self.username, self.password |
36 | 86 |
|
37 | | - # Call to untyped function "register_hook" in typed context |
38 | | - request.register_hook("response", self._handle401) # type: ignore[no-untyped-call] |
39 | 87 |
|
40 | | - return request |
| 88 | +class GlueAuthProvider(AuthProviderBase): |
| 89 | + """ |
| 90 | + AuthProvider allowing to be used with prepared credentials. |
| 91 | + """ |
41 | 92 |
|
42 | | - def _handle401( |
| 93 | + def __init__( |
43 | 94 | self, |
44 | | - response: requests.Response, |
45 | | - **kwargs: t.Any, |
46 | | - ) -> requests.Response: |
47 | | - if response.status_code != 401: |
48 | | - return response |
49 | | - |
50 | | - # If we get this far, probably the token is not valid anymore. |
51 | | - |
52 | | - # Try to reach for a new token once. |
53 | | - self._retrieve_token() |
54 | | - |
55 | | - assert self._access_token is not None |
56 | | - |
57 | | - # Consume content and release the original connection |
58 | | - # to allow our new request to reuse the same one. |
59 | | - response.content |
60 | | - response.close() |
61 | | - prepared_new_request = response.request.copy() |
62 | | - |
63 | | - prepared_new_request.headers["Authorization"] = f"Bearer {self._access_token}" |
64 | | - |
65 | | - # Avoid to enter into an infinity loop. |
66 | | - # Call to untyped function "deregister_hook" in typed context |
67 | | - prepared_new_request.deregister_hook( # type: ignore[no-untyped-call] |
68 | | - "response", self._handle401 |
69 | | - ) |
70 | | - |
71 | | - # "Response" has no attribute "connection" |
72 | | - new_response: requests.Response = response.connection.send(prepared_new_request, **kwargs) |
73 | | - new_response.history.append(response) |
74 | | - new_response.request = prepared_new_request |
75 | | - |
76 | | - return new_response |
77 | | - |
78 | | - def _retrieve_token(self) -> None: |
79 | | - data = { |
80 | | - "grant_type": "client_credentials", |
81 | | - } |
82 | | - |
83 | | - if self._scopes: |
84 | | - data["scope"] = " ".join(self._scopes) |
85 | | - |
86 | | - response: requests.Response = requests.post( |
87 | | - self._token_url, |
88 | | - data=data, |
89 | | - auth=self._token_server_auth, |
90 | | - verify=self._verify_ssl, |
91 | | - ) |
92 | | - |
93 | | - response.raise_for_status() |
94 | | - |
95 | | - token = response.json() |
96 | | - self._expire_at = datetime.now() + timedelta(seconds=token["expires_in"]) |
97 | | - self._access_token = token["access_token"] |
| 95 | + *, |
| 96 | + username: t.AnyStr | None = None, |
| 97 | + password: t.AnyStr | None = None, |
| 98 | + client_id: t.AnyStr | None = None, |
| 99 | + client_secret: t.AnyStr | None = None, |
| 100 | + cert: str | None = None, |
| 101 | + key: str | None = None, |
| 102 | + ): |
| 103 | + super().__init__() |
| 104 | + self.username: bytes | None = None |
| 105 | + self.password: bytes | None = None |
| 106 | + self.client_id: bytes | None = None |
| 107 | + self.client_secret: bytes | None = None |
| 108 | + self.cert: str | None = cert |
| 109 | + self.key: str | None = key |
| 110 | + |
| 111 | + if username is not None: |
| 112 | + assert password is not None |
| 113 | + self.username = username.encode("latin1") if isinstance(username, str) else username |
| 114 | + self.password = password.encode("latin1") if isinstance(password, str) else password |
| 115 | + if client_id is not None: |
| 116 | + assert client_secret is not None |
| 117 | + self.client_id = client_id.encode("latin1") if isinstance(client_id, str) else client_id |
| 118 | + self.client_secret = ( |
| 119 | + client_secret.encode("latin1") if isinstance(client_secret, str) else client_secret |
| 120 | + ) |
| 121 | + |
| 122 | + if cert is None and key is not None: |
| 123 | + raise RuntimeError("Key can only be used together with a cert.") |
| 124 | + |
| 125 | + def can_complete_http_basic(self) -> bool: |
| 126 | + return self.username is not None |
| 127 | + |
| 128 | + def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool: |
| 129 | + return self.client_id is not None |
| 130 | + |
| 131 | + def can_complete_mutualTLS(self) -> bool: |
| 132 | + return self.cert is not None |
| 133 | + |
| 134 | + async def http_basic_credentials(self) -> tuple[bytes, bytes]: |
| 135 | + assert self.username is not None |
| 136 | + assert self.password is not None |
| 137 | + return self.username, self.password |
| 138 | + |
| 139 | + async def oauth2_client_credentials(self) -> tuple[bytes, bytes]: |
| 140 | + assert self.client_id is not None |
| 141 | + assert self.client_secret is not None |
| 142 | + return self.client_id, self.client_secret |
| 143 | + |
| 144 | + def tls_credentials(self) -> tuple[str, str | None]: |
| 145 | + assert self.cert is not None |
| 146 | + return (self.cert, self.key) |
0 commit comments