diff --git a/src/api/providers/anthropic-vertex.ts b/src/api/providers/anthropic-vertex.ts index 7760587f4f..0684a31404 100644 --- a/src/api/providers/anthropic-vertex.ts +++ b/src/api/providers/anthropic-vertex.ts @@ -26,6 +26,7 @@ import { handleAiSdkError, } from "../transform/ai-sdk" import { calculateApiCostAnthropic } from "../../shared/cost" +import { applyCacheBreakpoints } from "../transform/cache-breakpoints" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" @@ -124,36 +125,10 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple anthropicProviderOptions.disableParallelToolUse = true } - /** - * Vertex API has specific limitations for prompt caching: - * 1. Maximum of 4 blocks can have cache_control - * 2. Only text blocks can be cached (images and other content types cannot) - * 3. Cache control can only be applied to user messages, not assistant messages - * - * Our caching strategy: - * - Cache the system prompt (1 block) - * - Cache the last text block of the second-to-last user message (1 block) - * - Cache the last text block of the last user message (1 block) - * This ensures we stay under the 4-block limit while maintaining effective caching - * for the most relevant context. - */ + // Apply cache control to user messages (Vertex allows up to 4 cache_control blocks; + // 1 for system prompt + 2 for the last 2 user message batches). const cacheProviderOption = { anthropic: { cacheControl: { type: "ephemeral" as const } } } - - const userMsgIndices = messages.reduce( - (acc, msg, index) => ("role" in msg && msg.role === "user" ? [...acc, index] : acc), - [] as number[], - ) - - const targetIndices = new Set() - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 - const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - - if (lastUserMsgIndex >= 0) targetIndices.add(lastUserMsgIndex) - if (secondLastUserMsgIndex >= 0) targetIndices.add(secondLastUserMsgIndex) - - if (targetIndices.size > 0) { - this.applyCacheControlToAiSdkMessages(messages as ModelMessage[], targetIndices, cacheProviderOption) - } + applyCacheBreakpoints(messages, { cacheProviderOption }) // Build streamText request // Cast providerOptions to any to bypass strict JSONObject typing — the AI SDK accepts the correct runtime values @@ -260,29 +235,6 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple } } - /** - * Apply cacheControl providerOptions to the correct AI SDK messages by walking - * the original Anthropic messages and converted AI SDK messages in parallel. - * - * convertToAiSdkMessages() can split a single Anthropic user message (containing - * tool_results + text) into 2 AI SDK messages (tool role + user role). This method - * accounts for that split so cache control lands on the right message. - */ - private applyCacheControlToAiSdkMessages( - aiSdkMessages: { role: string; providerOptions?: Record> }[], - targetIndices: Set, - cacheProviderOption: Record>, - ): void { - for (const idx of targetIndices) { - if (idx >= 0 && idx < aiSdkMessages.length) { - aiSdkMessages[idx].providerOptions = { - ...aiSdkMessages[idx].providerOptions, - ...cacheProviderOption, - } - } - } - } - getModel() { const modelId = this.options.apiModelId let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index 16d3424732..b370b80a1b 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -25,6 +25,7 @@ import { handleAiSdkError, } from "../transform/ai-sdk" import { calculateApiCostAnthropic } from "../../shared/cost" +import { applyCacheBreakpoints } from "../transform/cache-breakpoints" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" @@ -114,22 +115,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa // Apply cache control to user messages // Strategy: cache the last 2 user messages (write-to-cache + read-from-cache) const cacheProviderOption = { anthropic: { cacheControl: { type: "ephemeral" as const } } } - - const userMsgIndices = messages.reduce( - (acc, msg, index) => ("role" in msg && msg.role === "user" ? [...acc, index] : acc), - [] as number[], - ) - - const targetIndices = new Set() - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 - const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - - if (lastUserMsgIndex >= 0) targetIndices.add(lastUserMsgIndex) - if (secondLastUserMsgIndex >= 0) targetIndices.add(secondLastUserMsgIndex) - - if (targetIndices.size > 0) { - this.applyCacheControlToAiSdkMessages(messages as ModelMessage[], targetIndices, cacheProviderOption) - } + applyCacheBreakpoints(messages, { cacheProviderOption }) // Build streamText request // Cast providerOptions to any to bypass strict JSONObject typing — the AI SDK accepts the correct runtime values @@ -236,29 +222,6 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } } - /** - * Apply cacheControl providerOptions to the correct AI SDK messages by walking - * the original Anthropic messages and converted AI SDK messages in parallel. - * - * convertToAiSdkMessages() can split a single Anthropic user message (containing - * tool_results + text) into 2 AI SDK messages (tool role + user role). This method - * accounts for that split so cache control lands on the right message. - */ - private applyCacheControlToAiSdkMessages( - aiSdkMessages: { role: string; providerOptions?: Record> }[], - targetIndices: Set, - cacheProviderOption: Record>, - ): void { - for (const idx of targetIndices) { - if (idx >= 0 && idx < aiSdkMessages.length) { - aiSdkMessages[idx].providerOptions = { - ...aiSdkMessages[idx].providerOptions, - ...cacheProviderOption, - } - } - } - } - getModel() { const modelId = this.options.apiModelId let id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index ccfdb8613f..621615729b 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -32,6 +32,7 @@ import { handleAiSdkError, } from "../transform/ai-sdk" import { getModelParams } from "../transform/model-params" +import { applyCacheBreakpoints } from "../transform/cache-breakpoints" import { shouldUseReasoningBudget } from "../../shared/api" import { BaseProvider } from "./base-provider" import { DEFAULT_HEADERS } from "./constants" @@ -276,46 +277,12 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH if (usePromptCache) { const cachePointOption = { bedrock: { cachePoint: { type: "default" as const } } } - - // Find all user message indices in the original (pre-conversion) message array. - const originalUserIndices = filteredMessages.reduce( - (acc, msg, idx) => ("role" in msg && msg.role === "user" ? [...acc, idx] : acc), - [], - ) - - // Select up to 3 user messages for cache points (system prompt uses the 4th): - // - Last user message: write to cache for next request - // - Second-to-last user message: read from cache for current request - // - An "anchor" message earlier in the conversation for 20-block window coverage - const targetOriginalIndices = new Set() - const numUserMsgs = originalUserIndices.length - - if (numUserMsgs >= 1) { - // Always cache the last user message - targetOriginalIndices.add(originalUserIndices[numUserMsgs - 1]) - } - if (numUserMsgs >= 2) { - // Cache the second-to-last user message - targetOriginalIndices.add(originalUserIndices[numUserMsgs - 2]) - } - if (numUserMsgs >= 5) { - // Add an anchor cache point roughly in the first third of user messages. - // This ensures that the 20-block lookback from the second-to-last breakpoint - // can find a stable cache entry, covering all the assistant and tool messages - // in the middle of the conversation. We pick the user message at ~1/3 position. - const anchorIdx = Math.floor(numUserMsgs / 3) - // Only add if it's not already one of the last-2 targets - if (!targetOriginalIndices.has(originalUserIndices[anchorIdx])) { - targetOriginalIndices.add(originalUserIndices[anchorIdx]) - } - } - - // Apply cachePoint to the correct AI SDK messages by walking both arrays in parallel. - // A single original user message with tool_results becomes [tool-role msg, user-role msg] - // in the AI SDK array, while a plain user message becomes [user-role msg]. - if (targetOriginalIndices.size > 0) { - this.applyCachePointsToAiSdkMessages(aiSdkMessages, targetOriginalIndices, cachePointOption) - } + applyCacheBreakpoints(aiSdkMessages as RooMessage[], { + cacheProviderOption: cachePointOption, + maxMessageBreakpoints: 3, + useAnchor: true, + anchorThreshold: 5, + }) } // Build streamText request @@ -734,29 +701,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH ) } - /** - * Apply cachePoint providerOptions to the correct AI SDK messages by walking - * the original Anthropic messages and converted AI SDK messages in parallel. - * - * convertToAiSdkMessages() can split a single Anthropic user message (containing - * tool_results + text) into 2 AI SDK messages (tool role + user role). This method - * accounts for that split so cache points land on the right message. - */ - private applyCachePointsToAiSdkMessages( - aiSdkMessages: { role: string; providerOptions?: Record> }[], - targetIndices: Set, - cachePointOption: Record>, - ): void { - for (const idx of targetIndices) { - if (idx >= 0 && idx < aiSdkMessages.length) { - aiSdkMessages[idx].providerOptions = { - ...aiSdkMessages[idx].providerOptions, - ...cachePointOption, - } - } - } - } - /************************************************************************************ * * AMAZON REGIONS diff --git a/src/api/providers/minimax.ts b/src/api/providers/minimax.ts index 8bacea7b4a..0938b6789d 100644 --- a/src/api/providers/minimax.ts +++ b/src/api/providers/minimax.ts @@ -16,6 +16,7 @@ import { handleAiSdkError, } from "../transform/ai-sdk" import { calculateApiCostAnthropic } from "../../shared/cost" +import { applyCacheBreakpoints } from "../transform/cache-breakpoints" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" @@ -95,21 +96,7 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand } const cacheProviderOption = { anthropic: { cacheControl: { type: "ephemeral" as const } } } - const userMsgIndices = mergedMessages.reduce( - (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), - [] as number[], - ) - - const targetIndices = new Set() - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 - const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - - if (lastUserMsgIndex >= 0) targetIndices.add(lastUserMsgIndex) - if (secondLastUserMsgIndex >= 0) targetIndices.add(secondLastUserMsgIndex) - - if (targetIndices.size > 0) { - this.applyCacheControlToAiSdkMessages(aiSdkMessages, targetIndices, cacheProviderOption) - } + applyCacheBreakpoints(aiSdkMessages as RooMessage[], { cacheProviderOption }) const requestOptions = { model: this.client(modelConfig.id), @@ -212,21 +199,6 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand } } - private applyCacheControlToAiSdkMessages( - aiSdkMessages: { role: string; providerOptions?: Record> }[], - targetIndices: Set, - cacheProviderOption: Record>, - ): void { - for (const idx of targetIndices) { - if (idx >= 0 && idx < aiSdkMessages.length) { - aiSdkMessages[idx].providerOptions = { - ...aiSdkMessages[idx].providerOptions, - ...cacheProviderOption, - } - } - } - } - getModel() { const modelId = this.options.apiModelId diff --git a/src/api/transform/__tests__/cache-breakpoints.spec.ts b/src/api/transform/__tests__/cache-breakpoints.spec.ts new file mode 100644 index 0000000000..5d5bec553f --- /dev/null +++ b/src/api/transform/__tests__/cache-breakpoints.spec.ts @@ -0,0 +1,164 @@ +import type { RooMessage } from "../../../core/task-persistence/rooMessage" +import { applyCacheBreakpoints, type CacheBreakpointConfig } from "../cache-breakpoints" + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function makeUserMsg(content = "hi"): RooMessage { + return { role: "user", content: [{ type: "text", text: content }], ts: Date.now() } as RooMessage +} + +function makeAssistantMsg(content = "hello"): RooMessage { + return { role: "assistant", content: [{ type: "text", text: content }], ts: Date.now() } as RooMessage +} + +function makeToolMsg(content = "result"): RooMessage { + return { + role: "tool", + content: [{ type: "tool-result", toolCallId: "1", toolName: "test", output: { type: "text", value: content } }], + ts: Date.now(), + } as RooMessage +} + +const anthropicCache: CacheBreakpointConfig["cacheProviderOption"] = { + anthropic: { cacheControl: { type: "ephemeral" } }, +} + +function defaultConfig(overrides?: Partial): CacheBreakpointConfig { + return { cacheProviderOption: anthropicCache, ...overrides } +} + +function getProviderOptions(msg: RooMessage): Record> | undefined { + return (msg as RooMessage & { providerOptions?: Record> }).providerOptions +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("applyCacheBreakpoints", () => { + it("empty messages → no crash", () => { + const messages: RooMessage[] = [] + expect(() => applyCacheBreakpoints(messages, defaultConfig())).not.toThrow() + }) + + it("single user message → gets breakpoint", () => { + const messages = [makeUserMsg()] + applyCacheBreakpoints(messages, defaultConfig()) + + expect(getProviderOptions(messages[0])).toEqual(anthropicCache) + }) + + it("single tool message → gets breakpoint", () => { + const messages = [makeToolMsg()] + applyCacheBreakpoints(messages, defaultConfig()) + + expect(getProviderOptions(messages[0])).toEqual(anthropicCache) + }) + + it("3 batches with maxMessageBreakpoints=2 → only last 2 batches get breakpoints", () => { + // Batch 1: user (idx 0) + // Batch 2: tool (idx 2) + // Batch 3: user (idx 4) + const messages = [ + makeUserMsg("q1"), + makeAssistantMsg("a1"), + makeToolMsg("r1"), + makeAssistantMsg("a2"), + makeUserMsg("q2"), + ] + + applyCacheBreakpoints(messages, defaultConfig({ maxMessageBreakpoints: 2 })) + + // Batch 1 (idx 0) should NOT get a breakpoint + expect(getProviderOptions(messages[0])).toBeUndefined() + // Batch 2 last = tool at idx 2 + expect(getProviderOptions(messages[2])).toEqual(anthropicCache) + // Batch 3 last = user at idx 4 + expect(getProviderOptions(messages[4])).toEqual(anthropicCache) + // Assistants never get breakpoints + expect(getProviderOptions(messages[1])).toBeUndefined() + expect(getProviderOptions(messages[3])).toBeUndefined() + }) + + it("consecutive tool+user in same batch → only last in batch gets breakpoint", () => { + // [tool, user, assistant] → one batch: tool(0), user(1); then assistant(2) + // Last in batch = user at idx 1 + const messages = [makeToolMsg(), makeUserMsg(), makeAssistantMsg()] + + applyCacheBreakpoints(messages, defaultConfig()) + + expect(getProviderOptions(messages[0])).toBeUndefined() + expect(getProviderOptions(messages[1])).toEqual(anthropicCache) + expect(getProviderOptions(messages[2])).toBeUndefined() + }) + + it("long conversation with anchor enabled → anchor at ~1/3", () => { + // Create 6 batches: [user, assistant] × 6, ending with user + // Pattern: user, assistant, user, assistant, user, assistant, user, assistant, user, assistant, user + // Batches: idx 0, 2, 4, 6, 8, 10 → batchLastIndices = [0, 2, 4, 6, 8, 10] + const messages: RooMessage[] = [] + for (let i = 0; i < 6; i++) { + messages.push(makeUserMsg(`q${i}`)) + if (i < 5) { + messages.push(makeAssistantMsg(`a${i}`)) + } + } + // 6 batches total (each single user msg is a batch) + // anchorBatchIdx = Math.floor(6 / 3) = 2 → batchLastIndices[2] = idx 4 + + applyCacheBreakpoints( + messages, + defaultConfig({ maxMessageBreakpoints: 2, useAnchor: true, anchorThreshold: 5 }), + ) + + // Last 2 batches: idx 10 (batch 5) and idx 8 (batch 4) + expect(getProviderOptions(messages[10])).toEqual(anthropicCache) + expect(getProviderOptions(messages[8])).toEqual(anthropicCache) + // Anchor at batch 2 → idx 4 + expect(getProviderOptions(messages[4])).toEqual(anthropicCache) + // Other user messages should NOT have breakpoints + expect(getProviderOptions(messages[0])).toBeUndefined() + expect(getProviderOptions(messages[2])).toBeUndefined() + expect(getProviderOptions(messages[6])).toBeUndefined() + }) + + it("messages with existing providerOptions → preserved", () => { + const msg = makeUserMsg() + ;(msg as RooMessage & { providerOptions?: Record> }).providerOptions = { + other: { key: "val" }, + } + const messages = [msg] + + applyCacheBreakpoints(messages, defaultConfig()) + + const opts = getProviderOptions(messages[0]) + expect(opts).toEqual({ + other: { key: "val" }, + anthropic: { cacheControl: { type: "ephemeral" } }, + }) + }) + + it("maxMessageBreakpoints=3 (Bedrock config) → 3 breakpoints", () => { + // 4 batches: user, asst, user, asst, user, asst, user + const messages: RooMessage[] = [] + for (let i = 0; i < 4; i++) { + messages.push(makeUserMsg(`q${i}`)) + if (i < 3) { + messages.push(makeAssistantMsg(`a${i}`)) + } + } + // batchLastIndices = [0, 2, 4, 6] → 4 batches + // maxMessageBreakpoints=3 → last 3: idx 6, 4, 2 + + applyCacheBreakpoints(messages, defaultConfig({ maxMessageBreakpoints: 3 })) + + // Last 3 batch-end messages get breakpoints + expect(getProviderOptions(messages[6])).toEqual(anthropicCache) + expect(getProviderOptions(messages[4])).toEqual(anthropicCache) + expect(getProviderOptions(messages[2])).toEqual(anthropicCache) + // First batch (idx 0) should NOT + expect(getProviderOptions(messages[0])).toBeUndefined() + }) +}) diff --git a/src/api/transform/cache-breakpoints.ts b/src/api/transform/cache-breakpoints.ts new file mode 100644 index 0000000000..c79df139a0 --- /dev/null +++ b/src/api/transform/cache-breakpoints.ts @@ -0,0 +1,65 @@ +import type { RooMessage } from "../../core/task-persistence/rooMessage" + +export interface CacheBreakpointConfig { + /** The providerOptions value to apply to targeted messages */ + cacheProviderOption: Record> + /** Max number of message cache breakpoints (excluding system). Default: 2 */ + maxMessageBreakpoints?: number + /** Add an anchor breakpoint in the middle for long conversations. Default: false */ + useAnchor?: boolean + /** Min non-assistant batch count before anchor is added. Default: 5 */ + anchorThreshold?: number +} + +/** + * Apply cache breakpoints to RooMessage[]. + * + * Targets the last message in each non-assistant batch (user/tool). + * A "batch" is a consecutive run of non-assistant messages. + * We only cache the last message per batch to avoid redundant breakpoints. + * + * Mutates messages in place by adding providerOptions. + */ +export function applyCacheBreakpoints(messages: RooMessage[], config: CacheBreakpointConfig): void { + const { cacheProviderOption, maxMessageBreakpoints = 2, useAnchor = false, anchorThreshold = 5 } = config + + // Find the index of the last message in each non-assistant batch + const batchLastIndices: number[] = [] + let inBatch = false + + for (let i = 0; i < messages.length; i++) { + const msg = messages[i] + const isNonAssistant = "role" in msg && msg.role !== "assistant" + + if (isNonAssistant) { + inBatch = true + } else if (inBatch) { + batchLastIndices.push(i - 1) + inBatch = false + } + } + if (inBatch) { + batchLastIndices.push(messages.length - 1) + } + + // Select targets: last N batches + optional anchor + const targets = new Set() + const numBatches = batchLastIndices.length + + for (let j = 0; j < Math.min(maxMessageBreakpoints, numBatches); j++) { + targets.add(batchLastIndices[numBatches - 1 - j]) + } + + if (useAnchor && numBatches >= anchorThreshold) { + const anchorBatchIdx = Math.floor(numBatches / 3) + targets.add(batchLastIndices[anchorBatchIdx]) + } + + // Apply providerOptions to targeted messages + for (const idx of targets) { + const msg = messages[idx] as RooMessage & { + providerOptions?: Record> + } + msg.providerOptions = { ...msg.providerOptions, ...cacheProviderOption } + } +}