|
19 | 19 | import json |
20 | 20 | import logging |
21 | 21 | import mimetypes |
| 22 | +import random |
| 23 | +import time |
22 | 24 | from typing import Any, Iterator, Optional, Union |
23 | 25 | from urllib.parse import urlencode |
24 | 26 |
|
| 27 | +from google import genai |
| 28 | +from google.cloud import iam_credentials_v1 |
25 | 29 | from google.genai import _api_module |
26 | 30 | from google.genai import _common |
| 31 | +from google.genai import types as genai_types |
27 | 32 | from google.genai._common import get_value_by_path as getv |
28 | 33 | from google.genai._common import set_value_by_path as setv |
29 | 34 | from google.genai.pagers import Pager |
@@ -704,6 +709,52 @@ def delete( |
704 | 709 | """ |
705 | 710 | return self._delete(name=name, config=config) |
706 | 711 |
|
| 712 | + def generate_access_token( |
| 713 | + self, |
| 714 | + service_account_email: str, |
| 715 | + sandbox_id: str, |
| 716 | + port: str = "8080", |
| 717 | + timeout: int = 3600, |
| 718 | + ) -> str: |
| 719 | + """Signs a JWT with a Google Cloud service account.""" |
| 720 | + client = iam_credentials_v1.IAMCredentialsClient() |
| 721 | + name = f"projects/-/serviceAccounts/{service_account_email}" |
| 722 | + custom_claims = {"port": port, "sandbox_id": sandbox_id} |
| 723 | + payload = { |
| 724 | + "iat": int(time.time()), |
| 725 | + "exp": int(time.time()) + timeout, |
| 726 | + "iss": service_account_email, |
| 727 | + "nonce": random.randint(1, 1000000000), |
| 728 | + "aud": "vmaas-proxy-api", # default audience for sandbox proxy |
| 729 | + **custom_claims, |
| 730 | + } |
| 731 | + request = iam_credentials_v1.SignJwtRequest( |
| 732 | + name=name, |
| 733 | + payload=json.dumps(payload), |
| 734 | + ) |
| 735 | + response = client.sign_jwt(request=request) |
| 736 | + return response.signed_jwt |
| 737 | + |
| 738 | + def send_command( |
| 739 | + self, |
| 740 | + http_method: str, |
| 741 | + path: str, |
| 742 | + query_params: Any, |
| 743 | + access_token: str, |
| 744 | + headers: dict[str, str], |
| 745 | + request_dict: dict[str, object], |
| 746 | + sandbox_environment: Optional[types.SandboxEnvironment] = None, |
| 747 | + ) -> str | None: |
| 748 | + """Sends a command to the sandbox.""" |
| 749 | + # TODO(tenghuil): Get connection info from sandbox environment when it's ready. |
| 750 | + endpoint = "https://test-us-central1.autopush-sandbox.vertexai.goog/" + path |
| 751 | + |
| 752 | + headers = {"Authorization": f"Bearer {access_token}"} |
| 753 | + http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint) |
| 754 | + http_client = genai.Client(vertexai=True, http_options=http_options) |
| 755 | + response = http_client._api_client.request(http_method, path, request_dict) |
| 756 | + return response |
| 757 | + |
707 | 758 |
|
708 | 759 | class AsyncSandboxes(_api_module.BaseModule): |
709 | 760 |
|
|
0 commit comments