From af892038c09e75b376f99d2c4ef2b26be28686d4 Mon Sep 17 00:00:00 2001 From: Narek Amirbekian Date: Thu, 23 Oct 2025 16:47:06 -0700 Subject: [PATCH 1/2] Use graphql query builder --- pyproject.toml | 1 + truss/remote/baseten/api.py | 57 +++++++++++---------- truss/tests/remote/baseten/test_api.py | 68 ++++++++++++++++++++++++-- 3 files changed, 94 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d7c14d73a..8ad1073ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "tenacity>=8.0.1", "watchfiles>=0.19.0,<0.20", "truss-transfer>=0.0.37,<0.0.40", + "gql-query-builder (>=0.1.7,<0.2.0)", ] [project.urls] diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index 3c7afe2a2..ba2ee4d40 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -1,8 +1,10 @@ +import json import logging from enum import Enum from typing import Any, Dict, List, Mapping, Optional import requests +from gql_query_builder import GqlQuery from pydantic import BaseModel, Field from truss.remote.baseten import custom_types as b10_types @@ -299,37 +301,38 @@ def deploy_chain_atomic( chain_name: Optional[str] = None, environment: Optional[str] = None, is_draft: bool = False, + original_source_artifact_s3_key: Optional[str] = None, + allow_truss_download: Optional[bool] = True, ): - entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint) - - dependencies_str = ", ".join( - [ + if allow_truss_download is None: + allow_truss_download = True + + mutation_params = { + "chain_id": chain_id, + "chain_name": chain_name, + "environment": environment, + "original_source_artifact_s3_key": original_source_artifact_s3_key, + "allow_truss_download": "false" if allow_truss_download is False else None, + "is_draft": is_draft, + "entrypoint": _chainlet_data_atomic_to_graphql_mutation(entrypoint), + "dependencies": [ _chainlet_data_atomic_to_graphql_mutation(dependency) for dependency in dependencies - ] - ) + ], + "truss_user_env": "$trussUserEnv", + } + mutation_params = { + str(k): v for k, v in mutation_params.items() if v is not None + } - query_string = f""" - mutation ($trussUserEnv: String) {{ - deploy_chain_atomic( - {f'chain_id: "{chain_id}"' if chain_id else ""} - {f'chain_name: "{chain_name}"' if chain_name else ""} - {f'environment: "{environment}"' if environment else ""} - is_draft: {str(is_draft).lower()} - entrypoint: {entrypoint_str} - dependencies: [{dependencies_str}] - truss_user_env: $trussUserEnv - ) {{ - chain_deployment {{ - id - chain {{ - id - hostname - }} - }} - }} - }} - """ + gql = GqlQuery() + gql.operation( + "mutation", + "deploy_chain_atomic", + mutation_params, + ["chain_deployment { id chain { id hostname } }"], + ) + query_string = gql.generate() resp = self._post_graphql_query( query_string, variables={"trussUserEnv": truss_user_env.json()} diff --git a/truss/tests/remote/baseten/test_api.py b/truss/tests/remote/baseten/test_api.py index 28cbf6326..798d20415 100644 --- a/truss/tests/remote/baseten/test_api.py +++ b/truss/tests/remote/baseten/test_api.py @@ -353,8 +353,8 @@ def test_deploy_chain_deployment(mock_post, baseten_api): gql_mutation = mock_post.call_args[1]["json"]["query"] - assert 'environment: "production"' in gql_mutation - assert 'chain_id: "chain_id"' in gql_mutation + assert "environment: production" in gql_mutation + assert "chain_id: chain_id" in gql_mutation assert "dependencies:" in gql_mutation assert "entrypoint:" in gql_mutation @@ -378,8 +378,8 @@ def test_deploy_chain_deployment_with_gitinfo(mock_post, baseten_api): gql_mutation = mock_post.call_args[1]["json"]["query"] - assert 'environment: "production"' in gql_mutation - assert 'chain_id: "chain_id"' in gql_mutation + assert "environment: production" in gql_mutation + assert "chain_id: chain_id" in gql_mutation assert "dependencies:" in gql_mutation assert "entrypoint:" in gql_mutation @@ -402,12 +402,70 @@ def test_deploy_chain_deployment_no_environment(mock_post, baseten_api): gql_mutation = mock_post.call_args[1]["json"]["query"] - assert 'chain_id: "chain_id"' in gql_mutation + assert "chain_id: chain_id" in gql_mutation assert "environment" not in gql_mutation assert "dependencies:" in gql_mutation assert "entrypoint:" in gql_mutation +@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response()) +def test_deploy_chain_deployment_with_dependencies(mock_post, baseten_api): + dependencies = [ + ChainletDataAtomic( + name="dependency-1", + oracle=OracleData( + model_name="dep-model-1", + s3_key="dep-s3-key-1", + encoded_config_str="dep-encoded-config-str-1", + ), + ), + ChainletDataAtomic( + name="dependency-2", + oracle=OracleData( + model_name="dep-model-2", + s3_key="dep-s3-key-2", + encoded_config_str="dep-encoded-config-str-2", + ), + ), + ] + + baseten_api.deploy_chain_atomic( + environment="production", + chain_id="chain_id", + dependencies=dependencies, + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + ), + ), + truss_user_env=b10_types.TrussUserEnv.collect(), + ) + + gql_mutation = mock_post.call_args[1]["json"]["query"] + + # Single regex to check all assertions + import re + + pattern = ( + r"(?=.*environment: production)" + r"(?=.*chain_id: chain_id)" + r"(?=.*dependencies:)" + r"(?=.*entrypoint:)" + r'(?=.*name: "dependency-1")' + r'(?=.*name: "dependency-2")' + r'(?=.*model_name: "dep-model-1")' + r'(?=.*model_name: "dep-model-2")' + r'(?=.*s3_key: "dep-s3-key-1")' + r'(?=.*s3_key: "dep-s3-key-2")' + ) + assert re.search(pattern, gql_mutation), ( + f"GraphQL mutation does not contain all expected elements: {gql_mutation}" + ) + + @mock.patch("requests.post", return_value=mock_upsert_training_project_response()) def test_upsert_training_project(mock_post, baseten_api): baseten_api.upsert_training_project( From d17b9081da5adac7dc278a76f9521693f7ce1b35 Mon Sep 17 00:00:00 2001 From: Narek Amirbekian Date: Thu, 23 Oct 2025 17:22:36 -0700 Subject: [PATCH 2/2] Remove json --- pyproject.toml | 3 +++ truss/remote/baseten/api.py | 1 - uv.lock | 10 +++++++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8ad1073ea..77d58b7cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,9 @@ default-groups = [ "dev-server", ] +[tool.uv.extra-build-dependencies] +gql-query-builder = ["pip"] + # Simplified Hatchling configuration - let it auto-discover truss, manually include others [tool.hatch.build.targets.sdist] include = [ diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index ba2ee4d40..305d8e949 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -1,4 +1,3 @@ -import json import logging from enum import Enum from typing import Any, Dict, List, Mapping, Optional diff --git a/uv.lock b/uv.lock index 2b411d818..e017788b9 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9, <3.14" resolution-markers = [ "python_full_version >= '3.13'", @@ -965,6 +965,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, ] +[[package]] +name = "gql-query-builder" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/62/bd/68d07678108ae2038dfd98ed4fdfcde833a2c015af27ef23c2a30020b406/gql-query-builder-0.1.7.tar.gz", hash = "sha256:99fd8e3f883b75fded271ab7957b7da6d6ec197679f0da7ff3d7d74c34b12842", size = 6711, upload-time = "2021-12-07T01:51:40.782Z" } + [[package]] name = "grpcio" version = "1.74.0" @@ -3321,6 +3327,7 @@ dependencies = [ { name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "click", version = "8.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "google-cloud-storage" }, + { name = "gql-query-builder" }, { name = "httpx" }, { name = "httpx-ws" }, { name = "huggingface-hub" }, @@ -3393,6 +3400,7 @@ requires-dist = [ { name = "boto3", specifier = ">=1.34.85,<2" }, { name = "click", specifier = ">=8.0.3,<9" }, { name = "google-cloud-storage", specifier = ">=2.10.0" }, + { name = "gql-query-builder", specifier = ">=0.1.7,<0.2.0" }, { name = "httpx", specifier = ">=0.24.1" }, { name = "httpx-ws", specifier = ">=0.7.1,<0.8" }, { name = "huggingface-hub", specifier = ">=0.25.0" },