diff --git a/cmd/thv-operator/api/v1beta1/virtualmcpserver_types.go b/cmd/thv-operator/api/v1beta1/virtualmcpserver_types.go index 32e9c795f9..d2b37439a7 100644 --- a/cmd/thv-operator/api/v1beta1/virtualmcpserver_types.go +++ b/cmd/thv-operator/api/v1beta1/virtualmcpserver_types.go @@ -19,6 +19,7 @@ import ( // +kubebuilder:validation:XValidation:rule="!has(self.config) || !has(self.config.rateLimiting) || (has(self.sessionStorage) && self.sessionStorage.provider == 'redis')",message="config.rateLimiting requires sessionStorage with provider 'redis'" // +kubebuilder:validation:XValidation:rule="!(has(self.config) && has(self.config.rateLimiting) && has(self.config.rateLimiting.perUser)) || (has(self.incomingAuth) && self.incomingAuth.type == 'oidc')",message="config.rateLimiting.perUser requires incomingAuth.type oidc" // +kubebuilder:validation:XValidation:rule="!has(self.config) || !has(self.config.rateLimiting) || !has(self.config.rateLimiting.tools) || self.config.rateLimiting.tools.all(t, !has(t.perUser)) || (has(self.incomingAuth) && self.incomingAuth.type == 'oidc')",message="per-tool perUser rate limiting requires incomingAuth.type oidc" +// +kubebuilder:validation:XValidation:rule="!(has(self.embeddingServerRef) && has(self.config) && has(self.config.optimizer) && has(self.config.optimizer.embeddingProvider) && self.config.optimizer.embeddingProvider == 'openai')",message="embeddingServerRef provisions a managed TEI server and cannot be combined with optimizer.embeddingProvider 'openai'; openai mode uses embeddingService directly" // //nolint:lll // CEL validation rules exceed line length limit type VirtualMCPServerSpec struct { diff --git a/cmd/thv-operator/test-integration/virtualmcp/virtualmcpserver_embedding_cel_test.go b/cmd/thv-operator/test-integration/virtualmcp/virtualmcpserver_embedding_cel_test.go new file mode 100644 index 0000000000..1512d7adf3 --- /dev/null +++ b/cmd/thv-operator/test-integration/virtualmcp/virtualmcpserver_embedding_cel_test.go @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package controllers contains integration tests for the VirtualMCPServer controller +package controllers + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + mcpv1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1" + "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1/v1beta1test" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +func newVirtualMCPServerWithOptimizer(name string, optimizer *vmcpconfig.OptimizerConfig, + opts ...v1beta1test.VirtualMCPServerOption) *mcpv1beta1.VirtualMCPServer { + base := []v1beta1test.VirtualMCPServerOption{ + v1beta1test.WithVMCPGroupRef("test-group"), + v1beta1test.WithVMCPIncomingAuth(&mcpv1beta1.IncomingAuthConfig{Type: "anonymous"}), + v1beta1test.WithVMCPConfig(vmcpconfig.Config{Group: "test-group", Optimizer: optimizer}), + } + return v1beta1test.NewVirtualMCPServer(name, "default", append(base, opts...)...) +} + +var _ = Describe("CEL Validation for embedding provider on VirtualMCPServer", + Label("k8s", "cel", "validation"), func() { + It("should reject embeddingServerRef combined with embeddingProvider openai", func() { + vmcp := newVirtualMCPServerWithOptimizer("vmcp-ref-openai", + &vmcpconfig.OptimizerConfig{EmbeddingProvider: "openai", EmbeddingModel: "text-embedding-3-small"}, + v1beta1test.WithVMCPEmbeddingServerRef("managed-tei")) + err := k8sClient.Create(ctx, vmcp) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring( + "embeddingServerRef provisions a managed TEI server and cannot be combined with optimizer.embeddingProvider 'openai'")) + }) + + It("should accept embeddingServerRef with the default (tei) provider", func() { + vmcp := newVirtualMCPServerWithOptimizer("vmcp-ref-tei", + &vmcpconfig.OptimizerConfig{EmbeddingProvider: "tei"}, + v1beta1test.WithVMCPEmbeddingServerRef("managed-tei")) + err := k8sClient.Create(ctx, vmcp) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should accept embeddingProvider openai without an embeddingServerRef", func() { + vmcp := newVirtualMCPServerWithOptimizer("vmcp-openai-no-ref", + &vmcpconfig.OptimizerConfig{ + EmbeddingProvider: "openai", + EmbeddingService: "http://gateway.example:8080", + EmbeddingModel: "text-embedding-3-small", + }) + err := k8sClient.Create(ctx, vmcp) + Expect(err).NotTo(HaveOccurred()) + }) + }) diff --git a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml index 035a10a98b..d075d36855 100644 --- a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -1792,6 +1792,34 @@ spec: instead of all backend tools directly. This reduces token usage by allowing LLMs to discover relevant tools on demand rather than receiving all tool definitions. properties: + embeddingModel: + description: |- + EmbeddingModel is the model name requested from the embedding service + (e.g. "text-embedding-3-small"). Required when EmbeddingProvider is + "openai". Ignored for the "tei" provider, where the model is fixed by the + running TEI container. + + The API key for an OpenAI-compatible service is not configured here: it is + read from the OPENAI_API_KEY environment variable so the secret never + lands in a CRD spec or ConfigMap. An empty key omits the Authorization + header, which supports keyless in-cluster gateways. + type: string + embeddingProvider: + default: tei + description: |- + EmbeddingProvider selects the wire protocol used to talk to the embedding + service. "tei" speaks the HuggingFace Text Embeddings Inference API; + "openai" speaks the OpenAI-compatible /embeddings API, which lets the + optimizer use OpenAI, Azure OpenAI, or another OpenAI-compatible gateway. + Defaults to "tei" when empty. + + The "openai" provider reads EmbeddingService directly and cannot be combined + with EmbeddingServerRef, which provisions a managed TEI server; the operator + rejects that combination at admission. + enum: + - tei + - openai + type: string embeddingService: description: |- EmbeddingService is the full base URL of the embedding service endpoint @@ -2936,6 +2964,12 @@ spec: rule: '!has(self.config) || !has(self.config.rateLimiting) || !has(self.config.rateLimiting.tools) || self.config.rateLimiting.tools.all(t, !has(t.perUser)) || (has(self.incomingAuth) && self.incomingAuth.type == ''oidc'')' + - message: embeddingServerRef provisions a managed TEI server and cannot + be combined with optimizer.embeddingProvider 'openai'; openai mode + uses embeddingService directly + rule: '!(has(self.embeddingServerRef) && has(self.config) && has(self.config.optimizer) + && has(self.config.optimizer.embeddingProvider) && self.config.optimizer.embeddingProvider + == ''openai'')' status: description: VirtualMCPServerStatus defines the observed state of VirtualMCPServer properties: @@ -4880,6 +4914,34 @@ spec: instead of all backend tools directly. This reduces token usage by allowing LLMs to discover relevant tools on demand rather than receiving all tool definitions. properties: + embeddingModel: + description: |- + EmbeddingModel is the model name requested from the embedding service + (e.g. "text-embedding-3-small"). Required when EmbeddingProvider is + "openai". Ignored for the "tei" provider, where the model is fixed by the + running TEI container. + + The API key for an OpenAI-compatible service is not configured here: it is + read from the OPENAI_API_KEY environment variable so the secret never + lands in a CRD spec or ConfigMap. An empty key omits the Authorization + header, which supports keyless in-cluster gateways. + type: string + embeddingProvider: + default: tei + description: |- + EmbeddingProvider selects the wire protocol used to talk to the embedding + service. "tei" speaks the HuggingFace Text Embeddings Inference API; + "openai" speaks the OpenAI-compatible /embeddings API, which lets the + optimizer use OpenAI, Azure OpenAI, or another OpenAI-compatible gateway. + Defaults to "tei" when empty. + + The "openai" provider reads EmbeddingService directly and cannot be combined + with EmbeddingServerRef, which provisions a managed TEI server; the operator + rejects that combination at admission. + enum: + - tei + - openai + type: string embeddingService: description: |- EmbeddingService is the full base URL of the embedding service endpoint @@ -6024,6 +6086,12 @@ spec: rule: '!has(self.config) || !has(self.config.rateLimiting) || !has(self.config.rateLimiting.tools) || self.config.rateLimiting.tools.all(t, !has(t.perUser)) || (has(self.incomingAuth) && self.incomingAuth.type == ''oidc'')' + - message: embeddingServerRef provisions a managed TEI server and cannot + be combined with optimizer.embeddingProvider 'openai'; openai mode + uses embeddingService directly + rule: '!(has(self.embeddingServerRef) && has(self.config) && has(self.config.optimizer) + && has(self.config.optimizer.embeddingProvider) && self.config.optimizer.embeddingProvider + == ''openai'')' status: description: VirtualMCPServerStatus defines the observed state of VirtualMCPServer properties: diff --git a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml index d23ab42957..b23002cd86 100644 --- a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -1795,6 +1795,34 @@ spec: instead of all backend tools directly. This reduces token usage by allowing LLMs to discover relevant tools on demand rather than receiving all tool definitions. properties: + embeddingModel: + description: |- + EmbeddingModel is the model name requested from the embedding service + (e.g. "text-embedding-3-small"). Required when EmbeddingProvider is + "openai". Ignored for the "tei" provider, where the model is fixed by the + running TEI container. + + The API key for an OpenAI-compatible service is not configured here: it is + read from the OPENAI_API_KEY environment variable so the secret never + lands in a CRD spec or ConfigMap. An empty key omits the Authorization + header, which supports keyless in-cluster gateways. + type: string + embeddingProvider: + default: tei + description: |- + EmbeddingProvider selects the wire protocol used to talk to the embedding + service. "tei" speaks the HuggingFace Text Embeddings Inference API; + "openai" speaks the OpenAI-compatible /embeddings API, which lets the + optimizer use OpenAI, Azure OpenAI, or another OpenAI-compatible gateway. + Defaults to "tei" when empty. + + The "openai" provider reads EmbeddingService directly and cannot be combined + with EmbeddingServerRef, which provisions a managed TEI server; the operator + rejects that combination at admission. + enum: + - tei + - openai + type: string embeddingService: description: |- EmbeddingService is the full base URL of the embedding service endpoint @@ -2939,6 +2967,12 @@ spec: rule: '!has(self.config) || !has(self.config.rateLimiting) || !has(self.config.rateLimiting.tools) || self.config.rateLimiting.tools.all(t, !has(t.perUser)) || (has(self.incomingAuth) && self.incomingAuth.type == ''oidc'')' + - message: embeddingServerRef provisions a managed TEI server and cannot + be combined with optimizer.embeddingProvider 'openai'; openai mode + uses embeddingService directly + rule: '!(has(self.embeddingServerRef) && has(self.config) && has(self.config.optimizer) + && has(self.config.optimizer.embeddingProvider) && self.config.optimizer.embeddingProvider + == ''openai'')' status: description: VirtualMCPServerStatus defines the observed state of VirtualMCPServer properties: @@ -4883,6 +4917,34 @@ spec: instead of all backend tools directly. This reduces token usage by allowing LLMs to discover relevant tools on demand rather than receiving all tool definitions. properties: + embeddingModel: + description: |- + EmbeddingModel is the model name requested from the embedding service + (e.g. "text-embedding-3-small"). Required when EmbeddingProvider is + "openai". Ignored for the "tei" provider, where the model is fixed by the + running TEI container. + + The API key for an OpenAI-compatible service is not configured here: it is + read from the OPENAI_API_KEY environment variable so the secret never + lands in a CRD spec or ConfigMap. An empty key omits the Authorization + header, which supports keyless in-cluster gateways. + type: string + embeddingProvider: + default: tei + description: |- + EmbeddingProvider selects the wire protocol used to talk to the embedding + service. "tei" speaks the HuggingFace Text Embeddings Inference API; + "openai" speaks the OpenAI-compatible /embeddings API, which lets the + optimizer use OpenAI, Azure OpenAI, or another OpenAI-compatible gateway. + Defaults to "tei" when empty. + + The "openai" provider reads EmbeddingService directly and cannot be combined + with EmbeddingServerRef, which provisions a managed TEI server; the operator + rejects that combination at admission. + enum: + - tei + - openai + type: string embeddingService: description: |- EmbeddingService is the full base URL of the embedding service endpoint @@ -6027,6 +6089,12 @@ spec: rule: '!has(self.config) || !has(self.config.rateLimiting) || !has(self.config.rateLimiting.tools) || self.config.rateLimiting.tools.all(t, !has(t.perUser)) || (has(self.incomingAuth) && self.incomingAuth.type == ''oidc'')' + - message: embeddingServerRef provisions a managed TEI server and cannot + be combined with optimizer.embeddingProvider 'openai'; openai mode + uses embeddingService directly + rule: '!(has(self.embeddingServerRef) && has(self.config) && has(self.config.optimizer) + && has(self.config.optimizer.embeddingProvider) && self.config.optimizer.embeddingProvider + == ''openai'')' status: description: VirtualMCPServerStatus defines the observed state of VirtualMCPServer properties: diff --git a/docs/operator/crd-api.md b/docs/operator/crd-api.md index a1cc135267..6072c90e35 100644 --- a/docs/operator/crd-api.md +++ b/docs/operator/crd-api.md @@ -502,6 +502,8 @@ _Appears in:_ | --- | --- | --- | --- | | `embeddingService` _string_ | EmbeddingService is the full base URL of the embedding service endpoint
(e.g., http://my-embedding.default.svc.cluster.local:8080) for semantic
tool discovery.
In a Kubernetes environment, it is more convenient to use the
VirtualMCPServerSpec.EmbeddingServerRef field instead of setting this
directly. EmbeddingServerRef references an EmbeddingServer CRD by name,
and the operator automatically resolves the referenced resource's
Status.URL to populate this field. This provides managed lifecycle
(the operator watches the EmbeddingServer for readiness and URL changes)
and avoids hardcoding service URLs in the config. If both
EmbeddingServerRef and this field are set, EmbeddingServerRef takes
precedence and this value is overridden with a warning. | | Optional: \{\}
| | `embeddingServiceTimeout` _[vmcp.config.Duration](#vmcpconfigduration)_ | EmbeddingServiceTimeout is the HTTP request timeout for calls to the embedding service.
Defaults to 30s if not specified. | 30s | Pattern: `^([0-9]+(\.[0-9]+)?(ns\|us\|µs\|ms\|s\|m\|h))+$`
Type: string
Optional: \{\}
| +| `embeddingProvider` _string_ | EmbeddingProvider selects the wire protocol used to talk to the embedding
service. "tei" speaks the HuggingFace Text Embeddings Inference API;
"openai" speaks the OpenAI-compatible /embeddings API, which lets the
optimizer use OpenAI, Azure OpenAI, or another OpenAI-compatible gateway.
Defaults to "tei" when empty.
The "openai" provider reads EmbeddingService directly and cannot be combined
with EmbeddingServerRef, which provisions a managed TEI server; the operator
rejects that combination at admission. | tei | Enum: [tei openai]
Optional: \{\}
| +| `embeddingModel` _string_ | EmbeddingModel is the model name requested from the embedding service
(e.g. "text-embedding-3-small"). Required when EmbeddingProvider is
"openai". Ignored for the "tei" provider, where the model is fixed by the
running TEI container.
The API key for an OpenAI-compatible service is not configured here: it is
read from the OPENAI_API_KEY environment variable so the secret never
lands in a CRD spec or ConfigMap. An empty key omits the Authorization
header, which supports keyless in-cluster gateways. | | Optional: \{\}
| | `maxToolsToReturn` _integer_ | MaxToolsToReturn is the maximum number of tool results returned by a search query.
Defaults to 8 if not specified or zero. | | Maximum: 50
Minimum: 1
Optional: \{\}
| | `hybridSearchSemanticRatio` _string_ | HybridSearchSemanticRatio controls the balance between semantic (meaning-based)
and keyword search results. 0.0 = all keyword, 1.0 = all semantic.
Defaults to "0.5" if not specified or empty.
Serialized as a string because CRDs do not support float types portably. | | Pattern: `^([0-9]*[.])?[0-9]+$`
Optional: \{\}
| | `semanticDistanceThreshold` _string_ | SemanticDistanceThreshold is the maximum distance for semantic search results.
Results exceeding this threshold are filtered out from semantic search.
This threshold does not apply to keyword search.
Range: 0 = identical, 2 = completely unrelated.
Defaults to "1.0" if not specified or empty.
Serialized as a string because CRDs do not support float types portably. | | Pattern: `^([0-9]*[.])?[0-9]+$`
Optional: \{\}
| diff --git a/examples/operator/virtual-mcps/vmcp_optimizer_openai.yaml b/examples/operator/virtual-mcps/vmcp_optimizer_openai.yaml new file mode 100644 index 0000000000..b51a740c06 --- /dev/null +++ b/examples/operator/virtual-mcps/vmcp_optimizer_openai.yaml @@ -0,0 +1,87 @@ +# Example: VirtualMCPServer optimizer using an OpenAI-compatible embedding API +# +# Instead of a managed TEI EmbeddingServer, this points the optimizer at an +# external service that speaks the OpenAI /embeddings API — OpenAI, Azure +# OpenAI, or another OpenAI-compatible gateway. There is no EmbeddingServer +# or embeddingServerRef: the endpoint is reached directly via embeddingService. +# +# The API key is read from the OPENAI_API_KEY environment variable so it never +# lands in the CRD spec or the generated ConfigMap. Inject it into the vmcp +# container from a Secret via podTemplateSpec (omit it for keyless gateways). +# +# Note: unlike the TEI backend, the OpenAI API does not silently truncate +# over-long inputs; a tool description exceeding the model's context window +# returns an error rather than being truncated. +# +# Usage: +# kubectl apply -f vmcp_optimizer_openai.yaml + +--- +apiVersion: toolhive.stacklok.dev/v1beta1 +kind: MCPGroup +metadata: + name: optimizer-services + namespace: default +spec: + description: Backend services for an OpenAI-embedding optimizer + +--- +apiVersion: toolhive.stacklok.dev/v1beta1 +kind: MCPServer +metadata: + name: fetch + namespace: default +spec: + groupRef: + name: optimizer-services + image: ghcr.io/stackloklabs/gofetch/server + transport: streamable-http + proxyPort: 8080 + mcpPort: 8080 + +--- +# Secret holding the embedding API key. Omit for keyless in-cluster gateways. +apiVersion: v1 +kind: Secret +metadata: + name: embedding-api-key + namespace: default +type: Opaque +stringData: + apiKey: "sk-replace-me" + +--- +apiVersion: toolhive.stacklok.dev/v1beta1 +kind: VirtualMCPServer +metadata: + name: optimizer-vmcp + namespace: default +spec: + groupRef: + name: optimizer-services + config: + optimizer: + # Speak the OpenAI /embeddings API instead of TEI. + embeddingProvider: openai + # Base URL of the OpenAI-compatible service; "/embeddings" is appended. + embeddingService: http://llm-gateway.default.svc.cluster.local:8080/v1 + # Model requested from the service (required for the openai provider). + embeddingModel: text-embedding-3-small + embeddingServiceTimeout: 15s + + incomingAuth: + type: anonymous + outgoingAuth: + source: discovered + + # Inject the API key into the vmcp container as OPENAI_API_KEY. + podTemplateSpec: + spec: + containers: + - name: vmcp + env: + - name: OPENAI_API_KEY + valueFrom: + secretKeyRef: + name: embedding-api-key + key: apiKey diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go index de91d6cd49..669e3665ed 100644 --- a/pkg/vmcp/config/config.go +++ b/pkg/vmcp/config/config.go @@ -933,6 +933,32 @@ type OptimizerConfig struct { // +optional EmbeddingServiceTimeout Duration `json:"embeddingServiceTimeout,omitempty" yaml:"embeddingServiceTimeout,omitempty"` + // EmbeddingProvider selects the wire protocol used to talk to the embedding + // service. "tei" speaks the HuggingFace Text Embeddings Inference API; + // "openai" speaks the OpenAI-compatible /embeddings API, which lets the + // optimizer use OpenAI, Azure OpenAI, or another OpenAI-compatible gateway. + // Defaults to "tei" when empty. + // + // The "openai" provider reads EmbeddingService directly and cannot be combined + // with EmbeddingServerRef, which provisions a managed TEI server; the operator + // rejects that combination at admission. + // +kubebuilder:validation:Enum=tei;openai + // +kubebuilder:default="tei" + // +optional + EmbeddingProvider string `json:"embeddingProvider,omitempty" yaml:"embeddingProvider,omitempty"` + + // EmbeddingModel is the model name requested from the embedding service + // (e.g. "text-embedding-3-small"). Required when EmbeddingProvider is + // "openai". Ignored for the "tei" provider, where the model is fixed by the + // running TEI container. + // + // The API key for an OpenAI-compatible service is not configured here: it is + // read from the OPENAI_API_KEY environment variable so the secret never + // lands in a CRD spec or ConfigMap. An empty key omits the Authorization + // header, which supports keyless in-cluster gateways. + // +optional + EmbeddingModel string `json:"embeddingModel,omitempty" yaml:"embeddingModel,omitempty"` + // MaxToolsToReturn is the maximum number of tool results returned by a search query. // Defaults to 8 if not specified or zero. // +kubebuilder:validation:Minimum=1 diff --git a/pkg/vmcp/optimizer/internal/similarity/embedding_client.go b/pkg/vmcp/optimizer/internal/similarity/embedding_client.go new file mode 100644 index 0000000000..1d94d206de --- /dev/null +++ b/pkg/vmcp/optimizer/internal/similarity/embedding_client.go @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package similarity + +import ( + "fmt" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/types" +) + +// NewEmbeddingClient creates an EmbeddingClient from the given optimizer +// configuration, selecting the backend implementation from EmbeddingProvider. +// It returns (nil, nil) if cfg is nil or no embedding service URL is configured, +// meaning semantic search will be disabled. +func NewEmbeddingClient(cfg *types.OptimizerConfig) (types.EmbeddingClient, error) { + if cfg == nil || cfg.EmbeddingService == "" { + return nil, nil + } + + switch cfg.EmbeddingProvider { + case "", types.EmbeddingProviderTEI: + return newTEIClient(cfg.EmbeddingService, cfg.EmbeddingServiceTimeout) + case types.EmbeddingProviderOpenAI: + return newOpenAIClient(cfg.EmbeddingService, cfg.EmbeddingModel, cfg.EmbeddingAPIKey, cfg.EmbeddingServiceTimeout) + default: + return nil, fmt.Errorf("unsupported embedding provider %q (supported: %q, %q)", + cfg.EmbeddingProvider, types.EmbeddingProviderTEI, types.EmbeddingProviderOpenAI) + } +} diff --git a/pkg/vmcp/optimizer/internal/similarity/embedding_client_test.go b/pkg/vmcp/optimizer/internal/similarity/embedding_client_test.go new file mode 100644 index 0000000000..f38a5302be --- /dev/null +++ b/pkg/vmcp/optimizer/internal/similarity/embedding_client_test.go @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package similarity + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/types" +) + +func TestNewEmbeddingClient(t *testing.T) { + t.Parallel() + + // TEI selection queries the /info endpoint on construction, so a stub server + // is needed for that case. + teiInfo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == infoPath { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"max_client_batch_size": 16}`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + t.Cleanup(teiInfo.Close) + + t.Run("nil config disables semantic search", func(t *testing.T) { + t.Parallel() + client, err := NewEmbeddingClient(nil) + require.NoError(t, err) + require.Nil(t, client) + }) + + t.Run("empty service disables semantic search", func(t *testing.T) { + t.Parallel() + client, err := NewEmbeddingClient(&types.OptimizerConfig{EmbeddingProvider: types.EmbeddingProviderOpenAI}) + require.NoError(t, err) + require.Nil(t, client) + }) + + t.Run("empty provider defaults to TEI", func(t *testing.T) { + t.Parallel() + client, err := NewEmbeddingClient(&types.OptimizerConfig{EmbeddingService: teiInfo.URL}) + require.NoError(t, err) + require.IsType(t, &teiClient{}, client) + }) + + t.Run("tei provider", func(t *testing.T) { + t.Parallel() + client, err := NewEmbeddingClient(&types.OptimizerConfig{ + EmbeddingService: teiInfo.URL, + EmbeddingProvider: types.EmbeddingProviderTEI, + }) + require.NoError(t, err) + require.IsType(t, &teiClient{}, client) + }) + + t.Run("openai provider", func(t *testing.T) { + t.Parallel() + client, err := NewEmbeddingClient(&types.OptimizerConfig{ + EmbeddingService: "http://embeddings:8080/v1", + EmbeddingProvider: types.EmbeddingProviderOpenAI, + EmbeddingModel: "text-embedding-3-small", + }) + require.NoError(t, err) + require.IsType(t, &openAIClient{}, client) + }) + + t.Run("unsupported provider returns error", func(t *testing.T) { + t.Parallel() + client, err := NewEmbeddingClient(&types.OptimizerConfig{ + EmbeddingService: "http://embeddings:8080", + EmbeddingProvider: "cohere", + }) + require.ErrorContains(t, err, "unsupported embedding provider") + require.Nil(t, client) + }) +} diff --git a/pkg/vmcp/optimizer/internal/similarity/openai_client.go b/pkg/vmcp/optimizer/internal/similarity/openai_client.go new file mode 100644 index 0000000000..c9f3d1fdec --- /dev/null +++ b/pkg/vmcp/optimizer/internal/similarity/openai_client.go @@ -0,0 +1,176 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package similarity + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "time" +) + +const ( + embeddingsPath = "/embeddings" + + // openAIMaxBatchSize is the OpenAI cap on inputs per /embeddings request; + // compatible gateways generally honor the same limit. + openAIMaxBatchSize = 2048 +) + +// openAIClient implements types.EmbeddingClient against an OpenAI-compatible +// /embeddings API (OpenAI, Azure OpenAI, or another OpenAI-compatible gateway). +type openAIClient struct { + baseURL string + apiKey string + model string + httpClient *http.Client + maxBatchSize int +} + +// newOpenAIClient creates a client that POSTs to baseURL+"/embeddings" using the +// given model. A non-empty apiKey is sent as a Bearer token; an empty apiKey +// omits the Authorization header so keyless endpoints work. Zero timeout uses +// defaultTimeout. +func newOpenAIClient(baseURL, model, apiKey string, timeout time.Duration) (*openAIClient, error) { + if baseURL == "" { + return nil, fmt.Errorf("OpenAI embedding base URL is required") + } + if model == "" { + return nil, fmt.Errorf("OpenAI embedding model is required") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + if timeout == 0 { + timeout = defaultTimeout + } + + slog.Debug("OpenAI embedding client created", + "base_url", baseURL, "model", model, "timeout", timeout) + + return &openAIClient{ + baseURL: baseURL, + apiKey: apiKey, + model: model, + httpClient: &http.Client{Timeout: timeout}, + maxBatchSize: openAIMaxBatchSize, + }, nil +} + +type openAIEmbedRequest struct { + Model string `json:"model"` + Input []string `json:"input"` + // EncodingFormat pins the response to float arrays, since we decode into + // []float32; without it a compatible server may return base64. + EncodingFormat string `json:"encoding_format"` +} + +type openAIEmbedResponse struct { + Data []openAIEmbedding `json:"data"` +} + +type openAIEmbedding struct { + Index int `json:"index"` + Embedding []float32 `json:"embedding"` +} + +// Embed returns a vector embedding for the given text. +func (c *openAIClient) Embed(ctx context.Context, text string) ([]float32, error) { + results, err := c.EmbedBatch(ctx, []string{text}) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, fmt.Errorf("OpenAI returned empty response for single input") + } + return results[0], nil +} + +// EmbedBatch returns embeddings for multiple texts, chunking to respect the +// OpenAI /embeddings input batch size. +func (c *openAIClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + + allEmbeddings := make([][]float32, 0, len(texts)) + + for start := 0; start < len(texts); start += c.maxBatchSize { + end := min(start+c.maxBatchSize, len(texts)) + embeddings, err := c.embedChunk(ctx, texts[start:end]) + if err != nil { + return nil, err + } + allEmbeddings = append(allEmbeddings, embeddings...) + } + + slog.Debug("OpenAI embedding batch completed", + "inputs", len(texts), "chunks", (len(texts)+c.maxBatchSize-1)/c.maxBatchSize, + "dimensions", len(allEmbeddings[0])) + + return allEmbeddings, nil +} + +// embedChunk sends one batch to the /embeddings endpoint and returns the +// embeddings ordered to match texts. +func (c *openAIClient) embedChunk(ctx context.Context, texts []string) ([][]float32, error) { + bodyBytes, err := json.Marshal(openAIEmbedRequest{Model: c.model, Input: texts, EncodingFormat: "float"}) + if err != nil { + return nil, fmt.Errorf("failed to marshal OpenAI request: %w", err) + } + + url := c.baseURL + embeddingsPath + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create OpenAI request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + if c.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+c.apiKey) + } + + resp, err := c.httpClient.Do(req) // #nosec G704 -- URL is built from the configured embedding base URL + if err != nil { + return nil, fmt.Errorf("OpenAI request failed: %w", err) + } + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OpenAI returned status %d: %s", resp.StatusCode, string(body)) + } + + var embedResp openAIEmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("failed to decode OpenAI response: %w", err) + } + + if len(embedResp.Data) != len(texts) { + return nil, fmt.Errorf("OpenAI returned %d embeddings for %d inputs", len(embedResp.Data), len(texts)) + } + + // Place each embedding at its reported index; the API is free to return + // entries out of order. + embeddings := make([][]float32, len(texts)) + for _, d := range embedResp.Data { + if d.Index < 0 || d.Index >= len(texts) { + return nil, fmt.Errorf("OpenAI returned out-of-range embedding index %d for %d inputs", d.Index, len(texts)) + } + embeddings[d.Index] = d.Embedding + } + + return embeddings, nil +} + +// Close is a no-op for the OpenAI client. +func (*openAIClient) Close() error { + return nil +} diff --git a/pkg/vmcp/optimizer/internal/similarity/openai_client_integration_test.go b/pkg/vmcp/optimizer/internal/similarity/openai_client_integration_test.go new file mode 100644 index 0000000000..59efacaad2 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/similarity/openai_client_integration_test.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package similarity + +import ( + "cmp" + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestOpenAIClient_Live exercises the real /embeddings wire path against an +// OpenAI-compatible endpoint. It is skipped unless OPENAI_API_KEY is set, so the +// default `task test` run stays green. Override OPENAI_EMBEDDING_BASE_URL and +// OPENAI_EMBEDDING_MODEL to point it at a compatible gateway. +func TestOpenAIClient_Live(t *testing.T) { + t.Parallel() + + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + t.Skip("OPENAI_API_KEY not set; skipping live OpenAI embedding test") + } + + baseURL := cmp.Or(os.Getenv("OPENAI_EMBEDDING_BASE_URL"), "https://api.openai.com/v1") + model := cmp.Or(os.Getenv("OPENAI_EMBEDDING_MODEL"), "text-embedding-3-small") + + client, err := newOpenAIClient(baseURL, model, apiKey, 0) + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + + vec, err := client.Embed(ctx, "the quick brown fox") + require.NoError(t, err) + require.NotEmpty(t, vec, "embedding vector must not be empty") + + // Repeat one input so we can confirm results land back in request order: + // identical inputs must produce identical vectors at their own indices. + inputs := []string{"the quick brown fox", "lorem ipsum", "the quick brown fox"} + batch, err := client.EmbedBatch(ctx, inputs) + require.NoError(t, err) + require.Len(t, batch, len(inputs)) + for i, e := range batch { + require.Lenf(t, e, len(vec), "embedding %d has unexpected dimension", i) + } + require.Equal(t, batch[0], batch[2], "identical inputs must map to identical embeddings (order preserved)") +} diff --git a/pkg/vmcp/optimizer/internal/similarity/openai_client_test.go b/pkg/vmcp/optimizer/internal/similarity/openai_client_test.go new file mode 100644 index 0000000000..af7da66bc1 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/similarity/openai_client_test.go @@ -0,0 +1,320 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package similarity + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func Test_newOpenAIClient(t *testing.T) { + t.Parallel() + + t.Run("empty URL returns error", func(t *testing.T) { + t.Parallel() + client, err := newOpenAIClient("", "text-embedding-3-small", "key", 0) + require.ErrorContains(t, err, "OpenAI embedding base URL is required") + require.Nil(t, client) + }) + + t.Run("empty model returns error", func(t *testing.T) { + t.Parallel() + client, err := newOpenAIClient("http://embeddings:8080/v1", "", "key", 0) + require.ErrorContains(t, err, "OpenAI embedding model is required") + require.Nil(t, client) + }) + + t.Run("valid args create client with default batch size", func(t *testing.T) { + t.Parallel() + client, err := newOpenAIClient("http://embeddings:8080/v1", "text-embedding-3-small", "key", 0) + require.NoError(t, err) + require.NotNil(t, client) + require.Equal(t, openAIMaxBatchSize, client.maxBatchSize) + require.Equal(t, defaultTimeout, client.httpClient.Timeout) + }) + + t.Run("custom timeout", func(t *testing.T) { + t.Parallel() + client, err := newOpenAIClient("http://embeddings:8080/v1", "text-embedding-3-small", "key", 5*time.Second) + require.NoError(t, err) + require.NotNil(t, client) + require.Equal(t, 5*time.Second, client.httpClient.Timeout) + }) +} + +func TestOpenAIClient_Embed(t *testing.T) { + t.Parallel() + + expected := []float32{0.1, 0.2, 0.3} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, embeddingsPath, r.URL.Path) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + require.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + + var req openAIEmbedRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + require.Equal(t, "text-embedding-3-small", req.Model) + require.Equal(t, "float", req.EncodingFormat) + require.Len(t, req.Input, 1) + require.Equal(t, "hello world", req.Input[0]) + + writeOpenAIEmbeddings(t, w, [][]float32{expected}) + })) + t.Cleanup(srv.Close) + + client := newTestOpenAIClient(t, srv.URL, "test-key") + + result, err := client.Embed(context.Background(), "hello world") + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func TestOpenAIClient_EmbedBatch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + texts []string + handler http.HandlerFunc + wantErr string + wantLen int + wantResult [][]float32 + }{ + { + name: "empty input", + texts: nil, + }, + { + name: "single input", + texts: []string{"hello"}, + handler: func(w http.ResponseWriter, _ *http.Request) { + writeOpenAIEmbeddings(t, w, [][]float32{{0.1, 0.2}}) + }, + wantLen: 1, + wantResult: [][]float32{{0.1, 0.2}}, + }, + { + name: "multiple inputs", + texts: []string{"hello", "world"}, + handler: func(w http.ResponseWriter, _ *http.Request) { + writeOpenAIEmbeddings(t, w, [][]float32{{0.1, 0.2}, {0.3, 0.4}}) + }, + wantLen: 2, + wantResult: [][]float32{{0.1, 0.2}, {0.3, 0.4}}, + }, + { + name: "out-of-order data is reordered by index", + texts: []string{"hello", "world"}, + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(openAIEmbedResponse{Data: []openAIEmbedding{ + {Index: 1, Embedding: []float32{0.3, 0.4}}, + {Index: 0, Embedding: []float32{0.1, 0.2}}, + }}) + }, + wantLen: 2, + wantResult: [][]float32{{0.1, 0.2}, {0.3, 0.4}}, + }, + { + name: "server error", + texts: []string{"hello"}, + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal error")) + }, + wantErr: "OpenAI returned status 500", + }, + { + name: "mismatched count", + texts: []string{"hello", "world"}, + handler: func(w http.ResponseWriter, _ *http.Request) { + writeOpenAIEmbeddings(t, w, [][]float32{{0.1, 0.2}}) + }, + wantErr: "OpenAI returned 1 embeddings for 2 inputs", + }, + { + name: "out-of-range index", + texts: []string{"hello"}, + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(openAIEmbedResponse{Data: []openAIEmbedding{ + {Index: 5, Embedding: []float32{0.1, 0.2}}, + }}) + }, + wantErr: "out-of-range embedding index 5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var srv *httptest.Server + if tt.handler != nil { + srv = httptest.NewServer(tt.handler) + t.Cleanup(srv.Close) + } + + baseURL := "http://localhost:0" + if srv != nil { + baseURL = srv.URL + } + + client := newTestOpenAIClient(t, baseURL, "test-key") + + results, err := client.EmbedBatch(context.Background(), tt.texts) + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + if tt.wantLen > 0 { + require.Len(t, results, tt.wantLen) + require.Equal(t, tt.wantResult, results) + } else { + require.Nil(t, results) + } + }) + } +} + +func TestOpenAIClient_EmbedBatch_Chunking(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + maxBatchSize int + numInputs int + wantChunks int + }{ + {name: "inputs fit in single batch", maxBatchSize: 5, numInputs: 3, wantChunks: 1}, + {name: "inputs exactly fill one batch", maxBatchSize: 4, numInputs: 4, wantChunks: 1}, + {name: "inputs split into two batches", maxBatchSize: 3, numInputs: 5, wantChunks: 2}, + {name: "inputs split into many batches", maxBatchSize: 2, numInputs: 7, wantChunks: 4}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var chunkCount int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req openAIEmbedRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + require.LessOrEqual(t, len(req.Input), tt.maxBatchSize, + "chunk size should not exceed maxBatchSize") + chunkCount++ + + embeddings := make([][]float32, len(req.Input)) + for i := range embeddings { + embeddings[i] = []float32{float32(i) * 0.1} + } + writeOpenAIEmbeddings(t, w, embeddings) + })) + t.Cleanup(srv.Close) + + texts := make([]string, tt.numInputs) + for i := range texts { + texts[i] = fmt.Sprintf("text-%d", i) + } + + client := newTestOpenAIClientWithBatch(t, srv.URL, tt.maxBatchSize) + results, err := client.EmbedBatch(context.Background(), texts) + require.NoError(t, err) + require.Len(t, results, tt.numInputs) + require.Equal(t, tt.wantChunks, chunkCount) + }) + } +} + +func TestOpenAIClient_EmbedBatch_ChunkErrorStopsEarly(t *testing.T) { + t.Parallel() + + var callCount int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount++ + if callCount == 2 { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("server overloaded")) + return + } + writeOpenAIEmbeddings(t, w, [][]float32{{0.1}, {0.2}}) + })) + t.Cleanup(srv.Close) + + texts := make([]string, 6) // 3 chunks of 2 + for i := range texts { + texts[i] = fmt.Sprintf("text-%d", i) + } + + client := newTestOpenAIClientWithBatch(t, srv.URL, 2) + _, err := client.EmbedBatch(context.Background(), texts) + require.ErrorContains(t, err, "OpenAI returned status 500") + require.Equal(t, 2, callCount, "should stop after the failing chunk") +} + +func TestOpenAIClient_OmitsAuthHeaderWhenKeyless(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Empty(t, r.Header.Get("Authorization")) + writeOpenAIEmbeddings(t, w, [][]float32{{0.1}}) + })) + t.Cleanup(srv.Close) + + client := newTestOpenAIClient(t, srv.URL, "") + + _, err := client.Embed(context.Background(), "hello") + require.NoError(t, err) +} + +func TestOpenAIClient_Close(t *testing.T) { + t.Parallel() + + client := newTestOpenAIClient(t, "http://my-embedding:8080/v1", "key") + require.NoError(t, client.Close()) +} + +// writeOpenAIEmbeddings encodes embeddings as an OpenAI /embeddings response, +// assigning each entry its slice position as the index. +func writeOpenAIEmbeddings(t *testing.T, w http.ResponseWriter, embeddings [][]float32) { + t.Helper() + resp := openAIEmbedResponse{Data: make([]openAIEmbedding, len(embeddings))} + for i, e := range embeddings { + resp.Data[i] = openAIEmbedding{Index: i, Embedding: e} + } + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(resp)) +} + +// newTestOpenAIClient creates an openAIClient pointing at the given URL for +// testing. It defaults to a large batch size so requests are single-chunk. +func newTestOpenAIClient(t *testing.T, baseURL, apiKey string) *openAIClient { + t.Helper() + client := newTestOpenAIClientWithBatch(t, baseURL, 1000) + client.apiKey = apiKey + return client +} + +// newTestOpenAIClientWithBatch creates an openAIClient with a specific max batch +// size for testing, using a fixed API key. +func newTestOpenAIClientWithBatch(t *testing.T, baseURL string, maxBatchSize int) *openAIClient { + t.Helper() + return &openAIClient{ + baseURL: baseURL, + apiKey: "test-key", + model: "text-embedding-3-small", + httpClient: &http.Client{Timeout: defaultTimeout}, + maxBatchSize: maxBatchSize, + } +} diff --git a/pkg/vmcp/optimizer/internal/similarity/tei_client.go b/pkg/vmcp/optimizer/internal/similarity/tei_client.go index fe58ba4fd1..035f3e0f9d 100644 --- a/pkg/vmcp/optimizer/internal/similarity/tei_client.go +++ b/pkg/vmcp/optimizer/internal/similarity/tei_client.go @@ -12,8 +12,6 @@ import ( "log/slog" "net/http" "time" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/types" ) const ( @@ -38,16 +36,6 @@ type teiClient struct { maxBatchSize int } -// NewEmbeddingClient creates an EmbeddingClient from the given optimizer -// configuration. It returns (nil, nil) if cfg is nil or no embedding service -// URL is configured, meaning semantic search will be disabled. -func NewEmbeddingClient(cfg *types.OptimizerConfig) (types.EmbeddingClient, error) { - if cfg == nil || cfg.EmbeddingService == "" { - return nil, nil - } - return newTEIClient(cfg.EmbeddingService, cfg.EmbeddingServiceTimeout) -} - // newTEIClient creates a new TEI embedding client that calls the specified endpoint. // It queries the TEI /info endpoint to discover the server's maximum batch size. func newTEIClient(baseURL string, timeout time.Duration) (*teiClient, error) { diff --git a/pkg/vmcp/optimizer/internal/types/types.go b/pkg/vmcp/optimizer/internal/types/types.go index dac8beaa1e..1eb90f424a 100644 --- a/pkg/vmcp/optimizer/internal/types/types.go +++ b/pkg/vmcp/optimizer/internal/types/types.go @@ -37,6 +37,16 @@ type ToolStore interface { Close() error } +// Embedding provider identifiers select the wire protocol used to talk to the +// embedding service. They match config.OptimizerConfig.EmbeddingProvider. +const ( + // EmbeddingProviderTEI speaks the HuggingFace Text Embeddings Inference API. + EmbeddingProviderTEI = "tei" + + // EmbeddingProviderOpenAI speaks the OpenAI-compatible /embeddings API. + EmbeddingProviderOpenAI = "openai" +) + // EmbeddingClient generates vector embeddings from text. // Implementations may use local models, remote APIs, or deterministic fakes. // The dimensionality of embeddings can be inferred from the returned vectors. @@ -70,6 +80,20 @@ type OptimizerConfig struct { // Zero means use the default timeout (30s). EmbeddingServiceTimeout time.Duration + // EmbeddingProvider selects the embedding backend wire protocol + // (EmbeddingProviderTEI or EmbeddingProviderOpenAI). Empty defaults to TEI. + EmbeddingProvider string + + // EmbeddingModel is the model name requested from an OpenAI-compatible + // embedding service (e.g. "text-embedding-3-small"). Unused by the TEI + // provider, where the model is fixed by the running container. + EmbeddingModel string + + // EmbeddingAPIKey is the bearer token sent to an OpenAI-compatible embedding + // service. Empty means no Authorization header is sent, which supports + // keyless in-cluster gateways. Never populated for the TEI provider. + EmbeddingAPIKey string + // MaxToolsToReturn limits the number of tools returned by FindTool. MaxToolsToReturn *int diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index 2fe3f85587..7e41d0d0ed 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -16,6 +16,7 @@ import ( "context" "fmt" "log/slog" + "os" "strconv" "time" @@ -29,6 +30,12 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/types" ) +// embeddingAPIKeyEnvVar holds the bearer token for an OpenAI-compatible +// embedding service. It is an env var, not a config field, so the secret never +// lands in a CRD spec or ConfigMap. +// #nosec G101 -- This is an environment variable name, not a hardcoded credential +const embeddingAPIKeyEnvVar = "OPENAI_API_KEY" + // Config defines configuration options for the Optimizer. // It is defined in the internal/types package and aliased here so that // external consumers continue to use optimizer.Config. @@ -45,6 +52,12 @@ func GetAndValidateConfig(cfg *vmcpconfig.OptimizerConfig) (*Config, error) { optCfg := &Config{ EmbeddingService: cfg.EmbeddingService, EmbeddingServiceTimeout: time.Duration(cfg.EmbeddingServiceTimeout), + EmbeddingProvider: cfg.EmbeddingProvider, + EmbeddingModel: cfg.EmbeddingModel, + } + + if err := resolveEmbeddingProvider(optCfg); err != nil { + return nil, err } if cfg.MaxToolsToReturn != 0 { @@ -85,6 +98,32 @@ func GetAndValidateConfig(cfg *vmcpconfig.OptimizerConfig) (*Config, error) { return optCfg, nil } +// resolveEmbeddingProvider normalizes and validates the embedding provider on +// optCfg in place. An empty provider defaults to TEI so existing configs keep +// working; the OpenAI provider requires a service and model and reads its API +// key from the environment. +func resolveEmbeddingProvider(optCfg *Config) error { + switch optCfg.EmbeddingProvider { + case "": + optCfg.EmbeddingProvider = types.EmbeddingProviderTEI + case types.EmbeddingProviderTEI: + case types.EmbeddingProviderOpenAI: + if optCfg.EmbeddingService == "" { + return fmt.Errorf("optimizer.embeddingService is required when optimizer.embeddingProvider is %q", + types.EmbeddingProviderOpenAI) + } + if optCfg.EmbeddingModel == "" { + return fmt.Errorf("optimizer.embeddingModel is required when optimizer.embeddingProvider is %q", + types.EmbeddingProviderOpenAI) + } + optCfg.EmbeddingAPIKey = os.Getenv(embeddingAPIKeyEnvVar) + default: + return fmt.Errorf("optimizer.embeddingProvider must be %q or %q, got %q", + types.EmbeddingProviderTEI, types.EmbeddingProviderOpenAI, optCfg.EmbeddingProvider) + } + return nil +} + // Optimizer defines the interface for intelligent tool discovery and invocation. // // The default implementation delegates search to a ToolStore (SQLite FTS5 with diff --git a/pkg/vmcp/optimizer/optimizer_test.go b/pkg/vmcp/optimizer/optimizer_test.go index c389c0e17f..8ccef87b0c 100644 --- a/pkg/vmcp/optimizer/optimizer_test.go +++ b/pkg/vmcp/optimizer/optimizer_test.go @@ -18,6 +18,7 @@ import ( vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/tokencounter" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/types" "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/types/mocks" ) @@ -52,6 +53,54 @@ func TestGetAndValidateConfig(t *testing.T) { EmbeddingService: "http://embeddings:8080", }, }, + { + name: "explicit tei provider", + cfg: &vmcpconfig.OptimizerConfig{ + EmbeddingService: "http://embeddings:8080", + EmbeddingProvider: types.EmbeddingProviderTEI, + }, + expected: &Config{ + EmbeddingService: "http://embeddings:8080", + EmbeddingProvider: types.EmbeddingProviderTEI, + }, + }, + { + name: "openai provider with service and model", + cfg: &vmcpconfig.OptimizerConfig{ + EmbeddingService: "http://gateway:8080/v1", + EmbeddingProvider: types.EmbeddingProviderOpenAI, + EmbeddingModel: "text-embedding-3-small", + }, + expected: &Config{ + EmbeddingService: "http://gateway:8080/v1", + EmbeddingProvider: types.EmbeddingProviderOpenAI, + EmbeddingModel: "text-embedding-3-small", + }, + }, + { + name: "error: openai provider without service", + cfg: &vmcpconfig.OptimizerConfig{ + EmbeddingProvider: types.EmbeddingProviderOpenAI, + EmbeddingModel: "text-embedding-3-small", + }, + errContains: "optimizer.embeddingService is required", + }, + { + name: "error: openai provider without model", + cfg: &vmcpconfig.OptimizerConfig{ + EmbeddingService: "http://gateway:8080/v1", + EmbeddingProvider: types.EmbeddingProviderOpenAI, + }, + errContains: "optimizer.embeddingModel is required", + }, + { + name: "error: unknown provider", + cfg: &vmcpconfig.OptimizerConfig{ + EmbeddingService: "http://embeddings:8080", + EmbeddingProvider: "cohere", + }, + errContains: "optimizer.embeddingProvider must be", + }, { name: "all valid values are parsed", cfg: &vmcpconfig.OptimizerConfig{ @@ -208,6 +257,13 @@ func TestGetAndValidateConfig(t *testing.T) { require.NotNil(t, result) assert.Equal(t, tt.expected.EmbeddingService, result.EmbeddingService) + wantProvider := tt.expected.EmbeddingProvider + if wantProvider == "" { + wantProvider = types.EmbeddingProviderTEI + } + assert.Equal(t, wantProvider, result.EmbeddingProvider) + assert.Equal(t, tt.expected.EmbeddingModel, result.EmbeddingModel) + if tt.expected.MaxToolsToReturn != nil { require.NotNil(t, result.MaxToolsToReturn) assert.Equal(t, *tt.expected.MaxToolsToReturn, *result.MaxToolsToReturn) @@ -232,6 +288,39 @@ func TestGetAndValidateConfig(t *testing.T) { } } +func TestGetAndValidateConfig_OpenAIAPIKeyFromEnv(t *testing.T) { + openAICfg := func() *vmcpconfig.OptimizerConfig { + return &vmcpconfig.OptimizerConfig{ + EmbeddingService: "http://gateway:8080/v1", + EmbeddingProvider: types.EmbeddingProviderOpenAI, + EmbeddingModel: "text-embedding-3-small", + } + } + + t.Run("key is read from the environment", func(t *testing.T) { + t.Setenv(embeddingAPIKeyEnvVar, "sk-test") + result, err := GetAndValidateConfig(openAICfg()) + require.NoError(t, err) + assert.Equal(t, "sk-test", result.EmbeddingAPIKey) + }) + + t.Run("unset key yields a keyless client", func(t *testing.T) { + t.Setenv(embeddingAPIKeyEnvVar, "") + result, err := GetAndValidateConfig(openAICfg()) + require.NoError(t, err) + assert.Empty(t, result.EmbeddingAPIKey) + }) + + t.Run("tei provider never reads the key", func(t *testing.T) { + t.Setenv(embeddingAPIKeyEnvVar, "sk-test") + result, err := GetAndValidateConfig(&vmcpconfig.OptimizerConfig{ + EmbeddingService: "http://embeddings:8080", + }) + require.NoError(t, err) + assert.Empty(t, result.EmbeddingAPIKey) + }) +} + // newMockStoreWithSubstringSearch returns a gomock MockToolStore configured with // DoAndReturn handlers that accumulate tools via UpsertTools and perform // case-insensitive substring matching on Search. Suitable for tests that need