From 9fd6218d8481d65fb155d337cde181d57f3efdd5 Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Mon, 9 Feb 2026 18:02:39 -0500 Subject: [PATCH 1/2] feat: migrate OpenAiHandler to AI SDK Replace direct OpenAI SDK usage in the OpenAI Compatible provider with Vercel AI SDK, following the same pattern as already-migrated providers. Provider strategy: - Standard endpoints: @ai-sdk/openai (createOpenAI) with .chat() - Azure AI Inference: @ai-sdk/openai-compatible (createOpenAICompatible) with adjusted baseURL (/models) and queryParams for api-version - Azure OpenAI: @ai-sdk/azure (createAzure) with useDeploymentBasedUrls Key changes: - Streaming uses streamText() + processAiSdkStreamPart() - Non-streaming uses generateText() - O3/O1/O4 family: providerOptions.openai.systemMessageMode='developer' and reasoningEffort - DeepSeek R1 format: system prompt prepended as user message - TagMatcher retained for tag extraction in streaming text - reasoning_content handled natively by AI SDK providers - isAiSdkProvider() returns true - getOpenAiModels() standalone function unchanged Tests updated across all 4 test files (58 tests passing). --- .../__tests__/openai-native-tools.spec.ts | 53 +- .../__tests__/openai-timeout.spec.ts | 106 +- .../__tests__/openai-usage-tracking.spec.ts | 205 +--- src/api/providers/__tests__/openai.spec.ts | 1038 ++++++----------- src/api/providers/openai.ts | 596 ++++------ 5 files changed, 707 insertions(+), 1291 deletions(-) diff --git a/src/api/providers/__tests__/openai-native-tools.spec.ts b/src/api/providers/__tests__/openai-native-tools.spec.ts index d873b7457bb..371f6500dd3 100644 --- a/src/api/providers/__tests__/openai-native-tools.spec.ts +++ b/src/api/providers/__tests__/openai-native-tools.spec.ts @@ -4,15 +4,27 @@ import OpenAI from "openai" import { OpenAiHandler } from "../openai" +vi.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: vi.fn(() => vi.fn((modelId: string) => ({ modelId, provider: "openai-compatible" }))), +})) + +vi.mock("@ai-sdk/azure", () => ({ + createAzure: vi.fn(() => ({ + chat: vi.fn((modelId: string) => ({ modelId, provider: "azure.chat" })), + })), +})) + describe("OpenAiHandler native tools", () => { it("includes tools in request when tools are provided via metadata (regression test)", async () => { - const mockCreate = vi.fn().mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" } }], - } - }, - })) + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValueOnce({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve(undefined), + }) // Set openAiCustomModelInfo without any tool capability flags; tools should // still be passed whenever metadata.tools is present. @@ -26,16 +38,6 @@ describe("OpenAiHandler native tools", () => { }, } as unknown as import("../../../shared/api").ApiHandlerOptions) - // Patch the OpenAI client call - const mockClient = { - chat: { - completions: { - create: mockCreate, - }, - }, - } as unknown as OpenAI - ;(handler as unknown as { client: OpenAI }).client = mockClient - const tools: OpenAI.Chat.ChatCompletionTool[] = [ { type: "function", @@ -53,17 +55,12 @@ describe("OpenAiHandler native tools", () => { }) await stream.next() - expect(mockCreate).toHaveBeenCalledWith( + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - type: "function", - function: expect.objectContaining({ name: "test_tool" }), - }), - ]), - parallel_tool_calls: true, + tools: expect.objectContaining({ + test_tool: expect.anything(), + }), }), - expect.anything(), ) }) }) @@ -92,6 +89,10 @@ vi.mock("@ai-sdk/openai", () => ({ modelId: "gpt-4o", provider: "openai.responses", })) + ;(provider as any).chat = vi.fn((modelId: string) => ({ + modelId, + provider: "openai.chat", + })) return provider }), })) diff --git a/src/api/providers/__tests__/openai-timeout.spec.ts b/src/api/providers/__tests__/openai-timeout.spec.ts index 2a09fd94ffa..02805ec1853 100644 --- a/src/api/providers/__tests__/openai-timeout.spec.ts +++ b/src/api/providers/__tests__/openai-timeout.spec.ts @@ -3,51 +3,34 @@ import { OpenAiHandler } from "../openai" import { ApiHandlerOptions } from "../../../shared/api" -// Mock the timeout config utility -vitest.mock("../utils/timeout-config", () => ({ - getApiRequestTimeout: vitest.fn(), +const mockCreateOpenAI = vi.hoisted(() => vi.fn()) +const mockCreateOpenAICompatible = vi.hoisted(() => vi.fn()) +const mockCreateAzure = vi.hoisted(() => vi.fn()) + +vi.mock("@ai-sdk/openai", () => ({ + createOpenAI: mockCreateOpenAI.mockImplementation(() => ({ + chat: vi.fn(() => ({ modelId: "test", provider: "openai.chat" })), + })), })) -import { getApiRequestTimeout } from "../utils/timeout-config" - -// Mock OpenAI and AzureOpenAI -const mockOpenAIConstructor = vitest.fn() -const mockAzureOpenAIConstructor = vitest.fn() - -vitest.mock("openai", () => { - return { - __esModule: true, - default: vitest.fn().mockImplementation((config) => { - mockOpenAIConstructor(config) - return { - chat: { - completions: { - create: vitest.fn(), - }, - }, - } - }), - AzureOpenAI: vitest.fn().mockImplementation((config) => { - mockAzureOpenAIConstructor(config) - return { - chat: { - completions: { - create: vitest.fn(), - }, - }, - } - }), - } -}) +vi.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: mockCreateOpenAICompatible.mockImplementation(() => + vi.fn((modelId: string) => ({ modelId, provider: "openai-compatible" })), + ), +})) + +vi.mock("@ai-sdk/azure", () => ({ + createAzure: mockCreateAzure.mockImplementation(() => ({ + chat: vi.fn((modelId: string) => ({ modelId, provider: "azure.chat" })), + })), +})) -describe("OpenAiHandler timeout configuration", () => { +describe("OpenAiHandler provider configuration", () => { beforeEach(() => { - vitest.clearAllMocks() + vi.clearAllMocks() }) - it("should use default timeout for standard OpenAI", () => { - ;(getApiRequestTimeout as any).mockReturnValue(600000) - + it("should use createOpenAI for standard OpenAI endpoints", () => { const options: ApiHandlerOptions = { apiModelId: "gpt-4", openAiModelId: "gpt-4", @@ -56,19 +39,15 @@ describe("OpenAiHandler timeout configuration", () => { new OpenAiHandler(options) - expect(getApiRequestTimeout).toHaveBeenCalled() - expect(mockOpenAIConstructor).toHaveBeenCalledWith( + expect(mockCreateOpenAI).toHaveBeenCalledWith( expect.objectContaining({ baseURL: "https://api.openai.com/v1", apiKey: "test-key", - timeout: 600000, // 600 seconds in milliseconds }), ) }) - it("should use custom timeout for OpenAI-compatible providers", () => { - ;(getApiRequestTimeout as any).mockReturnValue(1800000) // 30 minutes - + it("should use createOpenAI for custom OpenAI-compatible providers", () => { const options: ApiHandlerOptions = { apiModelId: "custom-model", openAiModelId: "custom-model", @@ -78,17 +57,14 @@ describe("OpenAiHandler timeout configuration", () => { new OpenAiHandler(options) - expect(mockOpenAIConstructor).toHaveBeenCalledWith( + expect(mockCreateOpenAI).toHaveBeenCalledWith( expect.objectContaining({ baseURL: "http://localhost:8080/v1", - timeout: 1800000, // 1800 seconds in milliseconds }), ) }) - it("should use timeout for Azure OpenAI", () => { - ;(getApiRequestTimeout as any).mockReturnValue(900000) // 15 minutes - + it("should use createAzure for Azure OpenAI", () => { const options: ApiHandlerOptions = { apiModelId: "gpt-4", openAiModelId: "gpt-4", @@ -99,16 +75,16 @@ describe("OpenAiHandler timeout configuration", () => { new OpenAiHandler(options) - expect(mockAzureOpenAIConstructor).toHaveBeenCalledWith( + expect(mockCreateAzure).toHaveBeenCalledWith( expect.objectContaining({ - timeout: 900000, // 900 seconds in milliseconds + baseURL: "https://myinstance.openai.azure.com/openai", + apiKey: "test-key", + useDeploymentBasedUrls: true, }), ) }) - it("should use timeout for Azure AI Inference", () => { - ;(getApiRequestTimeout as any).mockReturnValue(1200000) // 20 minutes - + it("should use createOpenAICompatible for Azure AI Inference", () => { const options: ApiHandlerOptions = { apiModelId: "deepseek", openAiModelId: "deepseek", @@ -118,26 +94,32 @@ describe("OpenAiHandler timeout configuration", () => { new OpenAiHandler(options) - expect(mockOpenAIConstructor).toHaveBeenCalledWith( + expect(mockCreateOpenAICompatible).toHaveBeenCalledWith( expect.objectContaining({ - timeout: 1200000, // 1200 seconds in milliseconds + baseURL: "https://myinstance.services.ai.azure.com/models", + apiKey: "test-key", + queryParams: expect.objectContaining({ + "api-version": expect.any(String), + }), }), ) }) - it("should handle zero timeout (no timeout)", () => { - ;(getApiRequestTimeout as any).mockReturnValue(0) - + it("should include custom headers in provider configuration", () => { const options: ApiHandlerOptions = { apiModelId: "gpt-4", openAiModelId: "gpt-4", + openAiApiKey: "test-key", + openAiHeaders: { "X-Custom": "value" }, } new OpenAiHandler(options) - expect(mockOpenAIConstructor).toHaveBeenCalledWith( + expect(mockCreateOpenAI).toHaveBeenCalledWith( expect.objectContaining({ - timeout: 0, // No timeout + headers: expect.objectContaining({ + "X-Custom": "value", + }), }), ) }) diff --git a/src/api/providers/__tests__/openai-usage-tracking.spec.ts b/src/api/providers/__tests__/openai-usage-tracking.spec.ts index fc80360eee7..151d71a595e 100644 --- a/src/api/providers/__tests__/openai-usage-tracking.spec.ts +++ b/src/api/providers/__tests__/openai-usage-tracking.spec.ts @@ -5,89 +5,38 @@ import { Anthropic } from "@anthropic-ai/sdk" import { ApiHandlerOptions } from "../../../shared/api" import { OpenAiHandler } from "../openai" -const mockCreate = vitest.fn() +const { mockStreamText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), +})) -vitest.mock("openai", () => { +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - __esModule: true, - default: vitest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: "test-completion", - choices: [ - { - message: { role: "assistant", content: "Test response", refusal: null }, - finish_reason: "stop", - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - } - - // Return a stream with multiple chunks that include usage metrics - return { - [Symbol.asyncIterator]: async function* () { - // First chunk with partial usage - yield { - choices: [ - { - delta: { content: "Test " }, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 2, - total_tokens: 12, - }, - } - - // Second chunk with updated usage - yield { - choices: [ - { - delta: { content: "response" }, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 4, - total_tokens: 14, - }, - } - - // Final chunk with complete usage - yield { - choices: [ - { - delta: {}, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - }, - } - }), - }, - }, - })), + ...actual, + streamText: mockStreamText, + generateText: vi.fn(), } }) +vi.mock("@ai-sdk/openai", () => ({ + createOpenAI: vi.fn(() => ({ + chat: vi.fn(() => ({ + modelId: "gpt-4", + provider: "openai.chat", + })), + })), +})) + +vi.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: vi.fn(() => vi.fn((modelId: string) => ({ modelId, provider: "openai-compatible" }))), +})) + +vi.mock("@ai-sdk/azure", () => ({ + createAzure: vi.fn(() => ({ + chat: vi.fn((modelId: string) => ({ modelId, provider: "azure.chat" })), + })), +})) + describe("OpenAiHandler with usage tracking fix", () => { let handler: OpenAiHandler let mockOptions: ApiHandlerOptions @@ -99,7 +48,7 @@ describe("OpenAiHandler with usage tracking fix", () => { openAiBaseUrl: "https://api.openai.com/v1", } handler = new OpenAiHandler(mockOptions) - mockCreate.mockClear() + vi.clearAllMocks() }) describe("usage metrics with streaming", () => { @@ -117,19 +66,31 @@ describe("OpenAiHandler with usage tracking fix", () => { ] it("should only yield usage metrics once at the end of the stream", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test " } + yield { type: "text-delta", text: "response" } + } + + mockStreamText.mockReturnValueOnce({ + fullStream: mockFullStream(), + usage: Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }), + providerMetadata: Promise.resolve(undefined), + }) + const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] for await (const chunk of stream) { chunks.push(chunk) } - // Check we have text chunks const textChunks = chunks.filter((chunk) => chunk.type === "text") expect(textChunks).toHaveLength(2) expect(textChunks[0].text).toBe("Test ") expect(textChunks[1].text).toBe("response") - // Check we only have one usage chunk and it's the last one const usageChunks = chunks.filter((chunk) => chunk.type === "usage") expect(usageChunks).toHaveLength(1) expect(usageChunks[0]).toEqual({ @@ -138,49 +99,25 @@ describe("OpenAiHandler with usage tracking fix", () => { outputTokens: 5, }) - // Check the usage chunk is the last one reported from the API const lastChunk = chunks[chunks.length - 1] expect(lastChunk.type).toBe("usage") expect(lastChunk.inputTokens).toBe(10) expect(lastChunk.outputTokens).toBe(5) }) - it("should handle case where usage is only in the final chunk", async () => { - // Override the mock for this specific test - mockCreate.mockImplementationOnce(async (options) => { - if (!options.stream) { - return { - id: "test-completion", - choices: [{ message: { role: "assistant", content: "Test response" } }], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - } - - return { - [Symbol.asyncIterator]: async function* () { - // First chunk with no usage - yield { - choices: [{ delta: { content: "Test " }, index: 0 }], - usage: null, - } - - // Second chunk with no usage - yield { - choices: [{ delta: { content: "response" }, index: 0 }], - usage: null, - } - - // Final chunk with usage data - yield { - choices: [{ delta: {}, index: 0 }], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - }, - } + it("should handle case where usage is provided after stream completes", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test " } + yield { type: "text-delta", text: "response" } + } + + mockStreamText.mockReturnValueOnce({ + fullStream: mockFullStream(), + usage: Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }), + providerMetadata: Promise.resolve(undefined), }) const stream = handler.createMessage(systemPrompt, messages) @@ -189,7 +126,6 @@ describe("OpenAiHandler with usage tracking fix", () => { chunks.push(chunk) } - // Check usage metrics const usageChunks = chunks.filter((chunk) => chunk.type === "usage") expect(usageChunks).toHaveLength(1) expect(usageChunks[0]).toEqual({ @@ -200,28 +136,14 @@ describe("OpenAiHandler with usage tracking fix", () => { }) it("should handle case where no usage is provided", async () => { - // Override the mock for this specific test - mockCreate.mockImplementationOnce(async (options) => { - if (!options.stream) { - return { - id: "test-completion", - choices: [{ message: { role: "assistant", content: "Test response" } }], - usage: null, - } - } - - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" }, index: 0 }], - usage: null, - } - yield { - choices: [{ delta: {}, index: 0 }], - usage: null, - } - }, - } + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValueOnce({ + fullStream: mockFullStream(), + usage: Promise.resolve(undefined), + providerMetadata: Promise.resolve(undefined), }) const stream = handler.createMessage(systemPrompt, messages) @@ -230,7 +152,6 @@ describe("OpenAiHandler with usage tracking fix", () => { chunks.push(chunk) } - // Check we don't have any usage chunks const usageChunks = chunks.filter((chunk) => chunk.type === "usage") expect(usageChunks).toHaveLength(0) }) diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index 73b542dbc73..2399cbb4397 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -1,81 +1,113 @@ // npx vitest run api/providers/__tests__/openai.spec.ts -import { OpenAiHandler, getOpenAiModels } from "../openai" -import { ApiHandlerOptions } from "../../../shared/api" -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" -import { openAiModelInfoSaneDefaults } from "@roo-code/types" -import { Package } from "../../../shared/package" -import axios from "axios" - -const mockCreate = vitest.fn() +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) -vitest.mock("openai", () => { - const mockConstructor = vitest.fn() +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - __esModule: true, - default: mockConstructor.mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: "test-completion", - choices: [ - { - message: { role: "assistant", content: "Test response", refusal: null }, - finish_reason: "stop", - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - } - - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], - usage: null, - } - yield { - choices: [ - { - delta: {}, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - }, - } - }), - }, - }, - })), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) +const mockCreateOpenAI = vi.hoisted(() => + vi.fn(() => { + return { + chat: vi.fn(() => ({ + modelId: "gpt-4", + provider: "openai.chat", + })), + responses: vi.fn(() => ({ + modelId: "gpt-4", + provider: "openai.responses", + })), + } + }), +) + +vi.mock("@ai-sdk/openai", () => ({ + createOpenAI: mockCreateOpenAI, +})) + +const mockCreateOpenAICompatible = vi.hoisted(() => + vi.fn(() => { + return vi.fn((modelId: string) => ({ + modelId, + provider: "openai-compatible", + })) + }), +) + +vi.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: mockCreateOpenAICompatible, +})) + +const mockCreateAzure = vi.hoisted(() => + vi.fn(() => { + return { + chat: vi.fn((modelId: string) => ({ + modelId, + provider: "azure.chat", + })), + } + }), +) + +vi.mock("@ai-sdk/azure", () => ({ + createAzure: mockCreateAzure, +})) + // Mock axios for getOpenAiModels tests -vitest.mock("axios", () => ({ +vi.mock("axios", () => ({ default: { - get: vitest.fn(), + get: vi.fn(), }, })) +import { OpenAiHandler, getOpenAiModels } from "../openai" +import { ApiHandlerOptions } from "../../../shared/api" +import { Anthropic } from "@anthropic-ai/sdk" +import { openAiModelInfoSaneDefaults } from "@roo-code/types" +import axios from "axios" + +function createMockStreamResult(options?: { + textChunks?: string[] + reasoningChunks?: string[] + toolCalls?: Array<{ id: string; name: string; delta: string }> + usage?: { inputTokens: number; outputTokens: number } +}) { + const { + textChunks = ["Test response"], + reasoningChunks = [], + toolCalls = [], + usage = { inputTokens: 10, outputTokens: 5 }, + } = options ?? {} + + async function* mockFullStream() { + for (const text of textChunks) { + yield { type: "text-delta", text } + } + for (const text of reasoningChunks) { + yield { type: "reasoning-delta", text } + } + for (const tc of toolCalls) { + yield { type: "tool-input-start", id: tc.id, toolName: tc.name } + yield { type: "tool-input-delta", id: tc.id, delta: tc.delta } + yield { type: "tool-input-end", id: tc.id } + } + } + + return { + fullStream: mockFullStream(), + usage: Promise.resolve(usage), + providerMetadata: Promise.resolve(undefined), + } +} + describe("OpenAiHandler", () => { let handler: OpenAiHandler let mockOptions: ApiHandlerOptions @@ -87,7 +119,7 @@ describe("OpenAiHandler", () => { openAiBaseUrl: "https://api.openai.com/v1", } handler = new OpenAiHandler(mockOptions) - mockCreate.mockClear() + vi.clearAllMocks() }) describe("constructor", () => { @@ -105,18 +137,18 @@ describe("OpenAiHandler", () => { expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler) }) - it("should set default headers correctly", () => { - // Check that the OpenAI constructor was called with correct parameters - expect(vi.mocked(OpenAI)).toHaveBeenCalledWith({ - baseURL: expect.any(String), - apiKey: expect.any(String), - defaultHeaders: { - "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", - "X-Title": "Roo Code", - "User-Agent": `RooCode/${Package.version}`, - }, - timeout: expect.any(Number), - }) + it("should create an OpenAI provider with correct configuration", () => { + new OpenAiHandler(mockOptions) + expect(mockCreateOpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.openai.com/v1", + apiKey: "test-api-key", + }), + ) + }) + + it("should report as AI SDK provider", () => { + expect(handler.isAiSdkProvider()).toBe(true) }) }) @@ -134,7 +166,35 @@ describe("OpenAiHandler", () => { }, ] + it("should handle streaming responses", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + system: systemPrompt, + temperature: 0, + }), + ) + }) + it("should handle non-streaming mode", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "Test response", + toolCalls: [], + usage: { inputTokens: 10, outputTokens: 5 }, + }) + const handler = new OpenAiHandler({ ...mockOptions, openAiStreamingEnabled: false, @@ -158,31 +218,16 @@ describe("OpenAiHandler", () => { }) it("should handle tool calls in non-streaming mode", async () => { - mockCreate.mockResolvedValueOnce({ - choices: [ + mockGenerateText.mockResolvedValueOnce({ + text: "", + toolCalls: [ { - message: { - role: "assistant", - content: null, - tool_calls: [ - { - id: "call_1", - type: "function", - function: { - name: "test_tool", - arguments: '{"arg":"value"}', - }, - }, - ], - }, - finish_reason: "tool_calls", + toolCallId: "call_1", + toolName: "test_tool", + args: { arg: "value" }, }, ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, + usage: { inputTokens: 10, outputTokens: 5 }, }) const handler = new OpenAiHandler({ @@ -206,62 +251,13 @@ describe("OpenAiHandler", () => { }) }) - it("should handle streaming responses", async () => { - const stream = handler.createMessage(systemPrompt, messages) - const chunks: any[] = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - expect(chunks.length).toBeGreaterThan(0) - const textChunks = chunks.filter((chunk) => chunk.type === "text") - expect(textChunks).toHaveLength(1) - expect(textChunks[0].text).toBe("Test response") - }) - it("should handle tool calls in streaming responses", async () => { - mockCreate.mockImplementation(async (options) => { - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_1", - function: { name: "test_tool", arguments: "" }, - }, - ], - }, - finish_reason: null, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [{ index: 0, function: { arguments: '{"arg":' } }], - }, - finish_reason: null, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [{ index: 0, function: { arguments: '"value"}' } }], - }, - finish_reason: "tool_calls", - }, - ], - } - }, - } - }) + mockStreamText.mockReturnValueOnce( + createMockStreamResult({ + textChunks: [], + toolCalls: [{ id: "call_1", name: "test_tool", delta: '{"arg":"value"}' }], + }), + ) const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] @@ -269,90 +265,24 @@ describe("OpenAiHandler", () => { chunks.push(chunk) } - // Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly - const toolCallPartialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - expect(toolCallPartialChunks).toHaveLength(3) - // First chunk has id and name - expect(toolCallPartialChunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, + const toolStartChunks = chunks.filter((chunk) => chunk.type === "tool_call_start") + expect(toolStartChunks).toHaveLength(1) + expect(toolStartChunks[0]).toEqual({ + type: "tool_call_start", id: "call_1", name: "test_tool", - arguments: "", - }) - // Subsequent chunks have arguments - expect(toolCallPartialChunks[1]).toEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '{"arg":', - }) - expect(toolCallPartialChunks[2]).toEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '"value"}', }) - // Verify tool_call_end event is emitted when finish_reason is "tool_calls" - const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - expect(toolCallEndChunks).toHaveLength(1) - }) + const toolDeltaChunks = chunks.filter((chunk) => chunk.type === "tool_call_delta") + expect(toolDeltaChunks).toHaveLength(1) - it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => { - mockCreate.mockImplementation(async (options) => { - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_fallback", - function: { name: "fallback_tool", arguments: '{"test":"fallback"}' }, - }, - ], - }, - finish_reason: null, - }, - ], - } - // Stream ends without finish_reason being set to "tool_calls" - yield { - choices: [ - { - delta: {}, - finish_reason: "stop", // Different finish reason - }, - ], - } - }, - } - }) - - const stream = handler.createMessage(systemPrompt, messages) - const chunks: any[] = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly - const toolCallPartialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - expect(toolCallPartialChunks).toHaveLength(1) - expect(toolCallPartialChunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "call_fallback", - name: "fallback_tool", - arguments: '{"test":"fallback"}', - }) + const toolEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + expect(toolEndChunks).toHaveLength(1) }) it("should include reasoning_effort when reasoning effort is enabled", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + const reasoningOptions: ApiHandlerOptions = { ...mockOptions, enableReasoningEffort: true, @@ -365,16 +295,23 @@ describe("OpenAiHandler", () => { } const reasoningHandler = new OpenAiHandler(reasoningOptions) const stream = reasoningHandler.createMessage(systemPrompt, messages) - // Consume the stream to trigger the API call for await (const _chunk of stream) { } - // Assert the mockCreate was called with reasoning_effort - expect(mockCreate).toHaveBeenCalled() - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs.reasoning_effort).toBe("high") + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + providerOptions: expect.objectContaining({ + openai: expect.objectContaining({ + reasoningEffort: "high", + }), + }), + }), + ) }) it("should not include reasoning_effort when reasoning effort is disabled", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + const noReasoningOptions: ApiHandlerOptions = { ...mockOptions, enableReasoningEffort: false, @@ -382,16 +319,17 @@ describe("OpenAiHandler", () => { } const noReasoningHandler = new OpenAiHandler(noReasoningOptions) const stream = noReasoningHandler.createMessage(systemPrompt, messages) - // Consume the stream to trigger the API call for await (const _chunk of stream) { } - // Assert the mockCreate was called without reasoning_effort - expect(mockCreate).toHaveBeenCalled() - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs.reasoning_effort).toBeUndefined() + + expect(mockStreamText).toHaveBeenCalled() + const callArgs = mockStreamText.mock.calls[0][0] + expect(callArgs.providerOptions).toBeUndefined() }) - it("should include max_tokens when includeMaxTokens is true", async () => { + it("should include maxOutputTokens when includeMaxTokens is true", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + const optionsWithMaxTokens: ApiHandlerOptions = { ...mockOptions, includeMaxTokens: true, @@ -403,16 +341,19 @@ describe("OpenAiHandler", () => { } const handlerWithMaxTokens = new OpenAiHandler(optionsWithMaxTokens) const stream = handlerWithMaxTokens.createMessage(systemPrompt, messages) - // Consume the stream to trigger the API call for await (const _chunk of stream) { } - // Assert the mockCreate was called with max_tokens - expect(mockCreate).toHaveBeenCalled() - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs.max_completion_tokens).toBe(4096) + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + maxOutputTokens: 4096, + }), + ) }) - it("should not include max_tokens when includeMaxTokens is false", async () => { + it("should not include maxOutputTokens when includeMaxTokens is false", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + const optionsWithoutMaxTokens: ApiHandlerOptions = { ...mockOptions, includeMaxTokens: false, @@ -424,19 +365,19 @@ describe("OpenAiHandler", () => { } const handlerWithoutMaxTokens = new OpenAiHandler(optionsWithoutMaxTokens) const stream = handlerWithoutMaxTokens.createMessage(systemPrompt, messages) - // Consume the stream to trigger the API call for await (const _chunk of stream) { } - // Assert the mockCreate was called without max_tokens - expect(mockCreate).toHaveBeenCalled() - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs.max_completion_tokens).toBeUndefined() + + expect(mockStreamText).toHaveBeenCalled() + const callArgs = mockStreamText.mock.calls[0][0] + expect(callArgs.maxOutputTokens).toBeUndefined() }) - it("should not include max_tokens when includeMaxTokens is undefined", async () => { + it("should not include maxOutputTokens when includeMaxTokens is undefined", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + const optionsWithUndefinedMaxTokens: ApiHandlerOptions = { ...mockOptions, - // includeMaxTokens is not set, should not include max_tokens openAiCustomModelInfo: { contextWindow: 128_000, maxTokens: 4096, @@ -445,57 +386,61 @@ describe("OpenAiHandler", () => { } const handlerWithDefaultMaxTokens = new OpenAiHandler(optionsWithUndefinedMaxTokens) const stream = handlerWithDefaultMaxTokens.createMessage(systemPrompt, messages) - // Consume the stream to trigger the API call for await (const _chunk of stream) { } - // Assert the mockCreate was called without max_tokens - expect(mockCreate).toHaveBeenCalled() - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs.max_completion_tokens).toBeUndefined() + + expect(mockStreamText).toHaveBeenCalled() + const callArgs = mockStreamText.mock.calls[0][0] + expect(callArgs.maxOutputTokens).toBeUndefined() }) it("should use user-configured modelMaxTokens instead of model default maxTokens", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + const optionsWithUserMaxTokens: ApiHandlerOptions = { ...mockOptions, includeMaxTokens: true, - modelMaxTokens: 32000, // User-configured value + modelMaxTokens: 32000, openAiCustomModelInfo: { contextWindow: 128_000, - maxTokens: 4096, // Model's default value (should not be used) + maxTokens: 4096, supportsPromptCache: false, }, } const handlerWithUserMaxTokens = new OpenAiHandler(optionsWithUserMaxTokens) const stream = handlerWithUserMaxTokens.createMessage(systemPrompt, messages) - // Consume the stream to trigger the API call for await (const _chunk of stream) { } - // Assert the mockCreate was called with user-configured modelMaxTokens (32000), not model default maxTokens (4096) - expect(mockCreate).toHaveBeenCalled() - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs.max_completion_tokens).toBe(32000) + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + maxOutputTokens: 32000, + }), + ) }) it("should fallback to model default maxTokens when user modelMaxTokens is not set", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + const optionsWithoutUserMaxTokens: ApiHandlerOptions = { ...mockOptions, includeMaxTokens: true, - // modelMaxTokens is not set openAiCustomModelInfo: { contextWindow: 128_000, - maxTokens: 4096, // Model's default value (should be used as fallback) + maxTokens: 4096, supportsPromptCache: false, }, } const handlerWithoutUserMaxTokens = new OpenAiHandler(optionsWithoutUserMaxTokens) const stream = handlerWithoutUserMaxTokens.createMessage(systemPrompt, messages) - // Consume the stream to trigger the API call for await (const _chunk of stream) { } - // Assert the mockCreate was called with model default maxTokens (4096) as fallback - expect(mockCreate).toHaveBeenCalled() - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs.max_completion_tokens).toBe(4096) + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + maxOutputTokens: 4096, + }), + ) }) }) @@ -512,56 +457,69 @@ describe("OpenAiHandler", () => { }, ] - it("should handle API errors", async () => { - mockCreate.mockRejectedValueOnce(new Error("API Error")) + it("should handle API errors in streaming", async () => { + const errorStream = { + fullStream: { + [Symbol.asyncIterator]() { + return { + next: () => Promise.reject(new Error("API Error")), + } + }, + }, + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve(undefined), + } + mockStreamText.mockReturnValueOnce(errorStream) const stream = handler.createMessage("system prompt", testMessages) await expect(async () => { for await (const _chunk of stream) { - // Should not reach here } }).rejects.toThrow("API Error") }) - it("should handle rate limiting", async () => { - const rateLimitError = new Error("Rate limit exceeded") - rateLimitError.name = "Error" - ;(rateLimitError as any).status = 429 - mockCreate.mockRejectedValueOnce(rateLimitError) + it("should handle API errors in non-streaming", async () => { + mockGenerateText.mockRejectedValueOnce(new Error("API Error")) + + const handler = new OpenAiHandler({ + ...mockOptions, + openAiStreamingEnabled: false, + }) const stream = handler.createMessage("system prompt", testMessages) await expect(async () => { for await (const _chunk of stream) { - // Should not reach here } - }).rejects.toThrow("Rate limit exceeded") + }).rejects.toThrow("API Error") }) }) describe("completePrompt", () => { it("should complete prompt successfully", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "Test response", + }) + const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockCreate).toHaveBeenCalledWith( - { - model: mockOptions.openAiModelId, - messages: [{ role: "user", content: "Test prompt" }], - }, - {}, + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", + }), ) }) it("should handle API errors", async () => { - mockCreate.mockRejectedValueOnce(new Error("API Error")) + mockGenerateText.mockRejectedValueOnce(new Error("API Error")) await expect(handler.completePrompt("Test prompt")).rejects.toThrow("OpenAI completion error: API Error") }) it("should handle empty response", async () => { - mockCreate.mockImplementationOnce(() => ({ - choices: [{ message: { content: "" } }], - })) + mockGenerateText.mockResolvedValueOnce({ + text: "", + }) const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) @@ -588,21 +546,37 @@ describe("OpenAiHandler", () => { }) describe("Azure AI Inference Service", () => { - const azureOptions = { - ...mockOptions, - openAiBaseUrl: "https://test.services.ai.azure.com", - openAiModelId: "deepseek-v3", - azureApiVersion: "2024-05-01-preview", + function makeAzureOptions() { + return { + ...mockOptions, + openAiBaseUrl: "https://test.services.ai.azure.com", + openAiModelId: "deepseek-v3", + azureApiVersion: "2024-05-01-preview", + } } it("should initialize with Azure AI Inference Service configuration", () => { + const azureOptions = makeAzureOptions() const azureHandler = new OpenAiHandler(azureOptions) expect(azureHandler).toBeInstanceOf(OpenAiHandler) expect(azureHandler.getModel().id).toBe(azureOptions.openAiModelId) }) + it("should use createOpenAICompatible for Azure AI Inference", () => { + new OpenAiHandler(makeAzureOptions()) + expect(mockCreateOpenAICompatible).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://test.services.ai.azure.com/models", + apiKey: "test-api-key", + queryParams: { "api-version": "2024-05-01-preview" }, + }), + ) + }) + it("should handle streaming responses with Azure AI Inference Service", async () => { - const azureHandler = new OpenAiHandler(azureOptions) + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + + const azureHandler = new OpenAiHandler(makeAzureOptions()) const systemPrompt = "You are a helpful assistant." const messages: Anthropic.Messages.MessageParam[] = [ { @@ -621,33 +595,17 @@ describe("OpenAiHandler", () => { const textChunks = chunks.filter((chunk) => chunk.type === "text") expect(textChunks).toHaveLength(1) expect(textChunks[0].text).toBe("Test response") - - // Verify the API call was made with correct Azure AI Inference Service path - expect(mockCreate).toHaveBeenCalledWith( - { - model: azureOptions.openAiModelId, - messages: [ - { role: "system", content: systemPrompt }, - { role: "user", content: "Hello!" }, - ], - stream: true, - stream_options: { include_usage: true }, - temperature: 0, - tools: undefined, - tool_choice: undefined, - parallel_tool_calls: true, - }, - { path: "/models/chat/completions" }, - ) - - // Verify max_tokens is NOT included when not explicitly set - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs).not.toHaveProperty("max_completion_tokens") }) it("should handle non-streaming responses with Azure AI Inference Service", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "Test response", + toolCalls: [], + usage: { inputTokens: 10, outputTokens: 5 }, + }) + const azureHandler = new OpenAiHandler({ - ...azureOptions, + ...makeAzureOptions(), openAiStreamingEnabled: false, }) const systemPrompt = "You are a helpful assistant." @@ -673,82 +631,35 @@ describe("OpenAiHandler", () => { expect(usageChunk).toBeDefined() expect(usageChunk?.inputTokens).toBe(10) expect(usageChunk?.outputTokens).toBe(5) - - // Verify the API call was made with correct Azure AI Inference Service path - expect(mockCreate).toHaveBeenCalledWith( - { - model: azureOptions.openAiModelId, - messages: [ - { role: "system", content: systemPrompt }, - { role: "user", content: "Hello!" }, - ], - tools: undefined, - tool_choice: undefined, - parallel_tool_calls: true, - }, - { path: "/models/chat/completions" }, - ) - - // Verify max_tokens is NOT included when not explicitly set - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs).not.toHaveProperty("max_completion_tokens") }) it("should handle completePrompt with Azure AI Inference Service", async () => { - const azureHandler = new OpenAiHandler(azureOptions) + mockGenerateText.mockResolvedValueOnce({ + text: "Test response", + }) + + const azureHandler = new OpenAiHandler(makeAzureOptions()) const result = await azureHandler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockCreate).toHaveBeenCalledWith( - { - model: azureOptions.openAiModelId, - messages: [{ role: "user", content: "Test prompt" }], - }, - { path: "/models/chat/completions" }, - ) - - // Verify max_tokens is NOT included when includeMaxTokens is not set - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs).not.toHaveProperty("max_completion_tokens") }) }) - describe("Grok xAI Provider", () => { - const grokOptions = { - ...mockOptions, - openAiBaseUrl: "https://api.x.ai/v1", - openAiModelId: "grok-1", - } - - it("should initialize with Grok xAI configuration", () => { - const grokHandler = new OpenAiHandler(grokOptions) - expect(grokHandler).toBeInstanceOf(OpenAiHandler) - expect(grokHandler.getModel().id).toBe(grokOptions.openAiModelId) - }) - - it("should exclude stream_options when streaming with Grok xAI", async () => { - const grokHandler = new OpenAiHandler(grokOptions) - const systemPrompt = "You are a helpful assistant." - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: "Hello!", - }, - ] - - const stream = grokHandler.createMessage(systemPrompt, messages) - await stream.next() - - expect(mockCreate).toHaveBeenCalledWith( + describe("Azure OpenAI", () => { + it("should use createAzure for Azure OpenAI", () => { + new OpenAiHandler({ + ...mockOptions, + openAiBaseUrl: "https://myinstance.openai.azure.com", + openAiUseAzure: true, + azureApiVersion: "2024-06-01", + }) + expect(mockCreateAzure).toHaveBeenCalledWith( expect.objectContaining({ - model: grokOptions.openAiModelId, - stream: true, + baseURL: "https://myinstance.openai.azure.com/openai", + apiKey: "test-api-key", + apiVersion: "2024-06-01", + useDeploymentBasedUrls: true, }), - {}, ) - - const mockCalls = mockCreate.mock.calls - const lastCall = mockCalls[mockCalls.length - 1] - expect(lastCall[0]).not.toHaveProperty("stream_options") }) }) @@ -764,12 +675,13 @@ describe("OpenAiHandler", () => { }, } - it("should handle O3 model with streaming and include max_completion_tokens when includeMaxTokens is true", async () => { + it("should handle O3 model with streaming and developer role", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) + const o3Handler = new OpenAiHandler({ ...o3Options, includeMaxTokens: true, modelMaxTokens: 32000, - modelTemperature: 0.5, }) const systemPrompt = "You are a helpful assistant." const messages: Anthropic.Messages.MessageParam[] = [ @@ -785,150 +697,27 @@ describe("OpenAiHandler", () => { chunks.push(chunk) } - expect(mockCreate).toHaveBeenCalledWith( + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - model: "o3-mini", - messages: [ - { - role: "developer", - content: "Formatting re-enabled\nYou are a helpful assistant.", - }, - { role: "user", content: "Hello!" }, - ], - stream: true, - stream_options: { include_usage: true }, - reasoning_effort: "medium", + system: "Formatting re-enabled\nYou are a helpful assistant.", temperature: undefined, - // O3 models do not support deprecated max_tokens but do support max_completion_tokens - max_completion_tokens: 32000, + maxOutputTokens: 32000, + providerOptions: expect.objectContaining({ + openai: expect.objectContaining({ + systemMessageMode: "developer", + reasoningEffort: "medium", + }), + }), }), - {}, ) }) - it("should handle tool calls with O3 model in streaming mode", async () => { - const o3Handler = new OpenAiHandler(o3Options) + it("should handle O3 model with streaming and exclude maxOutputTokens when includeMaxTokens is false", async () => { + mockStreamText.mockReturnValueOnce(createMockStreamResult()) - mockCreate.mockImplementation(async (options) => { - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_1", - function: { name: "test_tool", arguments: "" }, - }, - ], - }, - finish_reason: null, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [{ index: 0, function: { arguments: "{}" } }], - }, - finish_reason: "tool_calls", - }, - ], - } - }, - } - }) - - const stream = o3Handler.createMessage("system", []) - const chunks: any[] = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly - const toolCallPartialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - expect(toolCallPartialChunks).toHaveLength(2) - expect(toolCallPartialChunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "call_1", - name: "test_tool", - arguments: "", - }) - expect(toolCallPartialChunks[1]).toEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: "{}", - }) - - // Verify tool_call_end event is emitted when finish_reason is "tool_calls" - const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - expect(toolCallEndChunks).toHaveLength(1) - }) - - it("should yield tool calls for O3 model even when finish_reason is not set (fallback behavior)", async () => { - const o3Handler = new OpenAiHandler(o3Options) - - mockCreate.mockImplementation(async (options) => { - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_o3_fallback", - function: { name: "o3_fallback_tool", arguments: '{"o3":"test"}' }, - }, - ], - }, - finish_reason: null, - }, - ], - } - // Stream ends with different finish reason - yield { - choices: [ - { - delta: {}, - finish_reason: "length", // Different finish reason - }, - ], - } - }, - } - }) - - const stream = o3Handler.createMessage("system", []) - const chunks: any[] = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly - const toolCallPartialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - expect(toolCallPartialChunks).toHaveLength(1) - expect(toolCallPartialChunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "call_o3_fallback", - name: "o3_fallback_tool", - arguments: '{"o3":"test"}', - }) - }) - - it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => { const o3Handler = new OpenAiHandler({ ...o3Options, includeMaxTokens: false, - modelTemperature: 0.7, }) const systemPrompt = "You are a helpful assistant." const messages: Anthropic.Messages.MessageParam[] = [ @@ -944,35 +733,21 @@ describe("OpenAiHandler", () => { chunks.push(chunk) } - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: "o3-mini", - messages: [ - { - role: "developer", - content: "Formatting re-enabled\nYou are a helpful assistant.", - }, - { role: "user", content: "Hello!" }, - ], - stream: true, - stream_options: { include_usage: true }, - reasoning_effort: "medium", - temperature: undefined, - }), - {}, - ) - - // Verify max_tokens is NOT included - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs).not.toHaveProperty("max_completion_tokens") + const callArgs = mockStreamText.mock.calls[0][0] + expect(callArgs.maxOutputTokens).toBeUndefined() }) - it("should handle O3 model non-streaming with reasoning_effort and max_completion_tokens when includeMaxTokens is true", async () => { + it("should handle O3 model non-streaming with reasoning_effort and maxOutputTokens", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "Test response", + toolCalls: [], + usage: { inputTokens: 10, outputTokens: 5 }, + }) + const o3Handler = new OpenAiHandler({ ...o3Options, openAiStreamingEnabled: false, includeMaxTokens: true, - modelTemperature: 0.3, }) const systemPrompt = "You are a helpful assistant." const messages: Anthropic.Messages.MessageParam[] = [ @@ -988,60 +763,62 @@ describe("OpenAiHandler", () => { chunks.push(chunk) } - expect(mockCreate).toHaveBeenCalledWith( + expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - model: "o3-mini", - messages: [ - { - role: "developer", - content: "Formatting re-enabled\nYou are a helpful assistant.", - }, - { role: "user", content: "Hello!" }, - ], - reasoning_effort: "medium", + system: "Formatting re-enabled\nYou are a helpful assistant.", temperature: undefined, - // O3 models do not support deprecated max_tokens but do support max_completion_tokens - max_completion_tokens: 65536, // Using default maxTokens from o3Options + maxOutputTokens: 65536, + providerOptions: expect.objectContaining({ + openai: expect.objectContaining({ + systemMessageMode: "developer", + reasoningEffort: "medium", + }), + }), }), - {}, ) - - // Verify stream is not set - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs).not.toHaveProperty("stream") }) - it("should handle tool calls with O3 model in non-streaming mode", async () => { - const o3Handler = new OpenAiHandler({ - ...o3Options, - openAiStreamingEnabled: false, + it("should handle tool calls with O3 model in streaming mode", async () => { + mockStreamText.mockReturnValueOnce( + createMockStreamResult({ + textChunks: [], + toolCalls: [{ id: "call_1", name: "test_tool", delta: "{}" }], + }), + ) + + const o3Handler = new OpenAiHandler(o3Options) + + const stream = o3Handler.createMessage("system", []) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolStartChunks = chunks.filter((chunk) => chunk.type === "tool_call_start") + expect(toolStartChunks).toHaveLength(1) + expect(toolStartChunks[0]).toEqual({ + type: "tool_call_start", + id: "call_1", + name: "test_tool", }) + }) - mockCreate.mockResolvedValueOnce({ - choices: [ + it("should handle tool calls with O3 model in non-streaming mode", async () => { + mockGenerateText.mockResolvedValueOnce({ + text: "", + toolCalls: [ { - message: { - role: "assistant", - content: null, - tool_calls: [ - { - id: "call_1", - type: "function", - function: { - name: "test_tool", - arguments: "{}", - }, - }, - ], - }, - finish_reason: "tool_calls", + toolCallId: "call_1", + toolName: "test_tool", + args: {}, }, ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, + usage: { inputTokens: 10, outputTokens: 5 }, + }) + + const o3Handler = new OpenAiHandler({ + ...o3Options, + openAiStreamingEnabled: false, }) const stream = o3Handler.createMessage("system", []) @@ -1059,85 +836,6 @@ describe("OpenAiHandler", () => { arguments: "{}", }) }) - - it("should use default temperature of 0 when not specified for O3 models", async () => { - const o3Handler = new OpenAiHandler({ - ...o3Options, - // No modelTemperature specified - }) - const systemPrompt = "You are a helpful assistant." - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: "Hello!", - }, - ] - - const stream = o3Handler.createMessage(systemPrompt, messages) - await stream.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: undefined, // Temperature is not supported for O3 models - }), - {}, - ) - }) - - it("should handle O3 model with Azure AI Inference Service respecting includeMaxTokens", async () => { - const o3AzureHandler = new OpenAiHandler({ - ...o3Options, - openAiBaseUrl: "https://test.services.ai.azure.com", - includeMaxTokens: false, // Should NOT include max_tokens - }) - const systemPrompt = "You are a helpful assistant." - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: "Hello!", - }, - ] - - const stream = o3AzureHandler.createMessage(systemPrompt, messages) - await stream.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: "o3-mini", - }), - { path: "/models/chat/completions" }, - ) - - // Verify max_tokens is NOT included when includeMaxTokens is false - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs).not.toHaveProperty("max_completion_tokens") - }) - - it("should NOT include max_tokens for O3 model with Azure AI Inference Service even when includeMaxTokens is true", async () => { - const o3AzureHandler = new OpenAiHandler({ - ...o3Options, - openAiBaseUrl: "https://test.services.ai.azure.com", - includeMaxTokens: true, // Should include max_tokens - }) - const systemPrompt = "You are a helpful assistant." - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: "Hello!", - }, - ] - - const stream = o3AzureHandler.createMessage(systemPrompt, messages) - await stream.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: "o3-mini", - // O3 models do not support max_tokens - }), - { path: "/models/chat/completions" }, - ) - }) }) }) diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 33b29abcafe..ab3f7f33070 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -1,5 +1,8 @@ import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI, { AzureOpenAI } from "openai" +import { createOpenAI } from "@ai-sdk/openai" +import { createOpenAICompatible } from "@ai-sdk/openai-compatible" +import { createAzure } from "@ai-sdk/azure" +import { streamText, generateText, ToolSet, LanguageModel } from "ai" import axios from "axios" import { @@ -7,31 +10,35 @@ import { azureOpenAiDefaultApiVersion, openAiModelInfoSaneDefaults, DEEP_SEEK_DEFAULT_TEMPERATURE, - OPENAI_AZURE_AI_INFERENCE_PATH, } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" import { TagMatcher } from "../../utils/tag-matcher" -import { convertToOpenAiMessages } from "../transform/openai-format" -import { convertToR1Format } from "../transform/r1-format" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" -import { getApiRequestTimeout } from "./utils/timeout-config" -import { handleOpenAIError } from "./utils/openai-error-handler" // TODO: Rename this to OpenAICompatibleHandler. Also, I think the // `OpenAINativeHandler` can subclass from this, since it's obviously // compatible with the OpenAI API. We can also rename it to `OpenAIHandler`. export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions - protected client: OpenAI private readonly providerName = "OpenAI" + private readonly isAzureAiInference: boolean + private readonly isAzureOpenAi: boolean + private readonly languageModelFactory: (modelId: string) => LanguageModel constructor(options: ApiHandlerOptions) { super() @@ -39,243 +46,245 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const baseURL = this.options.openAiBaseUrl || "https://api.openai.com/v1" const apiKey = this.options.openAiApiKey ?? "not-provided" - const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) - const urlHost = this._getUrlHost(this.options.openAiBaseUrl) - const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure + this.isAzureAiInference = this._isAzureAiInference(baseURL) + const urlHost = this._getUrlHost(baseURL) + this.isAzureOpenAi = + !this.isAzureAiInference && + (urlHost === "azure.com" || urlHost.endsWith(".azure.com") || !!options.openAiUseAzure) const headers = { ...DEFAULT_HEADERS, ...(this.options.openAiHeaders || {}), } - const timeout = getApiRequestTimeout() - - if (isAzureAiInference) { - // Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure - this.client = new OpenAI({ - baseURL, + if (this.isAzureAiInference) { + const provider = createOpenAICompatible({ + name: "OpenAI", + baseURL: `${baseURL}/models`, apiKey, - defaultHeaders: headers, - defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" }, - timeout, + headers, + queryParams: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" }, }) - } else if (isAzureOpenAi) { - // Azure API shape slightly differs from the core API shape: - // https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai - this.client = new AzureOpenAI({ - baseURL, + this.languageModelFactory = (modelId: string) => provider(modelId) + } else if (this.isAzureOpenAi) { + const azureBaseURL = baseURL.endsWith("/openai") ? baseURL : `${baseURL}/openai` + const provider = createAzure({ + baseURL: azureBaseURL, apiKey, apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion, - defaultHeaders: headers, - timeout, + headers, + useDeploymentBasedUrls: true, }) + this.languageModelFactory = (modelId: string) => provider.chat(modelId) } else { - this.client = new OpenAI({ + const provider = createOpenAI({ baseURL, apiKey, - defaultHeaders: headers, - timeout, + headers, }) + this.languageModelFactory = (modelId: string) => provider.chat(modelId) } } + protected getLanguageModel(): LanguageModel { + const { id } = this.getModel() + return this.languageModelFactory(id) + } + override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { info: modelInfo, reasoning } = this.getModel() - const modelUrl = this.options.openAiBaseUrl ?? "" + const { info: modelInfo, temperature, reasoning } = this.getModel() const modelId = this.options.openAiModelId ?? "" const enabledR1Format = this.options.openAiR1FormatEnabled ?? false - const isAzureAiInference = this._isAzureAiInference(modelUrl) const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format + const isO3Family = modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4") - if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) { - yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages, metadata) - return - } + const languageModel = this.getLanguageModel() - let systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { - role: "system", - content: systemPrompt, - } + const aiSdkMessages = convertToAiSdkMessages(messages) - if (this.options.openAiStreamingEnabled ?? true) { - let convertedMessages - - if (deepseekReasoner) { - convertedMessages = convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]) - } else { - if (modelInfo.supportsPromptCache) { - systemMessage = { - role: "system", - content: [ - { - type: "text", - text: systemPrompt, - // @ts-ignore-next-line - cache_control: { type: "ephemeral" }, - }, - ], - } - } + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined - convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] + let effectiveSystemPrompt: string | undefined = systemPrompt + let effectiveTemperature: number | undefined = + this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : (temperature ?? 0)) - if (modelInfo.supportsPromptCache) { - // Note: the following logic is copied from openrouter: - // Add cache_control to the last two user messages - // (note: this works because we only ever add one user message at a time, but if we added multiple we'd need to mark the user message before the last assistant message) - const lastTwoUserMessages = convertedMessages.filter((msg) => msg.role === "user").slice(-2) + const providerOptions: Record = {} - lastTwoUserMessages.forEach((msg) => { - if (typeof msg.content === "string") { - msg.content = [{ type: "text", text: msg.content }] - } + if (isO3Family) { + effectiveSystemPrompt = `Formatting re-enabled\n${systemPrompt}` + effectiveTemperature = undefined - if (Array.isArray(msg.content)) { - // NOTE: this is fine since env details will always be added at the end. but if it weren't there, and the user added a image_url type message, it would pop a text part before it and then move it after to the end. - let lastTextPart = msg.content.filter((part) => part.type === "text").pop() - - if (!lastTextPart) { - lastTextPart = { type: "text", text: "..." } - msg.content.push(lastTextPart) - } - - // @ts-ignore-next-line - lastTextPart["cache_control"] = { type: "ephemeral" } - } - }) - } + const openaiOpts: Record = { + systemMessageMode: "developer", + parallelToolCalls: metadata?.parallelToolCalls ?? true, } - const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) - - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model: modelId, - temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), - messages: convertedMessages, - stream: true as const, - ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), - ...(reasoning && reasoning), - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, + const effort = modelInfo.reasoningEffort as string | undefined + if (effort) { + openaiOpts.reasoningEffort = effort } - // Add max_tokens if needed - this.addMaxTokensIfNeeded(requestOptions, modelInfo) - - let stream - try { - stream = await this.client.chat.completions.create( - requestOptions, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) - } catch (error) { - throw handleOpenAIError(error, this.providerName) + providerOptions.openai = openaiOpts + } else if (reasoning?.reasoning_effort) { + providerOptions.openai = { + reasoningEffort: reasoning.reasoning_effort, + parallelToolCalls: metadata?.parallelToolCalls ?? true, } + } + + if (deepseekReasoner) { + effectiveSystemPrompt = undefined + aiSdkMessages.unshift({ role: "user", content: systemPrompt }) + } - const matcher = new TagMatcher( - "think", - (chunk) => - ({ - type: chunk.matched ? "reasoning" : "text", - text: chunk.data, - }) as const, + if (this.options.openAiStreamingEnabled ?? true) { + yield* this.handleStreaming( + languageModel, + effectiveSystemPrompt, + aiSdkMessages, + effectiveTemperature, + aiSdkTools, + metadata, + providerOptions, + modelInfo, + ) + } else { + yield* this.handleNonStreaming( + languageModel, + effectiveSystemPrompt, + aiSdkMessages, + effectiveTemperature, + aiSdkTools, + metadata, + providerOptions, + modelInfo, ) + } + } - let lastUsage - const activeToolCallIds = new Set() + private async *handleStreaming( + languageModel: LanguageModel, + systemPrompt: string | undefined, + messages: ReturnType, + temperature: number | undefined, + tools: ToolSet | undefined, + metadata: ApiHandlerCreateMessageMetadata | undefined, + providerOptions: Record, + modelInfo: ModelInfo, + ): ApiStream { + const result = streamText({ + model: languageModel, + system: systemPrompt, + messages, + temperature, + maxOutputTokens: this.getMaxOutputTokens(), + tools, + toolChoice: mapToolChoice(metadata?.tool_choice), + providerOptions: Object.keys(providerOptions).length > 0 ? providerOptions : undefined, + }) - for await (const chunk of stream) { - const delta = chunk.choices?.[0]?.delta ?? {} - const finishReason = chunk.choices?.[0]?.finish_reason + const matcher = new TagMatcher( + "think", + (chunk) => + ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) as const, + ) - if (delta.content) { - for (const chunk of matcher.update(delta.content)) { + try { + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + if (chunk.type === "text") { + for (const matchedChunk of matcher.update(chunk.text)) { + yield matchedChunk + } + } else { yield chunk } } - - if ("reasoning_content" in delta && delta.reasoning_content) { - yield { - type: "reasoning", - text: (delta.reasoning_content as string | undefined) || "", - } - } - - yield* this.processToolCalls(delta, finishReason, activeToolCallIds) - - if (chunk.usage) { - lastUsage = chunk.usage - } } for (const chunk of matcher.final()) { yield chunk } - if (lastUsage) { - yield this.processUsageMetrics(lastUsage, modelInfo) - } - } else { - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: modelId, - messages: deepseekReasoner - ? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]) - : [systemMessage, ...convertToOpenAiMessages(messages)], - // Tools are always present (minimum ALWAYS_AVAILABLE_TOOLS) - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, - } - - // Add max_tokens if needed - this.addMaxTokensIfNeeded(requestOptions, modelInfo) - - let response - try { - response = await this.client.chat.completions.create( - requestOptions, - this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) - } catch (error) { - throw handleOpenAIError(error, this.providerName) + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage, modelInfo) } + } catch (error) { + throw handleAiSdkError(error, this.providerName) + } + } - const message = response.choices?.[0]?.message + private async *handleNonStreaming( + languageModel: LanguageModel, + systemPrompt: string | undefined, + messages: ReturnType, + temperature: number | undefined, + tools: ToolSet | undefined, + metadata: ApiHandlerCreateMessageMetadata | undefined, + providerOptions: Record, + modelInfo: ModelInfo, + ): ApiStream { + try { + const { text, toolCalls, usage } = await generateText({ + model: languageModel, + system: systemPrompt, + messages, + temperature, + maxOutputTokens: this.getMaxOutputTokens(), + tools, + toolChoice: mapToolChoice(metadata?.tool_choice), + providerOptions: Object.keys(providerOptions).length > 0 ? providerOptions : undefined, + }) - if (message?.tool_calls) { - for (const toolCall of message.tool_calls) { - if (toolCall.type === "function") { - yield { - type: "tool_call", - id: toolCall.id, - name: toolCall.function.name, - arguments: toolCall.function.arguments, - } + if (toolCalls && toolCalls.length > 0) { + for (const toolCall of toolCalls) { + yield { + type: "tool_call", + id: toolCall.toolCallId, + name: toolCall.toolName, + arguments: JSON.stringify((toolCall as any).args), } } } yield { type: "text", - text: message?.content || "", + text: text || "", } - yield this.processUsageMetrics(response.usage, modelInfo) + if (usage) { + yield this.processUsageMetrics(usage, modelInfo) + } + } catch (error) { + throw handleAiSdkError(error, this.providerName) } } - protected processUsageMetrics(usage: any, _modelInfo?: ModelInfo): ApiStreamUsageChunk { + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + _modelInfo?: ModelInfo, + ): ApiStreamUsageChunk { return { type: "usage", - inputTokens: usage?.prompt_tokens || 0, - outputTokens: usage?.completion_tokens || 0, - cacheWriteTokens: usage?.cache_creation_input_tokens || undefined, - cacheReadTokens: usage?.cache_read_input_tokens || undefined, + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens: usage.details?.cachedInputTokens, } } @@ -292,208 +301,37 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl return { id, info, ...params } } + protected getMaxOutputTokens(): number | undefined { + if (this.options.includeMaxTokens !== true) { + return undefined + } + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } + async completePrompt(prompt: string): Promise { try { - const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) - const model = this.getModel() - const modelInfo = model.info - - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: model.id, - messages: [{ role: "user", content: prompt }], - } - - // Add max_tokens if needed - this.addMaxTokensIfNeeded(requestOptions, modelInfo) - - let response - try { - response = await this.client.chat.completions.create( - requestOptions, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? 0, + }) - return response.choices?.[0]?.message.content || "" + return text } catch (error) { if (error instanceof Error) { throw new Error(`${this.providerName} completion error: ${error.message}`) } - throw error } } - private async *handleO3FamilyMessage( - modelId: string, - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - ): ApiStream { - const modelInfo = this.getModel().info - const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) - - if (this.options.openAiStreamingEnabled ?? true) { - const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) - - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model: modelId, - messages: [ - { - role: "developer", - content: `Formatting re-enabled\n${systemPrompt}`, - }, - ...convertToOpenAiMessages(messages), - ], - stream: true, - ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), - reasoning_effort: modelInfo.reasoningEffort as "low" | "medium" | "high" | undefined, - temperature: undefined, - // Tools are always present (minimum ALWAYS_AVAILABLE_TOOLS) - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, - } - - // O3 family models do not support the deprecated max_tokens parameter - // but they do support max_completion_tokens (the modern OpenAI parameter) - // This allows O3 models to limit response length when includeMaxTokens is enabled - this.addMaxTokensIfNeeded(requestOptions, modelInfo) - - let stream - try { - stream = await this.client.chat.completions.create( - requestOptions, - methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } - - yield* this.handleStreamResponse(stream) - } else { - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: modelId, - messages: [ - { - role: "developer", - content: `Formatting re-enabled\n${systemPrompt}`, - }, - ...convertToOpenAiMessages(messages), - ], - reasoning_effort: modelInfo.reasoningEffort as "low" | "medium" | "high" | undefined, - temperature: undefined, - // Tools are always present (minimum ALWAYS_AVAILABLE_TOOLS) - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, - } - - // O3 family models do not support the deprecated max_tokens parameter - // but they do support max_completion_tokens (the modern OpenAI parameter) - // This allows O3 models to limit response length when includeMaxTokens is enabled - this.addMaxTokensIfNeeded(requestOptions, modelInfo) - - let response - try { - response = await this.client.chat.completions.create( - requestOptions, - methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } - - const message = response.choices?.[0]?.message - if (message?.tool_calls) { - for (const toolCall of message.tool_calls) { - if (toolCall.type === "function") { - yield { - type: "tool_call", - id: toolCall.id, - name: toolCall.function.name, - arguments: toolCall.function.arguments, - } - } - } - } - - yield { - type: "text", - text: message?.content || "", - } - yield this.processUsageMetrics(response.usage) - } - } - - private async *handleStreamResponse(stream: AsyncIterable): ApiStream { - const activeToolCallIds = new Set() - - for await (const chunk of stream) { - const delta = chunk.choices?.[0]?.delta - const finishReason = chunk.choices?.[0]?.finish_reason - - if (delta) { - if (delta.content) { - yield { - type: "text", - text: delta.content, - } - } - - yield* this.processToolCalls(delta, finishReason, activeToolCallIds) - } - - if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - } - } - } - } - - /** - * Helper generator to process tool calls from a stream chunk. - * Tracks active tool call IDs and yields tool_call_partial and tool_call_end events. - * @param delta - The delta object from the stream chunk - * @param finishReason - The finish_reason from the stream chunk - * @param activeToolCallIds - Set to track active tool call IDs (mutated in place) - */ - private *processToolCalls( - delta: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta | undefined, - finishReason: string | null | undefined, - activeToolCallIds: Set, - ): Generator< - | { type: "tool_call_partial"; index: number; id?: string; name?: string; arguments?: string } - | { type: "tool_call_end"; id: string } - > { - if (delta?.tool_calls) { - for (const toolCall of delta.tool_calls) { - if (toolCall.id) { - activeToolCallIds.add(toolCall.id) - } - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } - - // Emit tool_call_end events when finish_reason is "tool_calls" - // This ensures tool calls are finalized even if the stream doesn't properly close - if (finishReason === "tool_calls" && activeToolCallIds.size > 0) { - for (const id of activeToolCallIds) { - yield { type: "tool_call_end", id } - } - activeToolCallIds.clear() - } + override isAiSdkProvider(): boolean { + return true } protected _getUrlHost(baseUrl?: string): string { @@ -504,34 +342,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } - private _isGrokXAI(baseUrl?: string): boolean { - const urlHost = this._getUrlHost(baseUrl) - return urlHost.includes("x.ai") - } - protected _isAzureAiInference(baseUrl?: string): boolean { const urlHost = this._getUrlHost(baseUrl) return urlHost.endsWith(".services.ai.azure.com") } - - /** - * Adds max_completion_tokens to the request body if needed based on provider configuration - * Note: max_tokens is deprecated in favor of max_completion_tokens as per OpenAI documentation - * O3 family models handle max_tokens separately in handleO3FamilyMessage - */ - protected addMaxTokensIfNeeded( - requestOptions: - | OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming - | OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, - modelInfo: ModelInfo, - ): void { - // Only add max_completion_tokens if includeMaxTokens is true - if (this.options.includeMaxTokens === true) { - // Use user-configured modelMaxTokens if available, otherwise fall back to model's default maxTokens - // Using max_completion_tokens as max_tokens is deprecated - requestOptions.max_completion_tokens = this.options.modelMaxTokens || modelInfo.maxTokens - } - } } export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiHeaders?: Record) { From 8372f41883472a29c26104aa41643dfc47310d34 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Mon, 9 Feb 2026 23:33:33 +0000 Subject: [PATCH 2/2] fix: include reasoningTokens and providerMetadata in OpenAiHandler usage metrics --- .../__tests__/openai-usage-tracking.spec.ts | 73 +++++++++++++++++++ src/api/providers/openai.ts | 21 +++++- 2 files changed, 90 insertions(+), 4 deletions(-) diff --git a/src/api/providers/__tests__/openai-usage-tracking.spec.ts b/src/api/providers/__tests__/openai-usage-tracking.spec.ts index 151d71a595e..042c411f388 100644 --- a/src/api/providers/__tests__/openai-usage-tracking.spec.ts +++ b/src/api/providers/__tests__/openai-usage-tracking.spec.ts @@ -155,5 +155,78 @@ describe("OpenAiHandler with usage tracking fix", () => { const usageChunks = chunks.filter((chunk) => chunk.type === "usage") expect(usageChunks).toHaveLength(0) }) + + it("should include reasoningTokens from usage.details", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValueOnce({ + fullStream: mockFullStream(), + usage: Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + details: { + reasoningTokens: 3, + }, + }), + providerMetadata: Promise.resolve(undefined), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks).toHaveLength(1) + expect(usageChunks[0]).toEqual( + expect.objectContaining({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + reasoningTokens: 3, + }), + ) + }) + + it("should extract cache and reasoning tokens from providerMetadata", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValueOnce({ + fullStream: mockFullStream(), + usage: Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + }), + providerMetadata: Promise.resolve({ + openai: { + cachedPromptTokens: 80, + reasoningTokens: 20, + }, + }), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks).toHaveLength(1) + expect(usageChunks[0]).toEqual( + expect.objectContaining({ + type: "usage", + inputTokens: 100, + outputTokens: 50, + cacheReadTokens: 80, + reasoningTokens: 20, + }), + ) + }) }) }) diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index ab3f7f33070..f4b44c519c9 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -215,8 +215,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } const usage = await result.usage + const providerMetadata = await result.providerMetadata if (usage) { - yield this.processUsageMetrics(usage, modelInfo) + yield this.processUsageMetrics(usage, modelInfo, providerMetadata as any) } } catch (error) { throw handleAiSdkError(error, this.providerName) @@ -234,7 +235,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl modelInfo: ModelInfo, ): ApiStream { try { - const { text, toolCalls, usage } = await generateText({ + const { text, toolCalls, usage, providerMetadata } = await generateText({ model: languageModel, system: systemPrompt, messages, @@ -262,7 +263,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } if (usage) { - yield this.processUsageMetrics(usage, modelInfo) + yield this.processUsageMetrics(usage, modelInfo, providerMetadata as any) } } catch (error) { throw handleAiSdkError(error, this.providerName) @@ -279,12 +280,24 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } }, _modelInfo?: ModelInfo, + providerMetadata?: { + openai?: { + cachedPromptTokens?: number + reasoningTokens?: number + } + }, ): ApiStreamUsageChunk { + // Extract cache and reasoning metrics from OpenAI's providerMetadata when available, + // falling back to usage.details for standard AI SDK fields. + const cacheReadTokens = providerMetadata?.openai?.cachedPromptTokens ?? usage.details?.cachedInputTokens + const reasoningTokens = providerMetadata?.openai?.reasoningTokens ?? usage.details?.reasoningTokens + return { type: "usage", inputTokens: usage.inputTokens || 0, outputTokens: usage.outputTokens || 0, - cacheReadTokens: usage.details?.cachedInputTokens, + cacheReadTokens, + reasoningTokens, } }