Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
474 changes: 474 additions & 0 deletions pkg/ai/completion.go

Large diffs are not rendered by default.

486 changes: 486 additions & 0 deletions pkg/ai/completion_test.go

Large diffs are not rendered by default.

89 changes: 89 additions & 0 deletions pkg/ai/generate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package ai

import (
"context"
"encoding/json"
"errors"
"io"
"iter"

"github.com/docker/docker-agent/pkg/chat"
)

// StreamValue represents a single value yielded during streaming.
type StreamValue[Out, Stream any] struct {
Done bool
Chunk Stream // valid if Done is false
Value Out // valid if Done is true
Response *ModelResponse // valid if Done is true
}

// ModelStreamValue is a stream value for a model response.
// Out is never set because the value is already available in the Response field.
type ModelStreamValue = StreamValue[struct{}, chat.MessageStreamResponse]

// GenerateStream generates a model response and streams the output.
// It returns an iterator that yields streaming results.
func GenerateStream(ctx context.Context, opts ...Option) iter.Seq2[*ModelStreamValue, error] {
return func(yield func(*ModelStreamValue, error) bool) {
c := &completion{
yield: func(resp chat.MessageStreamResponse) bool {
return yield(&ModelStreamValue{
Done: false,
Chunk: resp,
}, nil)
},
}

c = c.applyOptions(opts...)

res, err := c.generate(ctx)
if errors.Is(err, io.EOF) {
return
}

if err != nil {
yield(nil, err)
return
}

yield(&ModelStreamValue{
Done: true,
Response: res,
}, nil)
}
}

// Generate runs a completion and returns the final model response.
// It handles retry, fallback, tool execution, and streaming internally.
func Generate(ctx context.Context, opts ...Option) (*ModelResponse, error) {
return new(completion).applyOptions(opts...).generate(ctx)
}

// GenerateText is a convenience wrapper around Generate that returns
// only the text content from the model response.
func GenerateText(ctx context.Context, opts ...Option) (string, error) {
res, err := Generate(ctx, opts...)
if err != nil {
return "", err
}

return res.Content, nil
}

// GenerateValue runs a completion and unmarshals the model's response
// content into the provided type. Use with structured output to get
// typed responses from the model.
func GenerateValue[Out any](ctx context.Context, opts ...Option) (*Out, error) {
res, err := Generate(ctx, opts...)
if err != nil {
return nil, err
}

var out Out
if err := json.Unmarshal([]byte(res.Content), &out); err != nil {
return nil, err
}

return &out, nil
}
229 changes: 229 additions & 0 deletions pkg/ai/generate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
package ai

import (
"errors"
"testing"

"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/chat"
)

func TestGenerateStream(t *testing.T) {
t.Parallel()

tests := []struct {
name string
p *mockProvider
err string
expContent string
}{
{
name: "happy path yields chunks then done",
p: &mockProvider{
id: "test",
msgs: []chat.MessageStreamResponse{
{
Choices: []chat.MessageStreamChoice{
{Delta: chat.MessageDelta{Content: "hello"}},
},
},
{
Choices: []chat.MessageStreamChoice{
{Delta: chat.MessageDelta{Content: " world"}},
},
},
{
Choices: []chat.MessageStreamChoice{
{FinishReason: chat.FinishReasonStop},
},
Usage: &chat.Usage{InputTokens: 10},
},
},
},
expContent: "hello world",
},
{
name: "error yields error",
p: &mockProvider{
id: "test",
err: errors.New("model failed"),
},
err: "model failed",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := []Option{
WithModels(tt.p),
WithMessages(chat.Message{Role: "user", Content: "test"}),
}

var (
chunks int
res *ModelResponse
)

for sv, err := range GenerateStream(t.Context(), opts...) {
if err != nil {
require.ErrorContains(t, err, tt.err)
return
}

if sv.Done {
res = sv.Response
break
}

chunks++
}

if tt.err != "" {
t.Fatal("expected error but got none")
}

require.NotNil(t, res)
require.Equal(t, tt.expContent, res.Content)
require.Positive(t, chunks)
})
}
}

