From c5cf98f004952e43897297a03d0a96dbee0af8a9 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Tue, 21 Apr 2026 10:03:05 +0200 Subject: [PATCH] Route Anthropic models on Vertex AI through the native endpoint Claude models on Google Cloud's Vertex AI Model Garden do not support the OpenAI-compatible /chat/completions endpoint and fail with: 'FAILED_PRECONDITION: The deployed model does not support ChatCompletions.' When provider_opts.publisher == "anthropic", requests are now routed through the Anthropic-native :rawPredict / :streamRawPredict endpoints (with the vertex-2023-10-16 body schema) using the anthropic-sdk-go/vertex subpackage, authenticated via Google Application Default Credentials. Other Model Garden publishers (meta, mistral, ...) continue to use the OpenAI-compatible path, and Gemini-on-Vertex is unchanged (it never enters this code path). - pkg/model/provider/anthropic/vertex.go: new NewVertexClient constructor. - pkg/model/provider/vertexai/modelgarden.go: single NewClient entry that dispatches on publisher and returns a small Client interface satisfied by both anthropic.Client and openai.Client; shared resolveProjectLocation helper with URL-injection-safe validation. - pkg/model/provider/provider.go: one-line dispatch into vertexai.NewClient. - Tests: cover publisher extraction, project/location resolution (env-var fallback, ${VAR} expansion, URL-injection attempts, uppercase rejection). - Docs: docs/providers/google/index.md explains the two endpoint paths. Fixes #2469 Assisted-By: docker-agent --- docs/providers/google/index.md | 10 +- go.mod | 3 + go.sum | 15 ++ pkg/model/provider/anthropic/vertex.go | 101 ++++++++++ pkg/model/provider/provider.go | 2 +- pkg/model/provider/vertexai/modelgarden.go | 158 ++++++++++------ .../provider/vertexai/modelgarden_test.go | 175 ++++++++++++++++++ 7 files changed, 403 insertions(+), 61 deletions(-) create mode 100644 pkg/model/provider/anthropic/vertex.go diff --git a/docs/providers/google/index.md b/docs/providers/google/index.md index 577df4341..e672c2053 100644 --- a/docs/providers/google/index.md +++ b/docs/providers/google/index.md @@ -117,8 +117,14 @@ models: You can use non-Gemini models (e.g. Claude, Llama) hosted on Google Cloud's [Vertex AI Model Garden](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models) through the `google` provider. When a `publisher` is specified in `provider_opts`, -requests are routed through Vertex AI's OpenAI-compatible endpoint instead of the -Gemini SDK. +requests are routed through the appropriate Vertex AI endpoint instead of the +Gemini SDK: + +- **Anthropic Claude** (`publisher: anthropic`) uses the Anthropic-native + `:rawPredict` / `:streamRawPredict` endpoints. Claude models on Vertex AI do + not support the OpenAI `/chat/completions` path. +- **Other publishers** (e.g. `meta`, `mistral`) use Vertex AI's + OpenAI-compatible `/chat/completions` endpoint. ### Authentication diff --git a/go.mod b/go.mod index 6837d4ce2..dc1ff70b9 100644 --- a/go.mod +++ b/go.mod @@ -75,6 +75,7 @@ require ( ) require ( + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23 // indirect github.com/danieljoos/wincred v1.2.2 // indirect @@ -85,6 +86,8 @@ require ( github.com/mtibben/percent v0.2.1 // indirect github.com/pb33f/jsonpath v0.8.2 // indirect github.com/pb33f/ordered-map/v2 v2.3.1 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect + google.golang.org/api v0.252.0 // indirect ) require ( diff --git a/go.sum b/go.sum index 7cfa26cfe..bd2ac8650 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4= cloud.google.com/go/auth v0.17.0/go.mod h1:6wv/t5/6rOPAX4fJiRjKkJCvswLwdet7G8+UGXt7nCQ= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= @@ -164,6 +166,8 @@ github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJ github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= github.com/coder/acp-go-sdk v0.6.3 h1:LsXQytehdjKIYJnoVWON/nf7mqbiarnyuyE3rrjBsXQ= github.com/coder/acp-go-sdk v0.6.3/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= @@ -216,6 +220,11 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= +github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA= +github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= +github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= +github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w= github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -404,6 +413,8 @@ github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxu github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= @@ -506,6 +517,8 @@ go.etcd.io/bbolt v1.4.0 h1:TU77id3TnN/zKr7CO/uk+fBCwF2jGcMuw2B/FMAzYIk= go.etcd.io/bbolt v1.4.0/go.mod h1:AsD+OCi/qPN1giOX1aiLAha3o1U8rAz65bvN4j0sRuk= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= @@ -592,6 +605,8 @@ gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/adk v1.1.0 h1:UNyIb604EWWJTCsmKg5tuo/oaESGiR9DHgFzTICN3zM= google.golang.org/adk v1.1.0/go.mod h1:26tMHp0FXeHshfHACOuH1RZJgtZ06GPaReu20q70KMU= +google.golang.org/api v0.252.0 h1:xfKJeAJaMwb8OC9fesr369rjciQ704AjU/psjkKURSI= +google.golang.org/api v0.252.0/go.mod h1:dnHOv81x5RAmumZ7BWLShB/u7JZNeyalImxHmtTHxqw= google.golang.org/genai v1.54.0 h1:ZQCa70WMTJDI11FdqWCzGvZ5PanpcpfoO6jl/lrSnGU= google.golang.org/genai v1.54.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= diff --git a/pkg/model/provider/anthropic/vertex.go b/pkg/model/provider/anthropic/vertex.go new file mode 100644 index 000000000..ca1ae92eb --- /dev/null +++ b/pkg/model/provider/anthropic/vertex.go @@ -0,0 +1,101 @@ +package anthropic + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/anthropics/anthropic-sdk-go/vertex" + "golang.org/x/oauth2/google" + + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/model/provider/options" +) + +// vertexCloudPlatformScope is the OAuth2 scope required for Vertex AI API access. +const vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" + +// NewVertexClient creates a new Anthropic client that talks to Claude models +// hosted on Google Cloud's Vertex AI via the Anthropic-native endpoints +// (`:rawPredict` and `:streamRawPredict`), authenticated with Google +// Application Default Credentials. +// +// This is required because Anthropic models on Vertex AI do not support the +// OpenAI-compatible `/chat/completions` endpoint and fail with +// `FAILED_PRECONDITION: The deployed model does not support ChatCompletions.` +// +// See: https://docs.anthropic.com/en/api/claude-on-vertex-ai +func NewVertexClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, project, location string, opts ...options.Opt) (*Client, error) { + if cfg == nil { + return nil, errors.New("model configuration is required") + } + if env == nil { + return nil, errors.New("environment provider is required") + } + if project == "" { + return nil, errors.New("vertex AI requires a GCP project") + } + if location == "" { + return nil, errors.New("vertex AI requires a GCP location") + } + + var globalOptions options.ModelOptions + for _, opt := range opts { + if opt != nil { + opt(&globalOptions) + } + } + + // Resolve GCP credentials up front so we can return a descriptive error + // instead of the panic that vertex.WithGoogleAuth would raise. + creds, err := google.FindDefaultCredentials(ctx, vertexCloudPlatformScope) + if err != nil { + return nil, fmt.Errorf("failed to obtain GCP credentials for Vertex AI: %w (run 'gcloud auth application-default login')", err) + } + + slog.Debug("Creating Anthropic client for Vertex AI", + "project", project, + "location", location, + "model", cfg.Model, + ) + + // vertex.WithCredentials configures the base URL, Google-authenticated + // HTTP client, and middleware that rewrites /v1/messages requests to the + // Anthropic-native Vertex AI endpoints (`:rawPredict` / `:streamRawPredict`) + // and injects the `anthropic_version: vertex-2023-10-16` body field. + // + // The explicit option.WithAPIKey("") is REQUIRED (do not remove): the + // anthropic SDK's NewClient applies DefaultClientOptions() first, which + // auto-reads ANTHROPIC_API_KEY from the environment and sets the + // X-Api-Key header. On Vertex AI the request is authenticated with + // OAuth2 (via the google transport in vertex.WithCredentials), so we + // must clear the stray X-Api-Key header that would otherwise leak a + // direct-API credential into Google's infrastructure. + client := anthropic.NewClient( + vertex.WithCredentials(ctx, location, project, creds), + option.WithAPIKey(""), + ) + + anthropicClient := &Client{ + Config: base.Config{ + ModelConfig: *cfg, + ModelOptions: globalOptions, + Env: env, + }, + clientFn: func(context.Context) (anthropic.Client, error) { + return client, nil + }, + } + + // File uploads via Anthropic's Files API are not supported on Vertex AI, + // but the FileManager is lazy and harmless if unused. + anthropicClient.fileManager = NewFileManager(anthropicClient.clientFn) + + slog.Debug("Anthropic (Vertex AI) client created successfully", "model", cfg.Model) + return anthropicClient, nil +} diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index 1d008a59f..eb27964a4 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -244,7 +244,7 @@ func createDirectProvider(ctx context.Context, cfg *latest.ModelConfig, env envi return anthropic.NewClient(ctx, enhancedCfg, env, opts...) case "google": // Route non-Gemini models on Vertex AI (Model Garden) through the - // OpenAI-compatible endpoint instead of the Gemini SDK. + // vertexai package, which picks the right endpoint per publisher. if vertexai.IsModelGardenConfig(enhancedCfg) { return vertexai.NewClient(ctx, enhancedCfg, env, opts...) } diff --git a/pkg/model/provider/vertexai/modelgarden.go b/pkg/model/provider/vertexai/modelgarden.go index a6f6424a9..dba259d22 100644 --- a/pkg/model/provider/vertexai/modelgarden.go +++ b/pkg/model/provider/vertexai/modelgarden.go @@ -1,10 +1,16 @@ // Package vertexai provides support for non-Gemini models hosted on -// Google Cloud's Vertex AI Model Garden via the OpenAI-compatible endpoint. +// Google Cloud's Vertex AI Model Garden. // // Vertex AI Model Garden hosts models from various publishers (Anthropic, -// Meta, Mistral, etc.) and exposes them through an OpenAI-compatible API. -// This package configures the OpenAI provider to talk to that endpoint -// using Google Cloud Application Default Credentials for authentication. +// Meta, Mistral, etc.) and exposes them through two different APIs: +// +// - Anthropic Claude models: the Anthropic-native `:rawPredict` / +// `:streamRawPredict` endpoints. Claude models do not support the +// OpenAI-compatible path. +// - Other publishers: Vertex AI's OpenAI-compatible `/chat/completions` +// endpoint. +// +// Authentication uses Google Cloud Application Default Credentials. // // Usage in agent config: // @@ -31,10 +37,14 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" + "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider/anthropic" + "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/openai" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/tools" ) // cloudPlatformScope is the OAuth2 scope required for Vertex AI API access. @@ -45,36 +55,66 @@ const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" // Locations: lowercase letters, digits, hyphens (e.g. us-central1). var validGCPIdentifier = regexp.MustCompile(`^[a-z][a-z0-9-]{1,29}$`) +// Client is the subset of provider.Provider returned by NewClient. Both +// anthropic.Client and openai.Client satisfy it, so the caller can treat +// the two Model Garden code paths uniformly. +type Client interface { + ID() string + CreateChatCompletionStream(ctx context.Context, messages []chat.Message, tools []tools.Tool) (chat.MessageStream, error) + BaseConfig() base.Config +} + // IsModelGardenConfig returns true when the ModelConfig describes a -// non-Gemini model on Vertex AI (i.e. the "publisher" provider_opt is set). +// non-Gemini model on Vertex AI (i.e. the "publisher" provider_opt is set +// to something other than "google"). func IsModelGardenConfig(cfg *latest.ModelConfig) bool { + p := publisher(cfg) + return p != "" && !strings.EqualFold(p, "google") +} + +// NewClient creates a client for a non-Gemini model on Vertex AI Model Garden, +// choosing the right endpoint based on the publisher: +// +// - publisher: anthropic → Anthropic-native `:streamRawPredict` endpoint. +// - other publishers → Vertex AI's OpenAI-compatible `/chat/completions`. +func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Client, error) { + project, location, err := resolveProjectLocation(ctx, cfg, env) + if err != nil { + return nil, err + } + if strings.EqualFold(publisher(cfg), "anthropic") { + return anthropic.NewVertexClient(ctx, cfg, env, project, location, opts...) + } + return newOpenAIClient(ctx, cfg, env, project, location, opts...) +} + +// publisher returns the provider_opts.publisher string, or "" if unset. +func publisher(cfg *latest.ModelConfig) string { if cfg == nil || cfg.ProviderOpts == nil { - return false + return "" } - publisher, _ := cfg.ProviderOpts["publisher"].(string) - return publisher != "" && !strings.EqualFold(publisher, "google") + p, _ := cfg.ProviderOpts["publisher"].(string) + return p } -// NewClient creates an OpenAI-compatible client pointing at the Vertex AI -// Model Garden endpoint. It uses Google Application Default Credentials -// for authentication. -func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (*openai.Client, error) { - project, _ := cfg.ProviderOpts["project"].(string) - location, _ := cfg.ProviderOpts["location"].(string) - publisher, _ := cfg.ProviderOpts["publisher"].(string) - - // Expand env vars in project/location. - var err error - project, err = environment.Expand(ctx, project, env) - if err != nil { - return nil, fmt.Errorf("expanding project: %w", err) +// resolveProjectLocation reads project and location from provider_opts, falls +// back to GOOGLE_CLOUD_PROJECT / GOOGLE_CLOUD_LOCATION, expands env var +// references, and validates the resulting values. +func resolveProjectLocation(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider) (project, location string, err error) { + if cfg == nil { + return "", "", errors.New("model configuration is required") } - location, err = environment.Expand(ctx, location, env) - if err != nil { - return nil, fmt.Errorf("expanding location: %w", err) + + project, _ = cfg.ProviderOpts["project"].(string) + location, _ = cfg.ProviderOpts["location"].(string) + + if project, err = environment.Expand(ctx, project, env); err != nil { + return "", "", fmt.Errorf("expanding project: %w", err) + } + if location, err = environment.Expand(ctx, location, env); err != nil { + return "", "", fmt.Errorf("expanding location: %w", err) } - // Fall back to environment variables if not set in provider_opts. if project == "" { project, _ = env.Get(ctx, "GOOGLE_CLOUD_PROJECT") } @@ -83,34 +123,38 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } if project == "" { - return nil, errors.New("vertex AI Model Garden requires a GCP project (set provider_opts.project or GOOGLE_CLOUD_PROJECT)") + return "", "", errors.New("vertex AI Model Garden requires a GCP project (set provider_opts.project or GOOGLE_CLOUD_PROJECT)") } if location == "" { - return nil, errors.New("vertex AI Model Garden requires a GCP location (set provider_opts.location or GOOGLE_CLOUD_LOCATION)") + return "", "", errors.New("vertex AI Model Garden requires a GCP location (set provider_opts.location or GOOGLE_CLOUD_LOCATION)") } - // Validate project and location to prevent URL path manipulation. + // Validate to prevent URL path manipulation. if !validGCPIdentifier.MatchString(project) { - return nil, fmt.Errorf("invalid GCP project ID: %q", project) + return "", "", fmt.Errorf("invalid GCP project ID: %q", project) } if !validGCPIdentifier.MatchString(location) { - return nil, fmt.Errorf("invalid GCP location: %q", location) + return "", "", fmt.Errorf("invalid GCP location: %q", location) } - // Build the base URL for the OpenAI-compatible endpoint. + return project, location, nil +} + +// newOpenAIClient creates a client pointing at Vertex AI's OpenAI-compatible +// endpoint. It uses Google Application Default Credentials for authentication. +func newOpenAIClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, project, location string, opts ...options.Opt) (*openai.Client, error) { // https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#openai_sdk baseURL := "https://" + location + "-aiplatform.googleapis.com/v1beta1/projects/" + url.PathEscape(project) + "/locations/" + url.PathEscape(location) + "/endpoints/openapi" slog.Debug("Creating Vertex AI Model Garden client", - "publisher", publisher, + "publisher", publisher(cfg), "project", project, "location", location, "model", cfg.Model, "base_url", baseURL, ) - // Get a GCP access token using Application Default Credentials. tokenSource, err := google.DefaultTokenSource(ctx, cloudPlatformScope) if err != nil { return nil, fmt.Errorf("failed to obtain GCP credentials for Vertex AI: %w (run 'gcloud auth application-default login')", err) @@ -120,26 +164,24 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return nil, fmt.Errorf("failed to get GCP access token: %w", err) } - // Build a modified config that the OpenAI provider can use. - // We override the base URL and set the token directly. + // Build a config for the OpenAI provider with the Vertex base URL and a + // synthetic token env var that the wrapping env provider resolves to a + // fresh GCP access token. + const tokenEnvVar = "_VERTEX_AI_ACCESS_TOKEN" oaiCfg := cfg.Clone() oaiCfg.BaseURL = baseURL - // Use a synthetic token key env var — we'll set it in a wrapper env provider. - const tokenEnvVar = "_VERTEX_AI_ACCESS_TOKEN" oaiCfg.TokenKey = tokenEnvVar - // Remove provider_opts that are specific to Vertex AI / not relevant for OpenAI. - delete(oaiCfg.ProviderOpts, "project") - delete(oaiCfg.ProviderOpts, "location") - delete(oaiCfg.ProviderOpts, "publisher") - - // Force chat completions API type (Vertex AI OpenAI endpoint uses this). + // Strip Vertex-specific provider_opts before handing off to the OpenAI + // provider, and force the chat-completions API type. if oaiCfg.ProviderOpts == nil { oaiCfg.ProviderOpts = map[string]any{} } + delete(oaiCfg.ProviderOpts, "project") + delete(oaiCfg.ProviderOpts, "location") + delete(oaiCfg.ProviderOpts, "publisher") oaiCfg.ProviderOpts["api_type"] = "openai_chatcompletions" - // Wrap the environment provider to inject the GCP access token. wrappedEnv := &tokenEnv{ Provider: env, key: tokenEnvVar, @@ -150,8 +192,8 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return openai.NewClient(ctx, oaiCfg, wrappedEnv, opts...) } -// tokenEnv wraps an environment.Provider to inject a GCP access token. -// It refreshes the token on each Get call to handle token expiry. +// tokenEnv wraps an environment.Provider to inject a GCP access token, +// refreshing it on each Get call (TokenSource handles caching internally). type tokenEnv struct { environment.Provider @@ -162,18 +204,18 @@ type tokenEnv struct { } func (e *tokenEnv) Get(ctx context.Context, name string) (string, bool) { - if name == e.key { - e.mu.Lock() - defer e.mu.Unlock() - - // Refresh token if needed — TokenSource handles caching. - tok, err := e.ts.Token() - if err != nil { - slog.Warn("Failed to refresh GCP access token, using cached", "error", err) - return e.tok, true - } - e.tok = tok.AccessToken + if name != e.key { + return e.Provider.Get(ctx, name) + } + + e.mu.Lock() + defer e.mu.Unlock() + + tok, err := e.ts.Token() + if err != nil { + slog.Warn("Failed to refresh GCP access token, using cached", "error", err) return e.tok, true } - return e.Provider.Get(ctx, name) + e.tok = tok.AccessToken + return e.tok, true } diff --git a/pkg/model/provider/vertexai/modelgarden_test.go b/pkg/model/provider/vertexai/modelgarden_test.go index 0be8764f9..5e9a3b878 100644 --- a/pkg/model/provider/vertexai/modelgarden_test.go +++ b/pkg/model/provider/vertexai/modelgarden_test.go @@ -1,9 +1,11 @@ package vertexai import ( + "strings" "testing" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" ) func TestIsModelGardenConfig(t *testing.T) { @@ -70,6 +72,35 @@ func TestIsModelGardenConfig(t *testing.T) { } } +func TestPublisher(t *testing.T) { + tests := []struct { + name string + cfg *latest.ModelConfig + want string + }{ + {name: "nil config", cfg: nil, want: ""}, + {name: "no provider_opts", cfg: &latest.ModelConfig{}, want: ""}, + { + name: "anthropic", + cfg: &latest.ModelConfig{ProviderOpts: map[string]any{"publisher": "anthropic"}}, + want: "anthropic", + }, + { + name: "non-string value", + cfg: &latest.ModelConfig{ProviderOpts: map[string]any{"publisher": 42}}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := publisher(tt.cfg); got != tt.want { + t.Errorf("publisher() = %q, want %q", got, tt.want) + } + }) + } +} + func TestValidGCPIdentifier(t *testing.T) { valid := []string{"my-project", "us-central1", "project123", "ab"} for _, s := range valid { @@ -85,3 +116,147 @@ func TestValidGCPIdentifier(t *testing.T) { } } } + +func TestResolveProjectLocation(t *testing.T) { + tests := []struct { + name string + cfg *latest.ModelConfig + env map[string]string + wantProject string + wantLoc string + wantErrSub string // substring expected in the error message; empty means no error + }{ + { + name: "nil config", + cfg: nil, + wantErrSub: "model configuration is required", + }, + { + name: "from provider_opts", + cfg: &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "project": "my-project", + "location": "us-central1", + }, + }, + wantProject: "my-project", + wantLoc: "us-central1", + }, + { + name: "from env vars", + cfg: &latest.ModelConfig{}, + env: map[string]string{ + "GOOGLE_CLOUD_PROJECT": "env-project", + "GOOGLE_CLOUD_LOCATION": "europe-west1", + }, + wantProject: "env-project", + wantLoc: "europe-west1", + }, + { + name: "provider_opts wins over env", + cfg: &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "project": "opts-project", + "location": "us-east5", + }, + }, + env: map[string]string{ + "GOOGLE_CLOUD_PROJECT": "env-project", + "GOOGLE_CLOUD_LOCATION": "europe-west1", + }, + wantProject: "opts-project", + wantLoc: "us-east5", + }, + { + name: "env var expansion in provider_opts", + cfg: &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "project": "${MY_PROJECT}", + "location": "${MY_LOC}", + }, + }, + env: map[string]string{ + "MY_PROJECT": "expanded-project", + "MY_LOC": "us-central1", + }, + wantProject: "expanded-project", + wantLoc: "us-central1", + }, + { + name: "unset env var in expansion fails", + cfg: &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "project": "${MISSING}", + "location": "us-central1", + }, + }, + wantErrSub: "expanding project", + }, + { + name: "missing project", + cfg: &latest.ModelConfig{ProviderOpts: map[string]any{"location": "us-central1"}}, + wantErrSub: "requires a GCP project", + }, + { + name: "missing location", + cfg: &latest.ModelConfig{ProviderOpts: map[string]any{"project": "my-project"}}, + wantErrSub: "requires a GCP location", + }, + { + name: "url-injection attempt in project", + cfg: &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "project": "../../evil", + "location": "us-central1", + }, + }, + wantErrSub: "invalid GCP project ID", + }, + { + name: "url-injection attempt in location", + cfg: &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "project": "my-project", + "location": "us-central1/../evil", + }, + }, + wantErrSub: "invalid GCP location", + }, + { + name: "uppercase rejected", + cfg: &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "project": "MY-PROJECT", + "location": "us-central1", + }, + }, + wantErrSub: "invalid GCP project ID", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + env := environment.NewMapEnvProvider(tt.env) + gotProject, gotLoc, err := resolveProjectLocation(t.Context(), tt.cfg, env) + + if tt.wantErrSub != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrSub) + } + if !strings.Contains(err.Error(), tt.wantErrSub) { + t.Fatalf("expected error containing %q, got %q", tt.wantErrSub, err.Error()) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotProject != tt.wantProject { + t.Errorf("project = %q, want %q", gotProject, tt.wantProject) + } + if gotLoc != tt.wantLoc { + t.Errorf("location = %q, want %q", gotLoc, tt.wantLoc) + } + }) + } +}