diff --git a/pyproject.toml b/pyproject.toml index d7c14d73a..77d58b7cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "tenacity>=8.0.1", "watchfiles>=0.19.0,<0.20", "truss-transfer>=0.0.37,<0.0.40", + "gql-query-builder (>=0.1.7,<0.2.0)", ] [project.urls] @@ -102,6 +103,9 @@ default-groups = [ "dev-server", ] +[tool.uv.extra-build-dependencies] +gql-query-builder = ["pip"] + # Simplified Hatchling configuration - let it auto-discover truss, manually include others [tool.hatch.build.targets.sdist] include = [ diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index 7af664d81..36fed8ac1 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional import requests +from gql_query_builder import GqlQuery from pydantic import BaseModel, Field from truss.base.custom_types import SafeModel @@ -328,52 +329,34 @@ def deploy_chain_atomic( ): if allow_truss_download is None: allow_truss_download = True - entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint) - dependencies_str = ", ".join( - [ + mutation_params = { + "chain_id": chain_id, + "chain_name": chain_name, + "environment": environment, + "original_source_artifact_s3_key": original_source_artifact_s3_key, + "allow_truss_download": "false" if allow_truss_download is False else None, + "is_draft": is_draft, + "entrypoint": _chainlet_data_atomic_to_graphql_mutation(entrypoint), + "dependencies": [ _chainlet_data_atomic_to_graphql_mutation(dependency) for dependency in dependencies - ] - ) - - params = [] - if chain_id: - params.append(f'chain_id: "{chain_id}"') - if chain_name: - params.append(f'chain_name: "{chain_name}"') - if environment: - params.append(f'environment: "{environment}"') - if original_source_artifact_s3_key: - params.append( - f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"' - ) - - params.append(f"is_draft: {str(is_draft).lower()}") - if allow_truss_download is False: - params.append("allow_truss_download: false") - - params_str = PARAMS_INDENT.join(params) + ], + "truss_user_env": "$trussUserEnv", + } + mutation_params = { + str(k): v for k, v in mutation_params.items() if v is not None + } - query_string = f""" - mutation ($trussUserEnv: String) {{ - deploy_chain_atomic( - {params_str} - entrypoint: {entrypoint_str} - dependencies: [{dependencies_str}] - truss_user_env: $trussUserEnv - ) {{ - chain_deployment {{ - id - chain {{ - id - hostname - }} - }} - }} - }} - """ + gql = GqlQuery() + gql.operation( + "mutation", + "deploy_chain_atomic", + mutation_params, + ["chain_deployment { id chain { id hostname } }"], + ) + query_string = gql.generate() resp = self._post_graphql_query( query_string, variables={"trussUserEnv": truss_user_env.json()} ) diff --git a/truss/tests/remote/baseten/test_api.py b/truss/tests/remote/baseten/test_api.py index 28cbf6326..798d20415 100644 --- a/truss/tests/remote/baseten/test_api.py +++ b/truss/tests/remote/baseten/test_api.py @@ -353,8 +353,8 @@ def test_deploy_chain_deployment(mock_post, baseten_api): gql_mutation = mock_post.call_args[1]["json"]["query"] - assert 'environment: "production"' in gql_mutation - assert 'chain_id: "chain_id"' in gql_mutation + assert "environment: production" in gql_mutation + assert "chain_id: chain_id" in gql_mutation assert "dependencies:" in gql_mutation assert "entrypoint:" in gql_mutation @@ -378,8 +378,8 @@ def test_deploy_chain_deployment_with_gitinfo(mock_post, baseten_api): gql_mutation = mock_post.call_args[1]["json"]["query"] - assert 'environment: "production"' in gql_mutation - assert 'chain_id: "chain_id"' in gql_mutation + assert "environment: production" in gql_mutation + assert "chain_id: chain_id" in gql_mutation assert "dependencies:" in gql_mutation assert "entrypoint:" in gql_mutation @@ -402,12 +402,70 @@ def test_deploy_chain_deployment_no_environment(mock_post, baseten_api): gql_mutation = mock_post.call_args[1]["json"]["query"] - assert 'chain_id: "chain_id"' in gql_mutation + assert "chain_id: chain_id" in gql_mutation assert "environment" not in gql_mutation assert "dependencies:" in gql_mutation assert "entrypoint:" in gql_mutation +@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response()) +def test_deploy_chain_deployment_with_dependencies(mock_post, baseten_api): + dependencies = [ + ChainletDataAtomic( + name="dependency-1", + oracle=OracleData( + model_name="dep-model-1", + s3_key="dep-s3-key-1", + encoded_config_str="dep-encoded-config-str-1", + ), + ), + ChainletDataAtomic( + name="dependency-2", + oracle=OracleData( + model_name="dep-model-2", + s3_key="dep-s3-key-2", + encoded_config_str="dep-encoded-config-str-2", + ), + ), + ] + + baseten_api.deploy_chain_atomic( + environment="production", + chain_id="chain_id", + dependencies=dependencies, + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + ), + ), + truss_user_env=b10_types.TrussUserEnv.collect(), + ) + + gql_mutation = mock_post.call_args[1]["json"]["query"] + + # Single regex to check all assertions + import re + + pattern = ( + r"(?=.*environment: production)" + r"(?=.*chain_id: chain_id)" + r"(?=.*dependencies:)" + r"(?=.*entrypoint:)" + r'(?=.*name: "dependency-1")' + r'(?=.*name: "dependency-2")' + r'(?=.*model_name: "dep-model-1")' + r'(?=.*model_name: "dep-model-2")' + r'(?=.*s3_key: "dep-s3-key-1")' + r'(?=.*s3_key: "dep-s3-key-2")' + ) + assert re.search(pattern, gql_mutation), ( + f"GraphQL mutation does not contain all expected elements: {gql_mutation}" + ) + + @mock.patch("requests.post", return_value=mock_upsert_training_project_response()) def test_upsert_training_project(mock_post, baseten_api): baseten_api.upsert_training_project( diff --git a/uv.lock b/uv.lock index 2b411d818..e017788b9 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9, <3.14" resolution-markers = [ "python_full_version >= '3.13'", @@ -965,6 +965,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, ] +[[package]] +name = "gql-query-builder" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/62/bd/68d07678108ae2038dfd98ed4fdfcde833a2c015af27ef23c2a30020b406/gql-query-builder-0.1.7.tar.gz", hash = "sha256:99fd8e3f883b75fded271ab7957b7da6d6ec197679f0da7ff3d7d74c34b12842", size = 6711, upload-time = "2021-12-07T01:51:40.782Z" } + [[package]] name = "grpcio" version = "1.74.0" @@ -3321,6 +3327,7 @@ dependencies = [ { name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "click", version = "8.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "google-cloud-storage" }, + { name = "gql-query-builder" }, { name = "httpx" }, { name = "httpx-ws" }, { name = "huggingface-hub" }, @@ -3393,6 +3400,7 @@ requires-dist = [ { name = "boto3", specifier = ">=1.34.85,<2" }, { name = "click", specifier = ">=8.0.3,<9" }, { name = "google-cloud-storage", specifier = ">=2.10.0" }, + { name = "gql-query-builder", specifier = ">=0.1.7,<0.2.0" }, { name = "httpx", specifier = ">=0.24.1" }, { name = "httpx-ws", specifier = ">=0.7.1,<0.8" }, { name = "huggingface-hub", specifier = ">=0.25.0" },