func TestGenerateText(t *testing.T) {
t.Parallel()

tests := []struct {
name string
p *mockProvider
err string
expContent string
}{
{
name: "returns text content",
p: &mockProvider{
id: "test",
msgs: []chat.MessageStreamResponse{
{
Choices: []chat.MessageStreamChoice{
{Delta: chat.MessageDelta{Content: "hello"}},
},
},
{
Choices: []chat.MessageStreamChoice{
{FinishReason: chat.FinishReasonStop},
},
},
},
},
expContent: "hello",
},
{
name: "error returns empty string",
p: &mockProvider{
id: "test",
err: errors.New("model failed"),
},
err: "model failed",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
text, err := GenerateText(t.Context(),
WithModels(tt.p),
WithMessages(chat.Message{Role: "user", Content: "test"}),
)

if tt.err != "" {
require.ErrorContains(t, err, tt.err)
require.Empty(t, text)
return
}

require.NoError(t, err)
require.Equal(t, tt.expContent, text)
})
}
}

func TestGenerateValue(t *testing.T) {
t.Parallel()

type Person struct {
Name string `json:"name"`
Age int `json:"age"`
}

tests := []struct {
name string
p *mockProvider
err string
exp *Person
}{
{
name: "unmarshals json response",
p: &mockProvider{
id: "test",
msgs: []chat.MessageStreamResponse{
{
Choices: []chat.MessageStreamChoice{
{Delta: chat.MessageDelta{Content: `{"name":"Alice","age":30}`}},
},
},
{
Choices: []chat.MessageStreamChoice{
{FinishReason: chat.FinishReasonStop},
},
},
},
},
exp: &Person{Name: "Alice", Age: 30},
},
{
name: "invalid json returns error",
p: &mockProvider{
id: "test",
msgs: []chat.MessageStreamResponse{
{
Choices: []chat.MessageStreamChoice{
{Delta: chat.MessageDelta{Content: "not json"}},
},
},
{
Choices: []chat.MessageStreamChoice{
{FinishReason: chat.FinishReasonStop},
},
},
},
},
err: "invalid character",
},
{
name: "model error returns error",
p: &mockProvider{
id: "test",
err: errors.New("model failed"),
},
err: "model failed",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := GenerateValue[Person](t.Context(),
WithModels(tt.p),
WithMessages(chat.Message{Role: "user", Content: "test"}),
)

if tt.err != "" {
require.ErrorContains(t, err, tt.err)
require.Nil(t, result)
return
}

require.NoError(t, err)
require.Equal(t, tt.exp, result)
})
}
}
43 changes: 43 additions & 0 deletions pkg/ai/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package ai

import (
"context"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/tools"
)

// StreamRequest holds the parameters for a single model stream call.
// It is passed through the interceptor chain and can be inspected or
// modified by interceptors before reaching the actual model call.
type StreamRequest struct {
Model provider.Provider
Messages []chat.Message
Tools []tools.Tool
}

// StreamInterceptor wraps a stream call, allowing callers to observe,
// modify, or short-circuit the request before and after it reaches the
// model. The interceptor receives the request and a handler to call the
// next step in the chain — either another interceptor or the actual
// model call. Returning without calling the handler skips the model call.
//
// Example:
//
// func logInterceptor(ctx context.Context, r *StreamRequest, h StreamHandler) (*ModelResponse, error) {
// // before: inspect or modify request
// res, err := h(ctx, r)
// // after: inspect response, record telemetry, etc.
// return res, err
// }
type StreamInterceptor func(context.Context, *StreamRequest, StreamHandler) (*ModelResponse, error)

// StreamHandler is the function signature for the next step in the
// interceptor chain. Call it to proceed with the stream request.
type StreamHandler func(context.Context, *StreamRequest) (*ModelResponse, error)

// ToolCallInterceptor wraps an individual tool call execution.
// The interceptor is responsible for calling tool.Handler and can
// add behavior before and after (permissions, logging, telemetry).
type ToolCallInterceptor func(context.Context, *ModelResponse, tools.ToolCall, tools.Tool) (*tools.ToolCallResult, error)
Loading
Loading