diff --git a/src/ucode/agents/claude.py b/src/ucode/agents/claude.py index d0d0380..1afe962 100644 --- a/src/ucode/agents/claude.py +++ b/src/ucode/agents/claude.py @@ -62,7 +62,11 @@ def _resolve_web_search_model(state: dict) -> str | None: WEB_SEARCH_MCP_NAME = "web_search" -_CLAUDE_MODEL_RE = re.compile(r"^databricks-claude-(opus|sonnet)-(\d+)-(\d+)(.*)$") +# Matches both the AI Gateway form (`databricks-claude-opus-4-8`) and the UC +# model-services form (`system.ai.claude-opus-4-8`). +_CLAUDE_MODEL_RE = re.compile( + r"^(?:system\.ai\.)?(?:databricks-)?claude-(opus|sonnet)-(\d+)-(\d+)(.*)$" +) # Env keys the MLflow Stop hook reads to route traces. Written into the # settings `env` block alongside the hook itself. diff --git a/src/ucode/agents/codex.py b/src/ucode/agents/codex.py index e0bb64b..e8d5eb9 100644 --- a/src/ucode/agents/codex.py +++ b/src/ucode/agents/codex.py @@ -255,6 +255,10 @@ def _openai_model_id(model: str | None) -> str | None: def _codex_model_id(model: str | None) -> str | None: + # UC model-services ids (`system.ai.gpt-5`) route by name through the + # gateway, so they must be sent verbatim — not rewritten to an OpenAI id. + if model and model.startswith("system.ai."): + return model if model in CODEX_OPENAI_ID_INCOMPATIBLE_MODELS: return model return _openai_model_id(model) @@ -263,7 +267,12 @@ def _codex_model_id(model: str | None) -> str | None: def _parse_gpt(model: str | None) -> tuple[int, int | None, int | None, str] | None: if not model: return None - match = _GPT_RE.fullmatch(model.split("/")[-1]) + # Strip the UC model-services prefix so `system.ai.gpt-5` parses for version + # selection; the original id is preserved by callers that need it verbatim. + tail = model.split("/")[-1] + if tail.startswith("system.ai."): + tail = tail[len("system.ai.") :] + match = _GPT_RE.fullmatch(tail) if not match: return None major, minor, patch, suffix = match.groups() diff --git a/src/ucode/cli.py b/src/ucode/cli.py index c363e22..9c722b5 100644 --- a/src/ucode/cli.py +++ b/src/ucode/cli.py @@ -33,6 +33,7 @@ discover_claude_models, discover_codex_models, discover_gemini_models, + discover_model_services, ensure_ai_gateway_v2, ensure_databricks_auth, find_profile_name_for_host, @@ -41,6 +42,7 @@ install_databricks_cli, normalize_workspace_url, run_databricks_login, + use_model_services, ) from ucode.mcp import ( MCP_CLIENTS, @@ -160,7 +162,13 @@ def configure_shared_state( don't error out. If ``None``, we resolve it from the host after login. """ workspace = normalize_workspace_url(workspace) - previous_workspace = load_state().get("workspace") + prior_state = load_state() + previous_workspace = prior_state.get("workspace") + # The flag is sticky: an explicit env var wins, otherwise fall back to what + # was persisted when the workspace was configured. Without this, every + # launch re-runs discovery and a missing env var would silently revert a + # model-services workspace to the databricks-* gateway names. + model_services = use_model_services(default=bool(prior_state.get("use_model_services"))) fetch_all = tools is None if force_login: run_databricks_login(workspace, profile) @@ -184,19 +192,29 @@ def configure_shared_state( claude_reason: str | None = None gemini_reason: str | None = None codex_reason: str | None = None - with spinner("Fetching available models..."): + claude_models = {} + gemini_models = [] + codex_models = [] + if model_services: + # Opt-in: one UC model-services call yields all families as + # `system.ai.` ids, bucketed by name. The single reason is + # shared across the families that were requested. + with spinner("Fetching available models (model services)..."): + ms_claude, ms_codex, ms_gemini, ms_reason = discover_model_services(workspace, token) if want_claude: - claude_models, claude_reason = discover_claude_models(workspace, token) - else: - claude_models = {} + claude_models, claude_reason = ms_claude, ms_reason if want_gemini: - gemini_models, gemini_reason = discover_gemini_models(workspace, token) - else: - gemini_models = [] + gemini_models, gemini_reason = ms_gemini, ms_reason if want_codex: - codex_models, codex_reason = discover_codex_models(workspace, token) - else: - codex_models = [] + codex_models, codex_reason = ms_codex, ms_reason + else: + with spinner("Fetching available models..."): + if want_claude: + claude_models, claude_reason = discover_claude_models(workspace, token) + if want_gemini: + gemini_models, gemini_reason = discover_gemini_models(workspace, token) + if want_codex: + codex_models, codex_reason = discover_codex_models(workspace, token) opencode_models: dict[str, list[str]] = {} if claude_models: opencode_models["anthropic"] = list(claude_models.values()) @@ -210,6 +228,9 @@ def configure_shared_state( state["profile"] = profile else: state.pop("profile", None) + # Persist the resolved flag so subsequent launches stay on the same + # discovery path without the env var being re-exported. + state["use_model_services"] = model_services state["base_urls"] = build_shared_base_urls(workspace) if want_claude: state["claude_models"] = claude_models diff --git a/src/ucode/databricks.py b/src/ucode/databricks.py index 2d45feb..d2cdc20 100644 --- a/src/ucode/databricks.py +++ b/src/ucode/databricks.py @@ -17,7 +17,7 @@ from typing import Literal, cast, overload from urllib import error as urllib_error from urllib import request as urllib_request -from urllib.parse import urlparse +from urllib.parse import urlencode, urlparse from databricks.sql.exc import ServerOperationError @@ -977,6 +977,174 @@ def build_auth_shell_command(workspace: str, profile: str | None = None) -> str: ) +def use_model_services(default: bool = False) -> bool: + """True when the opt-in UC model-services discovery path is enabled. + + Set ``UCODE_USE_MODEL_SERVICES=1`` (or true/yes/on) to discover models via + the Unity Catalog model-services API and address them as + ``system.ai.`` instead of the per-family AI Gateway listings. + + The env var, when set to any value, wins. ``default`` is the fallback used + when the env var is unset — callers pass the value persisted in state so a + workspace configured with the flag keeps using model services on later + launches without the env var being re-exported each time. + """ + raw = os.environ.get("UCODE_USE_MODEL_SERVICES") + if raw is None or not raw.strip(): + return default + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +# A model-service's `name` is `model-services/system.ai.`; the +# part after the prefix is exactly the model string agents send (no +# `databricks-` infix — that only appears on the inner destination name). +_MODEL_SERVICE_NAME_PREFIX = "model-services/" +# The metastore-scope listing returns services from EVERY schema (e.g. +# `main.user.foo`, `temp.*`, internal DLT schemas). We only want the +# Databricks-managed foundation models under `system.ai`. +_MODEL_SERVICE_REQUIRED_PREFIX = "system.ai." + + +def _model_service_id(service: dict) -> str | None: + """Extract the `system.ai.` id from one model-service entry. + + Returns None for services in any other schema, so user/internal model + services don't leak into the family buckets.""" + name = service.get("name") + if not isinstance(name, str): + return None + name = name.strip() + if name.startswith(_MODEL_SERVICE_NAME_PREFIX): + name = name[len(_MODEL_SERVICE_NAME_PREFIX) :] + if not name.startswith(_MODEL_SERVICE_REQUIRED_PREFIX): + return None + return name or None + + +# The model-services metastore listing is slow and flaky — large pages +# routinely 504 with `Timeout listing model services under metastore`. A small +# page is far more likely to come back, and each page gets a few retries before +# we give up. +_MODEL_SERVICES_PAGE_SIZE = 10 +_MODEL_SERVICES_PAGE_RETRIES = 4 + + +def _get_model_services_page( + url: str, token: str, *, retries: int = _MODEL_SERVICES_PAGE_RETRIES +) -> tuple[dict | list | None, str | None]: + """GET one model-services page, retrying on failure. + + The endpoint frequently 504s under load; a retry usually succeeds. Returns + the same (payload, reason) shape as ``_http_get_json`` — the last attempt's + result when all retries are exhausted.""" + payload: dict | list | None = None + reason: str | None = None + for attempt in range(retries): + payload, reason = _http_get_json(url, token, timeout=30) + if payload is not None: + return payload, None + _debug("model-services page", f"attempt {attempt + 1}/{retries} failed: {reason}") + return payload, reason + + +def list_model_services( + workspace: str, + token: str, + *, + page_size: int = _MODEL_SERVICES_PAGE_SIZE, + max_pages: int = 100, +) -> tuple[list[str], str | None]: + """List all `system.ai.*` model ids via the UC model-services API. + + Pages through ``/api/2.1/unity-catalog/model-services`` (metastore scope) + and returns the de-duplicated, sorted list of ``system.ai.`` + ids. Uses a small page size with per-page retries because the endpoint is + slow and frequently 504s. Returns (ids, reason); reason is None on success, + otherwise it describes why the list is empty (HTTP/network error or no + services). + """ + hostname = workspace_hostname(workspace) + ids: list[str] = [] + page_token: str | None = None + seen_tokens: set[str] = set() + last_reason: str | None = None + for _ in range(max_pages): + params: dict[str, str] = {"page_size": str(page_size)} + if page_token: + params["page_token"] = page_token + url = f"https://{hostname}/api/2.1/unity-catalog/model-services?{urlencode(params)}" + payload, reason = _get_model_services_page(url, token) + if payload is None: + # Surface the failure only if we have nothing yet; a mid-pagination + # blip still returns whatever we collected. + last_reason = reason + break + data = cast(dict, payload) if isinstance(payload, dict) else {} + for service in data.get("model_services", []): + if isinstance(service, dict): + model_id = _model_service_id(service) + if model_id: + ids.append(model_id) + page_token = data.get("next_page_token") or None + if not page_token: + last_reason = None + break + if page_token in seen_tokens: + break + seen_tokens.add(page_token) + + deduped = sorted(set(ids)) + if deduped: + return deduped, None + return [], last_reason or "model-services listing returned no models" + + +def discover_model_services( + workspace: str, token: str +) -> tuple[dict[str, str], list[str], list[str], str | None]: + """Discover models via UC model-services and bucket them by family name. + + Returns (claude_models, codex_models, gemini_models, reason): + + - ``claude_models`` maps ``opus``/``sonnet``/``haiku`` to the newest + matching ``system.ai.claude-*`` id (mirrors ``discover_claude_models``). + - ``codex_models`` is the list of ``system.ai.*gpt-*`` ids. + - ``gemini_models`` is the list of ``system.ai.*gemini-*`` ids, newest first. + + ``reason`` is None on success, else explains why nothing was found. Family + bucketing is by name substring because the model-services API does not + expose per-model API dialects. + """ + ids, reason = list_model_services(workspace, token) + if not ids: + return {}, [], [], reason + + claude_models: dict[str, str] = {} + for family in ("opus", "sonnet", "haiku"): + candidates = sorted( + [m for m in ids if f"claude-{family}-" in m], + reverse=True, + ) + if candidates: + claude_models[family] = candidates[0] + + codex_models = [m for m in ids if "gpt-" in m] + gemini_models = sorted([m for m in ids if "gemini-" in m], key=model_version_sort_key) + + if not (claude_models or codex_models or gemini_models): + sample = ", ".join(ids[:5]) + return ( + {}, + [], + [], + ( + "model-services returned model ids but none matched " + f"claude/gpt/gemini families (got: {sample})" + ), + ) + return claude_models, codex_models, gemini_models, None + + def discover_claude_models(workspace: str, token: str) -> tuple[dict[str, str], str | None]: """Discover Claude families on this workspace's AI Gateway. diff --git a/tests/test_agent_claude.py b/tests/test_agent_claude.py index ea33c63..9888efd 100644 --- a/tests/test_agent_claude.py +++ b/tests/test_agent_claude.py @@ -41,6 +41,14 @@ def test_does_not_duplicate_1m_suffix(self): overlay, _ = claude.render_overlay(WS, "databricks-claude-opus-4-7[1m]") assert overlay["env"]["ANTHROPIC_MODEL"] == "databricks-claude-opus-4-7[1m]" + def test_adds_1m_suffix_for_model_services_name(self): + overlay, _ = claude.render_overlay(WS, "system.ai.claude-opus-4-8") + assert overlay["env"]["ANTHROPIC_MODEL"] == "system.ai.claude-opus-4-8[1m]" + + def test_no_1m_suffix_for_model_services_haiku(self): + overlay, _ = claude.render_overlay(WS, "system.ai.claude-haiku-4-6") + assert overlay["env"]["ANTHROPIC_MODEL"] == "system.ai.claude-haiku-4-6" + def test_sets_anthropic_base_url(self): overlay, _ = claude.render_overlay(WS, "s4") assert overlay["env"]["ANTHROPIC_BASE_URL"] == f"{WS}/ai-gateway/anthropic" diff --git a/tests/test_agent_codex.py b/tests/test_agent_codex.py index b84b667..f8d6baf 100644 --- a/tests/test_agent_codex.py +++ b/tests/test_agent_codex.py @@ -337,6 +337,17 @@ def test_openai_model_id_maps_databricks_naming(self): def test_codex_model_id_preserves_openai_incompatible_models(self): assert codex._codex_model_id("databricks-gpt-5-2-codex") == "databricks-gpt-5-2-codex" assert codex._codex_model_id("databricks-gpt-5-4-nano") == "databricks-gpt-5-4-nano" + + def test_codex_model_id_passes_model_services_id_verbatim(self): + # UC model-services ids route by name, so they must not be rewritten + # to the OpenAI id form. + assert codex._codex_model_id("system.ai.gpt-5") == "system.ai.gpt-5" + assert codex._codex_model_id("system.ai.gpt-5-2-codex") == "system.ai.gpt-5-2-codex" + + def test_default_model_selects_model_services_gpt(self): + models = ["system.ai.gpt-5", "system.ai.gpt-5-5", "system.ai.claude-opus-4-8"] + + assert codex.default_model({"codex_models": models}) == "system.ai.gpt-5-5" assert codex._codex_model_id("databricks-gpt-5-5") == "gpt-5.5" diff --git a/tests/test_databricks.py b/tests/test_databricks.py index d3feeba..c55ee85 100644 --- a/tests/test_databricks.py +++ b/tests/test_databricks.py @@ -132,6 +132,160 @@ def test_selects_opus_4_8_when_advertised(self, monkeypatch): assert models["opus"] == "databricks-claude-opus-4-8" +def _model_service(model_id: str) -> dict: + """A model-services entry whose `name` strips to `model_id`.""" + return {"name": f"model-services/{model_id}"} + + +class TestUseModelServices: + def test_off_by_default(self, monkeypatch): + monkeypatch.delenv("UCODE_USE_MODEL_SERVICES", raising=False) + assert db_mod.use_model_services() is False + + def test_truthy_values_enable(self, monkeypatch): + for value in ("1", "true", "TRUE", "yes", "on"): + monkeypatch.setenv("UCODE_USE_MODEL_SERVICES", value) + assert db_mod.use_model_services() is True + + def test_falsey_values_disable(self, monkeypatch): + # A non-empty, non-truthy value explicitly disables — even over a + # persisted default of True. + for value in ("0", "false", "no"): + monkeypatch.setenv("UCODE_USE_MODEL_SERVICES", value) + assert db_mod.use_model_services(default=True) is False + + def test_unset_falls_back_to_default(self, monkeypatch): + # Sticky behavior: when the env var is unset (or blank), the persisted + # default decides. + monkeypatch.delenv("UCODE_USE_MODEL_SERVICES", raising=False) + assert db_mod.use_model_services(default=True) is True + assert db_mod.use_model_services(default=False) is False + monkeypatch.setenv("UCODE_USE_MODEL_SERVICES", "") + assert db_mod.use_model_services(default=True) is True + + def test_env_var_overrides_default(self, monkeypatch): + monkeypatch.setenv("UCODE_USE_MODEL_SERVICES", "1") + assert db_mod.use_model_services(default=False) is True + + +class TestDiscoverModelServices: + def test_buckets_families_by_name(self, monkeypatch): + payload = { + "model_services": [ + _model_service("system.ai.claude-opus-4-7"), + _model_service("system.ai.claude-opus-4-8"), + _model_service("system.ai.claude-sonnet-4-6"), + _model_service("system.ai.gpt-5"), + _model_service("system.ai.gemini-2-5-flash"), + _model_service("system.ai.gemini-3-5-flash"), + _model_service("system.ai.llama-4-maverick"), + ] + } + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=10: (payload, None) + ) + + claude, codex, gemini, reason = db_mod.discover_model_services(WS, "token") + + assert reason is None + # Newest opus wins; sonnet bucketed; haiku absent. + assert claude == { + "opus": "system.ai.claude-opus-4-8", + "sonnet": "system.ai.claude-sonnet-4-6", + } + assert codex == ["system.ai.gpt-5"] + # Gemini ordered newest-first via the shared sort key. + assert gemini[0] == "system.ai.gemini-3-5-flash" + # llama is not bucketed into any of the three families. + assert "system.ai.llama-4-maverick" not in codex + gemini + + def test_paginates_via_next_page_token(self, monkeypatch): + pages = { + None: { + "model_services": [_model_service("system.ai.gpt-5")], + "next_page_token": "tok2", + }, + "tok2": { + "model_services": [_model_service("system.ai.claude-opus-4-8")], + }, + } + + def fake_get(url, token, timeout=10): + token_param = None + if "page_token=" in url: + token_param = url.split("page_token=")[1].split("&")[0] + return pages[token_param], None + + monkeypatch.setattr(db_mod, "_http_get_json", fake_get) + + claude, codex, _, reason = db_mod.discover_model_services(WS, "token") + + assert reason is None + assert codex == ["system.ai.gpt-5"] + assert claude == {"opus": "system.ai.claude-opus-4-8"} + + def test_http_failure_returns_reason(self, monkeypatch): + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=10: (None, "HTTP 500 Server Error") + ) + + claude, codex, gemini, reason = db_mod.discover_model_services(WS, "token") + + assert (claude, codex, gemini) == ({}, [], []) + assert reason == "HTTP 500 Server Error" + + def test_no_matching_families_reports_sample(self, monkeypatch): + payload = {"model_services": [_model_service("system.ai.llama-4-maverick")]} + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=10: (payload, None) + ) + + claude, codex, gemini, reason = db_mod.discover_model_services(WS, "token") + + assert (claude, codex, gemini) == ({}, [], []) + assert reason is not None and "llama-4-maverick" in reason + + def test_ignores_non_system_ai_schemas(self, monkeypatch): + # The metastore listing returns services from every schema; only + # system.ai.* foundation models should be picked up. + payload = { + "model_services": [ + _model_service("system.ai.gpt-5"), + _model_service("main.svenwb.gpt-5-5"), + _model_service("temp.erni.claude-opus-4-8"), + _model_service("dnasi_agent_cuj.default.dnasi-gpt55-test"), + ] + } + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=10: (payload, None) + ) + + claude, codex, gemini, reason = db_mod.discover_model_services(WS, "token") + + assert reason is None + assert codex == ["system.ai.gpt-5"] + assert claude == {} # temp.erni.claude-* must not be bucketed + assert gemini == [] + + def test_retries_page_before_giving_up(self, monkeypatch): + payload = {"model_services": [_model_service("system.ai.gpt-5")]} + calls = {"n": 0} + + def flaky_get(url, token, timeout=10): + calls["n"] += 1 + if calls["n"] < 3: + return None, "HTTP 504 Gateway Timeout" + return payload, None + + monkeypatch.setattr(db_mod, "_http_get_json", flaky_get) + + ids, reason = db_mod.list_model_services(WS, "token") + + assert reason is None + assert ids == ["system.ai.gpt-5"] + assert calls["n"] == 3 # two failures, third succeeds + + def _foundation_models_payload(names): return { "endpoints": [