Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions go/adk/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,17 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
}
return models.NewSAPAICoreModelWithLogger(cfg, log)

case *adk.Foundry:
cfg := &models.FoundryConfig{
TransportConfig: transportConfigFromBase(m.BaseModel, nil),
Model: m.Model,
Endpoint: m.Endpoint,
Deployment: m.Deployment,
APIVersion: m.APIVersion,
AuthType: string(m.Auth.Type),
}
return models.NewFoundryModelWithLogger(ctx, cfg, log)

default:
return nil, fmt.Errorf("unsupported model type: %s", m.GetType())
}
Expand Down
139 changes: 139 additions & 0 deletions go/adk/pkg/models/foundry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package models

import (
"context"
"fmt"
"net/http"
"os"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/go-logr/logr"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
)

// FoundryConfig holds Foundry configuration.
type FoundryConfig struct {
TransportConfig
Model string
Endpoint string
Deployment string
APIVersion string
AuthType string
}

const (
foundryAuthTypeAPIKey = "APIKey"
foundryAuthTypeWorkloadIdentity = "WorkloadIdentity"
foundryAuthTypeAPIKeyPassthrough = "APIKeyPassthrough"
)

// NewFoundryModelWithLogger creates a Foundry model.
func NewFoundryModelWithLogger(ctx context.Context, config *FoundryConfig, logger logr.Logger) (*OpenAIModel, error) {
endpoint := config.Endpoint
if endpoint == "" {
endpoint = os.Getenv("FOUNDRY_ENDPOINT")
}
if endpoint == "" {
return nil, fmt.Errorf("FOUNDRY_ENDPOINT environment variable is not set")
}
deployment := config.Deployment
if deployment == "" {
deployment = os.Getenv("FOUNDRY_DEPLOYMENT")
}
if deployment == "" {
return nil, fmt.Errorf("FOUNDRY_DEPLOYMENT environment variable is not set")
}
apiVersion := config.APIVersion
if apiVersion == "" {
apiVersion = os.Getenv("FOUNDRY_API_VERSION")
}
if apiVersion == "" {
apiVersion = "2024-10-21"
}

httpClient, err := BuildHTTPClient(config.TransportConfig)
if err != nil {
return nil, err
}
opts := []option.RequestOption{
option.WithBaseURL(strings.TrimSuffix(endpoint, "/") + "/"),
option.WithQueryAdd("api-version", apiVersion),
option.WithMiddleware(azurePathRewriteMiddleware()),
option.WithHTTPClient(httpClient),
}

authType := config.AuthType
if authType == "" {
authType = foundryAuthTypeWorkloadIdentity
}
switch authType {
case foundryAuthTypeAPIKey:
apiKey := os.Getenv("FOUNDRY_API_KEY")
if apiKey == "" {
return nil, fmt.Errorf("FOUNDRY_API_KEY environment variable is not set")
}
opts = append(opts, option.WithHeader("Api-Key", apiKey))
case foundryAuthTypeWorkloadIdentity:
credential, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, fmt.Errorf("failed to create Azure credential: %w", err)
}
opts = append(opts,
option.WithAPIKey("foundry-entra"),
option.WithMiddleware(foundryBearerTokenMiddleware(credential)),
)
case foundryAuthTypeAPIKeyPassthrough:
config.APIKeyPassthrough = true
opts = append(opts, option.WithMiddleware(foundryPassthroughBearerTokenMiddleware()))
default:
return nil, fmt.Errorf("unsupported Foundry auth type: %s", authType)
}

client := openai.NewClient(opts...)
if logger.GetSink() != nil {
logger.Info("Initialized Foundry model", "model", config.Model, "deployment", deployment, "endpoint", endpoint, "apiVersion", apiVersion)
}
return &OpenAIModel{
Config: &OpenAIConfig{
TransportConfig: config.TransportConfig,
Model: deployment,
BaseUrl: strings.TrimSuffix(endpoint, "/") + "/",
},
Client: client,
IsAzure: true,
Logger: logger,
}, nil
}

type foundryTokenCredential interface {
GetToken(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error)
}

