diff --git a/cmd/agentcli/prestage.go b/cmd/agentcli/prestage.go index 5f79d84..42bcb76 100644 --- a/cmd/agentcli/prestage.go +++ b/cmd/agentcli/prestage.go @@ -154,7 +154,7 @@ func runPreStage(cfg cliConfig, messages []oai.Message, stderr io.Writer) ([]oai } } prepMessages = append(prepMessages, applyTranscriptHygiene(normalizedIn, cfg.debug)...) - req := oai.ChatCompletionsRequest{ + req := oai.ChatCompletionsRequest{ Model: prepModel, Messages: prepMessages, } @@ -168,7 +168,14 @@ func runPreStage(cfg cliConfig, messages []oai.Message, stderr io.Writer) ([]oai } else if effectiveTemp != nil { req.Temperature = effectiveTemp } - // Create a dedicated client honoring pre-stage timeout and normal retry policy + // Enforce prompt to fit context window for pre-stage as well + window := oai.ContextWindowForModel(prepModel) + promptBudget := oai.PromptTokenBudget(window, 0) + if oai.EstimateTokens(req.Messages) > promptBudget { + req.Messages = oai.TrimMessagesToFit(req.Messages, promptBudget) + } + + // Create a dedicated client honoring pre-stage timeout and normal retry policy httpClient := oai.NewClientWithRetry(prepBaseURL, prepAPIKey, cfg.prepHTTPTimeout, oai.RetryPolicy{MaxRetries: retries, Backoff: backoff}) dumpJSONIfDebug(stderr, "prep.request", req, cfg.debug) // Tag context with audit stage so HTTP audit lines include stage: "prep" diff --git a/cmd/agentcli/run_agent.go b/cmd/agentcli/run_agent.go index 519823a..4560099 100644 --- a/cmd/agentcli/run_agent.go +++ b/cmd/agentcli/run_agent.go @@ -179,7 +179,7 @@ func runAgent(cfg cliConfig, stdout io.Writer, stderr io.Writer) int { for { // Apply transcript hygiene before sending to the API when -debug is off hygienic := applyTranscriptHygiene(messages, cfg.debug) - req := oai.ChatCompletionsRequest{ + req := oai.ChatCompletionsRequest{ Model: cfg.model, Messages: hygienic, } @@ -203,11 +203,19 @@ func runAgent(cfg cliConfig, stdout io.Writer, stderr io.Writer) int { req.ToolChoice = "auto" } - // Include MaxTokens only when a positive completionCap is set. + // Include MaxTokens only when a positive completionCap is set. if completionCap > 0 { req.MaxTokens = completionCap } + // Enforce prompt to fit the context window, leaving room for completionCap + // Compute prompt budget and trim deterministically when needed + window := oai.ContextWindowForModel(cfg.model) + promptBudget := oai.PromptTokenBudget(window, req.MaxTokens) + if oai.EstimateTokens(req.Messages) > promptBudget { + req.Messages = oai.TrimMessagesToFit(req.Messages, promptBudget) + } + // Pre-flight validate message sequence to avoid API 400s for stray tool messages if err := oai.ValidateMessageSequence(req.Messages); err != nil { safeFprintf(stderr, "error: %v\n", err) @@ -276,7 +284,7 @@ func runAgent(cfg cliConfig, stdout io.Writer, stderr io.Writer) int { callCtx, cancel = context.WithTimeout(context.Background(), cfg.httpTimeout) } - // Fallback: non-streaming request + // Fallback: non-streaming request resp, err := httpClient.CreateChatCompletion(callCtx, req) cancel() if err != nil { diff --git a/internal/oai/context_window.go b/internal/oai/context_window.go index 6e2dfd6..a8ccfa3 100644 --- a/internal/oai/context_window.go +++ b/internal/oai/context_window.go @@ -45,3 +45,14 @@ func ClampCompletionCap(messages []Message, requestedCap int, window int) int { } return requestedCap } + +// PromptTokenBudget returns a safe token budget for the prompt given a +// model context window and a desired completion cap. A small safety margin +// of 32 tokens is reserved for reply/control tokens. +func PromptTokenBudget(window int, completionCap int) int { + budget := window - completionCap - 32 + if budget < 1 { + return 1 + } + return budget +} diff --git a/internal/oai/trim.go b/internal/oai/trim.go new file mode 100644 index 0000000..ac8f736 --- /dev/null +++ b/internal/oai/trim.go @@ -0,0 +1,182 @@ +package oai + +// TrimMessagesToFit reduces a transcript so its estimated tokens do not exceed +// the provided limit. Policy: +// - Pin the first system and developer messages when present. +// - Drop oldest non-pinned messages first until within limit. +// - If only pinned remain and still exceed limit, truncate their content +// proportionally but keep both messages. +// - As a last resort, keep only the newest message, truncated to fit. +func TrimMessagesToFit(in []Message, limit int) []Message { + if limit <= 0 || len(in) == 0 { + return []Message{} + } + estimate := func(msgs []Message) int { return EstimateTokens(msgs) } + + // Fast path: already fits + if estimate(in) <= limit { + return in + } + + out := append([]Message(nil), in...) + + // Drop oldest non-pinned messages until within limit. + for len(out) > 0 && estimate(out) > limit { + // Find first indices of pinned roles in current slice + sysIdx, devIdx := -1, -1 + for i := range out { + if sysIdx == -1 && out[i].Role == RoleSystem { + sysIdx = i + } + if devIdx == -1 && out[i].Role == RoleDeveloper { + devIdx = i + } + if sysIdx != -1 && devIdx != -1 { + break + } + } + // Remove first non-pinned from the front if any + removed := false + for j := 0; j < len(out); j++ { + if j != sysIdx && j != devIdx { + out = append(out[:j], out[j+1:]...) + removed = true + break + } + } + if !removed { + // Only pinned remain; proceed to truncation + break + } + } + + if estimate(out) <= limit { + return out + } + + // Truncation path: only pinned remain or still too large + // Identify pinned indices in current slice + sysIdx, devIdx := -1, -1 + for i := range out { + if sysIdx == -1 && out[i].Role == RoleSystem { + sysIdx = i + } + if devIdx == -1 && out[i].Role == RoleDeveloper { + devIdx = i + } + } + + // If no pinned present, keep newest single message truncated to fit + if sysIdx == -1 && devIdx == -1 { + last := out[len(out)-1] + return []Message{truncateMessageToBudget(last, limit)} + } + + cur := estimate(out) + if cur <= limit { + return out + } + + // Compute budgets + if sysIdx != -1 && devIdx != -1 { + sysTok := EstimateTokens([]Message{out[sysIdx]}) + devTok := EstimateTokens([]Message{out[devIdx]}) + totalPinned := sysTok + devTok + if totalPinned == 0 { + totalPinned = 1 + } + nonPinned := cur - totalPinned + targetPinned := limit - nonPinned + if targetPinned < 2 { // ensure at least 1 per pinned + targetPinned = 2 + } + // Allocate at least 1 token to each, distribute remainder proportionally + minPerPinned := 1 + remaining := targetPinned - 2*minPerPinned + if remaining < 0 { + remaining = 0 + } + var extraSys, extraDev int + if sysTok+devTok > 0 && remaining > 0 { + extraSys = (sysTok * remaining) / (sysTok + devTok) + extraDev = remaining - extraSys + } else { + extraSys, extraDev = 0, 0 + } + targetSys := minPerPinned + extraSys + targetDev := minPerPinned + extraDev + out[sysIdx] = truncateMessageToBudget(out[sysIdx], targetSys) + out[devIdx] = truncateMessageToBudget(out[devIdx], targetDev) + } else if sysIdx != -1 { // only system pinned + // allocate entire limit minus non-system tokens + nonSys := cur - EstimateTokens([]Message{out[sysIdx]}) + budget := limit - nonSys + if budget < 1 { + budget = 1 + } + out[sysIdx] = truncateMessageToBudget(out[sysIdx], budget) + } else if devIdx != -1 { // only developer pinned + nonDev := cur - EstimateTokens([]Message{out[devIdx]}) + budget := limit - nonDev + if budget < 1 { + budget = 1 + } + out[devIdx] = truncateMessageToBudget(out[devIdx], budget) + } + + // Final guard: if still above limit, drop oldest non-pinned if any; otherwise truncate newest to fit + for estimate(out) > limit { + removed := false + // Try to remove a non-pinned from the front + // Recompute pinned indices + sysIdx, devIdx = -1, -1 + for i := range out { + if sysIdx == -1 && out[i].Role == RoleSystem { + sysIdx = i + } + if devIdx == -1 && out[i].Role == RoleDeveloper { + devIdx = i + } + } + for j := 0; j < len(out); j++ { + if j != sysIdx && j != devIdx { + out = append(out[:j], out[j+1:]...) + removed = true + break + } + } + if !removed { + // No non-pinned remain; keep newest one truncated + last := out[len(out)-1] + out = []Message{truncateMessageToBudget(last, limit)} + break + } + } + + return out +} + +// truncateMessageToBudget returns a copy of msg with content truncated such that +// the single-message token estimate is <= budget (best-effort heuristic). +func truncateMessageToBudget(msg Message, budget int) Message { + if budget <= 1 { + msg.Content = "" + return msg + } + // Binary search on content length, using EstimateTokens heuristic + lo, hi := 0, len(msg.Content) + best := 0 + for lo <= hi { + mid := (lo + hi) / 2 + test := msg + test.Content = truncate(msg.Content, mid) + if EstimateTokens([]Message{test}) <= budget { + best = mid + lo = mid + 1 + } else { + hi = mid - 1 + } + } + msg.Content = truncate(msg.Content, best) + return msg +} diff --git a/internal/oai/trim_test.go b/internal/oai/trim_test.go new file mode 100644 index 0000000..aa210dd --- /dev/null +++ b/internal/oai/trim_test.go @@ -0,0 +1,85 @@ +package oai + +import "testing" + +// helper to build a message with role and content +func m(role, content string) Message { return Message{Role: role, Content: content} } + +func TestTrimMessagesToFit_PreservesSystemAndDeveloper(t *testing.T) { + sys := m(RoleSystem, repeat("S", 4000)) // ~1000 tokens + dev := m(RoleDeveloper, repeat("D", 4000)) // ~1000 tokens + u1 := m(RoleUser, repeat("u", 4000)) // ~1000 tokens + a1 := m(RoleAssistant, repeat("a", 4000)) // ~1000 tokens + u2 := m(RoleUser, repeat("u", 4000)) // ~1000 tokens + in := []Message{sys, dev, u1, a1, u2} + + // Limit so that we cannot keep all messages; must drop from the front (u1,a1) + limit := EstimateTokens(in) - 1500 + out := TrimMessagesToFit(in, limit) + + if EstimateTokens(out) > limit { + t.Fatalf("trim did not reduce to limit: got=%d limit=%d", EstimateTokens(out), limit) + } + if len(out) >= 2 { + if out[0].Role != RoleSystem { + t.Fatalf("first message should be system; got %q", out[0].Role) + } + if out[1].Role != RoleDeveloper { + t.Fatalf("second message should be developer; got %q", out[1].Role) + } + } else { + t.Fatalf("expected to preserve at least system and developer; got %d", len(out)) + } +} + +func TestTrimMessagesToFit_DropsOldestNonPinned(t *testing.T) { + sys := m(RoleSystem, "policy") + // 5 alternating user/assistant messages + msgs := []Message{sys} + for i := 0; i < 5; i++ { + msgs = append(msgs, m(RoleUser, repeat("U", 2000))) + msgs = append(msgs, m(RoleAssistant, repeat("A", 2000))) + } + // Force heavy trim + limit := EstimateTokens(msgs) / 2 + out := TrimMessagesToFit(msgs, limit) + if EstimateTokens(out) > limit { + t.Fatalf("expected tokens <= limit; got=%d limit=%d", EstimateTokens(out), limit) + } + // Ensure the newest non-pinned message remains (the last assistant) + if out[len(out)-1].Role != RoleAssistant { + t.Fatalf("expected newest assistant at tail; got %q", out[len(out)-1].Role) + } +} + +func TestTrimMessagesToFit_OnlySystemDeveloperTooLarge_TruncatesContent(t *testing.T) { + sys := m(RoleSystem, repeat("S", 20000)) // ~5000 tokens + dev := m(RoleDeveloper, repeat("D", 20000)) // ~5000 tokens + in := []Message{sys, dev} + limit := 3000 // far below combined estimate + out := TrimMessagesToFit(in, limit) + if EstimateTokens(out) > limit { + t.Fatalf("expected tokens <= limit after truncation; got=%d limit=%d", EstimateTokens(out), limit) + } + if len(out) != 2 { + t.Fatalf("should keep both system and developer; got %d", len(out)) + } + if len(out[0].Content) >= len(sys.Content) { + t.Fatalf("system content was not truncated") + } + if len(out[1].Content) >= len(dev.Content) { + t.Fatalf("developer content was not truncated") + } +} + +// repeat returns a string consisting of count repetitions of s. +func repeat(s string, count int) string { + if count <= 0 { + return "" + } + b := make([]byte, 0, len(s)*count) + for i := 0; i < count; i++ { + b = append(b, s...) + } + return string(b) +}