From 1a87e92359502dccd657ac16669c7603a2b5a7ad Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Mon, 9 Feb 2026 17:32:16 -0500 Subject: [PATCH] feat: migrate Requesty provider to AI SDK (@requesty/ai-sdk) Migrate the Requesty provider from the raw OpenAI client to the dedicated @requesty/ai-sdk package using the Vercel AI SDK pattern, consistent with other migrated providers (DeepSeek, xAI, Fireworks, etc.). Changes: - Replace OpenAI client with createRequesty from @requesty/ai-sdk - Use streamText/generateText from 'ai' package for streaming and completion - Use shared AI SDK utilities (convertToAiSdkMessages, processAiSdkStreamPart, convertToolsForAiSdk, mapToolChoice, handleAiSdkError) - Map reasoning effort and budget to Requesty's providerOptions.reasoningEffort - Pass trace_id and mode metadata via providerOptions.requesty.extraBody - Extract cache metrics from RequestyProviderMetadata via providerMetadata - Switch getModelParams format from 'anthropic' to 'openai' - Add isAiSdkProvider() returning true - Add type declarations for @requesty/ai-sdk (packaging workaround) - Update all tests to mock AI SDK instead of raw OpenAI client --- pnpm-lock.yaml | 19 +- src/api/providers/__tests__/requesty.spec.ts | 595 +++++++++++-------- src/api/providers/requesty.ts | 281 +++++---- src/package.json | 1 + src/types/requesty-ai-sdk.d.ts | 58 ++ 5 files changed, 588 insertions(+), 366 deletions(-) create mode 100644 src/types/requesty-ai-sdk.d.ts diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 304c654ef75..7e643479a6f 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -803,6 +803,9 @@ importers: '@qdrant/js-client-rest': specifier: ^1.14.0 version: 1.14.0(typescript@5.8.3) + '@requesty/ai-sdk': + specifier: ^3.0.0 + version: 3.0.0(vite@6.3.5(@types/node@20.17.50)(jiti@2.4.2)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0))(zod@3.25.76) '@roo-code/cloud': specifier: workspace:^ version: link:../packages/cloud @@ -3774,6 +3777,13 @@ packages: peerDependencies: '@redis/client': ^5.5.5 + '@requesty/ai-sdk@3.0.0': + resolution: {integrity: sha512-kGZuOshBx5yjQ3LQKITl/W89Xu8HrektZ6pN0peI64sQD8fpoWionUU+39m0fkdZhOlotebj5zO3emkX7y7O6Q==} + engines: {node: '>=18'} + peerDependencies: + vite: ^7.1.7 + zod: 3.25.76 + '@resvg/resvg-wasm@2.4.0': resolution: {integrity: sha512-C7c51Nn4yTxXFKvgh2txJFNweaVcfUPQxwEUFw4aWsCmfiBDJsTSwviIF8EcwjQ6k8bPyMWCl1vw4BdxE569Cg==} engines: {node: '>= 10'} @@ -13808,6 +13818,13 @@ snapshots: dependencies: '@redis/client': 5.5.5 + '@requesty/ai-sdk@3.0.0(vite@6.3.5(@types/node@20.17.50)(jiti@2.4.2)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0))(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.8 + '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) + vite: 6.3.5(@types/node@20.17.50)(jiti@2.4.2)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) + zod: 3.25.76 + '@resvg/resvg-wasm@2.4.0': {} '@rollup/rollup-android-arm-eabi@4.40.2': @@ -14953,7 +14970,7 @@ snapshots: sirv: 3.0.1 tinyglobby: 0.2.14 tinyrainbow: 2.0.0 - vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.2.1)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.17.50)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) '@vitest/utils@3.2.4': dependencies: diff --git a/src/api/providers/__tests__/requesty.spec.ts b/src/api/providers/__tests__/requesty.spec.ts index ea6a36b4b44..3be0b4451e1 100644 --- a/src/api/providers/__tests__/requesty.spec.ts +++ b/src/api/providers/__tests__/requesty.spec.ts @@ -1,31 +1,33 @@ // npx vitest run api/providers/__tests__/requesty.spec.ts -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" - -import { RequestyHandler } from "../requesty" -import { ApiHandlerOptions } from "../../../shared/api" -import { Package } from "../../../shared/package" -import { ApiHandlerCreateMessageMetadata } from "../../index" - -const mockCreate = vitest.fn() +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) -vitest.mock("openai", () => { +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - default: vitest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate, - }, - }, - })), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) -vitest.mock("delay", () => ({ default: vitest.fn(() => Promise.resolve()) })) +vi.mock("@requesty/ai-sdk", () => ({ + createRequesty: vi.fn(() => { + return vi.fn(() => ({ + modelId: "coding/claude-4-sonnet", + provider: "requesty", + })) + }), +})) + +vi.mock("delay", () => ({ default: vi.fn(() => Promise.resolve()) })) -vitest.mock("../fetchers/modelCache", () => ({ - getModels: vitest.fn().mockImplementation(() => { +vi.mock("../fetchers/modelCache", () => ({ + getModels: vi.fn().mockImplementation(() => { return Promise.resolve({ "coding/claude-4-sonnet": { maxTokens: 8192, @@ -42,41 +44,62 @@ vitest.mock("../fetchers/modelCache", () => ({ }), })) +import type { Anthropic } from "@anthropic-ai/sdk" +import { createRequesty } from "@requesty/ai-sdk" + +import { requestyDefaultModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" +import type { ApiHandlerCreateMessageMetadata } from "../../index" + +import { RequestyHandler } from "../requesty" + describe("RequestyHandler", () => { const mockOptions: ApiHandlerOptions = { requestyApiKey: "test-key", requestyModelId: "coding/claude-4-sonnet", } - beforeEach(() => vitest.clearAllMocks()) + beforeEach(() => vi.clearAllMocks()) - it("initializes with correct options", () => { - const handler = new RequestyHandler(mockOptions) - expect(handler).toBeInstanceOf(RequestyHandler) + describe("constructor", () => { + it("initializes with correct options", () => { + const handler = new RequestyHandler(mockOptions) + expect(handler).toBeInstanceOf(RequestyHandler) - expect(OpenAI).toHaveBeenCalledWith({ - baseURL: "https://router.requesty.ai/v1", - apiKey: mockOptions.requestyApiKey, - defaultHeaders: { - "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", - "X-Title": "Roo Code", - "User-Agent": `RooCode/${Package.version}`, - }, + expect(createRequesty).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://router.requesty.ai/v1", + apiKey: mockOptions.requestyApiKey, + compatibility: "compatible", + }), + ) }) - }) - it("can use a base URL instead of the default", () => { - const handler = new RequestyHandler({ ...mockOptions, requestyBaseUrl: "https://custom.requesty.ai/v1" }) - expect(handler).toBeInstanceOf(RequestyHandler) - - expect(OpenAI).toHaveBeenCalledWith({ - baseURL: "https://custom.requesty.ai/v1", - apiKey: mockOptions.requestyApiKey, - defaultHeaders: { - "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", - "X-Title": "Roo Code", - "User-Agent": `RooCode/${Package.version}`, - }, + it("can use a custom base URL", () => { + const handler = new RequestyHandler({ + ...mockOptions, + requestyBaseUrl: "https://custom.requesty.ai/v1", + }) + expect(handler).toBeInstanceOf(RequestyHandler) + + expect(createRequesty).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://custom.requesty.ai/v1", + apiKey: mockOptions.requestyApiKey, + }), + ) + }) + + it("uses 'not-provided' when no API key is given", () => { + const handler = new RequestyHandler({}) + expect(handler).toBeInstanceOf(RequestyHandler) + + expect(createRequesty).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "not-provided", + }), + ) }) }) @@ -101,66 +124,50 @@ describe("RequestyHandler", () => { }) }) - it("returns default model info when options are not provided", async () => { - const handler = new RequestyHandler({}) + it("returns default model info when requestyModelId is not provided", async () => { + const handler = new RequestyHandler({ requestyApiKey: "test-key" }) const result = await handler.fetchModel() - expect(result).toMatchObject({ - id: mockOptions.requestyModelId, - info: { - maxTokens: 8192, - contextWindow: 200000, - supportsImages: true, - supportsPromptCache: true, - inputPrice: 3, - outputPrice: 15, - cacheWritesPrice: 3.75, - cacheReadsPrice: 0.3, - description: "Claude 4 Sonnet", - }, - }) + expect(result.id).toBe(requestyDefaultModelId) }) }) describe("createMessage", () => { - it("generates correct stream chunks", async () => { - const handler = new RequestyHandler(mockOptions) + const systemPrompt = "test system prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }] - const mockStream = { - async *[Symbol.asyncIterator]() { - yield { - id: mockOptions.requestyModelId, - choices: [{ delta: { content: "test response" } }], - } - yield { - id: "test-id", - choices: [{ delta: {} }], - usage: { - prompt_tokens: 10, - completion_tokens: 20, - prompt_tokens_details: { - caching_tokens: 5, - cached_tokens: 2, - }, - }, - } - }, + it("generates correct stream chunks", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "test response" } } - mockCreate.mockResolvedValue(mockStream) + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) - const systemPrompt = "test system prompt" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }] + const mockProviderMetadata = Promise.resolve({ + requesty: { + usage: { + cachingTokens: 5, + cachedTokens: 2, + }, + }, + }) - const generator = handler.createMessage(systemPrompt, messages) - const chunks = [] + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) - for await (const chunk of generator) { + const handler = new RequestyHandler(mockOptions) + const chunks: any[] = [] + for await (const chunk of handler.createMessage(systemPrompt, messages)) { chunks.push(chunk) } - // Verify stream chunks - expect(chunks).toHaveLength(2) // One text chunk and one usage chunk + expect(chunks).toHaveLength(2) expect(chunks[0]).toEqual({ type: "text", text: "test response" }) expect(chunks[1]).toEqual({ type: "usage", @@ -168,181 +175,199 @@ describe("RequestyHandler", () => { outputTokens: 20, cacheWriteTokens: 5, cacheReadTokens: 2, + reasoningTokens: undefined, totalCost: expect.any(Number), }) + }) + + it("calls streamText with correct parameters", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "test response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + + const handler = new RequestyHandler(mockOptions) + const stream = handler.createMessage(systemPrompt, messages) + for await (const _chunk of stream) { + // consume + } - // Verify OpenAI client was called with correct parameters - expect(mockCreate).toHaveBeenCalledWith( + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - max_tokens: 8192, - messages: [ - { - role: "system", - content: "test system prompt", - }, - { - role: "user", - content: "test message", - }, - ], - model: "coding/claude-4-sonnet", - stream: true, - stream_options: { include_usage: true }, + system: "test system prompt", temperature: 0, + maxOutputTokens: 8192, }), ) }) - it("handles API errors", async () => { + it("passes trace_id and mode via providerOptions", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "test response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + mode: "code", + } + const handler = new RequestyHandler(mockOptions) + const stream = handler.createMessage(systemPrompt, messages, metadata) + for await (const _chunk of stream) { + // consume + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + providerOptions: { + requesty: expect.objectContaining({ + extraBody: { + requesty: { + trace_id: "test-task", + extra: { mode: "code" }, + }, + }, + }), + }, + }), + ) + }) + + it("handles API errors", async () => { const mockError = new Error("API Error") - mockCreate.mockRejectedValue(mockError) - const generator = handler.createMessage("test", []) - await expect(generator.next()).rejects.toThrow("API Error") + async function* errorStream() { + yield { type: "text-delta", text: "" } + throw mockError + } + + mockStreamText.mockReturnValue({ + fullStream: errorStream(), + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), + }) + + const handler = new RequestyHandler(mockOptions) + const generator = handler.createMessage(systemPrompt, messages) + await generator.next() + await expect(generator.next()).rejects.toThrow() }) describe("native tool support", () => { - const systemPrompt = "test system prompt" - const messages: Anthropic.Messages.MessageParam[] = [ + const toolMessages: Anthropic.Messages.MessageParam[] = [ { role: "user" as const, content: "What's the weather?" }, ] - const mockTools: OpenAI.Chat.ChatCompletionTool[] = [ - { - type: "function", - function: { - name: "get_weather", - description: "Get the current weather", - parameters: { - type: "object", - properties: { - location: { type: "string" }, + it("should include tools in request when tools are provided", async () => { + const mockTools = [ + { + type: "function", + function: { + name: "get_weather", + description: "Get the current weather", + parameters: { + type: "object", + properties: { + location: { type: "string" }, + }, + required: ["location"], }, - required: ["location"], }, }, - }, - ] + ] - beforeEach(() => { - const mockStream = { - async *[Symbol.asyncIterator]() { - yield { - id: "test-id", - choices: [{ delta: { content: "test response" } }], - } - }, + async function* mockFullStream() { + yield { type: "text-delta", text: "test response" } } - mockCreate.mockResolvedValue(mockStream) - }) - it("should include tools in request when tools are provided", async () => { + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + const metadata: ApiHandlerCreateMessageMetadata = { taskId: "test-task", - tools: mockTools, + tools: mockTools as any, tool_choice: "auto", } const handler = new RequestyHandler(mockOptions) - const iterator = handler.createMessage(systemPrompt, messages, metadata) - await iterator.next() + const stream = handler.createMessage(systemPrompt, toolMessages, metadata) + for await (const _chunk of stream) { + // consume + } - expect(mockCreate).toHaveBeenCalledWith( + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - type: "function", - function: expect.objectContaining({ - name: "get_weather", - description: "Get the current weather", - }), - }), - ]), - tool_choice: "auto", + tools: expect.any(Object), + toolChoice: expect.any(String), }), ) }) - it("should handle tool_call_partial chunks in streaming response", async () => { - const mockStreamWithToolCalls = { - async *[Symbol.asyncIterator]() { - yield { - id: "test-id", - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_123", - function: { - name: "get_weather", - arguments: '{"location":', - }, - }, - ], - }, - }, - ], - } - yield { - id: "test-id", - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { - arguments: '"New York"}', - }, - }, - ], - }, - }, - ], - } - yield { - id: "test-id", - choices: [{ delta: {} }], - usage: { prompt_tokens: 10, completion_tokens: 20 }, - } - }, + it("should handle tool call streaming parts", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "call_123", + toolName: "get_weather", + } + yield { + type: "tool-input-delta", + id: "call_123", + delta: '{"location":"New York"}', + } + yield { + type: "tool-input-end", + id: "call_123", + } } - mockCreate.mockResolvedValue(mockStreamWithToolCalls) - const metadata: ApiHandlerCreateMessageMetadata = { - taskId: "test-task", - tools: mockTools, - } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) const handler = new RequestyHandler(mockOptions) - const chunks = [] - for await (const chunk of handler.createMessage(systemPrompt, messages, metadata)) { + const chunks: any[] = [] + for await (const chunk of handler.createMessage(systemPrompt, toolMessages)) { chunks.push(chunk) } - // Expect two tool_call_partial chunks and one usage chunk - expect(chunks).toHaveLength(3) - expect(chunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, + const startChunks = chunks.filter((c) => c.type === "tool_call_start") + expect(startChunks).toHaveLength(1) + expect(startChunks[0]).toEqual({ + type: "tool_call_start", id: "call_123", name: "get_weather", - arguments: '{"location":', }) - expect(chunks[1]).toEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '"New York"}', + + const deltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + expect(deltaChunks).toHaveLength(1) + expect(deltaChunks[0]).toEqual({ + type: "tool_call_delta", + id: "call_123", + delta: '{"location":"New York"}', }) - expect(chunks[2]).toMatchObject({ - type: "usage", - inputTokens: 10, - outputTokens: 20, + + const endChunks = chunks.filter((c) => c.type === "tool_call_end") + expect(endChunks).toHaveLength(1) + expect(endChunks[0]).toEqual({ + type: "tool_call_end", + id: "call_123", }) }) }) @@ -350,36 +375,134 @@ describe("RequestyHandler", () => { describe("completePrompt", () => { it("returns correct response", async () => { - const handler = new RequestyHandler(mockOptions) - const mockResponse = { choices: [{ message: { content: "test completion" } }] } - - mockCreate.mockResolvedValue(mockResponse) + mockGenerateText.mockResolvedValue({ + text: "test completion", + }) + const handler = new RequestyHandler(mockOptions) const result = await handler.completePrompt("test prompt") expect(result).toBe("test completion") - - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.requestyModelId, - max_tokens: 8192, - messages: [{ role: "system", content: "test prompt" }], - temperature: 0, - }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "test prompt", + maxOutputTokens: 8192, + temperature: 0, + }), + ) }) it("handles API errors", async () => { - const handler = new RequestyHandler(mockOptions) const mockError = new Error("API Error") - mockCreate.mockRejectedValue(mockError) + mockGenerateText.mockRejectedValue(mockError) - await expect(handler.completePrompt("test prompt")).rejects.toThrow("API Error") + const handler = new RequestyHandler(mockOptions) + await expect(handler.completePrompt("test prompt")).rejects.toThrow() }) + }) - it("handles unexpected errors", async () => { - const handler = new RequestyHandler(mockOptions) - mockCreate.mockRejectedValue(new Error("Unexpected error")) + describe("processUsageMetrics", () => { + it("should correctly process usage metrics with Requesty provider metadata", () => { + class TestRequestyHandler extends RequestyHandler { + public testProcessUsageMetrics(usage: any, modelInfo?: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, modelInfo, providerMetadata) + } + } + + const testHandler = new TestRequestyHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const providerMetadata = { + requesty: { + usage: { + cachingTokens: 10, + cachedTokens: 20, + }, + }, + } + + const modelInfo = { + maxTokens: 8192, + contextWindow: 200000, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 3, + outputPrice: 15, + cacheWritesPrice: 3.75, + cacheReadsPrice: 0.3, + } + + const result = testHandler.testProcessUsageMetrics(usage, modelInfo, providerMetadata) - await expect(handler.completePrompt("test prompt")).rejects.toThrow("Unexpected error") + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(10) + expect(result.cacheReadTokens).toBe(20) + expect(result.totalCost).toBeGreaterThan(0) + }) + + it("should fall back to usage.details when providerMetadata is absent", () => { + class TestRequestyHandler extends RequestyHandler { + public testProcessUsageMetrics(usage: any, modelInfo?: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, modelInfo, providerMetadata) + } + } + + const testHandler = new TestRequestyHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 15, + reasoningTokens: 25, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(0) + expect(result.cacheReadTokens).toBe(15) + expect(result.reasoningTokens).toBe(25) + }) + + it("should handle missing cache metrics gracefully", () => { + class TestRequestyHandler extends RequestyHandler { + public testProcessUsageMetrics(usage: any, modelInfo?: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, modelInfo, providerMetadata) + } + } + + const testHandler = new TestRequestyHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(0) + expect(result.cacheReadTokens).toBe(0) + expect(result.totalCost).toBe(0) + }) + }) + + describe("isAiSdkProvider", () => { + it("returns true", () => { + const handler = new RequestyHandler(mockOptions) + expect(handler.isAiSdkProvider()).toBe(true) }) }) }) diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index b241c347b08..630c71345eb 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -1,60 +1,39 @@ import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" +import { createRequesty, type RequestyProviderMetadata } from "@requesty/ai-sdk" +import { streamText, generateText, ToolSet } from "ai" import { type ModelInfo, type ModelRecord, requestyDefaultModelId, requestyDefaultModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" import { calculateApiCostOpenAI } from "../../shared/cost" -import { convertToOpenAiMessages } from "../transform/openai-format" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" -import { AnthropicReasoningParams } from "../transform/reasoning" import { DEFAULT_HEADERS } from "./constants" import { getModels } from "./fetchers/modelCache" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { toRequestyServiceUrl } from "../../shared/utils/requesty" -import { handleOpenAIError } from "./utils/openai-error-handler" import { applyRouterToolPreferences } from "./utils/router-tool-preferences" -// Requesty usage includes an extra field for Anthropic use cases. -// Safely cast the prompt token details section to the appropriate structure. -interface RequestyUsage extends OpenAI.CompletionUsage { - prompt_tokens_details?: { - caching_tokens?: number - cached_tokens?: number - } - total_cost?: number -} - -type RequestyChatCompletionParamsStreaming = OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming & { - requesty?: { - trace_id?: string - extra?: { - mode?: string - } - } - thinking?: AnthropicReasoningParams -} - -type RequestyChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & { - requesty?: { - trace_id?: string - extra?: { - mode?: string - } - } - thinking?: AnthropicReasoningParams -} - +/** + * Requesty provider using the dedicated @requesty/ai-sdk package. + * Requesty is a unified LLM gateway providing access to 300+ models. + * This handler uses the Vercel AI SDK for streaming and tool support. + */ export class RequestyHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions protected models: ModelRecord = {} - private client: OpenAI + protected provider: ReturnType private baseURL: string - private readonly providerName = "Requesty" constructor(options: ApiHandlerOptions) { super() @@ -64,10 +43,11 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan const apiKey = this.options.requestyApiKey ?? "not-provided" - this.client = new OpenAI({ + this.provider = createRequesty({ baseURL: this.baseURL, apiKey: apiKey, - defaultHeaders: DEFAULT_HEADERS, + headers: DEFAULT_HEADERS, + compatibility: "compatible", }) } @@ -81,11 +61,10 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan const cachedInfo = this.models[id] ?? requestyDefaultModelInfo let info: ModelInfo = cachedInfo - // Apply tool preferences for models accessed through routers (OpenAI, Gemini) info = applyRouterToolPreferences(id, info) const params = getModelParams({ - format: "anthropic", + format: "openai", modelId: id, model: info, settings: this.options, @@ -95,125 +74,169 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan return { id, info, ...params } } - protected processUsageMetrics(usage: any, modelInfo?: ModelInfo): ApiStreamUsageChunk { - const requestyUsage = usage as RequestyUsage - const inputTokens = requestyUsage?.prompt_tokens || 0 - const outputTokens = requestyUsage?.completion_tokens || 0 - const cacheWriteTokens = requestyUsage?.prompt_tokens_details?.caching_tokens || 0 - const cacheReadTokens = requestyUsage?.prompt_tokens_details?.cached_tokens || 0 + /** + * Get the language model for the configured model ID, including reasoning settings. + * + * Reasoning settings (includeReasoning, reasoningEffort) must be passed as model + * settings — NOT as providerOptions — because the SDK reads them from this.settings + * to populate the top-level `include_reasoning` and `reasoning_effort` request fields. + */ + protected getLanguageModel() { + const { id, reasoningEffort, reasoningBudget } = this.getModel() + + let resolvedEffort: string | undefined + if (reasoningBudget) { + resolvedEffort = String(reasoningBudget) + } else if (reasoningEffort) { + resolvedEffort = reasoningEffort + } + + const needsReasoning = !!resolvedEffort + + return this.provider(id, { + ...(needsReasoning ? { includeReasoning: true, reasoningEffort: resolvedEffort } : {}), + }) + } + + /** + * Get the max output tokens parameter to include in the request. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } + + /** + * Build the Requesty provider options for tracing metadata. + * + * Note: providerOptions.requesty gets placed into body.requesty (the Requesty + * metadata field), NOT into top-level body fields. Only tracing/metadata should + * go here — reasoning settings go through model settings in getLanguageModel(). + */ + private getRequestyProviderOptions(metadata?: ApiHandlerCreateMessageMetadata) { + if (!metadata?.taskId && !metadata?.mode) { + return undefined + } + + return { + extraBody: { + requesty: { + trace_id: metadata?.taskId ?? null, + extra: { mode: metadata?.mode ?? null }, + }, + }, + } + } + + /** + * Process usage metrics from the AI SDK response, including Requesty's cache metrics. + * + * Requesty provides cache hit/miss info via providerMetadata, but only when both + * cachingTokens and cachedTokens are non-zero (SDK limitation). We fall back to + * usage.details.cachedInputTokens when providerMetadata is empty. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + modelInfo?: ModelInfo, + providerMetadata?: RequestyProviderMetadata, + ): ApiStreamUsageChunk { + const inputTokens = usage.inputTokens || 0 + const outputTokens = usage.outputTokens || 0 + const cacheWriteTokens = providerMetadata?.requesty?.usage?.cachingTokens ?? 0 + const cacheReadTokens = providerMetadata?.requesty?.usage?.cachedTokens ?? usage.details?.cachedInputTokens ?? 0 + const { totalCost } = modelInfo ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) : { totalCost: 0 } return { type: "usage", - inputTokens: inputTokens, - outputTokens: outputTokens, - cacheWriteTokens: cacheWriteTokens, - cacheReadTokens: cacheReadTokens, - totalCost: totalCost, + inputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + reasoningTokens: usage.details?.reasoningTokens, + totalCost, } } + /** + * Create a message stream using the AI SDK. + */ override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { - id: model, - info, - maxTokens: max_tokens, - temperature, - reasoningEffort: reasoning_effort, - reasoning: thinking, - } = await this.fetchModel() - - const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ - { role: "system", content: systemPrompt }, - ...convertToOpenAiMessages(messages), - ] - - // Map extended efforts to OpenAI Chat Completions-accepted values (omit unsupported) - const allowedEffort = (["low", "medium", "high"] as const).includes(reasoning_effort as any) - ? (reasoning_effort as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming["reasoning_effort"]) - : undefined - - const completionParams: RequestyChatCompletionParamsStreaming = { - messages: openAiMessages, - model, - max_tokens, - temperature, - ...(allowedEffort && { reasoning_effort: allowedEffort }), - ...(thinking && { thinking }), - stream: true, - stream_options: { include_usage: true }, - requesty: { trace_id: metadata?.taskId, extra: { mode: metadata?.mode } }, - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - } + const { info, temperature } = await this.fetchModel() + const languageModel = this.getLanguageModel() - let stream - try { - // With streaming params type, SDK returns an async iterable stream - stream = await this.client.chat.completions.create(completionParams) - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } - let lastUsage: any = undefined + const aiSdkMessages = convertToAiSdkMessages(messages) - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined - if (delta?.content) { - yield { type: "text", text: delta.content } - } + const requestyOptions = this.getRequestyProviderOptions(metadata) - if (delta && "reasoning_content" in delta && delta.reasoning_content) { - yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" } - } + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? 0, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), + ...(requestyOptions ? { providerOptions: { requesty: requestyOptions } } : {}), + } + + const result = streamText(requestOptions) - // Handle native tool calls - if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } + try { + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk } } - if (chunk.usage) { - lastUsage = chunk.usage + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, info, providerMetadata as RequestyProviderMetadata) } - } - - if (lastUsage) { - yield this.processUsageMetrics(lastUsage, info) + } catch (error) { + throw handleAiSdkError(error, "Requesty") } } + /** + * Complete a prompt using the AI SDK generateText. + */ async completePrompt(prompt: string): Promise { - const { id: model, maxTokens: max_tokens, temperature } = await this.fetchModel() + const { temperature } = await this.fetchModel() + const languageModel = this.getLanguageModel() - let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [{ role: "system", content: prompt }] - - const completionParams: RequestyChatCompletionParams = { - model, - max_tokens, - messages: openAiMessages, - temperature: temperature, - } - - let response: OpenAI.Chat.ChatCompletion try { - response = await this.client.chat.completions.create(completionParams) + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? 0, + }) + + return text } catch (error) { - throw handleOpenAIError(error, this.providerName) + throw handleAiSdkError(error, "Requesty") } - return response.choices[0]?.message.content || "" + } + + override isAiSdkProvider(): boolean { + return true } } diff --git a/src/package.json b/src/package.json index 72fb3b5df99..213e864a5da 100644 --- a/src/package.json +++ b/src/package.json @@ -469,6 +469,7 @@ "@mistralai/mistralai": "^1.9.18", "@modelcontextprotocol/sdk": "1.12.0", "@qdrant/js-client-rest": "^1.14.0", + "@requesty/ai-sdk": "^3.0.0", "@roo-code/cloud": "workspace:^", "@roo-code/core": "workspace:^", "@roo-code/ipc": "workspace:^", diff --git a/src/types/requesty-ai-sdk.d.ts b/src/types/requesty-ai-sdk.d.ts new file mode 100644 index 00000000000..d43c1d01715 --- /dev/null +++ b/src/types/requesty-ai-sdk.d.ts @@ -0,0 +1,58 @@ +declare module "@requesty/ai-sdk" { + import type { LanguageModelV2 } from "@ai-sdk/provider" + + type RequestyLanguageModel = LanguageModelV2 + + interface RequestyProviderMetadata { + requesty?: { + usage?: { + cachingTokens?: number + cachedTokens?: number + } + } + } + + type RequestyChatModelId = string + + type RequestyChatSettings = { + logitBias?: Record + logprobs?: boolean | number + parallelToolCalls?: boolean + user?: string + includeReasoning?: boolean + reasoningEffort?: "low" | "medium" | "high" | "max" | string + extraBody?: Record + models?: string[] + } + + interface RequestyProvider { + (modelId: RequestyChatModelId, settings?: RequestyChatSettings): RequestyLanguageModel + languageModel(modelId: RequestyChatModelId, settings?: RequestyChatSettings): RequestyLanguageModel + chat(modelId: RequestyChatModelId, settings?: RequestyChatSettings): RequestyLanguageModel + } + + interface RequestyProviderSettings { + baseURL?: string + baseUrl?: string + apiKey?: string + headers?: Record + compatibility?: "strict" | "compatible" + fetch?: typeof fetch + extraBody?: Record + } + + function createRequesty(options?: RequestyProviderSettings): RequestyProvider + + const requesty: RequestyProvider + + export { + type RequestyChatModelId, + type RequestyChatSettings, + type RequestyLanguageModel, + type RequestyProvider, + type RequestyProviderMetadata, + type RequestyProviderSettings, + createRequesty, + requesty, + } +}