func foundryBearerTokenMiddleware(credential foundryTokenCredential) option.Middleware {
return func(r *http.Request, next option.MiddlewareNext) (*http.Response, error) {
token, err := credential.GetToken(r.Context(), policy.TokenRequestOptions{Scopes: []string{"https://cognitiveservices.azure.com/.default"}})
if err != nil {
return nil, fmt.Errorf("failed to acquire Foundry token: %w", err)
}
r = r.Clone(r.Context())
r.Header.Set("Authorization", "Bearer "+token.Token)
return next(r)
}
}

func foundryPassthroughBearerTokenMiddleware() option.Middleware {
return func(r *http.Request, next option.MiddlewareNext) (*http.Response, error) {
token, ok := r.Context().Value(BearerTokenKey).(string)
if !ok || token == "" {
return next(r)
}
r = r.Clone(r.Context())
r.Header.Del("Api-Key")
r.Header.Set("Authorization", "Bearer "+token)
return next(r)
}
}
144 changes: 144 additions & 0 deletions go/adk/pkg/models/foundry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package models

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/go-logr/logr"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/shared"
)

func TestNewFoundryModelWithLoggerAPIKeyRequiresEnv(t *testing.T) {
t.Setenv("FOUNDRY_API_KEY", "")

_, err := NewFoundryModelWithLogger(context.Background(), &FoundryConfig{
Endpoint: "https://example.openai.azure.com/",
Deployment: "gpt-4-1-nano",
AuthType: foundryAuthTypeAPIKey,
}, logr.Discard())
if err == nil || !strings.Contains(err.Error(), "FOUNDRY_API_KEY environment variable is not set") {
t.Fatalf("NewFoundryModelWithLogger() error = %v, want missing FOUNDRY_API_KEY", err)
}
}

func TestNewFoundryModelWithLoggerAPIKeyPassthrough(t *testing.T) {
model, err := NewFoundryModelWithLogger(context.Background(), &FoundryConfig{
Endpoint: "https://example.openai.azure.com/",
Deployment: "gpt-4-1-nano",
AuthType: foundryAuthTypeAPIKeyPassthrough,
}, logr.Discard())
if err != nil {
t.Fatalf("NewFoundryModelWithLogger() error = %v", err)
}
if model == nil || model.Config == nil || !model.Config.APIKeyPassthrough {
t.Fatalf("APIKeyPassthrough = false, want true")
}
if !model.IsAzure {
t.Fatalf("IsAzure = false, want true")
}
}

func TestFoundryAPIKeyPassthroughSendsAuthorizationHeader(t *testing.T) {
requests := make(chan foundryRequest, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests <- foundryRequest{
apiKey: r.Header.Get("Api-Key"),
authorization: r.Header.Get("Authorization"),
path: r.URL.Path,
apiVersion: r.URL.Query().Get("api-version"),
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"id":"chatcmpl-test","object":"chat.completion","created":0,"model":"gpt-4-1-nano","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}]}`)
}))
t.Cleanup(server.Close)

model, err := NewFoundryModelWithLogger(context.Background(), &FoundryConfig{
Endpoint: server.URL,
Deployment: "gpt-4-1-nano",
AuthType: foundryAuthTypeAPIKeyPassthrough,
}, logr.Discard())
if err != nil {
t.Fatalf("NewFoundryModelWithLogger() error = %v", err)
}

ctx := context.WithValue(context.Background(), BearerTokenKey, "incoming-token")
_, err = model.Client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Model: shared.ChatModel("gpt-4-1-nano"),
Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage("hello")},
}, openAIPassthroughOpts(ctx, model)...)
if err != nil {
t.Fatalf("Chat completion request error = %v", err)
}
req := <-requests
if req.path != "/openai/deployments/gpt-4-1-nano/chat/completions" {
t.Fatalf("path = %q, want Azure deployment path", req.path)
}
if req.apiVersion != "2024-10-21" {
t.Fatalf("api-version = %q, want 2024-10-21", req.apiVersion)
}
if req.authorization != "Bearer incoming-token" {
t.Fatalf("Authorization header = %q, want Bearer incoming-token", req.authorization)
}
if req.apiKey != "" {
t.Fatalf("Api-Key header = %q, want empty", req.apiKey)
}
}

