Skip to content

Commit 7d06507

Browse files
authored
Merge pull request #401 fSupport nebius jwt credentials rom UgnineSirdis/support-nebius-jwt
2 parents dc0b847 + e4cd3eb commit 7d06507

File tree

10 files changed

+446
-42
lines changed

10 files changed

+446
-42
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ grpcio>=1.42.0
22
packaging
33
protobuf>=3.13.0,<5.0.0
44
aiohttp<4
5+
pyjwt==2.8.0

test-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,5 @@ pylint-protobuf
4646
cython
4747
freezegun==1.2.2
4848
pytest-cov
49+
yandexcloud
4950
-e .

tests/aio/test_credentials.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pytest
2+
import time
3+
import grpc
4+
import threading
5+
6+
import tests.auth.test_credentials
7+
import ydb.aio.iam
8+
9+
10+
class TestServiceAccountCredentials(ydb.aio.iam.ServiceAccountCredentials):
11+
def _channel_factory(self):
12+
return grpc.aio.insecure_channel(self._iam_endpoint)
13+
14+
def get_expire_time(self):
15+
return self._expires_in - time.time()
16+
17+
18+
class TestNebiusServiceAccountCredentials(ydb.aio.iam.NebiusServiceAccountCredentials):
19+
def get_expire_time(self):
20+
return self._expires_in - time.time()
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_yandex_service_account_credentials():
25+
server = tests.auth.test_credentials.IamTokenServiceTestServer()
26+
credentials = TestServiceAccountCredentials(
27+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
28+
tests.auth.test_credentials.ACCESS_KEY_ID,
29+
tests.auth.test_credentials.PRIVATE_KEY,
30+
server.get_endpoint(),
31+
)
32+
t = (await credentials.auth_metadata())[0][1]
33+
assert t == "test_token"
34+
assert credentials.get_expire_time() <= 42
35+
server.stop()
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_nebius_service_account_credentials():
40+
server = tests.auth.test_credentials.NebiusTokenServiceForTest()
41+
42+
def serve(s):
43+
s.handle_request()
44+
45+
serve_thread = threading.Thread(target=serve, args=(server,))
46+
serve_thread.start()
47+
48+
credentials = TestNebiusServiceAccountCredentials(
49+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
50+
tests.auth.test_credentials.ACCESS_KEY_ID,
51+
tests.auth.test_credentials.PRIVATE_KEY,
52+
server.endpoint(),
53+
)
54+
t = (await credentials.auth_metadata())[0][1]
55+
assert t == "test_nebius_token"
56+
assert credentials.get_expire_time() <= 42
57+
58+
serve_thread.join()

tests/auth/__init__.py

Whitespace-only changes.

