From 601d753488024ff54df2312a68f803c426ce3741 Mon Sep 17 00:00:00 2001 From: "s.hajyahya" Date: Tue, 14 Apr 2026 22:03:29 +0000 Subject: [PATCH] introduce pkg/ai as core LLM completion layer Add new pkg/ai package that extracts and centralizes model interaction logic from runtime. The package reuses existing types from chat, tools, and provider packages without moving them. --- pkg/ai/completion.go | 474 +++++++++++++++++ pkg/ai/completion_test.go | 486 ++++++++++++++++++ pkg/ai/generate.go | 89 ++++ pkg/ai/generate_test.go | 229 +++++++++ pkg/ai/interceptor.go | 43 ++ pkg/ai/option.go | 117 +++++ pkg/ai/options_test.go | 200 +++++++ .../testdata/content_and_tool_calls_in.json | 23 + .../testdata/content_and_tool_calls_out.json | 19 + pkg/ai/testdata/empty_stream_in.json | 1 + pkg/ai/testdata/empty_stream_out.json | 13 + pkg/ai/testdata/exec_tools_images_in.json | 23 + pkg/ai/testdata/exec_tools_images_out.json | 13 + pkg/ai/testdata/exec_tools_max_turns_in.json | 28 + pkg/ai/testdata/exec_tools_mixed_in.json | 33 ++ pkg/ai/testdata/exec_tools_mixed_out.json | 13 + pkg/ai/testdata/finish_length_in.json | 12 + pkg/ai/testdata/finish_length_out.json | 13 + .../finish_tool_calls_no_tools_in.json | 5 + .../finish_tool_calls_no_tools_out.json | 13 + pkg/ai/testdata/inferred_stop_in.json | 5 + pkg/ai/testdata/inferred_stop_out.json | 13 + pkg/ai/testdata/inferred_tool_calls_in.json | 9 + pkg/ai/testdata/inferred_tool_calls_out.json | 19 + pkg/ai/testdata/reasoning_in.json | 15 + pkg/ai/testdata/reasoning_out.json | 13 + pkg/ai/testdata/simple_text_in.json | 15 + pkg/ai/testdata/simple_text_out.json | 13 + pkg/ai/testdata/thinking_signature_in.json | 18 + pkg/ai/testdata/thinking_signature_out.json | 13 + pkg/ai/testdata/tool_calls_in.json | 27 + pkg/ai/testdata/tool_calls_out.json | 19 + pkg/runtime/fallback.go | 307 ++--------- pkg/runtime/streaming.go | 211 +++----- pkg/sessiontitle/generator.go | 82 +-- 35 files changed, 2153 insertions(+), 473 deletions(-) create mode 100644 pkg/ai/completion.go create mode 100644 pkg/ai/completion_test.go create mode 100644 pkg/ai/generate.go create mode 100644 pkg/ai/generate_test.go create mode 100644 pkg/ai/interceptor.go create mode 100644 pkg/ai/option.go create mode 100644 pkg/ai/options_test.go create mode 100644 pkg/ai/testdata/content_and_tool_calls_in.json create mode 100644 pkg/ai/testdata/content_and_tool_calls_out.json create mode 100644 pkg/ai/testdata/empty_stream_in.json create mode 100644 pkg/ai/testdata/empty_stream_out.json create mode 100644 pkg/ai/testdata/exec_tools_images_in.json create mode 100644 pkg/ai/testdata/exec_tools_images_out.json create mode 100644 pkg/ai/testdata/exec_tools_max_turns_in.json create mode 100644 pkg/ai/testdata/exec_tools_mixed_in.json create mode 100644 pkg/ai/testdata/exec_tools_mixed_out.json create mode 100644 pkg/ai/testdata/finish_length_in.json create mode 100644 pkg/ai/testdata/finish_length_out.json create mode 100644 pkg/ai/testdata/finish_tool_calls_no_tools_in.json create mode 100644 pkg/ai/testdata/finish_tool_calls_no_tools_out.json create mode 100644 pkg/ai/testdata/inferred_stop_in.json create mode 100644 pkg/ai/testdata/inferred_stop_out.json create mode 100644 pkg/ai/testdata/inferred_tool_calls_in.json create mode 100644 pkg/ai/testdata/inferred_tool_calls_out.json create mode 100644 pkg/ai/testdata/reasoning_in.json create mode 100644 pkg/ai/testdata/reasoning_out.json create mode 100644 pkg/ai/testdata/simple_text_in.json create mode 100644 pkg/ai/testdata/simple_text_out.json create mode 100644 pkg/ai/testdata/thinking_signature_in.json create mode 100644 pkg/ai/testdata/thinking_signature_out.json create mode 100644 pkg/ai/testdata/tool_calls_in.json create mode 100644 pkg/ai/testdata/tool_calls_out.json diff --git a/pkg/ai/completion.go b/pkg/ai/completion.go new file mode 100644 index 000000000..7c48b60da --- /dev/null +++ b/pkg/ai/completion.go @@ -0,0 +1,474 @@ +package ai + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "slices" + "strings" + "sync" + "time" + + "github.com/docker/docker-agent/pkg/backoff" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/modelerrors" + "github.com/docker/docker-agent/pkg/tools" +) + +// ModelResponse is the aggregated response from a model after processing +// a chat completion stream. It contains the model's text output, any tool +// calls requested, token usage, and metadata about which model responded. +type ModelResponse struct { + Calls []tools.ToolCall + FinishReason chat.FinishReason + Usage *chat.Usage + ThoughtSignature []byte + Stopped bool + Content string + ReasoningContent string + ThinkingSignature string + Model string + Turns int + // Messages contains the conversation messages built during tool execution + // turns (assistant messages with tool calls + tool result messages). + // Empty when tool calls are not executed by this package (WithReturnToolRequests). + // Does not include the final assistant message — that is represented by + // the Content and Calls fields of this response. + Messages []chat.Message +} + +type completion struct { + models []provider.Provider + messages []chat.Message + tools []tools.Tool + retries int + retryOnRateLimit bool + yield func(chat.MessageStreamResponse) bool + onModelFallback func(from, to provider.Provider, err error) + streamInterceptor StreamInterceptor + toolCallInterceptor ToolCallInterceptor + maxTurns int + turns int + returnToolRequests bool + requireContent bool + lg *slog.Logger +} + +func (c *completion) logger() *slog.Logger { + if c.lg != nil { + return c.lg + } + + return slog.Default() +} + +func (c *completion) applyOptions(opts ...Option) *completion { + for _, opt := range opts { + opt(c) + } + + return c +} + +func (c *completion) validate() error { + if len(c.models) == 0 { + return errors.New("pkg/ai: at least one model is required") + } + + if len(c.messages) == 0 { + return errors.New("pkg/ai: at least one message is required") + } + + if c.retries < 0 { + return errors.New("pkg/ai: retries cannot be negative") + } + + return nil +} + +func (c *completion) generate(ctx context.Context) (*ModelResponse, error) { + if err := c.validate(); err != nil { + return nil, err + } + + if c.retries == 0 { + c.retries = 1 + } + + var ( + err error + res *ModelResponse + ) + + for i, model := range c.models { + if i > 0 && c.onModelFallback != nil { + c.onModelFallback(c.models[i-1], model, err) + } + + for retry := range c.retries { + res, err = c.stream(ctx, model) + if err == nil { + return c.execTools(ctx, res) + } + + if ctx.Err() != nil { + return nil, err + } + + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(err) + + // Gate: only retry on 429 if opt-in is enabled AND no fallbacks exist. + // Default behavior (retryOnRateLimit=false) treats 429 as non-retryable, + // identical to today's behavior before this feature was added. + if rateLimited && (!c.retryOnRateLimit || len(c.models) > 1) { + c.logger().Warn("Rate limited, treating as non-retryable", + "model", model.ID(), + "retry_on_rate_limit_enabled", c.retryOnRateLimit, + "fallbacks_count", len(c.models), + "error", err) + break + } + + if !retryable && !rateLimited { + c.logger().Error("Non-retryable error from model", + "model", model.ID(), + "error", err, + ) + break + } + + // Opt-in enabled, no fallbacks → retry same model after honouring Retry-After (or backoff). + if retryAfter > backoff.MaxRetryAfterWait { + c.logger().Warn("Retry-After exceeds maximum, capping", + "model", model.ID(), + "retry_after", retryAfter, + "max", backoff.MaxRetryAfterWait) + retryAfter = backoff.MaxRetryAfterWait + } + + if retryAfter <= 0 { + retryAfter = backoff.Calculate(retry) + } + + c.logger().Warn("Retryable error from model", + "model", model.ID(), + "attempt", retry+1, + "retryAfter", retryAfter, + "error", err) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryAfter): + } + } + } + + prefix := "model failed" + if len(c.models) > 1 { + prefix = "all models failed" + } + + errp := fmt.Errorf("%s: %w", prefix, err) + if modelerrors.IsContextOverflowError(err) { + return nil, modelerrors.NewContextOverflowError(errp) + } + + return nil, errp +} + +func (c *completion) stream(ctx context.Context, model provider.Provider) (*ModelResponse, error) { + if c.streamInterceptor == nil { + c.streamInterceptor = func( + ctx context.Context, + r *StreamRequest, + h StreamHandler, + ) (*ModelResponse, error) { + return h(ctx, r) + } + } + + r := &StreamRequest{ + Model: model, + Messages: c.messages, + Tools: c.tools, + } + + return c.streamInterceptor(ctx, r, func(ctx context.Context, r *StreamRequest) (*ModelResponse, error) { + s, err := r.Model.CreateChatCompletionStream(ctx, r.Messages, r.Tools) + if err != nil { + return nil, err + } + + defer s.Close() + + var ( + content strings.Builder + reasoning strings.Builder + ) + + res := &ModelResponse{ + Model: model.ID(), + } + + toolCallIndex := make(map[string]int) + + for { + resp, err := s.Recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return nil, fmt.Errorf("error receiving from stream: %w", err) + } + + if c.yield != nil && !c.yield(resp) { + // Caller signaled to stop the stream. + return nil, io.EOF + } + + if resp.Usage != nil { + res.Usage = resp.Usage + } + + if len(resp.Choices) == 0 { + continue + } + + choice := resp.Choices[0] + + if len(choice.Delta.ThoughtSignature) > 0 { + res.ThoughtSignature = choice.Delta.ThoughtSignature + } + + if choice.FinishReason == chat.FinishReasonStop || choice.FinishReason == chat.FinishReasonLength { + res.Content = content.String() + res.ReasoningContent = reasoning.String() + res.Stopped = true + res.FinishReason = choice.FinishReason + + if c.requireContent && strings.TrimSpace(res.Content) == "" && len(res.Calls) == 0 { + return nil, errors.New("pkg/ai: model returned empty response") + } + + return res, nil + } + + if choice.FinishReason != "" { + res.FinishReason = choice.FinishReason + } + + // Handle tool call deltas + if len(choice.Delta.ToolCalls) > 0 { + for _, delta := range choice.Delta.ToolCalls { + idx, ok := toolCallIndex[delta.ID] + if !ok { + idx = len(res.Calls) + toolCallIndex[delta.ID] = idx + res.Calls = append(res.Calls, tools.ToolCall{ + ID: delta.ID, + Type: delta.Type, + }) + } + + tc := &res.Calls[idx] + + if delta.Type != "" { + tc.Type = delta.Type + } + + if delta.Function.Name != "" { + tc.Function.Name = delta.Function.Name + } + + if delta.Function.Arguments != "" { + tc.Function.Arguments += delta.Function.Arguments + } + } + + continue + } + + if choice.Delta.ReasoningContent != "" { + reasoning.WriteString(choice.Delta.ReasoningContent) + } + + if choice.Delta.ThinkingSignature != "" { + res.ThinkingSignature = choice.Delta.ThinkingSignature + } + + if choice.Delta.Content != "" { + content.WriteString(choice.Delta.Content) + } + } + + res.Content = content.String() + res.ReasoningContent = reasoning.String() + res.Stopped = content.Len() == 0 && len(res.Calls) == 0 + + // Prefer the provider's explicit finish reason when available. + // Only fall back to inference when no explicit reason was received. + if res.FinishReason == "" { + switch { + case len(res.Calls) > 0: + res.FinishReason = chat.FinishReasonToolCalls + case content.Len() > 0: + res.FinishReason = chat.FinishReasonStop + default: + res.FinishReason = chat.FinishReasonNull + } + } + + // Ensure finish reason agrees with actual stream output. + switch { + case res.FinishReason == chat.FinishReasonToolCalls && len(res.Calls) == 0: + res.FinishReason = chat.FinishReasonNull + case res.FinishReason == chat.FinishReasonStop && len(res.Calls) > 0: + res.FinishReason = chat.FinishReasonToolCalls + } + + if c.requireContent && strings.TrimSpace(res.Content) == "" && len(res.Calls) == 0 { + return nil, errors.New("pkg/ai: model returned empty response") + } + + return res, nil + }) +} + +func (c *completion) execTools(ctx context.Context, r *ModelResponse) (*ModelResponse, error) { + if len(r.Calls) == 0 || c.returnToolRequests { + return r, nil + } + + c.turns++ + + if c.maxTurns > 0 && c.turns > c.maxTurns { + return nil, fmt.Errorf("pkg/ai: max turns reached (%d)", c.maxTurns) + } + + functions := make(map[string]tools.Tool, len(c.tools)) + for _, t := range c.tools { + functions[t.Name] = t + } + + var wg sync.WaitGroup + msgs := make([]chat.Message, len(r.Calls)) + + for i, call := range r.Calls { + wg.Go(func() { + t, ok := functions[call.Function.Name] + if !ok { + c.logger().Warn("Tool call for unavailable tool", "tool", call.Function.Name) + msgs[i] = chat.Message{ + Role: chat.MessageRoleTool, + Content: fmt.Sprintf( + "Tool '%s' is not available. You can only use the tools provided to you.", + call.Function.Name, + ), + ToolCallID: call.ID, + IsError: true, + CreatedAt: time.Now().Format(time.RFC3339), + } + + return + } + + fn := c.toolCallInterceptor + if fn == nil { + fn = func( + ctx context.Context, + _ *ModelResponse, + _ tools.ToolCall, + _ tools.Tool, + ) (*tools.ToolCallResult, error) { + return t.Handler(ctx, call) + } + } + + res, err := fn(ctx, r, call, t) + if err != nil { + msgs[i] = chat.Message{ + Role: chat.MessageRoleTool, + Content: "Error calling tool: " + err.Error(), + ToolCallID: call.ID, + IsError: true, + CreatedAt: time.Now().Format(time.RFC3339), + } + return + } + + if strings.TrimSpace(res.Output) == "" { + res.Output = "(no output)" + } + + msg := chat.Message{ + Role: chat.MessageRoleTool, + Content: res.Output, + ToolCallID: call.ID, + CreatedAt: time.Now().Format(time.RFC3339), + } + + if len(res.Images) > 0 { + msg.MultiContent = append(msg.MultiContent, chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: res.Output, + }) + + for _, img := range res.Images { + msg.MultiContent = append(msg.MultiContent, chat.MessagePart{ + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + URL: "data:" + img.MimeType + ";base64," + img.Data, + Detail: chat.ImageURLDetailAuto, + }, + }) + } + } + + msgs[i] = msg + }) + } + + wg.Wait() + + // Append the assistant message with tool calls, then the tool results. + c.messages = append(c.messages, chat.Message{ + Role: chat.MessageRoleAssistant, + Content: r.Content, + ToolCalls: r.Calls, + }) + c.messages = append(c.messages, msgs...) + + if c.models[0].ID() != r.Model { + idx := slices.IndexFunc(c.models, func(m provider.Provider) bool { + return m.ID() == r.Model + }) + + // Rotate to put the responding model first. + c.models = append(c.models[idx:], c.models[:idx]...) + } + + r2, err := c.generate(ctx) + if err != nil { + return nil, err + } + + r2.Turns = c.turns + r2.Messages = c.messages + + if r2.Usage != nil && r.Usage != nil { + r2.Usage = &chat.Usage{ + InputTokens: r.Usage.InputTokens + r2.Usage.InputTokens, + OutputTokens: r.Usage.OutputTokens + r2.Usage.OutputTokens, + CachedInputTokens: r.Usage.CachedInputTokens + r2.Usage.CachedInputTokens, + CacheWriteTokens: r.Usage.CacheWriteTokens + r2.Usage.CacheWriteTokens, + ReasoningTokens: r.Usage.ReasoningTokens + r2.Usage.ReasoningTokens, + } + } + + return r2, nil +} diff --git a/pkg/ai/completion_test.go b/pkg/ai/completion_test.go new file mode 100644 index 000000000..4407c5ce8 --- /dev/null +++ b/pkg/ai/completion_test.go @@ -0,0 +1,486 @@ +package ai + +import ( + "context" + "encoding/json" + "errors" + "io" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/tools" +) + +func TestStream(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in string + out string + err string + p *mockProvider + }{ + { + name: "create stream returns an error", + err: io.ErrUnexpectedEOF.Error(), + p: &mockProvider{ + err: io.ErrUnexpectedEOF, + }, + }, + { + name: "stream returns an error", + err: "error receiving from stream", + p: &mockProvider{ + streamErr: io.ErrUnexpectedEOF, + }, + }, + { + name: "simple text", + in: "testdata/simple_text_in.json", + out: "testdata/simple_text_out.json", + }, + { + name: "tool calls", + in: "testdata/tool_calls_in.json", + out: "testdata/tool_calls_out.json", + }, + { + name: "reasoning content", + in: "testdata/reasoning_in.json", + out: "testdata/reasoning_out.json", + }, + { + name: "empty stream", + in: "testdata/empty_stream_in.json", + out: "testdata/empty_stream_out.json", + }, + { + name: "finish reason length", + in: "testdata/finish_length_in.json", + out: "testdata/finish_length_out.json", + }, + { + name: "thinking signature", + in: "testdata/thinking_signature_in.json", + out: "testdata/thinking_signature_out.json", + }, + { + name: "content and tool calls", + in: "testdata/content_and_tool_calls_in.json", + out: "testdata/content_and_tool_calls_out.json", + }, + { + name: "finish reason tool_calls but no tools", + in: "testdata/finish_tool_calls_no_tools_in.json", + out: "testdata/finish_tool_calls_no_tools_out.json", + }, + { + name: "inferred stop from content", + in: "testdata/inferred_stop_in.json", + out: "testdata/inferred_stop_out.json", + }, + { + name: "inferred tool_calls from tools", + in: "testdata/inferred_tool_calls_in.json", + out: "testdata/inferred_tool_calls_out.json", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.p == nil { + tt.p = new(mockProvider) + } + + if tt.in != "" { + data, err := os.ReadFile(tt.in) + require.NoError(t, err) + + var msgs []chat.MessageStreamResponse + require.NoError(t, json.Unmarshal(data, &msgs)) + tt.p.msgs = msgs + } + + c := new(completion).applyOptions(WithReturnToolRequests()) + + resp, err := c.stream(t.Context(), tt.p) + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + return + } + + require.NoError(t, err) + + exp, err := os.ReadFile(tt.out) + require.NoError(t, err) + + resp.Messages = nil + got, err := json.Marshal(resp) + require.NoError(t, err) + + require.JSONEq(t, string(exp), string(got)) + }) + } +} + +func TestGenerate(t *testing.T) { + t.Parallel() + + msgs := []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{Content: "ok"}}}}, + {Choices: []chat.MessageStreamChoice{{FinishReason: chat.FinishReasonStop}}}, + } + + tests := []struct { + name string + models []*mockProvider + retries int + retryOnRate bool + err string + expCallCount map[string]int + onFallbackCount int + streamInterceptorCount int + }{ + { + name: "validation no models", + models: []*mockProvider{}, + err: "at least one model is required", + }, + { + name: "single model success", + models: []*mockProvider{ + { + id: "primary", + msgs: msgs, + }, + }, + retries: 1, + expCallCount: map[string]int{"primary": 1}, + streamInterceptorCount: 1, + }, + { + name: "single model retryable error then success", + models: []*mockProvider{ + { + id: "primary", + msgs: msgs, + failCount: 1, + err: errors.New("500 internal server error"), + }, + }, + retries: 3, + expCallCount: map[string]int{"primary": 2}, + streamInterceptorCount: 2, + }, + { + name: "single model non-retryable error", + models: []*mockProvider{ + { + id: "primary", + err: errors.New("400 Bad Request"), + }, + }, + retries: 3, + err: "model failed", + expCallCount: map[string]int{"primary": 1}, + }, + { + name: "fallback primary fails then fallback succeeds", + models: []*mockProvider{ + { + id: "primary", + err: errors.New("400 Bad Request"), + }, + { + id: "fallback", + msgs: msgs, + }, + }, + retries: 1, + expCallCount: map[string]int{"primary": 1, "fallback": 1}, + onFallbackCount: 1, + streamInterceptorCount: 2, + }, + { + name: "all models fail", + models: []*mockProvider{ + { + id: "primary", + err: errors.New("400 Bad Request"), + }, + { + id: "fallback", + err: errors.New("400 Bad Request"), + }, + }, + retries: 1, + err: "all models failed", + expCallCount: map[string]int{"primary": 1, "fallback": 1}, + onFallbackCount: 1, + }, + { + name: "rate limited skips to fallback", + models: []*mockProvider{ + { + id: "primary", + err: errors.New("429 Too Many Requests"), + }, + { + id: "fallback", + msgs: msgs, + }, + }, + retries: 3, + expCallCount: map[string]int{"primary": 1, "fallback": 1}, + onFallbackCount: 1, + streamInterceptorCount: 2, + }, + { + name: "rate limited retry opt-in no fallback", + models: []*mockProvider{ + { + id: "primary", + msgs: msgs, + failCount: 1, + err: errors.New("429 Too Many Requests"), + }, + }, + retries: 3, + retryOnRate: true, + expCallCount: map[string]int{"primary": 2}, + streamInterceptorCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ( + fallbackCount int + streamStartCount int + ) + + c := &completion{ + messages: []chat.Message{{Role: "user", Content: "test"}}, + retries: tt.retries, + retryOnRateLimit: tt.retryOnRate, + onModelFallback: func(from, to provider.Provider, err error) { + fallbackCount++ + }, + streamInterceptor: func(ctx context.Context, r *StreamRequest, h StreamHandler) (*ModelResponse, error) { + streamStartCount++ + return h(ctx, r) + }, + } + + for _, m := range tt.models { + c.models = append(c.models, m) + } + + res, err := c.generate(t.Context()) + + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + return + } + + require.NoError(t, err) + require.NotNil(t, res) + + for id, count := range tt.expCallCount { + for _, m := range tt.models { + if m.id == id { + require.Equal(t, count, m.callCount, "call count for %s", id) + } + } + } + + require.Equal(t, tt.onFallbackCount, fallbackCount, "onModelFallback count") + require.Equal(t, tt.streamInterceptorCount, streamStartCount, "streamInterceptor count") + }) + } +} + +func TestExecTools(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in string + out string + err string + tools []tools.Tool + opts []Option + }{ + { + name: "max turns reached", + in: "testdata/exec_tools_max_turns_in.json", + err: "max turns reached", + tools: []tools.Tool{ + { + Name: "greet", + Handler: func(ctx context.Context, call tools.ToolCall) (*tools.ToolCallResult, error) { + return &tools.ToolCallResult{Output: "Hello!"}, nil + }, + }, + }, + opts: []Option{WithMaxTurns(1)}, + }, + { + name: "tool returns images", + in: "testdata/exec_tools_images_in.json", + out: "testdata/exec_tools_images_out.json", + tools: []tools.Tool{ + { + Name: "screenshot", + Handler: func(ctx context.Context, call tools.ToolCall) (*tools.ToolCallResult, error) { + return &tools.ToolCallResult{ + Output: "screenshot taken", + Images: []tools.ImageContent{ + {MimeType: "image/png", Data: "iVBOR"}, + }, + }, nil + }, + }, + }, + }, + { + name: "mixed tool calls success not found and error", + in: "testdata/exec_tools_mixed_in.json", + out: "testdata/exec_tools_mixed_out.json", + tools: []tools.Tool{ + { + Name: "greet", + Handler: func(ctx context.Context, call tools.ToolCall) (*tools.ToolCallResult, error) { + return &tools.ToolCallResult{Output: "Hello, Alice!"}, nil + }, + }, + { + Name: "failing_tool", + Handler: func(ctx context.Context, call tools.ToolCall) (*tools.ToolCallResult, error) { + return nil, errors.New("something went wrong") + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := os.ReadFile(tt.in) + require.NoError(t, err) + + var responses []*ModelResponse + require.NoError(t, json.Unmarshal(data, &responses)) + + var turn int + + c := (&completion{ + models: []provider.Provider{&mockProvider{id: "test"}}, + messages: []chat.Message{{Role: "user", Content: "test"}}, + tools: tt.tools, + streamInterceptor: func(ctx context.Context, r *StreamRequest, h StreamHandler) (*ModelResponse, error) { + if turn >= len(responses) { + return nil, errors.New("unexpected call to stream interceptor") + } + res := responses[turn] + turn++ + return res, nil + }, + }).applyOptions(tt.opts...) + + res, err := c.generate(t.Context()) + + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + return + } + + require.NoError(t, err) + + exp, err := os.ReadFile(tt.out) + require.NoError(t, err) + + // Nil out Messages before comparison — they contain + // timestamps that vary per run. + res.Messages = nil + + got, err := json.Marshal(res) + require.NoError(t, err) + + require.JSONEq(t, string(exp), string(got)) + }) + } +} + +type mockProvider struct { + id string + err error + streamErr error + msgs []chat.MessageStreamResponse + callCount int + failCount int +} + +func (m *mockProvider) ID() string { + if m.id != "" { + return m.id + } + + return "mock" +} + +func (m *mockProvider) BaseConfig() base.Config { + return base.Config{} +} + +func (m *mockProvider) CreateChatCompletionStream( + ctx context.Context, + _ []chat.Message, + _ []tools.Tool, +) (chat.MessageStream, error) { + m.callCount++ + + if m.failCount > 0 { + m.failCount-- + if m.failCount == 0 { + m.failCount = -1 + } + return nil, m.err + } + + if m.failCount == 0 && m.err != nil { + return nil, m.err + } + + return &mockStream{ + err: m.streamErr, + msgs: m.msgs, + }, nil +} + +type mockStream struct { + err error + msgs []chat.MessageStreamResponse +} + +func (m *mockStream) Recv() (chat.MessageStreamResponse, error) { + if m.err != nil { + return chat.MessageStreamResponse{}, m.err + } + + if len(m.msgs) == 0 { + return chat.MessageStreamResponse{}, io.EOF + } + + msg := m.msgs[0] + m.msgs = m.msgs[1:] + return msg, nil +} + +func (m *mockStream) Close() {} diff --git a/pkg/ai/generate.go b/pkg/ai/generate.go new file mode 100644 index 000000000..6254e9350 --- /dev/null +++ b/pkg/ai/generate.go @@ -0,0 +1,89 @@ +package ai + +import ( + "context" + "encoding/json" + "errors" + "io" + "iter" + + "github.com/docker/docker-agent/pkg/chat" +) + +// StreamValue represents a single value yielded during streaming. +type StreamValue[Out, Stream any] struct { + Done bool + Chunk Stream // valid if Done is false + Value Out // valid if Done is true + Response *ModelResponse // valid if Done is true +} + +// ModelStreamValue is a stream value for a model response. +// Out is never set because the value is already available in the Response field. +type ModelStreamValue = StreamValue[struct{}, chat.MessageStreamResponse] + +// GenerateStream generates a model response and streams the output. +// It returns an iterator that yields streaming results. +func GenerateStream(ctx context.Context, opts ...Option) iter.Seq2[*ModelStreamValue, error] { + return func(yield func(*ModelStreamValue, error) bool) { + c := &completion{ + yield: func(resp chat.MessageStreamResponse) bool { + return yield(&ModelStreamValue{ + Done: false, + Chunk: resp, + }, nil) + }, + } + + c = c.applyOptions(opts...) + + res, err := c.generate(ctx) + if errors.Is(err, io.EOF) { + return + } + + if err != nil { + yield(nil, err) + return + } + + yield(&ModelStreamValue{ + Done: true, + Response: res, + }, nil) + } +} + +// Generate runs a completion and returns the final model response. +// It handles retry, fallback, tool execution, and streaming internally. +func Generate(ctx context.Context, opts ...Option) (*ModelResponse, error) { + return new(completion).applyOptions(opts...).generate(ctx) +} + +// GenerateText is a convenience wrapper around Generate that returns +// only the text content from the model response. +func GenerateText(ctx context.Context, opts ...Option) (string, error) { + res, err := Generate(ctx, opts...) + if err != nil { + return "", err + } + + return res.Content, nil +} + +// GenerateValue runs a completion and unmarshals the model's response +// content into the provided type. Use with structured output to get +// typed responses from the model. +func GenerateValue[Out any](ctx context.Context, opts ...Option) (*Out, error) { + res, err := Generate(ctx, opts...) + if err != nil { + return nil, err + } + + var out Out + if err := json.Unmarshal([]byte(res.Content), &out); err != nil { + return nil, err + } + + return &out, nil +} diff --git a/pkg/ai/generate_test.go b/pkg/ai/generate_test.go new file mode 100644 index 000000000..6b8afff72 --- /dev/null +++ b/pkg/ai/generate_test.go @@ -0,0 +1,229 @@ +package ai + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" +) + +func TestGenerateStream(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + p *mockProvider + err string + expContent string + }{ + { + name: "happy path yields chunks then done", + p: &mockProvider{ + id: "test", + msgs: []chat.MessageStreamResponse{ + { + Choices: []chat.MessageStreamChoice{ + {Delta: chat.MessageDelta{Content: "hello"}}, + }, + }, + { + Choices: []chat.MessageStreamChoice{ + {Delta: chat.MessageDelta{Content: " world"}}, + }, + }, + { + Choices: []chat.MessageStreamChoice{ + {FinishReason: chat.FinishReasonStop}, + }, + Usage: &chat.Usage{InputTokens: 10}, + }, + }, + }, + expContent: "hello world", + }, + { + name: "error yields error", + p: &mockProvider{ + id: "test", + err: errors.New("model failed"), + }, + err: "model failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := []Option{ + WithModels(tt.p), + WithMessages(chat.Message{Role: "user", Content: "test"}), + } + + var ( + chunks int + res *ModelResponse + ) + + for sv, err := range GenerateStream(t.Context(), opts...) { + if err != nil { + require.ErrorContains(t, err, tt.err) + return + } + + if sv.Done { + res = sv.Response + break + } + + chunks++ + } + + if tt.err != "" { + t.Fatal("expected error but got none") + } + + require.NotNil(t, res) + require.Equal(t, tt.expContent, res.Content) + require.Positive(t, chunks) + }) + } +} + +func TestGenerateText(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + p *mockProvider + err string + expContent string + }{ + { + name: "returns text content", + p: &mockProvider{ + id: "test", + msgs: []chat.MessageStreamResponse{ + { + Choices: []chat.MessageStreamChoice{ + {Delta: chat.MessageDelta{Content: "hello"}}, + }, + }, + { + Choices: []chat.MessageStreamChoice{ + {FinishReason: chat.FinishReasonStop}, + }, + }, + }, + }, + expContent: "hello", + }, + { + name: "error returns empty string", + p: &mockProvider{ + id: "test", + err: errors.New("model failed"), + }, + err: "model failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + text, err := GenerateText(t.Context(), + WithModels(tt.p), + WithMessages(chat.Message{Role: "user", Content: "test"}), + ) + + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + require.Empty(t, text) + return + } + + require.NoError(t, err) + require.Equal(t, tt.expContent, text) + }) + } +} + +func TestGenerateValue(t *testing.T) { + t.Parallel() + + type Person struct { + Name string `json:"name"` + Age int `json:"age"` + } + + tests := []struct { + name string + p *mockProvider + err string + exp *Person + }{ + { + name: "unmarshals json response", + p: &mockProvider{ + id: "test", + msgs: []chat.MessageStreamResponse{ + { + Choices: []chat.MessageStreamChoice{ + {Delta: chat.MessageDelta{Content: `{"name":"Alice","age":30}`}}, + }, + }, + { + Choices: []chat.MessageStreamChoice{ + {FinishReason: chat.FinishReasonStop}, + }, + }, + }, + }, + exp: &Person{Name: "Alice", Age: 30}, + }, + { + name: "invalid json returns error", + p: &mockProvider{ + id: "test", + msgs: []chat.MessageStreamResponse{ + { + Choices: []chat.MessageStreamChoice{ + {Delta: chat.MessageDelta{Content: "not json"}}, + }, + }, + { + Choices: []chat.MessageStreamChoice{ + {FinishReason: chat.FinishReasonStop}, + }, + }, + }, + }, + err: "invalid character", + }, + { + name: "model error returns error", + p: &mockProvider{ + id: "test", + err: errors.New("model failed"), + }, + err: "model failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := GenerateValue[Person](t.Context(), + WithModels(tt.p), + WithMessages(chat.Message{Role: "user", Content: "test"}), + ) + + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + require.Nil(t, result) + return + } + + require.NoError(t, err) + require.Equal(t, tt.exp, result) + }) + } +} diff --git a/pkg/ai/interceptor.go b/pkg/ai/interceptor.go new file mode 100644 index 000000000..fc375650c --- /dev/null +++ b/pkg/ai/interceptor.go @@ -0,0 +1,43 @@ +package ai + +import ( + "context" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/tools" +) + +// StreamRequest holds the parameters for a single model stream call. +// It is passed through the interceptor chain and can be inspected or +// modified by interceptors before reaching the actual model call. +type StreamRequest struct { + Model provider.Provider + Messages []chat.Message + Tools []tools.Tool +} + +// StreamInterceptor wraps a stream call, allowing callers to observe, +// modify, or short-circuit the request before and after it reaches the +// model. The interceptor receives the request and a handler to call the +// next step in the chain — either another interceptor or the actual +// model call. Returning without calling the handler skips the model call. +// +// Example: +// +// func logInterceptor(ctx context.Context, r *StreamRequest, h StreamHandler) (*ModelResponse, error) { +// // before: inspect or modify request +// res, err := h(ctx, r) +// // after: inspect response, record telemetry, etc. +// return res, err +// } +type StreamInterceptor func(context.Context, *StreamRequest, StreamHandler) (*ModelResponse, error) + +// StreamHandler is the function signature for the next step in the +// interceptor chain. Call it to proceed with the stream request. +type StreamHandler func(context.Context, *StreamRequest) (*ModelResponse, error) + +// ToolCallInterceptor wraps an individual tool call execution. +// The interceptor is responsible for calling tool.Handler and can +// add behavior before and after (permissions, logging, telemetry). +type ToolCallInterceptor func(context.Context, *ModelResponse, tools.ToolCall, tools.Tool) (*tools.ToolCallResult, error) diff --git a/pkg/ai/option.go b/pkg/ai/option.go new file mode 100644 index 000000000..083ffbbe1 --- /dev/null +++ b/pkg/ai/option.go @@ -0,0 +1,117 @@ +package ai + +import ( + "log/slog" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/tools" +) + +// Option configures a completion request. +type Option func(*completion) + +// WithLogger sets the logger used by the completion engine. +// Defaults to slog.Default() if not set. +func WithLogger(lg *slog.Logger) Option { + return func(c *completion) { + c.lg = lg + } +} + +// WithModels sets the model providers for the completion. The first model +// is the primary; any additional models are used as fallbacks in order +// when the primary fails with a non-retryable error. +func WithModels(models ...provider.Provider) Option { + return func(c *completion) { + c.models = append(c.models, models...) + } +} + +// WithMessages sets the conversation messages to send to the model. +func WithMessages(messages ...chat.Message) Option { + return func(c *completion) { + c.messages = append(c.messages, messages...) + } +} + +// WithTools sets the tools available for the model to call. +func WithTools(t ...tools.Tool) Option { + return func(c *completion) { + c.tools = append(c.tools, t...) + } +} + +// WithRetries sets the number of retry attempts per model on retryable +// errors (5xx, timeouts). The total attempts per model is n + 1. +func WithRetries(n int) Option { + return func(c *completion) { + c.retries = n + } +} + +// WithRetryOnRateLimit enables retrying on 429 rate limit errors when +// no fallback models are available. By default, 429 errors are treated +// as non-retryable and skip to the next fallback. +func WithRetryOnRateLimit() Option { + return func(c *completion) { + c.retryOnRateLimit = true + } +} + +// WithOnModelFallback sets a callback that is called when switching from +// one model to another due to a failure. The callback receives the previous +// model, the next model, and the error that caused the fallback. +func WithOnModelFallback(fn func(from, to provider.Provider, err error)) Option { + return func(c *completion) { + c.onModelFallback = fn + } +} + +// WithRequireContent causes the completion to treat an empty model +// response (no content and no tool calls) as an error, triggering +// a fallback to the next model in the chain. +func WithRequireContent() Option { + return func(c *completion) { + c.requireContent = true + } +} + +// WithReturnToolRequests configures whether to return tool requests +// instead of making the tool calls and continuing the generation. +func WithReturnToolRequests() Option { + return func(c *completion) { + c.returnToolRequests = true + } +} + +// WithMaxTurns sets the maximum number of tool execution round trips. +// A turn is one cycle of: model returns tool calls → tools execute → +// results sent back to model. For example, WithMaxTurns(2) allows up +// to 2 tool round trips (3 total model calls: initial + 2 follow-ups). +// The turn count is available on ModelResponse.Turns after completion. +// A value of 0 means no limit. +func WithMaxTurns(n int) Option { + return func(c *completion) { + c.maxTurns = n + } +} + +// WithToolCallInterceptor sets an interceptor that wraps each individual +// tool call execution. The interceptor is responsible for calling +// tool.Handler and can add behavior before and after. +func WithToolCallInterceptor(fn ToolCallInterceptor) Option { + return func(c *completion) { + c.toolCallInterceptor = fn + } +} + +// WithStreamInterceptor sets an interceptor that wraps every model stream +// call. The interceptor can observe, modify, or short-circuit the request +// and response. It is called on every attempt, including retries and +// fallbacks. +func WithStreamInterceptor(fn StreamInterceptor) Option { + return func(c *completion) { + c.streamInterceptor = fn + } +} diff --git a/pkg/ai/options_test.go b/pkg/ai/options_test.go new file mode 100644 index 000000000..d97de7fb8 --- /dev/null +++ b/pkg/ai/options_test.go @@ -0,0 +1,200 @@ +package ai + +import ( + "context" + "log/slog" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/tools" +) + +func TestOptions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts []Option + fn func(t *testing.T, c *completion) + }{ + { + name: "WithModels sets models", + opts: []Option{ + WithModels(new(mockProvider), new(mockProvider)), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.Len(t, c.models, 2) + }, + }, + { + name: "WithModels appends", + opts: []Option{ + WithModels(new(mockProvider)), + WithModels(new(mockProvider)), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.Len(t, c.models, 2) + }, + }, + { + name: "WithMessages sets messages", + opts: []Option{ + WithMessages( + chat.Message{Role: "system", Content: "you are helpful"}, + chat.Message{Role: "user", Content: "hello"}, + ), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.Len(t, c.messages, 2) + require.Equal(t, "system", string(c.messages[0].Role)) + require.Equal(t, "user", string(c.messages[1].Role)) + }, + }, + { + name: "WithMessages appends", + opts: []Option{ + WithMessages(chat.Message{Role: "system", Content: "a"}), + WithMessages(chat.Message{Role: "user", Content: "b"}), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.Len(t, c.messages, 2) + }, + }, + { + name: "WithTools sets tools", + opts: []Option{ + WithTools( + tools.Tool{Name: "read_file"}, + tools.Tool{Name: "write_file"}, + ), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.Len(t, c.tools, 2) + require.Equal(t, "read_file", c.tools[0].Name) + require.Equal(t, "write_file", c.tools[1].Name) + }, + }, + { + name: "WithTools appends", + opts: []Option{ + WithTools(tools.Tool{Name: "a"}), + WithTools(tools.Tool{Name: "b"}), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.Len(t, c.tools, 2) + }, + }, + { + name: "WithRetries sets retries", + opts: []Option{ + WithRetries(5), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.Equal(t, 5, c.retries) + }, + }, + { + name: "WithRetryOnRateLimit enables flag", + opts: []Option{ + WithRetryOnRateLimit(), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.True(t, c.retryOnRateLimit) + }, + }, + { + name: "WithOnModelFallback sets callback", + opts: []Option{ + WithOnModelFallback(func(from, to provider.Provider, err error) {}), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.NotNil(t, c.onModelFallback) + }, + }, + { + name: "WithToolCallInterceptor sets interceptor", + opts: []Option{ + WithToolCallInterceptor(func( + context.Context, *ModelResponse, tools.ToolCall, tools.Tool, + ) (*tools.ToolCallResult, error) { + return nil, nil + }), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.NotNil(t, c.toolCallInterceptor) + }, + }, + { + name: "WithMaxTurns sets max turns", + opts: []Option{ + WithMaxTurns(5), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.Equal(t, 5, c.maxTurns) + }, + }, + { + name: "WithReturnToolRequests enables flag", + opts: []Option{ + WithReturnToolRequests(), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.True(t, c.returnToolRequests) + }, + }, + { + name: "WithLogger sets logger", + opts: []Option{ + WithLogger(slog.Default()), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.NotNil(t, c.lg) + }, + }, + { + name: "WithRequireContent enables flag", + opts: []Option{ + WithRequireContent(), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.True(t, c.requireContent) + }, + }, + { + name: "WithStreamInterceptor sets interceptor", + opts: []Option{ + WithStreamInterceptor(func(ctx context.Context, r *StreamRequest, h StreamHandler) (*ModelResponse, error) { + return h(ctx, r) + }), + }, + fn: func(t *testing.T, c *completion) { + t.Helper() + require.NotNil(t, c.streamInterceptor) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := (&completion{}).applyOptions(tt.opts...) + tt.fn(t, c) + }) + } +} diff --git a/pkg/ai/testdata/content_and_tool_calls_in.json b/pkg/ai/testdata/content_and_tool_calls_in.json new file mode 100644 index 000000000..aaafcdff8 --- /dev/null +++ b/pkg/ai/testdata/content_and_tool_calls_in.json @@ -0,0 +1,23 @@ +[ + { + "choices": [{ "delta": { "content": "Let me check that file." } }] + }, + { + "choices": [{ + "delta": { + "tool_calls": [{ "id": "call_2", "type": "function", "function": { "name": "read_file" } }] + } + }] + }, + { + "choices": [{ + "delta": { + "tool_calls": [{ "id": "call_2", "type": "function", "function": { "arguments": "{\"path\": \"/tmp\"}" } }] + } + }] + }, + { + "choices": [], + "usage": { "input_tokens": 25, "output_tokens": 12 } + } +] diff --git a/pkg/ai/testdata/content_and_tool_calls_out.json b/pkg/ai/testdata/content_and_tool_calls_out.json new file mode 100644 index 000000000..04fe1e018 --- /dev/null +++ b/pkg/ai/testdata/content_and_tool_calls_out.json @@ -0,0 +1,19 @@ +{ + "Calls": [ + { + "id": "call_2", + "type": "function", + "function": { "name": "read_file", "arguments": "{\"path\": \"/tmp\"}" } + } + ], + "FinishReason": "tool_calls", + "Usage": { "input_tokens": 25, "output_tokens": 12, "cached_input_tokens": 0, "cached_write_tokens": 0 }, + "ThoughtSignature": null, + "Stopped": false, + "Content": "Let me check that file.", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/ai/testdata/empty_stream_in.json b/pkg/ai/testdata/empty_stream_in.json new file mode 100644 index 000000000..fe51488c7 --- /dev/null +++ b/pkg/ai/testdata/empty_stream_in.json @@ -0,0 +1 @@ +[] diff --git a/pkg/ai/testdata/empty_stream_out.json b/pkg/ai/testdata/empty_stream_out.json new file mode 100644 index 000000000..267f53741 --- /dev/null +++ b/pkg/ai/testdata/empty_stream_out.json @@ -0,0 +1,13 @@ +{ + "Calls": null, + "FinishReason": "null", + "Usage": null, + "ThoughtSignature": null, + "Stopped": true, + "Content": "", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/ai/testdata/exec_tools_images_in.json b/pkg/ai/testdata/exec_tools_images_in.json new file mode 100644 index 000000000..0b037b32a --- /dev/null +++ b/pkg/ai/testdata/exec_tools_images_in.json @@ -0,0 +1,23 @@ +[ + { + "Calls": [ + { + "id": "call_1", + "type": "function", + "function": { "name": "screenshot", "arguments": "{}" } + } + ], + "FinishReason": "tool_calls", + "Usage": { "input_tokens": 10, "output_tokens": 5 }, + "Model": "test", + "Content": "" + }, + { + "Calls": null, + "FinishReason": "stop", + "Usage": { "input_tokens": 30, "output_tokens": 10 }, + "Stopped": true, + "Content": "I can see the screenshot.", + "Model": "test" + } +] diff --git a/pkg/ai/testdata/exec_tools_images_out.json b/pkg/ai/testdata/exec_tools_images_out.json new file mode 100644 index 000000000..494f2b58f --- /dev/null +++ b/pkg/ai/testdata/exec_tools_images_out.json @@ -0,0 +1,13 @@ +{ + "Calls": null, + "FinishReason": "stop", + "Usage": { "input_tokens": 40, "output_tokens": 15, "cached_input_tokens": 0, "cached_write_tokens": 0 }, + "ThoughtSignature": null, + "Stopped": true, + "Content": "I can see the screenshot.", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "test", + "Turns": 1, + "Messages": null +} diff --git a/pkg/ai/testdata/exec_tools_max_turns_in.json b/pkg/ai/testdata/exec_tools_max_turns_in.json new file mode 100644 index 000000000..68bfc11c3 --- /dev/null +++ b/pkg/ai/testdata/exec_tools_max_turns_in.json @@ -0,0 +1,28 @@ +[ + { + "Calls": [ + { + "id": "call_1", + "type": "function", + "function": { "name": "greet", "arguments": "{\"name\": \"Alice\"}" } + } + ], + "FinishReason": "tool_calls", + "Usage": { "input_tokens": 10, "output_tokens": 5 }, + "Model": "test", + "Content": "" + }, + { + "Calls": [ + { + "id": "call_2", + "type": "function", + "function": { "name": "greet", "arguments": "{\"name\": \"Bob\"}" } + } + ], + "FinishReason": "tool_calls", + "Usage": { "input_tokens": 20, "output_tokens": 5 }, + "Model": "test", + "Content": "" + } +] diff --git a/pkg/ai/testdata/exec_tools_mixed_in.json b/pkg/ai/testdata/exec_tools_mixed_in.json new file mode 100644 index 000000000..85ccfd35f --- /dev/null +++ b/pkg/ai/testdata/exec_tools_mixed_in.json @@ -0,0 +1,33 @@ +[ + { + "Calls": [ + { + "id": "call_1", + "type": "function", + "function": { "name": "greet", "arguments": "{\"name\": \"Alice\"}" } + }, + { + "id": "call_2", + "type": "function", + "function": { "name": "unknown_tool", "arguments": "{}" } + }, + { + "id": "call_3", + "type": "function", + "function": { "name": "failing_tool", "arguments": "{}" } + } + ], + "FinishReason": "tool_calls", + "Usage": { "input_tokens": 20, "output_tokens": 10 }, + "Model": "test", + "Content": "" + }, + { + "Calls": null, + "FinishReason": "stop", + "Usage": { "input_tokens": 50, "output_tokens": 15 }, + "Stopped": true, + "Content": "Hello! I greeted Alice for you.", + "Model": "test" + } +] diff --git a/pkg/ai/testdata/exec_tools_mixed_out.json b/pkg/ai/testdata/exec_tools_mixed_out.json new file mode 100644 index 000000000..b321b5008 --- /dev/null +++ b/pkg/ai/testdata/exec_tools_mixed_out.json @@ -0,0 +1,13 @@ +{ + "Calls": null, + "FinishReason": "stop", + "Usage": { "input_tokens": 70, "output_tokens": 25, "cached_input_tokens": 0, "cached_write_tokens": 0 }, + "ThoughtSignature": null, + "Stopped": true, + "Content": "Hello! I greeted Alice for you.", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "test", + "Turns": 1, + "Messages": null +} diff --git a/pkg/ai/testdata/finish_length_in.json b/pkg/ai/testdata/finish_length_in.json new file mode 100644 index 000000000..916ba742d --- /dev/null +++ b/pkg/ai/testdata/finish_length_in.json @@ -0,0 +1,12 @@ +[ + { + "choices": [{ "delta": { "content": "This response is trun" } }] + }, + { + "choices": [{ "delta": { "content": "cated due to" } }] + }, + { + "choices": [{ "finish_reason": "length" }], + "usage": { "input_tokens": 50, "output_tokens": 100 } + } +] diff --git a/pkg/ai/testdata/finish_length_out.json b/pkg/ai/testdata/finish_length_out.json new file mode 100644 index 000000000..cd0ea74be --- /dev/null +++ b/pkg/ai/testdata/finish_length_out.json @@ -0,0 +1,13 @@ +{ + "Calls": null, + "FinishReason": "length", + "Usage": { "input_tokens": 50, "output_tokens": 100, "cached_input_tokens": 0, "cached_write_tokens": 0 }, + "ThoughtSignature": null, + "Stopped": true, + "Content": "This response is truncated due to", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/ai/testdata/finish_tool_calls_no_tools_in.json b/pkg/ai/testdata/finish_tool_calls_no_tools_in.json new file mode 100644 index 000000000..c75ea6174 --- /dev/null +++ b/pkg/ai/testdata/finish_tool_calls_no_tools_in.json @@ -0,0 +1,5 @@ +[ + { + "choices": [{ "delta": { "content": "No tools here." }, "finish_reason": "tool_calls" }] + } +] diff --git a/pkg/ai/testdata/finish_tool_calls_no_tools_out.json b/pkg/ai/testdata/finish_tool_calls_no_tools_out.json new file mode 100644 index 000000000..578ecac3d --- /dev/null +++ b/pkg/ai/testdata/finish_tool_calls_no_tools_out.json @@ -0,0 +1,13 @@ +{ + "Calls": null, + "FinishReason": "null", + "Usage": null, + "ThoughtSignature": null, + "Stopped": false, + "Content": "No tools here.", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/ai/testdata/inferred_stop_in.json b/pkg/ai/testdata/inferred_stop_in.json new file mode 100644 index 000000000..81a3ad121 --- /dev/null +++ b/pkg/ai/testdata/inferred_stop_in.json @@ -0,0 +1,5 @@ +[ + { + "choices": [{ "delta": { "content": "Just some text." } }] + } +] diff --git a/pkg/ai/testdata/inferred_stop_out.json b/pkg/ai/testdata/inferred_stop_out.json new file mode 100644 index 000000000..b9ebe0a9f --- /dev/null +++ b/pkg/ai/testdata/inferred_stop_out.json @@ -0,0 +1,13 @@ +{ + "Calls": null, + "FinishReason": "stop", + "Usage": null, + "ThoughtSignature": null, + "Stopped": false, + "Content": "Just some text.", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/ai/testdata/inferred_tool_calls_in.json b/pkg/ai/testdata/inferred_tool_calls_in.json new file mode 100644 index 000000000..ad6619070 --- /dev/null +++ b/pkg/ai/testdata/inferred_tool_calls_in.json @@ -0,0 +1,9 @@ +[ + { + "choices": [{ + "delta": { + "tool_calls": [{ "id": "call_5", "type": "function", "function": { "name": "echo", "arguments": "{\"msg\": \"hi\"}" } }] + } + }] + } +] diff --git a/pkg/ai/testdata/inferred_tool_calls_out.json b/pkg/ai/testdata/inferred_tool_calls_out.json new file mode 100644 index 000000000..aa13a6a83 --- /dev/null +++ b/pkg/ai/testdata/inferred_tool_calls_out.json @@ -0,0 +1,19 @@ +{ + "Calls": [ + { + "id": "call_5", + "type": "function", + "function": { "name": "echo", "arguments": "{\"msg\": \"hi\"}" } + } + ], + "FinishReason": "tool_calls", + "Usage": null, + "ThoughtSignature": null, + "Stopped": false, + "Content": "", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/ai/testdata/reasoning_in.json b/pkg/ai/testdata/reasoning_in.json new file mode 100644 index 000000000..474f72cf1 --- /dev/null +++ b/pkg/ai/testdata/reasoning_in.json @@ -0,0 +1,15 @@ +[ + { + "choices": [{ "delta": { "reasoning_content": "Let me think" } }] + }, + { + "choices": [{ "delta": { "reasoning_content": " about this..." } }] + }, + { + "choices": [{ "delta": { "content": "The answer is 42." } }] + }, + { + "choices": [{ "finish_reason": "stop" }], + "usage": { "input_tokens": 20, "output_tokens": 10 } + } +] diff --git a/pkg/ai/testdata/reasoning_out.json b/pkg/ai/testdata/reasoning_out.json new file mode 100644 index 000000000..a845b788b --- /dev/null +++ b/pkg/ai/testdata/reasoning_out.json @@ -0,0 +1,13 @@ +{ + "Calls": null, + "FinishReason": "stop", + "Usage": { "input_tokens": 20, "output_tokens": 10, "cached_input_tokens": 0, "cached_write_tokens": 0 }, + "ThoughtSignature": null, + "Stopped": true, + "Content": "The answer is 42.", + "ReasoningContent": "Let me think about this...", + "ThinkingSignature": "", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/ai/testdata/simple_text_in.json b/pkg/ai/testdata/simple_text_in.json new file mode 100644 index 000000000..2643b44f7 --- /dev/null +++ b/pkg/ai/testdata/simple_text_in.json @@ -0,0 +1,15 @@ +[ + { + "choices": [{ "delta": { "content": "Hello" } }] + }, + { + "choices": [{ "delta": { "content": " world" } }] + }, + { + "choices": [{ "delta": { "content": "!" } }] + }, + { + "choices": [{ "finish_reason": "stop" }], + "usage": { "input_tokens": 10, "output_tokens": 3 } + } +] diff --git a/pkg/ai/testdata/simple_text_out.json b/pkg/ai/testdata/simple_text_out.json new file mode 100644 index 000000000..99bf65bd7 --- /dev/null +++ b/pkg/ai/testdata/simple_text_out.json @@ -0,0 +1,13 @@ +{ + "Calls": null, + "FinishReason": "stop", + "Usage": { "input_tokens": 10, "output_tokens": 3, "cached_input_tokens": 0, "cached_write_tokens": 0 }, + "ThoughtSignature": null, + "Stopped": true, + "Content": "Hello world!", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/ai/testdata/thinking_signature_in.json b/pkg/ai/testdata/thinking_signature_in.json new file mode 100644 index 000000000..c63d6c784 --- /dev/null +++ b/pkg/ai/testdata/thinking_signature_in.json @@ -0,0 +1,18 @@ +[ + { + "choices": [{ "delta": { "thinking_signature": "sig_abc123" } }] + }, + { + "choices": [{ "delta": { "thought_signature": "dGhvdWdodA==" } }] + }, + { + "choices": [{ "delta": { "reasoning_content": "Deep thought..." } }] + }, + { + "choices": [{ "delta": { "content": "Result." } }] + }, + { + "choices": [{ "finish_reason": "stop" }], + "usage": { "input_tokens": 30, "output_tokens": 15 } + } +] diff --git a/pkg/ai/testdata/thinking_signature_out.json b/pkg/ai/testdata/thinking_signature_out.json new file mode 100644 index 000000000..5cbbb7a01 --- /dev/null +++ b/pkg/ai/testdata/thinking_signature_out.json @@ -0,0 +1,13 @@ +{ + "Calls": null, + "FinishReason": "stop", + "Usage": { "input_tokens": 30, "output_tokens": 15, "cached_input_tokens": 0, "cached_write_tokens": 0 }, + "ThoughtSignature": "dGhvdWdodA==", + "Stopped": true, + "Content": "Result.", + "ReasoningContent": "Deep thought...", + "ThinkingSignature": "sig_abc123", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/ai/testdata/tool_calls_in.json b/pkg/ai/testdata/tool_calls_in.json new file mode 100644 index 000000000..238a6900b --- /dev/null +++ b/pkg/ai/testdata/tool_calls_in.json @@ -0,0 +1,27 @@ +[ + { + "choices": [{ + "delta": { + "tool_calls": [{ "id": "call_1", "type": "function", "function": { "name": "read_file" } }] + } + }] + }, + { + "choices": [{ + "delta": { + "tool_calls": [{ "id": "call_1", "type": "function", "function": { "arguments": "{\"path\":" } }] + } + }] + }, + { + "choices": [{ + "delta": { + "tool_calls": [{ "id": "call_1", "type": "function", "function": { "arguments": " \"/root\"}" } }] + } + }] + }, + { + "choices": [{ "finish_reason": "tool_calls" }], + "usage": { "input_tokens": 15, "output_tokens": 8 } + } +] diff --git a/pkg/ai/testdata/tool_calls_out.json b/pkg/ai/testdata/tool_calls_out.json new file mode 100644 index 000000000..d82a967b6 --- /dev/null +++ b/pkg/ai/testdata/tool_calls_out.json @@ -0,0 +1,19 @@ +{ + "Calls": [ + { + "id": "call_1", + "type": "function", + "function": { "name": "read_file", "arguments": "{\"path\": \"/root\"}" } + } + ], + "FinishReason": "tool_calls", + "Usage": { "input_tokens": 15, "output_tokens": 8, "cached_input_tokens": 0, "cached_write_tokens": 0 }, + "ThoughtSignature": null, + "Stopped": false, + "Content": "", + "ReasoningContent": "", + "ThinkingSignature": "", + "Model": "mock", + "Turns": 0, + "Messages": null +} diff --git a/pkg/runtime/fallback.go b/pkg/runtime/fallback.go index 8d5e8d686..ce2e85ef7 100644 --- a/pkg/runtime/fallback.go +++ b/pkg/runtime/fallback.go @@ -2,13 +2,11 @@ package runtime import ( "context" - "errors" - "fmt" "log/slog" "time" "github.com/docker/docker-agent/pkg/agent" - "github.com/docker/docker-agent/pkg/backoff" + "github.com/docker/docker-agent/pkg/ai" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/modelerrors" @@ -52,32 +50,6 @@ func buildModelChain(primary provider.Provider, fallbacks []provider.Provider) [ } // logFallbackAttempt logs information about a fallback attempt -func logFallbackAttempt(agentName string, model modelWithFallback, attempt, maxRetries int, err error) { - if model.isFallback { - slog.Warn("Fallback model attempt", - "agent", agentName, - "model", model.provider.ID(), - "fallback_index", model.index, - "attempt", attempt+1, - "max_retries", maxRetries+1, - "previous_error", err) - } else { - slog.Warn("Primary model failed, trying fallbacks", - "agent", agentName, - "model", model.provider.ID(), - "error", err) - } -} - -// logRetryBackoff logs when we're backing off before a retry -func logRetryBackoff(agentName, modelID string, attempt int, backoffDelay time.Duration) { - slog.Debug("Backing off before retry", - "agent", agentName, - "model", modelID, - "attempt", attempt+1, - "backoff", backoffDelay) -} - // getCooldownState returns the current cooldown state for an agent (thread-safe). // Returns nil if no cooldown is active or if cooldown has expired. // Expired entries are evicted to prevent stale state accumulation. @@ -158,20 +130,11 @@ func getEffectiveRetries(a *agent.Agent) int { return retries } -// tryModelWithFallback attempts to create a stream and get a response using the primary model, +// tryModelWithFallback attempts to get a response using the primary model, // falling back to configured fallback models if the primary fails. // -// Retry behavior: -// - Retryable errors (5xx, timeouts): retry the same model with exponential backoff -// - Non-retryable errors (429, 4xx): skip to the next model in the chain immediately -// -// Cooldown behavior: -// - When the primary fails with a non-retryable error and a fallback succeeds, the runtime -// "sticks" with that fallback for a configurable cooldown period. -// - During cooldown, subsequent calls skip the primary and start from the pinned fallback. -// - When cooldown expires, the primary is tried again; if it succeeds, cooldown is cleared. -// -// Returns the stream result, the model that was used, and any error. +// Retry, fallback, and streaming are delegated to pkg/ai. Cooldown state +// (pinning to a successful fallback) is managed here in the runtime. func (r *LocalRuntime) tryModelWithFallback( ctx context.Context, a *agent.Agent, @@ -184,246 +147,76 @@ func (r *LocalRuntime) tryModelWithFallback( ) (streamResult, provider.Provider, error) { fallbackModels := a.FallbackModels() - fallbackRetries := getEffectiveRetries(a) - - // Build the chain of models to try: primary (index 0) + fallbacks (index 1+) - modelChain := buildModelChain(primaryModel, fallbackModels) + // Build model list respecting cooldown + models := []provider.Provider{primaryModel} + models = append(models, fallbackModels...) - // Check if we're in a cooldown period and should skip the primary - startIndex := 0 - inCooldown := false cooldownState := r.getCooldownState(a.Name()) if cooldownState != nil && len(fallbackModels) > cooldownState.fallbackIndex { - // We're in cooldown - start from the pinned fallback (skip primary) - startIndex = cooldownState.fallbackIndex + 1 // +1 because index 0 is primary - inCooldown = true + models = models[cooldownState.fallbackIndex+1:] slog.Debug("Skipping primary due to cooldown", "agent", a.Name(), "start_from_fallback_index", cooldownState.fallbackIndex, "cooldown_until", cooldownState.until.Format(time.RFC3339)) } - var lastErr error - primaryFailedWithNonRetryable := false - hasFallbacks := len(fallbackModels) > 0 - - for chainIdx := startIndex; chainIdx < len(modelChain); chainIdx++ { - modelEntry := modelChain[chainIdx] - - // Each model in the chain gets (1 + retries) attempts for retryable errors. - // Non-retryable errors (429 with fallbacks, 4xx) skip immediately to the next model. - // 429 without fallbacks is retried directly on the same model. - maxAttempts := 1 + fallbackRetries - - for attempt := range maxAttempts { - // Check context before each attempt - if ctx.Err() != nil { - return streamResult{}, nil, ctx.Err() - } - - // Apply backoff before retry (not on first attempt of each model) - if attempt > 0 { - backoffDelay := backoff.Calculate(attempt - 1) - logRetryBackoff(a.Name(), modelEntry.provider.ID(), attempt, backoffDelay) - if !backoff.SleepWithContext(ctx, backoffDelay) { - return streamResult{}, nil, ctx.Err() - } - } - - // Emit fallback event when transitioning to a new model (but not when starting in cooldown) - if chainIdx > startIndex && attempt == 0 { - logFallbackAttempt(a.Name(), modelEntry, attempt, fallbackRetries, lastErr) - // Get the previous model's ID for the event - prevModelID := modelChain[chainIdx-1].provider.ID() - reason := "" - if lastErr != nil { - reason = lastErr.Error() - } - events <- ModelFallback( - a.Name(), - prevModelID, - modelEntry.provider.ID(), - reason, - attempt+1, - maxAttempts, - ) - } - - slog.Debug("Creating chat completion stream", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "is_fallback", modelEntry.isFallback, - "in_cooldown", inCooldown, - "attempt", attempt+1) - - stream, err := modelEntry.provider.CreateChatCompletionStream(ctx, messages, agentTools) + retries := getEffectiveRetries(a) + maxAttempts := retries + 1 + + opts := []ai.Option{ + ai.WithLogger(slog.With("agent", a.Name())), + ai.WithModels(models...), + ai.WithMessages(messages...), + ai.WithTools(agentTools...), + ai.WithRetries(retries), + ai.WithReturnToolRequests(), + ai.WithOnModelFallback(func(from, to provider.Provider, err error) { + reason := "" if err != nil { - lastErr = err - - // Context cancellation is never retryable - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return streamResult{}, nil, err - } - - decision := r.handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable) - if decision == retryDecisionReturn { - return streamResult{}, nil, ctx.Err() - } else if decision == retryDecisionBreak { - break - } - continue + reason = err.Error() } - - // Stream created successfully, now handle it - slog.Debug("Processing stream", "agent", a.Name(), "model", modelEntry.provider.ID()) - - // If the provider is a rule-based router, notify the sidebar - // of the selected sub-model's YAML-configured name. - if rp, ok := modelEntry.provider.(interface{ LastSelectedModelID() string }); ok { + events <- ModelFallback(a.Name(), from.ID(), to.ID(), reason, 1, maxAttempts) + }), + ai.WithStreamInterceptor(func(ctx context.Context, r *ai.StreamRequest, h ai.StreamHandler) (*ai.ModelResponse, error) { + if rp, ok := r.Model.(interface{ LastSelectedModelID() string }); ok { if selected := rp.LastSelectedModelID(); selected != "" { events <- AgentInfo(a.Name(), selected, a.Description(), a.WelcomeMessage()) } } - - res, err := r.handleStream(ctx, stream, a, agentTools, sess, m, events) - if err != nil { - lastErr = err - - // Context cancellation stops everything - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return streamResult{}, nil, err - } - - decision := r.handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable) - if decision == retryDecisionReturn { - return streamResult{}, nil, ctx.Err() - } else if decision == retryDecisionBreak { - break - } - continue - } - - // Success! - // Handle cooldown state based on which model succeeded - switch { - case modelEntry.isFallback && primaryFailedWithNonRetryable: - // Primary failed with non-retryable error, fallback succeeded. - // Set cooldown to stick with this fallback. - r.setCooldownState(a.Name(), modelEntry.index, getEffectiveCooldown(a)) - case !modelEntry.isFallback: - // Primary succeeded - clear any existing cooldown. - // This handles both normal success and recovery after cooldown expires. - r.clearCooldownState(a.Name()) - } - - return res, modelEntry.provider, nil - } + return h(ctx, r) + }), } - // All models and retries exhausted. - // If the last error (or any error in the chain) was a context overflow, - // wrap it in a ContextOverflowError so the caller can auto-compact. - if lastErr != nil { - prefix := "model failed" - if hasFallbacks { - prefix = "all models failed" - } - wrapped := fmt.Errorf("%s: %w", prefix, lastErr) - if modelerrors.IsContextOverflowError(lastErr) { - return streamResult{}, nil, modelerrors.NewContextOverflowError(wrapped) - } - return streamResult{}, nil, wrapped + if r.retryOnRateLimit { + opts = append(opts, ai.WithRetryOnRateLimit()) } - return streamResult{}, nil, errors.New("model failed with unknown error") -} - -// retryDecision is the outcome of handleModelError. -type retryDecision int -const ( - // retryDecisionContinue means retry the same model (backoff already applied). - retryDecisionContinue retryDecision = iota - // retryDecisionBreak means skip to the next model in the fallback chain. - retryDecisionBreak - // retryDecisionReturn means context was cancelled; return immediately. - retryDecisionReturn -) - -// handleModelError classifies err and decides what to do next: -// - retryDecisionReturn — context cancelled while sleeping; caller returns ctx.Err() -// - retryDecisionBreak — non-retryable error or 429 with fallbacks; skip to next model -// - retryDecisionContinue — retryable error or 429 without fallbacks; retry same model -// -// Side-effect: sets *primaryFailedWithNonRetryable when the primary model fails with a -// non-retryable (or rate-limited-with-fallbacks) error. -func (r *LocalRuntime) handleModelError( - ctx context.Context, - err error, - a *agent.Agent, - modelEntry modelWithFallback, - attempt int, - hasFallbacks bool, - primaryFailedWithNonRetryable *bool, -) retryDecision { - retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(err) - - if rateLimited { - // Gate: only retry on 429 if opt-in is enabled AND no fallbacks exist. - // Default behavior (retryOnRateLimit=false) treats 429 as non-retryable, - // identical to today's behavior before this feature was added. - if !r.retryOnRateLimit || hasFallbacks { - slog.Warn("Rate limited, treating as non-retryable", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "retry_on_rate_limit_enabled", r.retryOnRateLimit, - "has_fallbacks", hasFallbacks, - "error", err) - if !modelEntry.isFallback { - *primaryFailedWithNonRetryable = true - } - return retryDecisionBreak - } + seq := ai.GenerateStream(ctx, opts...) + res, err := r.handleStream(ctx, seq, a, agentTools, sess, m, events) + if err != nil { + return streamResult{}, nil, err + } - // Opt-in enabled, no fallbacks → retry same model after honouring Retry-After (or backoff). - waitDuration := retryAfter - if waitDuration <= 0 { - waitDuration = backoff.Calculate(attempt) - } else if waitDuration > backoff.MaxRetryAfterWait { - slog.Warn("Retry-After exceeds maximum, capping", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "retry_after", retryAfter, - "max", backoff.MaxRetryAfterWait) - waitDuration = backoff.MaxRetryAfterWait - } - slog.Warn("Rate limited, retrying (opt-in enabled)", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "attempt", attempt+1, - "wait", waitDuration, - "retry_after_from_header", retryAfter > 0, - "error", err) - if !backoff.SleepWithContext(ctx, waitDuration) { - return retryDecisionReturn + // Resolve which provider was used + var usedModel provider.Provider + for _, m := range models { + if m.ID() == res.Model { + usedModel = m + break } - return retryDecisionContinue } - if !retryable { - slog.Error("Non-retryable error from model", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "error", err) - if !modelEntry.isFallback { - *primaryFailedWithNonRetryable = true + // Handle cooldown state based on which model succeeded + if usedModel != nil && usedModel.ID() == primaryModel.ID() { + r.clearCooldownState(a.Name()) + } else if usedModel != nil { + for i, fb := range fallbackModels { + if fb.ID() == usedModel.ID() { + r.setCooldownState(a.Name(), i, getEffectiveCooldown(a)) + break + } } - return retryDecisionBreak } - slog.Warn("Retryable error from model", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "attempt", attempt+1, - "error", err) - return retryDecisionContinue + return res, usedModel, nil } diff --git a/pkg/runtime/streaming.go b/pkg/runtime/streaming.go index d60fc7554..332f25650 100644 --- a/pkg/runtime/streaming.go +++ b/pkg/runtime/streaming.go @@ -3,12 +3,11 @@ package runtime import ( "context" "errors" - "fmt" - "io" + "iter" "log/slog" - "strings" "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/ai" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/session" @@ -28,142 +27,97 @@ type streamResult struct { Stopped bool FinishReason chat.FinishReason Usage *chat.Usage + Model string } -// handleStream reads a chat.MessageStream to completion, emitting streaming +// handleStream consumes an ai.GenerateStream sequence, emitting per-chunk // events (content deltas, partial tool calls, reasoning tokens) and returning -// the aggregated streamResult. The caller is responsible for adding the -// resulting assistant message to the session. -func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStream, a *agent.Agent, agentTools []tools.Tool, sess *session.Session, m *modelsdev.Model, events chan Event) (streamResult, error) { - defer stream.Close() - - var fullContent strings.Builder - var fullReasoningContent strings.Builder - var thinkingSignature string - var thoughtSignature []byte - var toolCalls []tools.ToolCall - var messageUsage *chat.Usage - var providerFinishReason chat.FinishReason - - toolCallIndex := make(map[string]int) // toolCallID -> index in toolCalls slice - emittedPartial := make(map[string]bool) // toolCallID -> whether we've emitted a partial event +// the aggregated streamResult. Stream aggregation (content, tool calls, finish +// reason) is handled by pkg/ai; this method only handles event emission and +// telemetry recording. +func (r *LocalRuntime) handleStream( + ctx context.Context, + seq iter.Seq2[*ai.ModelStreamValue, error], + a *agent.Agent, + agentTools []tools.Tool, + sess *session.Session, + m *modelsdev.Model, + events chan Event, +) (streamResult, error) { + emittedPartial := make(map[string]bool) toolDefMap := make(map[string]tools.Tool, len(agentTools)) for _, t := range agentTools { toolDefMap[t.Name] = t } - // recordUsage persists the final token counts and emits telemetry exactly - // once per stream, after we have the most accurate usage snapshot. - usageRecorded := false - recordUsage := func() { - if usageRecorded || messageUsage == nil { - return - } - usageRecorded = true - - sess.InputTokens = messageUsage.InputTokens + messageUsage.CachedInputTokens + messageUsage.CacheWriteTokens - sess.OutputTokens = messageUsage.OutputTokens + // Track partial tool call names for event emission + toolNames := make(map[string]string) - modelName := "unknown" - if m != nil { - modelName = m.Name - } - telemetry.RecordTokenUsage(ctx, modelName, sess.InputTokens, sess.OutputTokens, sess.TotalCost()) - } - - for { - response, err := stream.Recv() - if errors.Is(err, io.EOF) { - break - } + for sv, err := range seq { if err != nil { - return streamResult{Stopped: true}, fmt.Errorf("error receiving from stream: %w", err) + return streamResult{Stopped: true}, err } - if response.Usage != nil { - // Always keep the latest usage snapshot; some providers (e.g. - // Gemini) emit updated usage on every chunk with cumulative - // token counts, so the last value is the most accurate. - messageUsage = response.Usage - } + if sv.Done { + res := sv.Response - if len(response.Choices) == 0 { - continue - } - choice := response.Choices[0] + // Record usage and telemetry + if res.Usage != nil { + sess.InputTokens = res.Usage.InputTokens + res.Usage.CachedInputTokens + res.Usage.CacheWriteTokens + sess.OutputTokens = res.Usage.OutputTokens - if len(choice.Delta.ThoughtSignature) > 0 { - thoughtSignature = choice.Delta.ThoughtSignature - } + modelName := "unknown" + if m != nil { + modelName = m.Name + } + telemetry.RecordTokenUsage(ctx, modelName, sess.InputTokens, sess.OutputTokens, sess.TotalCost()) + } - if choice.FinishReason == chat.FinishReasonStop || choice.FinishReason == chat.FinishReasonLength { - recordUsage() return streamResult{ - Calls: toolCalls, - Content: fullContent.String(), - ReasoningContent: fullReasoningContent.String(), - ThinkingSignature: thinkingSignature, - ThoughtSignature: thoughtSignature, - Stopped: true, - FinishReason: choice.FinishReason, - Usage: messageUsage, + Calls: res.Calls, + Content: res.Content, + ReasoningContent: res.ReasoningContent, + ThinkingSignature: res.ThinkingSignature, + ThoughtSignature: res.ThoughtSignature, + Stopped: res.Stopped, + FinishReason: res.FinishReason, + Usage: res.Usage, + Model: res.Model, }, nil } - // Track the provider's explicit finish reason (e.g. tool_calls) so we - // can prefer it over inference after the loop. stop/length are already - // handled by the early return above. - if choice.FinishReason != "" { - providerFinishReason = choice.FinishReason + // Process chunk — emit events + chunk := sv.Chunk + + if len(chunk.Choices) == 0 { + continue } - // Handle tool calls + choice := chunk.Choices[0] + + // Emit partial tool calls if len(choice.Delta.ToolCalls) > 0 { - // Process each tool call delta for _, delta := range choice.Delta.ToolCalls { - idx, exists := toolCallIndex[delta.ID] - if !exists { - idx = len(toolCalls) - toolCallIndex[delta.ID] = idx - toolCalls = append(toolCalls, tools.ToolCall{ - ID: delta.ID, - Type: delta.Type, - }) - } - - tc := &toolCalls[idx] + learningName := delta.Function.Name != "" && toolNames[delta.ID] == "" - // Track if we're learning the name for the first time - learningName := delta.Function.Name != "" && tc.Function.Name == "" - - // Update fields from delta - if delta.Type != "" { - tc.Type = delta.Type - } if delta.Function.Name != "" { - tc.Function.Name = delta.Function.Name - } - if delta.Function.Arguments != "" { - tc.Function.Arguments += delta.Function.Arguments + toolNames[delta.ID] = delta.Function.Name } - // Emit PartialToolCall once we have a name, and on subsequent argument deltas. - // Only the newly received argument bytes are sent, not the full - // accumulated arguments, to avoid re-transmitting the entire payload - // on every token. - if tc.Function.Name != "" && (learningName || delta.Function.Arguments != "") { + name := toolNames[delta.ID] + if name != "" && (learningName || delta.Function.Arguments != "") { if !emittedPartial[delta.ID] || delta.Function.Arguments != "" { partial := tools.ToolCall{ - ID: tc.ID, - Type: tc.Type, + ID: delta.ID, + Type: delta.Type, Function: tools.FunctionCall{ - Name: tc.Function.Name, + Name: name, Arguments: delta.Function.Arguments, }, } toolDef := tools.Tool{} if !emittedPartial[delta.ID] { - toolDef = toolDefMap[tc.Function.Name] + toolDef = toolDefMap[name] } events <- PartialToolCall(partial, toolDef, a.Name()) emittedPartial[delta.ID] = true @@ -175,61 +129,14 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre if choice.Delta.ReasoningContent != "" { events <- AgentChoiceReasoning(a.Name(), sess.ID, choice.Delta.ReasoningContent) - fullReasoningContent.WriteString(choice.Delta.ReasoningContent) - } - - // Capture thinking signature for Anthropic extended thinking - if choice.Delta.ThinkingSignature != "" { - thinkingSignature = choice.Delta.ThinkingSignature } if choice.Delta.Content != "" { events <- AgentChoice(a.Name(), sess.ID, choice.Delta.Content) - fullContent.WriteString(choice.Delta.Content) } } - recordUsage() - - // If the stream completed without producing any content or tool calls, likely because of a token limit, stop to avoid breaking the request loop - // NOTE(krissetto): this can likely be removed once compaction works properly with all providers (aka dmr) - stoppedDueToNoOutput := fullContent.Len() == 0 && len(toolCalls) == 0 - - // Prefer the provider's explicit finish reason when available (e.g. - // tool_calls). Only fall back to inference when no explicit reason was - // received (stream ended with bare EOF): - // - tool calls present → tool_calls (model was requesting tools) - // - content but no tool calls → stop (natural completion) - // - no output at all → null (unknown; likely token limit) - finishReason := providerFinishReason - if finishReason == "" { - switch { - case len(toolCalls) > 0: - finishReason = chat.FinishReasonToolCalls - case fullContent.Len() > 0: - finishReason = chat.FinishReasonStop - default: - finishReason = chat.FinishReasonNull - } - } - // Ensure finish reason agrees with the actual stream output. - switch { - case finishReason == chat.FinishReasonToolCalls && len(toolCalls) == 0: - finishReason = chat.FinishReasonNull - case finishReason == chat.FinishReasonStop && len(toolCalls) > 0: - finishReason = chat.FinishReasonToolCalls - } - - return streamResult{ - Calls: toolCalls, - Content: fullContent.String(), - ReasoningContent: fullReasoningContent.String(), - ThinkingSignature: thinkingSignature, - ThoughtSignature: thoughtSignature, - Stopped: stoppedDueToNoOutput, - FinishReason: finishReason, - Usage: messageUsage, - }, nil + return streamResult{Stopped: true}, errors.New("stream ended without final response") } // stripImageContent returns a copy of messages with all image-related content diff --git a/pkg/sessiontitle/generator.go b/pkg/sessiontitle/generator.go index 0ebccc30d..f2a670801 100644 --- a/pkg/sessiontitle/generator.go +++ b/pkg/sessiontitle/generator.go @@ -5,13 +5,12 @@ package sessiontitle import ( "context" - "errors" "fmt" - "io" "log/slog" "strings" "time" + "github.com/docker/docker-agent/pkg/ai" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" @@ -67,7 +66,8 @@ func (g *Generator) Generate(ctx context.Context, sessionID string, userMessages return "", nil } - slog.Debug("Generating title for session", "session_id", sessionID, "message_count", len(userMessages)) + lg := slog.With("session_id", sessionID) + lg.Debug("Generating title for session", "message_count", len(userMessages)) // Format messages for the prompt var formattedMessages strings.Builder @@ -88,11 +88,8 @@ func (g *Generator) Generate(ctx context.Context, sessionID string, userMessages }, } - var lastErr error - for idx, baseModel := range g.models { - if ctx.Err() != nil { - return "", ctx.Err() - } + models := make([]provider.Provider, 0, len(g.models)) + for _, baseModel := range g.models { if baseModel == nil { continue } @@ -108,65 +105,22 @@ func (g *Generator) Generate(ctx context.Context, sessionID string, userMessages options.WithGeneratingTitle(), ) - // Call the provider directly (no tools needed for title generation) - stream, err := titleModel.CreateChatCompletionStream(ctx, messages, nil) - if err != nil { - lastErr = err - slog.Error("Failed to create title generation stream", - "session_id", sessionID, - "model", baseModel.ID(), - "attempt", idx+1, - "error", err) - continue - } - - // Drain the stream to collect the full title - var title strings.Builder - var streamErr error - for { - response, err := stream.Recv() - if errors.Is(err, io.EOF) { - break - } - if err != nil { - streamErr = err - break - } - if len(response.Choices) > 0 { - title.WriteString(response.Choices[0].Delta.Content) - } - } - stream.Close() - - if streamErr != nil { - lastErr = streamErr - slog.Error("Error receiving from title stream", - "session_id", sessionID, - "model", baseModel.ID(), - "attempt", idx+1, - "error", streamErr) - continue - } - - result := sanitizeTitle(title.String()) - if result == "" { - // Empty/invalid title output - treat as a failure and try fallbacks. - lastErr = fmt.Errorf("empty title output from model %q", baseModel.ID()) - slog.Debug("Generated empty title, trying next model", - "session_id", sessionID, - "model", baseModel.ID(), - "attempt", idx+1) - continue - } - - slog.Debug("Generated session title", "session_id", sessionID, "title", result, "model", baseModel.ID()) - return result, nil + models = append(models, titleModel) } - if lastErr != nil { - return "", fmt.Errorf("generating title failed: %w", lastErr) + str, err := ai.GenerateText( + ctx, + ai.WithModels(models...), + ai.WithMessages(messages...), + ai.WithRequireContent(), + ai.WithLogger(lg), + ) + if err != nil { + return "", fmt.Errorf("generating title failed: %w", err) } - return "", nil + + lg.Debug("Generated session title", "title", str) + return sanitizeTitle(str), nil } // sanitizeTitle ensures the title is a single line by taking only the first