func TestFoundryBearerTokenMiddlewareUsesRequestContext(t *testing.T) {
credential := &requestContextCredential{t: t}
middleware := foundryBearerTokenMiddleware(credential)
req := httptest.NewRequest(http.MethodPost, "https://example.com/chat/completions", nil)
req = req.WithContext(context.WithValue(req.Context(), foundryRequestContextKey{}, "request-context"))

_, err := middleware(req, func(r *http.Request) (*http.Response, error) {
if got := r.Header.Get("Authorization"); got != "Bearer request-token" {
t.Fatalf("Authorization = %q, want bearer token", got)
}
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
})
if err != nil {
t.Fatalf("middleware error = %v", err)
}
}

func TestNewFoundryModelWithLoggerUnsupportedAuthType(t *testing.T) {
_, err := NewFoundryModelWithLogger(context.Background(), &FoundryConfig{
Endpoint: "https://example.openai.azure.com/",
Deployment: "gpt-4-1-nano",
AuthType: "Unknown",
}, logr.Discard())
if err == nil || !strings.Contains(err.Error(), "unsupported Foundry auth type: Unknown") {
t.Fatalf("NewFoundryModelWithLogger() error = %v, want unsupported auth type", err)
}
}

type foundryRequestContextKey struct{}

type foundryRequest struct {
apiKey string
authorization string
path string
apiVersion string
}

type requestContextCredential struct {
t *testing.T
}

func (c *requestContextCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
c.t.Helper()
if got := ctx.Value(foundryRequestContextKey{}); got != "request-context" {
c.t.Fatalf("GetToken context marker = %v, want request-context", got)
}
if len(opts.Scopes) != 1 || opts.Scopes[0] != "https://cognitiveservices.azure.com/.default" {
c.t.Fatalf("Scopes = %v, want cognitive services scope", opts.Scopes)
}
return azcore.AccessToken{Token: "request-token"}, nil
}
45 changes: 45 additions & 0 deletions go/api/adk/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ const (
ModelTypeGemini = "gemini"
ModelTypeBedrock = "bedrock"
ModelTypeSAPAICore = "sap_ai_core"
ModelTypeFoundry = "foundry"
)

func (o *OpenAI) MarshalJSON() ([]byte, error) {
Expand Down Expand Up @@ -302,6 +303,42 @@ func (s *SAPAICore) GetType() string {
return ModelTypeSAPAICore
}

// Types for Foundry
type FoundryAuthType string

const (
FoundryAuthTypeAPIKey FoundryAuthType = "APIKey"
FoundryAuthTypeWorkloadIdentity FoundryAuthType = "WorkloadIdentity"
FoundryAuthTypeAPIKeyPassthrough FoundryAuthType = "APIKeyPassthrough"
)

type FoundryAuth struct {
Type FoundryAuthType `json:"type"`
}

type Foundry struct {
BaseModel
Endpoint string `json:"endpoint"`
Deployment string `json:"deployment"`
APIVersion string `json:"api_version"`
Auth FoundryAuth `json:"auth"`
}

func (a *Foundry) GetType() string {
return ModelTypeFoundry
}

func (a *Foundry) MarshalJSON() ([]byte, error) {
type Alias Foundry
return json.Marshal(&struct {
Type string `json:"type"`
*Alias
}{
Type: ModelTypeFoundry,
Alias: (*Alias)(a),
})
}

// GenericModel is a catch-all model type used by the Go ADK when the model
// type doesn't match any known constant.
type GenericModel struct {
Expand Down Expand Up @@ -370,6 +407,12 @@ func ParseModel(bytes []byte) (Model, error) {
return nil, err
}
return &sapAICore, nil
case ModelTypeFoundry:
var foundry Foundry
if err := json.Unmarshal(bytes, &foundry); err != nil {
return nil, err
}
return &foundry, nil
}
return nil, fmt.Errorf("unknown model type: %s", model.Type)
}
Expand Down Expand Up @@ -438,6 +481,8 @@ func ModelToEmbeddingConfig(m Model) *EmbeddingConfig {
case *SAPAICore:
e.Model = v.Model
e.BaseUrl = v.BaseUrl
case *Foundry:
e.Model = v.Model
default:
e.Model = ""
}
Expand Down
Loading
Loading