tests/auth/test_credentials.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import jwt
2+
import concurrent.futures
3+
import grpc
4+
import time
5+
import http.server
6+
import urllib
7+
import threading
8+
import json
9+
10+
import ydb.iam
11+
12+
from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc
13+
from yandex.cloud.iam.v1 import iam_token_service_pb2
14+
15+
SERVICE_ACCOUNT_ID = "sa_id"
16+
ACCESS_KEY_ID = "key_id"
17+
PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC75/JS3rMcLJxv\nFgpOzF5+2gH+Yig3RE2MTl9uwC0BZKAv6foYr7xywQyWIK+W1cBhz8R4LfFmZo2j\nM0aCvdRmNBdW0EDSTnHLxCsFhoQWLVq+bI5f5jzkcoiioUtaEpADPqwgVULVtN/n\nnPJiZ6/dU30C3jmR6+LUgEntUtWt3eq3xQIn5lG3zC1klBY/HxtfH5Hu8xBvwRQT\nJnh3UpPLj8XwSmriDgdrhR7o6umWyVuGrMKlLHmeivlfzjYtfzO1MOIMG8t2/zxG\nR+xb4Vwks73sH1KruH/0/JMXU97npwpe+Um+uXhpldPygGErEia7abyZB2gMpXqr\nWYKMo02NAgMBAAECggEAO0BpC5OYw/4XN/optu4/r91bupTGHKNHlsIR2rDzoBhU\nYLd1evpTQJY6O07EP5pYZx9mUwUdtU4KRJeDGO/1/WJYp7HUdtxwirHpZP0lQn77\nuccuX/QQaHLrPekBgz4ONk+5ZBqukAfQgM7fKYOLk41jgpeDbM2Ggb6QUSsJISEp\nzrwpI/nNT/wn+Hvx4DxrzWU6wF+P8kl77UwPYlTA7GsT+T7eKGVH8xsxmK8pt6lg\nsvlBA5XosWBWUCGLgcBkAY5e4ZWbkdd183o+oMo78id6C+PQPE66PLDtHWfpRRmN\nm6XC03x6NVhnfvfozoWnmS4+e4qj4F/emCHvn0GMywKBgQDLXlj7YPFVXxZpUvg/\nrheVcCTGbNmQJ+4cZXx87huqwqKgkmtOyeWsRc7zYInYgraDrtCuDBCfP//ZzOh0\nLxepYLTPk5eNn/GT+VVrqsy35Ccr60g7Lp/bzb1WxyhcLbo0KX7/6jl0lP+VKtdv\nmto+4mbSBXSM1Y5BVVoVgJ3T/wKBgQDsiSvPRzVi5TTj13x67PFymTMx3HCe2WzH\nJUyepCmVhTm482zW95pv6raDr5CTO6OYpHtc5sTTRhVYEZoEYFTM9Vw8faBtluWG\nBjkRh4cIpoIARMn74YZKj0C/0vdX7SHdyBOU3bgRPHg08Hwu3xReqT1kEPSI/B2V\n4pe5fVrucwKBgQCNFgUxUA3dJjyMES18MDDYUZaRug4tfiYouRdmLGIxUxozv6CG\nZnbZzwxFt+GpvPUV4f+P33rgoCvFU+yoPctyjE6j+0aW0DFucPmb2kBwCu5J/856\nkFwCx3blbwFHAco+SdN7g2kcwgmV2MTg/lMOcU7XwUUcN0Obe7UlWbckzQKBgQDQ\nnXaXHL24GGFaZe4y2JFmujmNy1dEsoye44W9ERpf9h1fwsoGmmCKPp90az5+rIXw\nFXl8CUgk8lXW08db/r4r+ma8Lyx0GzcZyplAnaB5/6j+pazjSxfO4KOBy4Y89Tb+\nTP0AOcCi6ws13bgY+sUTa/5qKA4UVw+c5zlb7nRpgwKBgGXAXhenFw1666482iiN\ncHSgwc4ZHa1oL6aNJR1XWH+aboBSwR+feKHUPeT4jHgzRGo/aCNHD2FE5I8eBv33\nof1kWYjAO0YdzeKrW0rTwfvt9gGg+CS397aWu4cy+mTI+MNfBgeDAIVBeJOJXLlX\nhL8bFAuNNVrCOp79TNnNIsh7\n-----END PRIVATE KEY-----\n" # noqa: E501
18+
PUBLIC_KEY = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu+fyUt6zHCycbxYKTsxe\nftoB/mIoN0RNjE5fbsAtAWSgL+n6GK+8csEMliCvltXAYc/EeC3xZmaNozNGgr3U\nZjQXVtBA0k5xy8QrBYaEFi1avmyOX+Y85HKIoqFLWhKQAz6sIFVC1bTf55zyYmev\n3VN9At45kevi1IBJ7VLVrd3qt8UCJ+ZRt8wtZJQWPx8bXx+R7vMQb8EUEyZ4d1KT\ny4/F8Epq4g4Ha4Ue6OrplslbhqzCpSx5nor5X842LX8ztTDiDBvLdv88RkfsW+Fc\nJLO97B9Sq7h/9PyTF1Pe56cKXvlJvrl4aZXT8oBhKxImu2m8mQdoDKV6q1mCjKNN\njQIDAQAB\n-----END PUBLIC KEY-----\n" # noqa: E501
19+
20+
21+
def test_metadata_credentials():
22+
credentials = ydb.iam.MetadataUrlCredentials()
23+
raised = False
24+
try:
25+
credentials.auth_metadata()
26+
except Exception:
27+
raised = True
28+
29+
assert raised
30+
31+
32+
class IamTokenServiceForTest(iam_token_service_pb2_grpc.IamTokenServiceServicer):
33+
def Create(self, request, context):
34+
print("IAM token service request: {}".format(request))
35+
# Validate jwt:
36+
decoded = jwt.decode(
37+
request.jwt, key=PUBLIC_KEY, algorithms=["PS256"], audience="https://iam.api.cloud.yandex.net/iam/v1/tokens"
38+
)
39+
assert decoded["iss"] == SERVICE_ACCOUNT_ID
40+
assert decoded["aud"] == "https://iam.api.cloud.yandex.net/iam/v1/tokens"
41+
assert abs(decoded["iat"] - time.time()) <= 60
42+
assert abs(decoded["exp"] - time.time()) <= 3600
43+
44+
response = iam_token_service_pb2.CreateIamTokenResponse(iam_token="test_token")
45+
response.expires_at.seconds = int(time.time() + 42)
46+
return response
47+
48+
49+
class IamTokenServiceTestServer(object):
50+
def __init__(self):
51+
self.server = grpc.server(concurrent.futures.ThreadPoolExecutor(max_workers=2))
52+
iam_token_service_pb2_grpc.add_IamTokenServiceServicer_to_server(IamTokenServiceForTest(), self.server)
53+
self.server.add_insecure_port(self.get_endpoint())
54+
self.server.start()
55+
56+
def stop(self):
57+
self.server.stop(1)
58+
self.server.wait_for_termination()
59+
60+
def get_endpoint(self):
61+
return "localhost:54321"
62+
63+
64+
class TestServiceAccountCredentials(ydb.iam.ServiceAccountCredentials):
65+
def _channel_factory(self):
66+
return grpc.insecure_channel(self._iam_endpoint)
67+
68+
def get_expire_time(self):
69+
return self._expires_in - time.time()
70+
71+
72+
class TestNebiusServiceAccountCredentials(ydb.iam.NebiusServiceAccountCredentials):
73+
def get_expire_time(self):
74+
return self._expires_in - time.time()
75+
76+
77+
class NebiusTokenServiceHandler(http.server.BaseHTTPRequestHandler):
78+
def do_POST(self):
79+
assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
80+
assert self.path == "/token/exchange"
81+
content_length = int(self.headers["Content-Length"])
82+
post_data = self.rfile.read(content_length).decode("utf8")
83+
print("NebiusTokenServiceHandler.POST data: {}".format(post_data))
84+
parsed_request = urllib.parse.parse_qs(str(post_data))
85+
assert len(parsed_request["grant_type"]) == 1
86+
assert parsed_request["grant_type"][0] == "urn:ietf:params:oauth:grant-type:token-exchange"
87+
88+
assert len(parsed_request["requested_token_type"]) == 1
89+
assert parsed_request["requested_token_type"][0] == "urn:ietf:params:oauth:token-type:access_token"
90+
91+
assert len(parsed_request["subject_token_type"]) == 1
92+
assert parsed_request["subject_token_type"][0] == "urn:ietf:params:oauth:token-type:jwt"
93+
94+
assert len(parsed_request["subject_token"]) == 1
95+
jwt_token = parsed_request["subject_token"][0]
96+
decoded = jwt.decode(
97+
jwt_token, key=PUBLIC_KEY, algorithms=["RS256"], audience="token-service.iam.new.nebiuscloud.net"
98+
)
99+
assert decoded["iss"] == SERVICE_ACCOUNT_ID
100+
assert decoded["sub"] == SERVICE_ACCOUNT_ID
101+
assert decoded["aud"] == "token-service.iam.new.nebiuscloud.net"
102+
assert abs(decoded["iat"] - time.time()) <= 60
103+
assert abs(decoded["exp"] - time.time()) <= 3600
104+
105+
response = {
106+
"access_token": "test_nebius_token",
107+
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
108+
"token_type": "Bearer",
109+
"expires_in": 42,
110+
}
111+
112+
self.send_response(200)
113+
self.send_header("Content-type", "application/json")
114+
self.end_headers()
115+
self.wfile.write(json.dumps(response).encode("utf8"))
116+
117+
118+
class NebiusTokenServiceForTest(http.server.HTTPServer):
119+
def __init__(self):
120+
http.server.HTTPServer.__init__(self, ("localhost", 54322), NebiusTokenServiceHandler)
121+
122+
def endpoint(self):
123+
return "http://localhost:54322/token/exchange"
124+
125+
126+
def test_yandex_service_account_credentials():
127+
server = IamTokenServiceTestServer()
128+
credentials = TestServiceAccountCredentials(SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.get_endpoint())
129+
t = credentials.get_auth_token()
130+
assert t == "test_token"
131+
assert credentials.get_expire_time() <= 42
132+
server.stop()
133+
134+
135+
def test_nebius_service_account_credentials():
136+
server = NebiusTokenServiceForTest()
137+
138+
def serve(s):
139+
s.handle_request()
140+
141+
serve_thread = threading.Thread(target=serve, args=(server,))
142+
serve_thread.start()
143+
144+
credentials = TestNebiusServiceAccountCredentials(SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.endpoint())
145+
t = credentials.get_auth_token()
146+
assert t == "test_nebius_token"
147+
assert credentials.get_expire_time() <= 42
148+
149+
serve_thread.join()

