diff --git a/src/api/providers/__tests__/base-openai-compatible-provider-timeout.spec.ts b/src/api/providers/__tests__/base-openai-compatible-provider-timeout.spec.ts deleted file mode 100644 index baa7ae953bc..00000000000 --- a/src/api/providers/__tests__/base-openai-compatible-provider-timeout.spec.ts +++ /dev/null @@ -1,119 +0,0 @@ -// npx vitest run api/providers/__tests__/base-openai-compatible-provider-timeout.spec.ts - -import type { ModelInfo } from "@roo-code/types" - -import { BaseOpenAiCompatibleProvider } from "../base-openai-compatible-provider" - -// Mock the timeout config utility -vitest.mock("../utils/timeout-config", () => ({ - getApiRequestTimeout: vitest.fn(), -})) - -import { getApiRequestTimeout } from "../utils/timeout-config" - -// Mock OpenAI and capture constructor calls -const mockOpenAIConstructor = vitest.fn() - -vitest.mock("openai", () => { - return { - __esModule: true, - default: vitest.fn().mockImplementation((config) => { - mockOpenAIConstructor(config) - return { - chat: { - completions: { - create: vitest.fn(), - }, - }, - } - }), - } -}) - -// Create a concrete test implementation of the abstract base class -class TestOpenAiCompatibleProvider extends BaseOpenAiCompatibleProvider<"test-model"> { - constructor(apiKey: string) { - const testModels: Record<"test-model", ModelInfo> = { - "test-model": { - maxTokens: 4096, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.5, - outputPrice: 1.5, - }, - } - - super({ - providerName: "TestProvider", - baseURL: "https://test.example.com/v1", - defaultProviderModelId: "test-model", - providerModels: testModels, - apiKey, - }) - } -} - -describe("BaseOpenAiCompatibleProvider Timeout Configuration", () => { - beforeEach(() => { - vitest.clearAllMocks() - }) - - it("should call getApiRequestTimeout when creating the provider", () => { - ;(getApiRequestTimeout as any).mockReturnValue(600000) - - new TestOpenAiCompatibleProvider("test-api-key") - - expect(getApiRequestTimeout).toHaveBeenCalled() - }) - - it("should pass the default timeout to the OpenAI client constructor", () => { - ;(getApiRequestTimeout as any).mockReturnValue(600000) // 600 seconds in ms - - new TestOpenAiCompatibleProvider("test-api-key") - - expect(mockOpenAIConstructor).toHaveBeenCalledWith( - expect.objectContaining({ - baseURL: "https://test.example.com/v1", - apiKey: "test-api-key", - timeout: 600000, - }), - ) - }) - - it("should use custom timeout value from getApiRequestTimeout", () => { - ;(getApiRequestTimeout as any).mockReturnValue(1800000) // 30 minutes in ms - - new TestOpenAiCompatibleProvider("test-api-key") - - expect(mockOpenAIConstructor).toHaveBeenCalledWith( - expect.objectContaining({ - timeout: 1800000, - }), - ) - }) - - it("should handle zero timeout (no timeout)", () => { - ;(getApiRequestTimeout as any).mockReturnValue(0) - - new TestOpenAiCompatibleProvider("test-api-key") - - expect(mockOpenAIConstructor).toHaveBeenCalledWith( - expect.objectContaining({ - timeout: 0, - }), - ) - }) - - it("should pass DEFAULT_HEADERS to the OpenAI client constructor", () => { - ;(getApiRequestTimeout as any).mockReturnValue(600000) - - new TestOpenAiCompatibleProvider("test-api-key") - - expect(mockOpenAIConstructor).toHaveBeenCalledWith( - expect.objectContaining({ - defaultHeaders: expect.any(Object), - }), - ) - }) -}) diff --git a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts deleted file mode 100644 index 6f8d121e69e..00000000000 --- a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts +++ /dev/null @@ -1,548 +0,0 @@ -// npx vitest run api/providers/__tests__/base-openai-compatible-provider.spec.ts - -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" - -import type { ModelInfo } from "@roo-code/types" - -import { BaseOpenAiCompatibleProvider } from "../base-openai-compatible-provider" - -// Create mock functions -const mockCreate = vi.fn() - -// Mock OpenAI module -vi.mock("openai", () => ({ - default: vi.fn(() => ({ - chat: { - completions: { - create: mockCreate, - }, - }, - })), -})) - -// Create a concrete test implementation of the abstract base class -class TestOpenAiCompatibleProvider extends BaseOpenAiCompatibleProvider<"test-model"> { - constructor(apiKey: string) { - const testModels: Record<"test-model", ModelInfo> = { - "test-model": { - maxTokens: 4096, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.5, - outputPrice: 1.5, - }, - } - - super({ - providerName: "TestProvider", - baseURL: "https://test.example.com/v1", - defaultProviderModelId: "test-model", - providerModels: testModels, - apiKey, - }) - } -} - -describe("BaseOpenAiCompatibleProvider", () => { - let handler: TestOpenAiCompatibleProvider - - beforeEach(() => { - vi.clearAllMocks() - handler = new TestOpenAiCompatibleProvider("test-api-key") - }) - - afterEach(() => { - vi.restoreAllMocks() - }) - - describe("TagMatcher reasoning tags", () => { - it("should handle reasoning tags () from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "Let me think" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: " about this" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "The answer is 42" } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // TagMatcher yields chunks as they're processed - expect(chunks).toEqual([ - { type: "reasoning", text: "Let me think" }, - { type: "reasoning", text: " about this" }, - { type: "text", text: "The answer is 42" }, - ]) - }) - - it("should handle complete tag in a single chunk", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "Regular text before " } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "Complete thought" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: " regular text after" } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // When a complete tag arrives in one chunk, TagMatcher may not parse it - // This test documents the actual behavior - expect(chunks.length).toBeGreaterThan(0) - expect(chunks[0]).toEqual({ type: "text", text: "Regular text before " }) - }) - - it("should handle incomplete tag at end of stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "Incomplete thought" } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // TagMatcher should handle incomplete tags and flush remaining content - expect(chunks.length).toBeGreaterThan(0) - expect( - chunks.some( - (c) => (c.type === "text" || c.type === "reasoning") && c.text.includes("Incomplete thought"), - ), - ).toBe(true) - }) - - it("should handle text without any tags", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "Just regular text" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: " without reasoning" } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - expect(chunks).toEqual([ - { type: "text", text: "Just regular text" }, - { type: "text", text: " without reasoning" }, - ]) - }) - - it("should handle tags that start at beginning of stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "reasoning" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: " content" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: " normal text" } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - expect(chunks).toEqual([ - { type: "reasoning", text: "reasoning" }, - { type: "reasoning", text: " content" }, - { type: "text", text: " normal text" }, - ]) - }) - }) - - describe("reasoning_content field", () => { - it("should filter out whitespace-only reasoning_content", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { reasoning_content: "\n" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { reasoning_content: " " } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { reasoning_content: "\t\n " } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "Regular content" } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Should only have the regular content, not the whitespace-only reasoning - expect(chunks).toEqual([{ type: "text", text: "Regular content" }]) - }) - - it("should yield non-empty reasoning_content", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { reasoning_content: "Thinking step 1" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { reasoning_content: "\n" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { reasoning_content: "Thinking step 2" } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Should only yield the non-empty reasoning content - expect(chunks).toEqual([ - { type: "reasoning", text: "Thinking step 1" }, - { type: "reasoning", text: "Thinking step 2" }, - ]) - }) - - it("should handle reasoning_content with leading/trailing whitespace", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { reasoning_content: " content with spaces " } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Should yield reasoning with spaces (only pure whitespace is filtered) - expect(chunks).toEqual([{ type: "reasoning", text: " content with spaces " }]) - }) - }) - - describe("Basic functionality", () => { - it("should create stream with correct parameters", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - const systemPrompt = "Test system prompt" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }] - - const messageGenerator = handler.createMessage(systemPrompt, messages) - await messageGenerator.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: "test-model", - temperature: 0, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), - stream: true, - stream_options: { include_usage: true }, - }), - undefined, - ) - }) - - it("should yield usage data from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: {} }], - usage: { prompt_tokens: 100, completion_tokens: 50 }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() - - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 100, outputTokens: 50 }) - }) - }) - - describe("Tool call handling", () => { - it("should yield tool_call_end events when finish_reason is tool_calls", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_123", - function: { name: "test_tool", arguments: '{"arg":' }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { arguments: '"value"}' }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - }, - ], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Should have tool_call_partial and tool_call_end - const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - - expect(partialChunks).toHaveLength(2) - expect(endChunks).toHaveLength(1) - expect(endChunks[0]).toEqual({ type: "tool_call_end", id: "call_123" }) - }) - - it("should yield multiple tool_call_end events for parallel tool calls", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_001", - function: { name: "tool_a", arguments: "{}" }, - }, - { - index: 1, - id: "call_002", - function: { name: "tool_b", arguments: "{}" }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - }, - ], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - expect(endChunks).toHaveLength(2) - expect(endChunks.map((c: any) => c.id).sort()).toEqual(["call_001", "call_002"]) - }) - - it("should not yield tool_call_end when finish_reason is not tool_calls", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { content: "Some text response" }, - finish_reason: "stop", - }, - ], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - expect(endChunks).toHaveLength(0) - }) - }) -}) diff --git a/src/api/providers/__tests__/roo.spec.ts b/src/api/providers/__tests__/roo.spec.ts index a6a76fe100d..9bac9b459c2 100644 --- a/src/api/providers/__tests__/roo.spec.ts +++ b/src/api/providers/__tests__/roo.spec.ts @@ -5,61 +5,28 @@ import { rooDefaultModelId } from "@roo-code/types" import { ApiHandlerOptions } from "../../../shared/api" -// Mock OpenAI client -const mockCreate = vitest.fn() - -vitest.mock("openai", () => { - return { - __esModule: true, - default: vitest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: "test-completion", - choices: [ - { - message: { role: "assistant", content: "Test response" }, - finish_reason: "stop", - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - } +// Mock the AI SDK +const mockStreamText = vitest.fn() +const mockGenerateText = vitest.fn() +const mockCreateOpenAICompatible = vitest.fn() + +vitest.mock("ai", () => ({ + streamText: (...args: unknown[]) => mockStreamText(...args), + generateText: (...args: unknown[]) => mockGenerateText(...args), + tool: vitest.fn((t) => t), + jsonSchema: vitest.fn((s) => s), +})) - return { - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" }, index: 0 }], - usage: null, - } - yield { - choices: [{ delta: {}, index: 0 }], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - }, - } - }), - }, - }, - })), - } -}) +vitest.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: (...args: unknown[]) => { + mockCreateOpenAICompatible(...args) + return vitest.fn((modelId: string) => ({ modelId, provider: "roo" })) + }, +})) // Mock CloudService - Define functions outside to avoid initialization issues -const mockGetSessionToken = vitest.fn() -const mockHasInstance = vitest.fn() - -// Create mock functions that we can control const mockGetSessionTokenFn = vitest.fn() const mockHasInstanceFn = vitest.fn() -const mockOnFn = vitest.fn() vitest.mock("@roo-code/cloud", () => ({ CloudService: { @@ -128,6 +95,45 @@ vitest.mock("../../providers/fetchers/modelCache", () => ({ import { RooHandler } from "../roo" import { CloudService } from "@roo-code/cloud" +/** + * Helper to create a mock stream result for streamText. + */ +function createMockStreamResult(options?: { + textChunks?: string[] + reasoningChunks?: string[] + toolCallParts?: Array<{ type: string; id?: string; toolName?: string; delta?: string }> + inputTokens?: number + outputTokens?: number + providerMetadata?: Record +}) { + const { + textChunks = ["Test response"], + reasoningChunks = [], + toolCallParts = [], + inputTokens = 10, + outputTokens = 5, + providerMetadata = undefined, + } = options ?? {} + + const fullStream = (async function* () { + for (const text of reasoningChunks) { + yield { type: "reasoning-delta", text } + } + for (const text of textChunks) { + yield { type: "text-delta", text, id: "1" } + } + for (const part of toolCallParts) { + yield part + } + })() + + return { + fullStream, + usage: Promise.resolve({ inputTokens, outputTokens }), + providerMetadata: Promise.resolve(providerMetadata), + } +} + describe("RooHandler", () => { let handler: RooHandler let mockOptions: ApiHandlerOptions @@ -146,7 +152,9 @@ describe("RooHandler", () => { // Set up CloudService mocks for successful authentication mockHasInstanceFn.mockReturnValue(true) mockGetSessionTokenFn.mockReturnValue("test-session-token") - mockCreate.mockClear() + mockStreamText.mockClear() + mockGenerateText.mockClear() + mockCreateOpenAICompatible.mockClear() vitest.clearAllMocks() }) @@ -187,8 +195,6 @@ describe("RooHandler", () => { it("should pass correct configuration to base class", () => { handler = new RooHandler(mockOptions) expect(handler).toBeInstanceOf(RooHandler) - // The handler should be initialized with correct base URL and API key - // We can't directly test the parent class constructor, but we can verify the handler works expect(handler).toBeDefined() }) }) @@ -199,21 +205,26 @@ describe("RooHandler", () => { }) it("should update API key before making request", async () => { - // Set up a fresh token that will be returned when createMessage is called const freshToken = "fresh-session-token" mockGetSessionTokenFn.mockReturnValue(freshToken) + mockStreamText.mockReturnValue(createMockStreamResult()) const stream = handler.createMessage(systemPrompt, messages) - // Consume the stream to trigger the API call for await (const _chunk of stream) { // Just consume } - // Verify getSessionToken was called to get the fresh token - expect(mockGetSessionTokenFn).toHaveBeenCalled() + // Verify createOpenAICompatible was called (per-request provider creates fresh one) + expect(mockCreateOpenAICompatible).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: freshToken, + }), + ) }) it("should handle streaming responses", async () => { + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] for await (const chunk of stream) { @@ -227,6 +238,8 @@ describe("RooHandler", () => { }) it("should include usage information", async () => { + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] for await (const chunk of stream) { @@ -240,33 +253,34 @@ describe("RooHandler", () => { }) it("should handle API errors", async () => { - mockCreate.mockRejectedValueOnce(new Error("API Error")) + mockStreamText.mockReturnValue({ + fullStream: { + [Symbol.asyncIterator]() { + return { + next: () => Promise.reject(new Error("API Error")), + } + }, + }, + usage: new Promise(() => {}), // never resolves; stream throws before usage is awaited + providerMetadata: Promise.resolve(undefined), + }) + const stream = handler.createMessage(systemPrompt, messages) await expect(async () => { for await (const _chunk of stream) { // Should not reach here } - }).rejects.toThrow("API Error") + }).rejects.toThrow() }) it("should handle empty response content", async () => { - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { content: null }, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 0, - total_tokens: 10, - }, - } - }, - }) + mockStreamText.mockReturnValue( + createMockStreamResult({ + textChunks: [], + inputTokens: 10, + outputTokens: 0, + }), + ) const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] @@ -281,6 +295,8 @@ describe("RooHandler", () => { }) it("should handle multiple messages in conversation", async () => { + mockStreamText.mockReturnValue(createMockStreamResult()) + const multipleMessages: Anthropic.Messages.MessageParam[] = [ { role: "user", content: "First message" }, { role: "assistant", content: "First response" }, @@ -293,18 +309,45 @@ describe("RooHandler", () => { chunks.push(chunk) } - expect(mockCreate).toHaveBeenCalledWith( + // Verify streamText was called with system prompt and converted messages + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - messages: expect.arrayContaining([ - expect.objectContaining({ role: "system", content: systemPrompt }), - expect.objectContaining({ role: "user", content: "First message" }), - expect.objectContaining({ role: "assistant", content: "First response" }), - expect.objectContaining({ role: "user", content: "Second message" }), - ]), + system: systemPrompt, + messages: expect.any(Array), }), + ) + }) + + it("should pass X-Roo-App-Version header via createOpenAICompatible", async () => { + mockStreamText.mockReturnValue(createMockStreamResult()) + + const stream = handler.createMessage(systemPrompt, messages) + for await (const _chunk of stream) { + // consume + } + + expect(mockCreateOpenAICompatible).toHaveBeenCalledWith( + expect.objectContaining({ + headers: expect.objectContaining({ + "X-Roo-App-Version": expect.any(String), + }), + }), + ) + }) + + it("should pass X-Roo-Task-ID header when taskId is provided", async () => { + mockStreamText.mockReturnValue(createMockStreamResult()) + + const stream = handler.createMessage(systemPrompt, messages, { taskId: "test-task-123" }) + for await (const _chunk of stream) { + // consume + } + + expect(mockCreateOpenAICompatible).toHaveBeenCalledWith( expect.objectContaining({ headers: expect.objectContaining({ "X-Roo-App-Version": expect.any(String), + "X-Roo-Task-ID": "test-task-123", }), }), ) @@ -317,52 +360,39 @@ describe("RooHandler", () => { }) it("should complete prompt successfully", async () => { + mockGenerateText.mockResolvedValue({ text: "Test response" }) + const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.apiModelId, - messages: [{ role: "user", content: "Test prompt" }], - }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", + }), + ) }) it("should update API key before making request", async () => { - // Set up a fresh token that will be returned when completePrompt is called const freshToken = "fresh-session-token" mockGetSessionTokenFn.mockReturnValue(freshToken) - - // Access the client's apiKey property to verify it gets updated - const clientApiKeyGetter = vitest.fn() - Object.defineProperty(handler["client"], "apiKey", { - get: clientApiKeyGetter, - set: vitest.fn(), - configurable: true, - }) + mockGenerateText.mockResolvedValue({ text: "Test response" }) await handler.completePrompt("Test prompt") - // Verify getSessionToken was called to get the fresh token - expect(mockGetSessionTokenFn).toHaveBeenCalled() + // Verify createOpenAICompatible was called with fresh token + expect(mockCreateOpenAICompatible).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: freshToken, + }), + ) }) it("should handle API errors", async () => { - mockCreate.mockRejectedValueOnce(new Error("API Error")) - await expect(handler.completePrompt("Test prompt")).rejects.toThrow( - "Roo Code Cloud completion error: API Error", - ) + mockGenerateText.mockRejectedValue(new Error("API Error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Roo Code Cloud") }) it("should handle empty response", async () => { - mockCreate.mockResolvedValueOnce({ - choices: [{ message: { content: "" } }], - }) - const result = await handler.completePrompt("Test prompt") - expect(result).toBe("") - }) - - it("should handle missing response content", async () => { - mockCreate.mockResolvedValueOnce({ - choices: [{ message: {} }], - }) + mockGenerateText.mockResolvedValue({ text: "" }) const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) @@ -377,7 +407,6 @@ describe("RooHandler", () => { const modelInfo = handler.getModel() expect(modelInfo.id).toBe(mockOptions.apiModelId) expect(modelInfo.info).toBeDefined() - // Models are loaded dynamically, so we just verify the structure expect(modelInfo.info.maxTokens).toBeDefined() expect(modelInfo.info.contextWindow).toBeDefined() }) @@ -387,7 +416,6 @@ describe("RooHandler", () => { const modelInfo = handlerWithoutModel.getModel() expect(modelInfo.id).toBe(rooDefaultModelId) expect(modelInfo.info).toBeDefined() - // Models are loaded dynamically expect(modelInfo.info.maxTokens).toBeDefined() expect(modelInfo.info.contextWindow).toBeDefined() }) @@ -399,7 +427,6 @@ describe("RooHandler", () => { const modelInfo = handlerWithUnknownModel.getModel() expect(modelInfo.id).toBe("unknown-model-id") expect(modelInfo.info).toBeDefined() - // Should return fallback info for unknown models (dynamic models will be merged in real usage) expect(modelInfo.info.maxTokens).toBeDefined() expect(modelInfo.info.contextWindow).toBeDefined() expect(modelInfo.info.supportsImages).toBeDefined() @@ -409,7 +436,6 @@ describe("RooHandler", () => { }) it("should handle any model ID since models are loaded dynamically", () => { - // Test with various model IDs - they should all work since models are loaded dynamically const testModelIds = ["xai/grok-code-fast-1", "roo/sonic", "deepseek/deepseek-chat-v3.1"] for (const modelId of testModelIds) { @@ -417,7 +443,6 @@ describe("RooHandler", () => { const modelInfo = handlerWithModel.getModel() expect(modelInfo.id).toBe(modelId) expect(modelInfo.info).toBeDefined() - // Verify the structure has required fields expect(modelInfo.info.maxTokens).toBeDefined() expect(modelInfo.info.contextWindow).toBeDefined() } @@ -428,7 +453,6 @@ describe("RooHandler", () => { apiModelId: "minimax/minimax-m2:free", }) const modelInfo = handlerWithMinimax.getModel() - // The settings from API should already be applied in the cached model info expect(modelInfo.info.inputPrice).toBe(0.15) expect(modelInfo.info.outputPrice).toBe(0.6) }) @@ -437,20 +461,17 @@ describe("RooHandler", () => { describe("temperature and model configuration", () => { it("should use default temperature of 0", async () => { handler = new RooHandler(mockOptions) + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) for await (const _chunk of stream) { // Consume stream } - expect(mockCreate).toHaveBeenCalledWith( + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ temperature: 0, }), - expect.objectContaining({ - headers: expect.objectContaining({ - "X-Roo-App-Version": expect.any(String), - }), - }), ) }) @@ -459,29 +480,23 @@ describe("RooHandler", () => { ...mockOptions, modelTemperature: 0.9, }) + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) for await (const _chunk of stream) { // Consume stream } - expect(mockCreate).toHaveBeenCalledWith( + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ temperature: 0.9, }), - expect.objectContaining({ - headers: expect.objectContaining({ - "X-Roo-App-Version": expect.any(String), - }), - }), ) }) it("should use correct API endpoint", () => { - // The base URL should be set to Roo's API endpoint - // We can't directly test the OpenAI client configuration, but we can verify the handler initializes handler = new RooHandler(mockOptions) expect(handler).toBeInstanceOf(RooHandler) - // The handler should work with the Roo API endpoint }) }) @@ -497,10 +512,8 @@ describe("RooHandler", () => { it("should handle undefined auth service gracefully", () => { mockHasInstanceFn.mockReturnValue(true) - // Mock CloudService with undefined authService const originalGetSessionToken = mockGetSessionTokenFn.getMockImplementation() - // Temporarily make authService return undefined mockGetSessionTokenFn.mockImplementation(() => undefined) try { @@ -516,11 +529,9 @@ describe("RooHandler", () => { expect(() => { new RooHandler(mockOptions) }).not.toThrow() - // Constructor should succeed even with undefined auth service const handler = new RooHandler(mockOptions) expect(handler).toBeInstanceOf(RooHandler) } finally { - // Restore original mock implementation if (originalGetSessionToken) { mockGetSessionTokenFn.mockImplementation(originalGetSessionToken) } else { @@ -535,137 +546,126 @@ describe("RooHandler", () => { expect(() => { new RooHandler(mockOptions) }).not.toThrow() - // Constructor should succeed even with empty session token const handler = new RooHandler(mockOptions) expect(handler).toBeInstanceOf(RooHandler) }) }) describe("reasoning effort support", () => { - it("should include reasoning with enabled: false when not enabled", async () => { + /** + * Helper: extracts the `transformRequestBody` function from the most recent + * `createOpenAICompatible` call and invokes it with a sample body to return + * the transformed result. Returns `undefined` when no transform was provided. + */ + function getTransformedBody(): Record | undefined { + const callArgs = mockCreateOpenAICompatible.mock.calls[0]?.[0] + if (!callArgs?.transformRequestBody) { + return undefined + } + const sampleBody = { model: "test-model", messages: [] } + return callArgs.transformRequestBody(sampleBody) + } + + it("should inject reasoning { enabled: false } via transformRequestBody when not enabled", async () => { handler = new RooHandler(mockOptions) + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) for await (const _chunk of stream) { // Consume stream } - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: mockOptions.apiModelId, - messages: expect.any(Array), - stream: true, - stream_options: { include_usage: true }, - reasoning: { enabled: false }, - }), - expect.objectContaining({ - headers: expect.objectContaining({ - "X-Roo-App-Version": expect.any(String), - }), - }), - ) + // Reasoning is injected via transformRequestBody when creating the provider + const transformed = getTransformedBody() + expect(transformed).toBeDefined() + expect(transformed!.reasoning).toEqual({ enabled: false }) + // Original body fields are preserved + expect(transformed!.model).toBe("test-model") }) - it("should include reasoning with enabled: false when explicitly disabled", async () => { + it("should inject reasoning { enabled: false } via transformRequestBody when explicitly disabled", async () => { handler = new RooHandler({ ...mockOptions, enableReasoningEffort: false, }) + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) for await (const _chunk of stream) { // Consume stream } - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - reasoning: { enabled: false }, - }), - expect.objectContaining({ - headers: expect.objectContaining({ - "X-Roo-App-Version": expect.any(String), - }), - }), - ) + const transformed = getTransformedBody() + expect(transformed).toBeDefined() + expect(transformed!.reasoning).toEqual({ enabled: false }) }) - it("should include reasoning with enabled: true and effort: low", async () => { + it("should inject reasoning { enabled: true, effort: 'low' } via transformRequestBody", async () => { handler = new RooHandler({ ...mockOptions, reasoningEffort: "low", }) + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) for await (const _chunk of stream) { // Consume stream } - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - reasoning: { enabled: true, effort: "low" }, - }), - expect.objectContaining({ - headers: expect.objectContaining({ - "X-Roo-App-Version": expect.any(String), - }), - }), - ) + const transformed = getTransformedBody() + expect(transformed).toBeDefined() + expect(transformed!.reasoning).toEqual({ enabled: true, effort: "low" }) }) - it("should include reasoning with enabled: true and effort: medium", async () => { + it("should inject reasoning { enabled: true, effort: 'medium' } via transformRequestBody", async () => { handler = new RooHandler({ ...mockOptions, reasoningEffort: "medium", }) + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) for await (const _chunk of stream) { // Consume stream } - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - reasoning: { enabled: true, effort: "medium" }, - }), - expect.objectContaining({ - headers: expect.objectContaining({ - "X-Roo-App-Version": expect.any(String), - }), - }), - ) + const transformed = getTransformedBody() + expect(transformed).toBeDefined() + expect(transformed!.reasoning).toEqual({ enabled: true, effort: "medium" }) }) - it("should include reasoning with enabled: true and effort: high", async () => { + it("should inject reasoning { enabled: true, effort: 'high' } via transformRequestBody", async () => { handler = new RooHandler({ ...mockOptions, reasoningEffort: "high", }) + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) for await (const _chunk of stream) { // Consume stream } - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - reasoning: { enabled: true, effort: "high" }, - }), - expect.objectContaining({ - headers: expect.objectContaining({ - "X-Roo-App-Version": expect.any(String), - }), - }), - ) + const transformed = getTransformedBody() + expect(transformed).toBeDefined() + expect(transformed!.reasoning).toEqual({ enabled: true, effort: "high" }) }) - it("should not include reasoning for minimal (treated as none)", async () => { + it("should not provide transformRequestBody for minimal (treated as none)", async () => { handler = new RooHandler({ ...mockOptions, reasoningEffort: "minimal", }) + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) for await (const _chunk of stream) { // Consume stream } - // minimal should result in no reasoning parameter - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs.reasoning).toBeUndefined() + // minimal should result in no reasoning parameter, thus no transformRequestBody + const callArgs = mockCreateOpenAICompatible.mock.calls[0][0] + expect(callArgs.transformRequestBody).toBeUndefined() }) it("should handle enableReasoningEffort: false overriding reasoningEffort setting", async () => { @@ -674,76 +674,32 @@ describe("RooHandler", () => { enableReasoningEffort: false, reasoningEffort: "high", }) + mockStreamText.mockReturnValue(createMockStreamResult()) + const stream = handler.createMessage(systemPrompt, messages) for await (const _chunk of stream) { // Consume stream } // When explicitly disabled, should send enabled: false regardless of effort setting - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - reasoning: { enabled: false }, - }), - expect.objectContaining({ - headers: expect.objectContaining({ - "X-Roo-App-Version": expect.any(String), - }), - }), - ) + const transformed = getTransformedBody() + expect(transformed).toBeDefined() + expect(transformed!.reasoning).toEqual({ enabled: false }) }) }) - describe("tool calls handling", () => { + describe("reasoning details accumulation", () => { beforeEach(() => { handler = new RooHandler(mockOptions) }) - it("should yield raw tool call chunks when tool_calls present", async () => { - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_123", - function: { name: "read_file", arguments: '{"path":"' }, - }, - ], - }, - index: 0, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { arguments: 'test.ts"}' }, - }, - ], - }, - index: 0, - }, - ], - } - yield { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - index: 0, - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - }, - }) + it("should accumulate reasoning text from reasoning-delta parts", async () => { + mockStreamText.mockReturnValue( + createMockStreamResult({ + reasoningChunks: ["thinking ", "about ", "this"], + textChunks: ["answer"], + }), + ) const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] @@ -751,60 +707,73 @@ describe("RooHandler", () => { chunks.push(chunk) } - // Verify we get raw tool call chunks - const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") + const reasoningChunks = chunks.filter((c) => c.type === "reasoning") + expect(reasoningChunks).toHaveLength(3) + expect(reasoningChunks[0].text).toBe("thinking ") + expect(reasoningChunks[1].text).toBe("about ") + expect(reasoningChunks[2].text).toBe("this") - expect(rawChunks).toHaveLength(2) - expect(rawChunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "call_123", - name: "read_file", - arguments: '{"path":"', - }) - expect(rawChunks[1]).toEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: 'test.ts"}', - }) + const details = handler.getReasoningDetails() + expect(details).toBeDefined() + expect(details![0].type).toBe("reasoning.text") + expect(details![0].text).toBe("thinking about this") }) - it("should yield raw tool call chunks even when finish_reason is not tool_calls", async () => { - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_456", - function: { - name: "write_to_file", - arguments: '{"path":"test.ts","content":"hello"}', - }, - }, - ], - }, - index: 0, - }, - ], - } - yield { - choices: [ - { - delta: {}, - finish_reason: "stop", - index: 0, - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - }, - }) + it("should override reasoning details from providerMetadata", async () => { + const providerReasoningDetails = [{ type: "reasoning.summary", summary: "Server summary", index: 0 }] + + mockStreamText.mockReturnValue( + createMockStreamResult({ + reasoningChunks: ["local thinking"], + textChunks: ["answer"], + providerMetadata: { + roo: { reasoning_details: providerReasoningDetails }, + }, + }), + ) + + const stream = handler.createMessage(systemPrompt, messages) + for await (const _chunk of stream) { + // consume + } + + const details = handler.getReasoningDetails() + expect(details).toBeDefined() + expect(details).toEqual(providerReasoningDetails) + }) + + it("should return undefined when no reasoning details", async () => { + mockStreamText.mockReturnValue( + createMockStreamResult({ + reasoningChunks: [], + textChunks: ["just text"], + }), + ) + + const stream = handler.createMessage(systemPrompt, messages) + for await (const _chunk of stream) { + // consume + } + + expect(handler.getReasoningDetails()).toBeUndefined() + }) + }) + + describe("usage and cost processing", () => { + beforeEach(() => { + handler = new RooHandler(mockOptions) + }) + + it("should use server-side cost from providerMetadata when available", async () => { + mockStreamText.mockReturnValue( + createMockStreamResult({ + inputTokens: 100, + outputTokens: 50, + providerMetadata: { + roo: { cost: 0.005 }, + }, + }), + ) const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] @@ -812,142 +781,88 @@ describe("RooHandler", () => { chunks.push(chunk) } - const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - - expect(rawChunks).toHaveLength(1) - expect(rawChunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "call_456", - name: "write_to_file", - arguments: '{"path":"test.ts","content":"hello"}', - }) + const usageChunk = chunks.find((c) => c.type === "usage") + expect(usageChunk).toBeDefined() + expect(usageChunk.totalCost).toBe(0.005) }) - it("should handle multiple tool calls with different indices", async () => { - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_1", - function: { name: "read_file", arguments: '{"path":"file1.ts"}' }, - }, - ], - }, - index: 0, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 1, - id: "call_2", - function: { name: "read_file", arguments: '{"path":"file2.ts"}' }, - }, - ], - }, - index: 0, - }, - ], - } - yield { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - index: 0, - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - }, + it("should report 0 cost for free models", async () => { + const freeHandler = new RooHandler({ + apiModelId: "xai/grok-code-fast-1", // has isFree: false but inputPrice/outputPrice = 0 }) + mockStreamText.mockReturnValue( + createMockStreamResult({ + inputTokens: 100, + outputTokens: 50, + providerMetadata: { + roo: { cost: 0.005 }, + }, + }), + ) + + const stream = freeHandler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunk = chunks.find((c) => c.type === "usage") + expect(usageChunk).toBeDefined() + // Model is not marked as isFree, so cost should be from server + expect(usageChunk.totalCost).toBe(0.005) + }) + + it("should include cache tokens from providerMetadata", async () => { + mockStreamText.mockReturnValue( + createMockStreamResult({ + inputTokens: 100, + outputTokens: 50, + providerMetadata: { + roo: { + cache_creation_input_tokens: 20, + cache_read_input_tokens: 30, + }, + }, + }), + ) + const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] for await (const chunk of stream) { chunks.push(chunk) } - const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") + const usageChunk = chunks.find((c) => c.type === "usage") + expect(usageChunk).toBeDefined() + expect(usageChunk.cacheWriteTokens).toBe(20) + expect(usageChunk.cacheReadTokens).toBe(30) + }) + }) - expect(rawChunks).toHaveLength(2) - expect(rawChunks[0].index).toBe(0) - expect(rawChunks[0].id).toBe("call_1") - expect(rawChunks[1].index).toBe(1) - expect(rawChunks[1].id).toBe("call_2") + describe("isAiSdkProvider", () => { + it("should return true", () => { + handler = new RooHandler(mockOptions) + expect(handler.isAiSdkProvider()).toBe(true) }) + }) - it("should emit raw chunks for streaming arguments", async () => { - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_789", - function: { name: "execute_command", arguments: '{"command":"' }, - }, - ], - }, - index: 0, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { arguments: "npm install" }, - }, - ], - }, - index: 0, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { arguments: '"}' }, - }, - ], - }, - index: 0, - }, - ], - } - yield { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - index: 0, - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - }, - }) + describe("tool calls handling", () => { + beforeEach(() => { + handler = new RooHandler(mockOptions) + }) + + it("should yield tool call events from AI SDK stream", async () => { + mockStreamText.mockReturnValue( + createMockStreamResult({ + textChunks: [], + toolCallParts: [ + { type: "tool-input-start", id: "call_123", toolName: "read_file" }, + { type: "tool-input-delta", id: "call_123", delta: '{"path":"test.ts"}' }, + { type: "tool-input-end", id: "call_123" }, + ], + }), + ) const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] @@ -955,26 +870,36 @@ describe("RooHandler", () => { chunks.push(chunk) } - const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") + const startChunks = chunks.filter((c) => c.type === "tool_call_start") + const deltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const endChunks = chunks.filter((c) => c.type === "tool_call_end") - expect(rawChunks).toHaveLength(3) - expect(rawChunks[0].arguments).toBe('{"command":"') - expect(rawChunks[1].arguments).toBe("npm install") - expect(rawChunks[2].arguments).toBe('"}') - }) + expect(startChunks).toHaveLength(1) + expect(startChunks[0].id).toBe("call_123") + expect(startChunks[0].name).toBe("read_file") - it("should not yield tool call chunks when no tool calls present", async () => { - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Regular text response" }, index: 0 }], - } - yield { - choices: [{ delta: {}, finish_reason: "stop", index: 0 }], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - }, - }) + expect(deltaChunks).toHaveLength(1) + expect(deltaChunks[0].id).toBe("call_123") + expect(deltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(endChunks).toHaveLength(1) + expect(endChunks[0].id).toBe("call_123") + }) + + it("should handle multiple tool calls", async () => { + mockStreamText.mockReturnValue( + createMockStreamResult({ + textChunks: [], + toolCallParts: [ + { type: "tool-input-start", id: "call_1", toolName: "read_file" }, + { type: "tool-input-delta", id: "call_1", delta: '{"path":"file1.ts"}' }, + { type: "tool-input-end", id: "call_1" }, + { type: "tool-input-start", id: "call_2", toolName: "read_file" }, + { type: "tool-input-delta", id: "call_2", delta: '{"path":"file2.ts"}' }, + { type: "tool-input-end", id: "call_2" }, + ], + }), + ) const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] @@ -982,71 +907,35 @@ describe("RooHandler", () => { chunks.push(chunk) } - const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - expect(rawChunks).toHaveLength(0) + const startChunks = chunks.filter((c) => c.type === "tool_call_start") + const endChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(startChunks).toHaveLength(2) + expect(startChunks[0].id).toBe("call_1") + expect(startChunks[1].id).toBe("call_2") + + expect(endChunks).toHaveLength(2) + expect(endChunks[0].id).toBe("call_1") + expect(endChunks[1].id).toBe("call_2") }) - it("should yield tool_call_end events when finish_reason is tool_calls", async () => { - // Import NativeToolCallParser to set up state - const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser") - - // Clear any previous state - NativeToolCallParser.clearRawChunkState() - - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_finish_test", - function: { name: "read_file", arguments: '{"path":"test.ts"}' }, - }, - ], - }, - index: 0, - }, - ], - } - yield { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - index: 0, - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - }, - }) + it("should not yield tool call chunks when no tool calls present", async () => { + mockStreamText.mockReturnValue( + createMockStreamResult({ + textChunks: ["Regular text response"], + }), + ) const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] for await (const chunk of stream) { - // Simulate what Task.ts does: when we receive tool_call_partial, - // process it through NativeToolCallParser to populate rawChunkTracker - if (chunk.type === "tool_call_partial") { - NativeToolCallParser.processRawChunk({ - index: chunk.index, - id: chunk.id, - name: chunk.name, - arguments: chunk.arguments, - }) - } chunks.push(chunk) } - // Should have tool_call_partial and tool_call_end - const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - - expect(partialChunks).toHaveLength(1) - expect(endChunks).toHaveLength(1) - expect(endChunks[0].id).toBe("call_finish_test") + const toolChunks = chunks.filter( + (c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end", + ) + expect(toolChunks).toHaveLength(0) }) }) }) diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts deleted file mode 100644 index fc3d769ae2a..00000000000 --- a/src/api/providers/base-openai-compatible-provider.ts +++ /dev/null @@ -1,260 +0,0 @@ -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" - -import type { ModelInfo } from "@roo-code/types" - -import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api" -import { TagMatcher } from "../../utils/tag-matcher" -import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" -import { convertToOpenAiMessages } from "../transform/openai-format" - -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" -import { DEFAULT_HEADERS } from "./constants" -import { BaseProvider } from "./base-provider" -import { handleOpenAIError } from "./utils/openai-error-handler" -import { calculateApiCostOpenAI } from "../../shared/cost" -import { getApiRequestTimeout } from "./utils/timeout-config" - -type BaseOpenAiCompatibleProviderOptions = ApiHandlerOptions & { - providerName: string - baseURL: string - defaultProviderModelId: ModelName - providerModels: Record - defaultTemperature?: number -} - -export abstract class BaseOpenAiCompatibleProvider - extends BaseProvider - implements SingleCompletionHandler -{ - protected readonly providerName: string - protected readonly baseURL: string - protected readonly defaultTemperature: number - protected readonly defaultProviderModelId: ModelName - protected readonly providerModels: Record - - protected readonly options: ApiHandlerOptions - - protected client: OpenAI - - constructor({ - providerName, - baseURL, - defaultProviderModelId, - providerModels, - defaultTemperature, - ...options - }: BaseOpenAiCompatibleProviderOptions) { - super() - - this.providerName = providerName - this.baseURL = baseURL - this.defaultProviderModelId = defaultProviderModelId - this.providerModels = providerModels - this.defaultTemperature = defaultTemperature ?? 0 - - this.options = options - - if (!this.options.apiKey) { - throw new Error("API key is required") - } - - this.client = new OpenAI({ - baseURL, - apiKey: this.options.apiKey, - defaultHeaders: DEFAULT_HEADERS, - timeout: getApiRequestTimeout(), - }) - } - - protected createStream( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - requestOptions?: OpenAI.RequestOptions, - ) { - const { id: model, info } = this.getModel() - - // Centralized cap: clamp to 20% of the context window (unless provider-specific exceptions apply) - const max_tokens = - getModelMaxOutputTokens({ - modelId: model, - model: info, - settings: this.options, - format: "openai", - }) ?? undefined - - const temperature = this.options.modelTemperature ?? info.defaultTemperature ?? this.defaultTemperature - - const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model, - max_tokens, - temperature, - messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], - stream: true, - stream_options: { include_usage: true }, - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, - } - - // Add thinking parameter if reasoning is enabled and model supports it - if (this.options.enableReasoningEffort && info.supportsReasoningBinary) { - ;(params as any).thinking = { type: "enabled" } - } - - try { - return this.client.chat.completions.create(params, requestOptions) - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } - } - - override async *createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - ): ApiStream { - const stream = await this.createStream(systemPrompt, messages, metadata) - - const matcher = new TagMatcher( - "think", - (chunk) => - ({ - type: chunk.matched ? "reasoning" : "text", - text: chunk.data, - }) as const, - ) - - let lastUsage: OpenAI.CompletionUsage | undefined - const activeToolCallIds = new Set() - - for await (const chunk of stream) { - // Check for provider-specific error responses (e.g., MiniMax base_resp) - const chunkAny = chunk as any - if (chunkAny.base_resp?.status_code && chunkAny.base_resp.status_code !== 0) { - throw new Error( - `${this.providerName} API Error (${chunkAny.base_resp.status_code}): ${chunkAny.base_resp.status_msg || "Unknown error"}`, - ) - } - - const delta = chunk.choices?.[0]?.delta - const finishReason = chunk.choices?.[0]?.finish_reason - - if (delta?.content) { - for (const processedChunk of matcher.update(delta.content)) { - yield processedChunk - } - } - - if (delta) { - for (const key of ["reasoning_content", "reasoning"] as const) { - if (key in delta) { - const reasoning_content = ((delta as any)[key] as string | undefined) || "" - if (reasoning_content?.trim()) { - yield { type: "reasoning", text: reasoning_content } - } - break - } - } - } - - // Emit raw tool call chunks - NativeToolCallParser handles state management - if (delta?.tool_calls) { - for (const toolCall of delta.tool_calls) { - if (toolCall.id) { - activeToolCallIds.add(toolCall.id) - } - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } - - // Emit tool_call_end events when finish_reason is "tool_calls" - // This ensures tool calls are finalized even if the stream doesn't properly close - if (finishReason === "tool_calls" && activeToolCallIds.size > 0) { - for (const id of activeToolCallIds) { - yield { type: "tool_call_end", id } - } - activeToolCallIds.clear() - } - - if (chunk.usage) { - lastUsage = chunk.usage - } - } - - if (lastUsage) { - yield this.processUsageMetrics(lastUsage, this.getModel().info) - } - - // Process any remaining content - for (const processedChunk of matcher.final()) { - yield processedChunk - } - } - - protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk { - const inputTokens = usage?.prompt_tokens || 0 - const outputTokens = usage?.completion_tokens || 0 - const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0 - const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 - - const { totalCost } = modelInfo - ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) - : { totalCost: 0 } - - return { - type: "usage", - inputTokens, - outputTokens, - cacheWriteTokens: cacheWriteTokens || undefined, - cacheReadTokens: cacheReadTokens || undefined, - totalCost, - } - } - - async completePrompt(prompt: string): Promise { - const { id: modelId, info: modelInfo } = this.getModel() - - const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = { - model: modelId, - messages: [{ role: "user", content: prompt }], - } - - // Add thinking parameter if reasoning is enabled and model supports it - if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) { - ;(params as any).thinking = { type: "enabled" } - } - - try { - const response = await this.client.chat.completions.create(params) - - // Check for provider-specific error responses (e.g., MiniMax base_resp) - const responseAny = response as any - if (responseAny.base_resp?.status_code && responseAny.base_resp.status_code !== 0) { - throw new Error( - `${this.providerName} API Error (${responseAny.base_resp.status_code}): ${responseAny.base_resp.status_msg || "Unknown error"}`, - ) - } - - return response.choices?.[0]?.message.content || "" - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } - } - - override getModel() { - const id = - this.options.apiModelId && this.options.apiModelId in this.providerModels - ? (this.options.apiModelId as ModelName) - : this.defaultProviderModelId - - return { id, info: this.providerModels[id] } - } -} diff --git a/src/api/providers/roo.ts b/src/api/providers/roo.ts index b455a1885ed..3b57b9cd745 100644 --- a/src/api/providers/roo.ts +++ b/src/api/providers/roo.ts @@ -1,90 +1,116 @@ import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" +import { createOpenAICompatible } from "@ai-sdk/openai-compatible" +import { streamText, generateText } from "ai" import { rooDefaultModelId, getApiProtocol, type ImageGenerationApiMethod } from "@roo-code/types" import { CloudService } from "@roo-code/cloud" -import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser" - import { Package } from "../../shared/package" import type { ApiHandlerOptions } from "../../shared/api" +import { calculateApiCostOpenAI } from "../../shared/cost" import { ApiStream } from "../transform/stream" import { getModelParams } from "../transform/model-params" -import { convertToOpenAiMessages } from "../transform/openai-format" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + handleAiSdkError, + mapToolChoice, +} from "../transform/ai-sdk" +import { type ReasoningDetail } from "../transform/openai-format" import type { RooReasoningParams } from "../transform/reasoning" import { getRooReasoning } from "../transform/reasoning" -import type { ApiHandlerCreateMessageMetadata } from "../index" -import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" -import { getModels, getModelsFromCache } from "../providers/fetchers/modelCache" -import { handleOpenAIError } from "./utils/openai-error-handler" +import type { ApiHandlerCreateMessageMetadata, SingleCompletionHandler } from "../index" +import { BaseProvider } from "./base-provider" +import { getModels, getModelsFromCache } from "./fetchers/modelCache" import { generateImageWithProvider, generateImageWithImagesApi, ImageGenerationResult } from "./utils/image-generation" import { t } from "../../i18n" -// Extend OpenAI's CompletionUsage to include Roo specific fields -interface RooUsage extends OpenAI.CompletionUsage { - cache_creation_input_tokens?: number - cost?: number -} - -// Add custom interface for Roo params to support reasoning -type RooChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParamsStreaming & { - reasoning?: RooReasoningParams -} - function getSessionToken(): string { const token = CloudService.hasInstance() ? CloudService.instance.authService?.getSessionToken() : undefined return token ?? "unauthenticated" } -export class RooHandler extends BaseOpenAiCompatibleProvider { +export class RooHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions private fetcherBaseURL: string - private currentReasoningDetails: any[] = [] + private currentReasoningDetails: ReasoningDetail[] = [] constructor(options: ApiHandlerOptions) { - const sessionToken = options.rooApiKey ?? getSessionToken() + super() + this.options = options let baseURL = process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy" - // Ensure baseURL ends with /v1 for OpenAI client, but don't duplicate it + // Ensure baseURL ends with /v1 for API calls, but don't duplicate it if (!baseURL.endsWith("/v1")) { baseURL = `${baseURL}/v1` } - // Always construct the handler, even without a valid token. - // The provider-proxy server will return 401 if authentication fails. - super({ - ...options, - providerName: "Roo Code Cloud", - baseURL, // Already has /v1 suffix - apiKey: sessionToken, - defaultProviderModelId: rooDefaultModelId, - providerModels: {}, - }) - - // Load dynamic models asynchronously - strip /v1 from baseURL for fetcher + // Strip /v1 from baseURL for fetcher this.fetcherBaseURL = baseURL.endsWith("/v1") ? baseURL.slice(0, -3) : baseURL + const sessionToken = options.rooApiKey ?? getSessionToken() + this.loadDynamicModels(this.fetcherBaseURL, sessionToken).catch((error) => { console.error("[RooHandler] Failed to load dynamic models:", error) }) } - protected override createStream( + /** + * Per-request provider factory. Creates a fresh provider instance + * to ensure the latest session token is used for each request. + */ + private createRooProvider(options?: { reasoning?: RooReasoningParams; taskId?: string }) { + const token = this.options.rooApiKey ?? getSessionToken() + const headers: Record = { + "X-Roo-App-Version": Package.version, + } + if (options?.taskId) { + headers["X-Roo-Task-ID"] = options.taskId + } + const reasoning = options?.reasoning + return createOpenAICompatible({ + name: "roo", + apiKey: token || "not-provided", + baseURL: `${this.fetcherBaseURL}/v1`, + headers, + ...(reasoning && { + transformRequestBody: (body: Record) => ({ + ...body, + reasoning, + }), + }), + }) + } + + override isAiSdkProvider() { + return true as const + } + + getReasoningDetails(): ReasoningDetail[] | undefined { + return this.currentReasoningDetails.length > 0 ? this.currentReasoningDetails : undefined + } + + override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, - requestOptions?: OpenAI.RequestOptions, - ) { - const { id: model, info } = this.getModel() + ): ApiStream { + // Reset reasoning_details accumulator for this request + this.currentReasoningDetails = [] - // Get model parameters including reasoning + const model = this.getModel() + const { id: modelId, info } = model + + // Get model parameters including reasoning budget/effort const params = getModelParams({ format: "openai", - modelId: model, + modelId, model: info, settings: this.options, - defaultTemperature: this.defaultTemperature, + defaultTemperature: 0, }) // Get Roo-specific reasoning parameters @@ -95,231 +121,102 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { settings: this.options, }) - const max_tokens = params.maxTokens ?? undefined - const temperature = params.temperature ?? this.defaultTemperature + const maxTokens = params.maxTokens ?? undefined + const temperature = params.temperature ?? 0 - const rooParams: RooChatCompletionParams = { - model, - max_tokens, - temperature, - messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], - stream: true, - stream_options: { include_usage: true }, - ...(reasoning && { reasoning }), - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - } + // Create per-request provider with fresh session token + const provider = this.createRooProvider({ reasoning, taskId: metadata?.taskId }) - try { - this.client.apiKey = this.options.rooApiKey ?? getSessionToken() - return this.client.chat.completions.create(rooParams, requestOptions) - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } - } + // Convert messages and tools to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + const tools = convertToolsForAiSdk(this.convertToolsForOpenAI(metadata?.tools)) - getReasoningDetails(): any[] | undefined { - return this.currentReasoningDetails.length > 0 ? this.currentReasoningDetails : undefined - } + let accumulatedReasoningText = "" + let lastStreamError: string | undefined - override async *createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - ): ApiStream { try { - // Reset reasoning_details accumulator for this request - this.currentReasoningDetails = [] - - const headers: Record = { - "X-Roo-App-Version": Package.version, - } - - if (metadata?.taskId) { - headers["X-Roo-Task-ID"] = metadata.taskId - } - - const stream = await this.createStream(systemPrompt, messages, metadata, { headers }) - - let lastUsage: RooUsage | undefined = undefined - // Accumulator for reasoning_details FROM the API. - // We preserve the original shape of reasoning_details to prevent malformed responses. - const reasoningDetailsAccumulator = new Map< - string, - { - type: string - text?: string - summary?: string - data?: string - id?: string | null - format?: string - signature?: string - index: number - } - >() - - // Track whether we've yielded displayable text from reasoning_details. - // When reasoning_details has displayable content (reasoning.text or reasoning.summary), - // we skip yielding the top-level reasoning field to avoid duplicate display. - let hasYieldedReasoningFromDetails = false - - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - const finishReason = chunk.choices[0]?.finish_reason - - if (delta) { - // Handle reasoning_details array format (used by Gemini 3, Claude, OpenAI o-series, etc.) - // See: https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks - // Priority: Check for reasoning_details first, as it's the newer format - const deltaWithReasoning = delta as typeof delta & { - reasoning_details?: Array<{ - type: string - text?: string - summary?: string - data?: string - id?: string | null - format?: string - signature?: string - index?: number - }> - } - - if (deltaWithReasoning.reasoning_details && Array.isArray(deltaWithReasoning.reasoning_details)) { - for (const detail of deltaWithReasoning.reasoning_details) { - const index = detail.index ?? 0 - // Use id as key when available to merge chunks that share the same reasoning block id - // This ensures that reasoning.summary and reasoning.encrypted chunks with the same id - // are merged into a single object, matching the provider's expected format - const key = detail.id ?? `${detail.type}-${index}` - const existing = reasoningDetailsAccumulator.get(key) - - if (existing) { - // Accumulate text/summary/data for existing reasoning detail - if (detail.text !== undefined) { - existing.text = (existing.text || "") + detail.text - } - if (detail.summary !== undefined) { - existing.summary = (existing.summary || "") + detail.summary - } - if (detail.data !== undefined) { - existing.data = (existing.data || "") + detail.data - } - // Update other fields if provided - // Note: Don't update type - keep original type (e.g., reasoning.summary) - // even when encrypted data chunks arrive with type reasoning.encrypted - if (detail.id !== undefined) existing.id = detail.id - if (detail.format !== undefined) existing.format = detail.format - if (detail.signature !== undefined) existing.signature = detail.signature - } else { - // Start new reasoning detail accumulation - reasoningDetailsAccumulator.set(key, { - type: detail.type, - text: detail.text, - summary: detail.summary, - data: detail.data, - id: detail.id, - format: detail.format, - signature: detail.signature, - index, - }) - } - - // Yield text for display (still fragmented for live streaming) - // Only reasoning.text and reasoning.summary have displayable content - // reasoning.encrypted is intentionally skipped as it contains redacted content - let reasoningText: string | undefined - if (detail.type === "reasoning.text" && typeof detail.text === "string") { - reasoningText = detail.text - } else if (detail.type === "reasoning.summary" && typeof detail.summary === "string") { - reasoningText = detail.summary - } - - if (reasoningText) { - hasYieldedReasoningFromDetails = true - yield { type: "reasoning", text: reasoningText } - } - } - } - - // Handle top-level reasoning field for UI display. - // Skip if we've already yielded from reasoning_details to avoid duplicate display. - if ("reasoning" in delta && delta.reasoning && typeof delta.reasoning === "string") { - if (!hasYieldedReasoningFromDetails) { - yield { type: "reasoning", text: delta.reasoning } - } - } else if ("reasoning_content" in delta && typeof delta.reasoning_content === "string") { - // Also check for reasoning_content for backward compatibility - if (!hasYieldedReasoningFromDetails) { - yield { type: "reasoning", text: delta.reasoning_content } - } - } - - // Emit raw tool call chunks - NativeToolCallParser handles state management - if ("tool_calls" in delta && Array.isArray(delta.tool_calls)) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } + const result = streamText({ + model: provider(modelId), + system: systemPrompt, + messages: aiSdkMessages, + maxOutputTokens: maxTokens && maxTokens > 0 ? maxTokens : undefined, + temperature, + tools, + toolChoice: mapToolChoice(metadata?.tool_choice), + }) - if (delta.content) { - yield { - type: "text", - text: delta.content, - } - } + for await (const part of result.fullStream) { + if (part.type === "reasoning-delta" && part.text !== "[REDACTED]") { + accumulatedReasoningText += part.text } - - if (finishReason) { - const endEvents = NativeToolCallParser.processFinishReason(finishReason) - for (const event of endEvents) { - yield event + for (const chunk of processAiSdkStreamPart(part)) { + if (chunk.type === "error") { + lastStreamError = chunk.message } + yield chunk } + } - if (chunk.usage) { - lastUsage = chunk.usage as RooUsage - } + // Build reasoning details from accumulated text + if (accumulatedReasoningText) { + this.currentReasoningDetails.push({ + type: "reasoning.text", + text: accumulatedReasoningText, + index: 0, + }) } - // After streaming completes, store ONLY the reasoning_details we received from the API. - if (reasoningDetailsAccumulator.size > 0) { - this.currentReasoningDetails = Array.from(reasoningDetailsAccumulator.values()) + // Check provider metadata for reasoning_details (override if present) + const providerMetadata = + (await result.providerMetadata) ?? (await (result as any).experimental_providerMetadata) + const rooMeta = providerMetadata?.roo as Record | undefined + + const providerReasoningDetails = rooMeta?.reasoning_details as ReasoningDetail[] | undefined + if (providerReasoningDetails && providerReasoningDetails.length > 0) { + this.currentReasoningDetails = providerReasoningDetails } - if (lastUsage) { - // Check if the current model is marked as free - const model = this.getModel() - const isFreeModel = model.info.isFree ?? false - - // Normalize input tokens based on protocol expectations: - // - OpenAI protocol expects TOTAL input tokens (cached + non-cached) - // - Anthropic protocol expects NON-CACHED input tokens (caches passed separately) - const modelId = model.id - const apiProtocol = getApiProtocol("roo", modelId) - - const promptTokens = lastUsage.prompt_tokens || 0 - const cacheWrite = lastUsage.cache_creation_input_tokens || 0 - const cacheRead = lastUsage.prompt_tokens_details?.cached_tokens || 0 - const nonCached = Math.max(0, promptTokens - cacheWrite - cacheRead) - - const inputTokensForDownstream = apiProtocol === "anthropic" ? nonCached : promptTokens - - yield { - type: "usage", - inputTokens: inputTokensForDownstream, - outputTokens: lastUsage.completion_tokens || 0, - cacheWriteTokens: cacheWrite, - cacheReadTokens: cacheRead, - totalCost: isFreeModel ? 0 : (lastUsage.cost ?? 0), - } + // Process usage with protocol-aware normalization + const usage = await result.usage + const promptTokens = usage.inputTokens ?? 0 + const completionTokens = usage.outputTokens ?? 0 + + // Extract cache tokens from provider metadata + const cacheCreation = (rooMeta?.cache_creation_input_tokens as number) ?? 0 + const cacheRead = (rooMeta?.cache_read_input_tokens as number) ?? (rooMeta?.cached_tokens as number) ?? 0 + + // Protocol-aware token normalization: + // - OpenAI protocol expects TOTAL input tokens (cached + non-cached) + // - Anthropic protocol expects NON-CACHED input tokens (caches passed separately) + const apiProtocol = getApiProtocol("roo", modelId) + const nonCached = Math.max(0, promptTokens - cacheCreation - cacheRead) + const inputTokens = apiProtocol === "anthropic" ? nonCached : promptTokens + + // Cost: prefer server-side cost, fall back to client-side calculation + const isFreeModel = info.isFree === true + const serverCost = rooMeta?.cost as number | undefined + const { totalCost: calculatedCost } = calculateApiCostOpenAI( + info, + promptTokens, + completionTokens, + cacheCreation, + cacheRead, + ) + const totalCost = isFreeModel ? 0 : (serverCost ?? calculatedCost) + + yield { + type: "usage" as const, + inputTokens, + outputTokens: completionTokens, + cacheWriteTokens: cacheCreation, + cacheReadTokens: cacheRead, + totalCost, } } catch (error) { + if (lastStreamError) { + throw new Error(lastStreamError) + } + const errorContext = { error: error instanceof Error ? error.message : String(error), stack: error instanceof Error ? error.stack : undefined, @@ -329,13 +226,24 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { console.error(`[RooHandler] Error during message streaming: ${JSON.stringify(errorContext)}`) - throw error + throw handleAiSdkError(error, "Roo Code Cloud") } } - override async completePrompt(prompt: string): Promise { - // Update API key before making request to ensure we use the latest session token - this.client.apiKey = this.options.rooApiKey ?? getSessionToken() - return super.completePrompt(prompt) + + async completePrompt(prompt: string): Promise { + const { id: modelId } = this.getModel() + const provider = this.createRooProvider() + + try { + const result = await generateText({ + model: provider(modelId), + prompt, + temperature: this.options.modelTemperature ?? 0, + }) + return result.text + } catch (error) { + throw handleAiSdkError(error, "Roo Code Cloud") + } } private async loadDynamicModels(baseURL: string, apiKey?: string): Promise {