diff --git a/src/api/providers/__tests__/bedrock.spec.ts b/src/api/providers/__tests__/bedrock.spec.ts index 2cb09fc56db..51c019f08fd 100644 --- a/src/api/providers/__tests__/bedrock.spec.ts +++ b/src/api/providers/__tests__/bedrock.spec.ts @@ -1287,4 +1287,67 @@ describe("AwsBedrockHandler", () => { expect(mockCaptureException).toHaveBeenCalled() }) }) + + describe("prompt cache default behavior", () => { + function setupMockStreamText() { + async function* mockFullStream() { + yield { type: "text-delta", text: "response" } + } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + } + + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + it("should enable prompt caching by default when awsUsePromptCache is undefined", async () => { + setupMockStreamText() + + const defaultHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + // awsUsePromptCache is intentionally omitted (undefined) + }) + + const generator = defaultHandler.createMessage("You are a helpful assistant", messages) + for await (const _chunk of generator) { + // consume the stream + } + + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // systemProviderOptions should include cachePoint since prompt caching defaults to ON + expect(callArgs.systemProviderOptions).toEqual({ + bedrock: { cachePoint: { type: "default" } }, + }) + }) + + it("should disable prompt caching when awsUsePromptCache is explicitly false", async () => { + setupMockStreamText() + + const disabledHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsUsePromptCache: false, + }) + + const generator = disabledHandler.createMessage("You are a helpful assistant", messages) + for await (const _chunk of generator) { + // consume the stream + } + + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // systemProviderOptions should NOT include cachePoint since caching is explicitly disabled + expect(callArgs.systemProviderOptions).toBeUndefined() + }) + }) }) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 0bb5936c2ba..ca15d2578c3 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -271,7 +271,9 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH // We identify targets in the ORIGINAL Anthropic messages (before AI SDK conversion) // because convertToAiSdkMessages() splits user messages containing tool_results into // separate "tool" + "user" role messages, which would skew naive counting. - const usePromptCache = Boolean(this.options.awsUsePromptCache && this.supportsAwsPromptCache(modelConfig)) + const usePromptCache = Boolean( + (this.options.awsUsePromptCache ?? true) && this.supportsAwsPromptCache(modelConfig), + ) if (usePromptCache) { const cachePointOption = { bedrock: { cachePoint: { type: "default" as const } } } diff --git a/webview-ui/src/components/settings/providers/Bedrock.tsx b/webview-ui/src/components/settings/providers/Bedrock.tsx index d9c69f8a8e6..ed554f126d9 100644 --- a/webview-ui/src/components/settings/providers/Bedrock.tsx +++ b/webview-ui/src/components/settings/providers/Bedrock.tsx @@ -198,7 +198,7 @@ export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedMo {selectedModelInfo?.supportsPromptCache && ( <>
{t("settings:providers.enablePromptCaching")} diff --git a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts index 8925adf5fda..c09bf215979 100644 --- a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts +++ b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts @@ -464,6 +464,52 @@ describe("useSelectedModel", () => { }) }) + describe("bedrock provider with custom ARN", () => { + beforeEach(() => { + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + requesty: {}, + litellm: {}, + }, + isLoading: false, + isError: false, + } as any) + + mockUseOpenRouterModelProviders.mockReturnValue({ + data: {}, + isLoading: false, + isError: false, + } as any) + }) + + it("should enable supportsPromptCache for custom-arn model", () => { + const apiConfiguration: ProviderSettings = { + apiProvider: "bedrock", + apiModelId: "custom-arn", + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.id).toBe("custom-arn") + expect(result.current.info?.supportsPromptCache).toBe(true) + }) + + it("should enable supportsImages for custom-arn model", () => { + const apiConfiguration: ProviderSettings = { + apiProvider: "bedrock", + apiModelId: "custom-arn", + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.id).toBe("custom-arn") + expect(result.current.info?.supportsImages).toBe(true) + }) + }) + describe("litellm provider", () => { beforeEach(() => { mockUseOpenRouterModelProviders.mockReturnValue({ diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index fa28684e243..5f0feb1fc7f 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -183,7 +183,7 @@ function getSelectedModel({ if (id === "custom-arn") { return { id, - info: { maxTokens: 5000, contextWindow: 128_000, supportsPromptCache: false, supportsImages: true }, + info: { maxTokens: 5000, contextWindow: 128_000, supportsPromptCache: true, supportsImages: true }, } }