tests/table/test_tx.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,6 @@ def test_tx_begin(driver_sync, database):
3838
tx.rollback()
3939

4040

41-
def test_credentials():
42-
credentials = ydb.iam.MetadataUrlCredentials()
43-
raised = False
44-
try:
45-
credentials.auth_metadata()
46-
except Exception:
47-
raised = True
48-
49-
assert raised
50-
51-
5241
def test_tx_snapshot_ro(driver_sync, database):
5342
session = driver_sync.table_client.session().create()
5443
description = (

ydb/aio/iam.py

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@
55
import logging
66
from ydb.iam import auth
77
from .credentials import AbstractExpiringTokenCredentials
8+
from ydb import issues
89

910
logger = logging.getLogger(__name__)
1011

1112
try:
12-
from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc
13-
from yandex.cloud.iam.v1 import iam_token_service_pb2
1413
import jwt
1514
except ImportError:
1615
jwt = None
16+
17+
try:
18+
from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc
19+
from yandex.cloud.iam.v1 import iam_token_service_pb2
20+
except ImportError:
1721
iam_token_service_pb2_grpc = None
1822
iam_token_service_pb2 = None
1923

@@ -55,6 +59,51 @@ async def _make_token_request(self):
5559
IamTokenCredentials = TokenServiceCredentials
5660

5761

62+
class OAuth2JwtTokenExchangeCredentials(AbstractExpiringTokenCredentials, auth.BaseJWTCredentials):
63+
def __init__(
64+
self,
65+
token_exchange_url,
66+
account_id,
67+
access_key_id,
68+
private_key,
69+
algorithm,
70+
token_service_url,
71+
subject=None,
72+
):
73+
super(OAuth2JwtTokenExchangeCredentials, self).__init__()
74+
auth.BaseJWTCredentials.__init__(
75+
self, account_id, access_key_id, private_key, algorithm, token_service_url, subject
76+
)
77+
assert aiohttp is not None, "Install aiohttp library to use OAuth 2.0 token exchange credentials provider"
78+
self._token_exchange_url = token_exchange_url
79+
80+
async def _make_token_request(self):
81+
params = {
82+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
83+
"requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
84+
"subject_token": self._get_jwt(),
85+
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
86+
}
87+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
88+
89+
timeout = aiohttp.ClientTimeout(total=2)
90+
async with aiohttp.ClientSession(timeout=timeout) as session:
91+
async with session.post(self._token_exchange_url, data=params, headers=headers) as response:
92+
if response.status == 403:
93+
raise issues.Unauthenticated(await response.text())
94+
if response.status >= 500:
95+
raise issues.Unavailable(await response.text())
96+
if response.status >= 400:
97+
raise issues.BadRequest(await response.text())
98+
if response.status != 200:
99+
raise issues.Error(await response.text())
100+
101+
response_json = await response.json()
102+
access_token = response_json["access_token"]
103+
expires_in = response_json["expires_in"]
104+
return {"access_token": access_token, "expires_in": expires_in}
105+
106+
58107
class JWTIamCredentials(TokenServiceCredentials, auth.BaseJWTCredentials):
59108
def __init__(
60109
self,
@@ -65,16 +114,39 @@ def __init__(
65114
iam_channel_credentials=None,
66115
):
67116
TokenServiceCredentials.__init__(self, iam_endpoint, iam_channel_credentials)
68-
auth.BaseJWTCredentials.__init__(self, account_id, access_key_id, private_key)
117+
auth.BaseJWTCredentials.__init__(
118+
self,
119+
account_id,
120+
access_key_id,
121+
private_key,
122+
auth.YANDEX_CLOUD_JWT_ALGORITHM,
123+
auth.YANDEX_CLOUD_IAM_TOKEN_SERVICE_URL,
124+
)
69125

70126
def _get_token_request(self):
71-
return iam_token_service_pb2.CreateIamTokenRequest(
72-
jwt=auth.get_jwt(
73-
self._account_id,
74-
self._access_key_id,
75-
self._private_key,
76-
self._jwt_expiration_timeout,
77-
)
127+
return iam_token_service_pb2.CreateIamTokenRequest(jwt=self._get_jwt())
128+
129+
130+
class NebiusJWTIamCredentials(OAuth2JwtTokenExchangeCredentials):
131+
def __init__(
132+
self,
133+
account_id,
134+
access_key_id,
135+
private_key,
136+
token_exchange_url=None,
137+
):
138+
url = token_exchange_url
139+
if url is None:
140+
url = auth.NEBIUS_CLOUD_IAM_TOKEN_EXCHANGE_URL
141+
OAuth2JwtTokenExchangeCredentials.__init__(
142+
self,
143+
url,
144+
account_id,
145+
access_key_id,
146+
private_key,
147+
auth.NEBIUS_CLOUD_JWT_ALGORITHM,
148+
auth.NEBIUS_CLOUD_IAM_TOKEN_SERVICE_AUDIENCE,
149+
account_id,
78150
)
79151

80152

@@ -130,3 +202,20 @@ def __init__(
130202
iam_endpoint,
131203
iam_channel_credentials,
132204
)
205+
206+
207+
class NebiusServiceAccountCredentials(NebiusJWTIamCredentials):
208+
def __init__(
209+
self,
210+
service_account_id,
211+
access_key_id,
212+
private_key,
213+
iam_endpoint=None,
214+
iam_channel_credentials=None,
215+
):
216+
super(NebiusServiceAccountCredentials, self).__init__(
217+
service_account_id,
218+
access_key_id,
219+
private_key,
220+
iam_endpoint,
221+
)

ydb/driver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ def credentials_from_env_variables(tracer=None):
3838

3939
return ydb.iam.ServiceAccountCredentials.from_file(service_account_key_file)
4040

41+
nebius_service_account_key_file = os.getenv("YDB_NEBIUS_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS")
42+
if nebius_service_account_key_file is not None:
43+
ctx.trace({"credentials.nebius_service_account_key_file": True})
44+
import ydb.iam
45+
46+
return ydb.iam.NebiusServiceAccountCredentials.from_file(nebius_service_account_key_file)
47+
4148
anonymous_credetials = os.getenv("YDB_ANONYMOUS_CREDENTIALS", "0") == "1"
4249
if anonymous_credetials:
4350
ctx.trace({"credentials.anonymous": True})

ydb/iam/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# -*- coding: utf-8 -*-
22
from .auth import ServiceAccountCredentials # noqa
3+
from .auth import NebiusServiceAccountCredentials # noqa
34
from .auth import MetadataUrlCredentials # noqa

0 commit comments

Comments
 (0)