diff --git a/src/api/providers/__tests__/minimax.spec.ts b/src/api/providers/__tests__/minimax.spec.ts index 86cb5e0194..683b931652 100644 --- a/src/api/providers/__tests__/minimax.spec.ts +++ b/src/api/providers/__tests__/minimax.spec.ts @@ -1,407 +1,427 @@ -// npx vitest run src/api/providers/__tests__/minimax.spec.ts +import { describe, it, expect, beforeEach } from "vitest" -vitest.mock("vscode", () => ({ - workspace: { - getConfiguration: vitest.fn().mockReturnValue({ - get: vitest.fn().mockReturnValue(600), // Default timeout in seconds +import type { Anthropic } from "@anthropic-ai/sdk" + +import { minimaxDefaultModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" +import type { ApiStream, ApiStreamChunk } from "../../transform/stream" +import { MiniMaxHandler } from "../minimax" + +const { + mockStreamText, + mockGenerateText, + mockCreateAnthropic, + mockModel, + mockMergeEnvironmentDetailsForMiniMax, + mockHandleAiSdkError, +} = vi.hoisted(() => { + const mockModel = vi.fn().mockReturnValue("mock-model-instance") + return { + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), + mockCreateAnthropic: vi.fn().mockReturnValue(mockModel), + mockModel, + mockMergeEnvironmentDetailsForMiniMax: vi.fn((messages: Anthropic.Messages.MessageParam[]) => messages), + mockHandleAiSdkError: vi.fn((error: unknown, providerName: string) => { + const message = error instanceof Error ? error.message : String(error) + return new Error(`${providerName}: ${message}`) }), - }, -})) + } +}) -import { Anthropic } from "@anthropic-ai/sdk" +vi.mock("ai", () => ({ + streamText: mockStreamText, + generateText: mockGenerateText, +})) -import { type MinimaxModelId, minimaxDefaultModelId, minimaxModels } from "@roo-code/types" +vi.mock("@ai-sdk/anthropic", () => ({ + createAnthropic: mockCreateAnthropic, +})) -import { MiniMaxHandler } from "../minimax" +vi.mock("../../transform/minimax-format", () => ({ + mergeEnvironmentDetailsForMiniMax: mockMergeEnvironmentDetailsForMiniMax, +})) -vitest.mock("@anthropic-ai/sdk", () => { - const mockCreate = vitest.fn() +vi.mock("../../transform/ai-sdk", async (importOriginal) => { + const actual = await importOriginal() return { - Anthropic: vitest.fn(() => ({ - messages: { - create: mockCreate, - }, - })), + ...actual, + handleAiSdkError: mockHandleAiSdkError, } }) +type HandlerOptions = Omit, "minimaxBaseUrl"> & { + minimaxBaseUrl?: string +} + +function createHandler(options: HandlerOptions = {}) { + return new MiniMaxHandler({ + minimaxApiKey: "test-api-key", + ...options, + } as ApiHandlerOptions) +} + +function createMockStream( + chunks: Array>, + usage: { inputTokens?: number; outputTokens?: number } = { inputTokens: 10, outputTokens: 5 }, + providerMetadata: Record> = { + anthropic: { + cacheReadInputTokens: 0, + cacheCreationInputTokens: 0, + }, + }, +) { + const stream = (async function* () { + for (const chunk of chunks) { + yield chunk + } + })() + + return { + fullStream: stream, + usage: Promise.resolve(usage), + providerMetadata: Promise.resolve(providerMetadata), + response: Promise.resolve({ headers: new Headers() }), + } +} + +async function collectChunks(stream: ApiStream): Promise { + const chunks: ApiStreamChunk[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + return chunks +} + describe("MiniMaxHandler", () => { - let handler: MiniMaxHandler - let mockCreate: any + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "Hello" }], + }, + ] beforeEach(() => { - vitest.clearAllMocks() - const anthropicInstance = (Anthropic as unknown as any)() - mockCreate = anthropicInstance.messages.create - }) - - describe("International MiniMax (default)", () => { - beforeEach(() => { - handler = new MiniMaxHandler({ - minimaxApiKey: "test-minimax-api-key", - minimaxBaseUrl: "https://api.minimax.io/v1", - }) + vi.clearAllMocks() + mockCreateAnthropic.mockReturnValue(mockModel) + mockMergeEnvironmentDetailsForMiniMax.mockImplementation( + (inputMessages: Anthropic.Messages.MessageParam[]) => inputMessages, + ) + mockHandleAiSdkError.mockImplementation((error: unknown, providerName: string) => { + const message = error instanceof Error ? error.message : String(error) + return new Error(`${providerName}: ${message}`) }) + }) - it("should use the correct international MiniMax base URL by default", () => { - new MiniMaxHandler({ minimaxApiKey: "test-minimax-api-key" }) - expect(Anthropic).toHaveBeenCalledWith( + describe("constructor", () => { + it("uses default base URL when no baseUrl is provided", () => { + createHandler() + expect(mockCreateAnthropic).toHaveBeenCalledWith( expect.objectContaining({ - baseURL: "https://api.minimax.io/anthropic", + baseURL: "https://api.minimax.io/anthropic/v1", }), ) }) - it("should convert /v1 endpoint to /anthropic endpoint", () => { - new MiniMaxHandler({ - minimaxApiKey: "test-minimax-api-key", + it("converts /v1 base URL to /anthropic/v1", () => { + createHandler({ minimaxBaseUrl: "https://api.minimax.io/v1", }) - expect(Anthropic).toHaveBeenCalledWith( + + expect(mockCreateAnthropic).toHaveBeenCalledWith( expect.objectContaining({ - baseURL: "https://api.minimax.io/anthropic", + baseURL: "https://api.minimax.io/anthropic/v1", }), ) }) - it("should use the provided API key", () => { - const minimaxApiKey = "test-minimax-api-key" - new MiniMaxHandler({ minimaxApiKey }) - expect(Anthropic).toHaveBeenCalledWith(expect.objectContaining({ apiKey: minimaxApiKey })) - }) + it("appends /v1 for base URL already ending with /anthropic", () => { + createHandler({ + minimaxBaseUrl: "https://api.minimax.io/anthropic", + }) - it("should return default model when no model is specified", () => { - const model = handler.getModel() - expect(model.id).toBe(minimaxDefaultModelId) - expect(model.info).toEqual(minimaxModels[minimaxDefaultModelId]) + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.minimax.io/anthropic/v1", + }), + ) }) - it("should return specified model when valid model is provided", () => { - const testModelId: MinimaxModelId = "MiniMax-M2" - const handlerWithModel = new MiniMaxHandler({ - apiModelId: testModelId, - minimaxApiKey: "test-minimax-api-key", + it("appends /anthropic/v1 when base URL has no suffix", () => { + createHandler({ + minimaxBaseUrl: "https://api.minimax.io/custom", }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual(minimaxModels[testModelId]) - }) - it("should return MiniMax-M2 model with correct configuration", () => { - const testModelId: MinimaxModelId = "MiniMax-M2" - const handlerWithModel = new MiniMaxHandler({ - apiModelId: testModelId, - minimaxApiKey: "test-minimax-api-key", - }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual(minimaxModels[testModelId]) - expect(model.info.contextWindow).toBe(192_000) - expect(model.info.maxTokens).toBe(16_384) - expect(model.info.supportsPromptCache).toBe(true) - expect(model.info.cacheWritesPrice).toBe(0.375) - expect(model.info.cacheReadsPrice).toBe(0.03) + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.minimax.io/custom/anthropic/v1", + }), + ) }) - it("should return MiniMax-M2-Stable model with correct configuration", () => { - const testModelId: MinimaxModelId = "MiniMax-M2-Stable" - const handlerWithModel = new MiniMaxHandler({ - apiModelId: testModelId, - minimaxApiKey: "test-minimax-api-key", + it("supports the China endpoint", () => { + createHandler({ + minimaxBaseUrl: "https://api.minimaxi.com/anthropic", }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual(minimaxModels[testModelId]) - expect(model.info.contextWindow).toBe(192_000) - expect(model.info.maxTokens).toBe(16_384) - expect(model.info.supportsPromptCache).toBe(true) - expect(model.info.cacheWritesPrice).toBe(0.375) - expect(model.info.cacheReadsPrice).toBe(0.03) - }) - }) - describe("China MiniMax", () => { - beforeEach(() => { - handler = new MiniMaxHandler({ - minimaxApiKey: "test-minimax-api-key", - minimaxBaseUrl: "https://api.minimaxi.com/v1", - }) + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.minimaxi.com/anthropic/v1", + }), + ) }) - it("should use the correct China MiniMax base URL", () => { - new MiniMaxHandler({ - minimaxApiKey: "test-minimax-api-key", - minimaxBaseUrl: "https://api.minimaxi.com/v1", + it("treats empty baseUrl as falsy and falls back to default", () => { + createHandler({ + minimaxBaseUrl: "", }) - expect(Anthropic).toHaveBeenCalledWith( - expect.objectContaining({ baseURL: "https://api.minimaxi.com/anthropic" }), + + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.minimax.io/anthropic/v1", + }), ) }) - it("should convert China /v1 endpoint to /anthropic endpoint", () => { - new MiniMaxHandler({ - minimaxApiKey: "test-minimax-api-key", - minimaxBaseUrl: "https://api.minimaxi.com/v1", + it("passes API key through to createAnthropic", () => { + createHandler({ + minimaxApiKey: "minimax-key-123", }) - expect(Anthropic).toHaveBeenCalledWith( - expect.objectContaining({ baseURL: "https://api.minimaxi.com/anthropic" }), + + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "minimax-key-123", + }), ) }) + }) - it("should use the provided API key for China", () => { - const minimaxApiKey = "test-minimax-api-key" - new MiniMaxHandler({ minimaxApiKey, minimaxBaseUrl: "https://api.minimaxi.com/v1" }) - expect(Anthropic).toHaveBeenCalledWith(expect.objectContaining({ apiKey: minimaxApiKey })) + describe("getModel", () => { + it("returns default model when no model ID is specified", () => { + const handler = createHandler() + const model = handler.getModel() + expect(model.id).toBe("MiniMax-M2") + expect(model.temperature).toBe(1) }) - it("should return default model when no model is specified", () => { + it("returns specified model when valid model ID is provided", () => { + const handler = createHandler({ + apiModelId: "MiniMax-M2-Stable", + }) + const model = handler.getModel() + expect(model.id).toBe("MiniMax-M2-Stable") + }) + + it("falls back to default model when unknown model ID is provided", () => { + const handler = createHandler({ + apiModelId: "unknown-model", + }) const model = handler.getModel() expect(model.id).toBe(minimaxDefaultModelId) - expect(model.info).toEqual(minimaxModels[minimaxDefaultModelId]) }) }) - describe("Default behavior", () => { - it("should default to international base URL when none is specified", () => { - const handlerDefault = new MiniMaxHandler({ minimaxApiKey: "test-minimax-api-key" }) - expect(Anthropic).toHaveBeenCalledWith( + describe("createMessage", () => { + it("streams text chunks and calls streamText with expected params", async () => { + mockStreamText.mockReturnValue( + createMockStream([ + { type: "text-delta", text: "Hello" }, + { type: "text-delta", text: " world" }, + ]), + ) + + const handler = createHandler() + const chunks = await collectChunks(handler.createMessage(systemPrompt, messages)) + + expect(mockModel).toHaveBeenCalledWith("MiniMax-M2") + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - baseURL: "https://api.minimax.io/anthropic", + model: "mock-model-instance", + system: systemPrompt, + temperature: 1, + messages: expect.any(Array), }), ) - const model = handlerDefault.getModel() - expect(model.id).toBe(minimaxDefaultModelId) - expect(model.info).toEqual(minimaxModels[minimaxDefaultModelId]) - }) - - it("should default to MiniMax-M2 model", () => { - const handlerDefault = new MiniMaxHandler({ minimaxApiKey: "test-minimax-api-key" }) - const model = handlerDefault.getModel() - expect(model.id).toBe("MiniMax-M2") + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(2) + expect(textChunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(textChunks[1]).toEqual({ type: "text", text: " world" }) }) - }) - describe("API Methods", () => { - beforeEach(() => { - handler = new MiniMaxHandler({ minimaxApiKey: "test-minimax-api-key" }) - }) + it("streams reasoning chunks", async () => { + mockStreamText.mockReturnValue( + createMockStream([ + { type: "reasoning", text: "thinking..." }, + { type: "reasoning", text: " step 2" }, + ]), + ) - it("completePrompt method should return text from MiniMax API", async () => { - const expectedResponse = "This is a test response from MiniMax" - mockCreate.mockResolvedValueOnce({ - content: [{ type: "text", text: expectedResponse }], - }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) + const handler = createHandler() + const chunks = await collectChunks(handler.createMessage(systemPrompt, messages)) - it("should handle errors in completePrompt", async () => { - const errorMessage = "MiniMax API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - await expect(handler.completePrompt("test prompt")).rejects.toThrow() + const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning") + expect(reasoningChunks).toHaveLength(2) + expect(reasoningChunks[0]).toEqual({ type: "reasoning", text: "thinking..." }) + expect(reasoningChunks[1]).toEqual({ type: "reasoning", text: " step 2" }) }) - it("createMessage should yield text content from stream", async () => { - const testContent = "This is test content from MiniMax stream" - - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - type: "content_block_start", - index: 0, - content_block: { type: "text", text: testContent }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - }) + it("streams tool call chunks", async () => { + mockStreamText.mockReturnValue( + createMockStream([ + { type: "tool-input-start", id: "call_1", toolName: "read_file" }, + { type: "tool-input-delta", id: "call_1", delta: '{"path":"a.ts"}' }, + { type: "tool-input-end", id: "call_1" }, + ]), + ) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + const handler = createHandler() + const chunks = await collectChunks(handler.createMessage(systemPrompt, messages)) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "text", text: testContent }) + expect(chunks).toContainEqual({ + type: "tool_call_start", + id: "call_1", + name: "read_file", + }) + expect(chunks).toContainEqual({ + type: "tool_call_delta", + id: "call_1", + delta: '{"path":"a.ts"}', + }) + expect(chunks).toContainEqual({ + type: "tool_call_end", + id: "call_1", + }) }) - it("createMessage should yield usage data from stream", async () => { - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - type: "message_start", - message: { - usage: { - input_tokens: 10, - output_tokens: 20, - }, - }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - }) + it("yields usage chunk with token and cost information", async () => { + mockStreamText.mockReturnValue( + createMockStream( + [{ type: "text-delta", text: "Done" }], + { inputTokens: 10, outputTokens: 5 }, + { + anthropic: { + cacheCreationInputTokens: 3, + cacheReadInputTokens: 2, + }, + }, + ), + ) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + const handler = createHandler() + const chunks = await collectChunks(handler.createMessage(systemPrompt, messages)) + const usageChunk = chunks.find((chunk) => chunk.type === "usage") - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + expect(usageChunk).toMatchObject({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + cacheWriteTokens: 3, + cacheReadTokens: 2, + }) + expect(typeof usageChunk?.totalCost).toBe("number") }) - it("createMessage should pass correct parameters to MiniMax client", async () => { - const modelId: MinimaxModelId = "MiniMax-M2" - const modelInfo = minimaxModels[modelId] - const handlerWithModel = new MiniMaxHandler({ - apiModelId: modelId, - minimaxApiKey: "test-minimax-api-key", - }) + it("calls mergeEnvironmentDetailsForMiniMax before conversion", async () => { + const mergedMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "Merged message" }], + }, + ] + mockMergeEnvironmentDetailsForMiniMax.mockReturnValueOnce(mergedMessages) + mockStreamText.mockReturnValue(createMockStream([{ type: "text-delta", text: "OK" }])) + + const handler = createHandler() + await collectChunks(handler.createMessage(systemPrompt, messages)) + + expect(mockMergeEnvironmentDetailsForMiniMax).toHaveBeenCalledWith(messages) + const callArgs = mockStreamText.mock.calls[0]?.[0] + expect(callArgs.messages).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + role: "user", + content: [{ type: "text", text: "Merged message" }], + providerOptions: { + anthropic: { + cacheControl: { type: "ephemeral" }, + }, + }, + }), + ]), + ) + }) - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), + it("handles errors via handleAiSdkError", async () => { + mockStreamText.mockImplementation(() => { + throw new Error("API Error") }) - const systemPrompt = "Test system prompt for MiniMax" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for MiniMax" }] + const handler = createHandler() + const stream = handler.createMessage(systemPrompt, messages) - const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) - await messageGenerator.next() + await expect(async () => { + await collectChunks(stream) + }).rejects.toThrow("MiniMax: API Error") + expect(mockHandleAiSdkError).toHaveBeenCalledWith(expect.any(Error), "MiniMax") + }) + }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: Math.min(modelInfo.maxTokens, Math.ceil(modelInfo.contextWindow * 0.2)), - temperature: 1, - system: expect.any(Array), - messages: expect.any(Array), - stream: true, - }), - ) + describe("thinking signature", () => { + it("returns undefined thought signature before any request", () => { + const handler = createHandler() + expect(handler.getThoughtSignature()).toBeUndefined() }) - it("should use temperature 1 by default", async () => { - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } + it("captures thought signature from stream providerMetadata", async () => { + const signature = "test-thinking-signature" + mockStreamText.mockReturnValue( + createMockStream([ + { + type: "reasoning-delta", + text: "thinking...", + providerMetadata: { anthropic: { signature } }, }, - }), - }) + { type: "text-delta", text: "Answer" }, + ]), + ) - const messageGenerator = handler.createMessage("test", []) - await messageGenerator.next() + const handler = createHandler() + await collectChunks(handler.createMessage(systemPrompt, messages)) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 1, - }), - ) + expect(handler.getThoughtSignature()).toBe(signature) }) - it("should handle thinking blocks in stream", async () => { - const thinkingContent = "Let me think about this..." - - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - type: "content_block_start", - index: 0, - content_block: { type: "thinking", thinking: thinkingContent }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - }) + it("returns undefined redacted thinking blocks before any request", () => { + const handler = createHandler() + expect(handler.getRedactedThinkingBlocks()).toBeUndefined() + }) + }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + describe("completePrompt", () => { + it("calls generateText with model and prompt and returns text", async () => { + mockGenerateText.mockResolvedValue({ text: "response" }) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "reasoning", text: thinkingContent }) - }) + const handler = createHandler() + const result = await handler.completePrompt("test prompt") - it("should handle tool calls in stream", async () => { - mockCreate.mockResolvedValueOnce({ - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - type: "content_block_start", - index: 0, - content_block: { - type: "tool_use", - id: "tool-123", - name: "get_weather", - input: { city: "London" }, - }, - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - type: "content_block_stop", - index: 0, - }, - }) - .mockResolvedValueOnce({ done: true }), + expect(result).toBe("response") + expect(mockModel).toHaveBeenCalledWith("MiniMax-M2") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: "mock-model-instance", + prompt: "test prompt", }), - }) - - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() - - expect(firstChunk.done).toBe(false) - // Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly - expect(firstChunk.value).toEqual({ - type: "tool_call_partial", - index: 0, - id: "tool-123", - name: "get_weather", - arguments: undefined, - }) + ) }) }) - describe("Model Configuration", () => { - it("should correctly configure MiniMax-M2 model properties", () => { - const model = minimaxModels["MiniMax-M2"] - expect(model.maxTokens).toBe(16_384) - expect(model.contextWindow).toBe(192_000) - expect(model.supportsImages).toBe(false) - expect(model.supportsPromptCache).toBe(true) - expect(model.inputPrice).toBe(0.3) - expect(model.outputPrice).toBe(1.2) - expect(model.cacheWritesPrice).toBe(0.375) - expect(model.cacheReadsPrice).toBe(0.03) - }) - - it("should correctly configure MiniMax-M2-Stable model properties", () => { - const model = minimaxModels["MiniMax-M2-Stable"] - expect(model.maxTokens).toBe(16_384) - expect(model.contextWindow).toBe(192_000) - expect(model.supportsImages).toBe(false) - expect(model.supportsPromptCache).toBe(true) - expect(model.inputPrice).toBe(0.3) - expect(model.outputPrice).toBe(1.2) - expect(model.cacheWritesPrice).toBe(0.375) - expect(model.cacheReadsPrice).toBe(0.03) + describe("isAiSdkProvider", () => { + it("returns true", () => { + const handler = createHandler() + expect(handler.isAiSdkProvider()).toBe(true) }) }) }) diff --git a/src/api/providers/minimax.ts b/src/api/providers/minimax.ts index bfcf4e3be4..07a4978af9 100644 --- a/src/api/providers/minimax.ts +++ b/src/api/providers/minimax.ts @@ -1,277 +1,264 @@ -import { Anthropic } from "@anthropic-ai/sdk" -import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming" -import { CacheControlEphemeral } from "@anthropic-ai/sdk/resources" -import OpenAI from "openai" +import type { Anthropic } from "@anthropic-ai/sdk" +import { createAnthropic } from "@ai-sdk/anthropic" +import { streamText, generateText, ToolSet } from "ai" -import { type MinimaxModelId, minimaxDefaultModelId, minimaxModels } from "@roo-code/types" +import { type ModelInfo, minimaxDefaultModelId, minimaxModels } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" - -import { ApiStream } from "../transform/stream" +import type { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" import { mergeEnvironmentDetailsForMiniMax } from "../transform/minimax-format" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" +import { calculateApiCostAnthropic } from "../../shared/cost" +import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" -import { calculateApiCostAnthropic } from "../../shared/cost" -import { convertOpenAIToolsToAnthropic } from "../../core/prompts/tools/native-tools/converters" - -/** - * Converts OpenAI tool_choice to Anthropic ToolChoice format - */ -function convertOpenAIToolChoice( - toolChoice: OpenAI.Chat.ChatCompletionCreateParams["tool_choice"], -): Anthropic.Messages.MessageCreateParams["tool_choice"] | undefined { - if (!toolChoice) { - return undefined - } - - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "none": - return undefined // Anthropic doesn't have "none", just omit tools - case "auto": - return { type: "auto" } - case "required": - return { type: "any" } - default: - return { type: "auto" } - } - } - - // Handle object form { type: "function", function: { name: string } } - if (typeof toolChoice === "object" && "function" in toolChoice) { - return { - type: "tool", - name: toolChoice.function.name, - } - } - - return { type: "auto" } -} export class MiniMaxHandler extends BaseProvider implements SingleCompletionHandler { + private client: ReturnType private options: ApiHandlerOptions - private client: Anthropic + private readonly providerName = "MiniMax" + private lastThoughtSignature: string | undefined + private lastRedactedThinkingBlocks: Array<{ type: "redacted_thinking"; data: string }> = [] constructor(options: ApiHandlerOptions) { super() this.options = options - // Use Anthropic-compatible endpoint - // Default to international endpoint: https://api.minimax.io/anthropic - // China endpoint: https://api.minimaxi.com/anthropic - let baseURL = options.minimaxBaseUrl || "https://api.minimax.io/anthropic" - - // If user provided a /v1 endpoint, convert to /anthropic - if (baseURL.endsWith("/v1")) { - baseURL = baseURL.replace(/\/v1$/, "/anthropic") - } else if (!baseURL.endsWith("/anthropic")) { - baseURL = `${baseURL.replace(/\/$/, "")}/anthropic` + const rawBaseUrl = this.options.minimaxBaseUrl + let resolvedBaseUrl: string | undefined + + if (rawBaseUrl) { + if (rawBaseUrl.endsWith("/anthropic/v1")) { + resolvedBaseUrl = rawBaseUrl + } else if (rawBaseUrl.endsWith("/v1")) { + resolvedBaseUrl = rawBaseUrl.slice(0, -3) + "/anthropic/v1" + } else if (rawBaseUrl.endsWith("/anthropic")) { + resolvedBaseUrl = rawBaseUrl + "/v1" + } else { + resolvedBaseUrl = rawBaseUrl + "/anthropic/v1" + } + } else { + resolvedBaseUrl = "https://api.minimax.io/anthropic/v1" } - this.client = new Anthropic({ - baseURL, - apiKey: options.minimaxApiKey, + this.client = createAnthropic({ + baseURL: resolvedBaseUrl, + apiKey: this.options.minimaxApiKey ?? "", + headers: DEFAULT_HEADERS, }) } - async *createMessage( + override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - let stream: AnthropicStream - const cacheControl: CacheControlEphemeral = { type: "ephemeral" } - const { id: modelId, info, maxTokens, temperature } = this.getModel() - - // MiniMax M2 models support prompt caching - const supportsPromptCache = info.supportsPromptCache ?? false - - // Merge environment_details from messages that follow tool_result blocks - // into the tool_result content. This preserves reasoning continuity for - // thinking models by preventing user messages from interrupting the - // reasoning context after tool use (similar to r1-format's mergeToolResultText). - const processedMessages = mergeEnvironmentDetailsForMiniMax(messages) - - // Build the system blocks array - const systemBlocks: Anthropic.Messages.TextBlockParam[] = [ - supportsPromptCache - ? { text: systemPrompt, type: "text", cache_control: cacheControl } - : { text: systemPrompt, type: "text" }, - ] - - // Prepare request parameters - const requestParams: Anthropic.Messages.MessageCreateParams = { - model: modelId, - max_tokens: maxTokens ?? 16_384, - temperature: temperature ?? 1.0, - system: systemBlocks, - messages: supportsPromptCache ? this.addCacheControl(processedMessages, cacheControl) : processedMessages, - stream: true, - tools: convertOpenAIToolsToAnthropic(metadata?.tools ?? []), - tool_choice: convertOpenAIToolChoice(metadata?.tool_choice), + const modelConfig = this.getModel() + + // Reset thinking state for this request + this.lastThoughtSignature = undefined + this.lastRedactedThinkingBlocks = [] + + const modelParams = getModelParams({ + format: "anthropic", + modelId: modelConfig.id, + model: modelConfig.info, + settings: this.options, + defaultTemperature: 1.0, + }) + + const mergedMessages = mergeEnvironmentDetailsForMiniMax(messages) + const aiSdkMessages = convertToAiSdkMessages(mergedMessages) + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + const anthropicProviderOptions: Record = {} + + if (modelParams.reasoning && modelParams.reasoningBudget) { + anthropicProviderOptions.thinking = { + type: "enabled", + budgetTokens: modelParams.reasoningBudget, + } } - stream = await this.client.messages.create(requestParams) - - let inputTokens = 0 - let outputTokens = 0 - let cacheWriteTokens = 0 - let cacheReadTokens = 0 - - for await (const chunk of stream) { - switch (chunk.type) { - case "message_start": { - // Tells us cache reads/writes/input/output. - const { - input_tokens = 0, - output_tokens = 0, - cache_creation_input_tokens, - cache_read_input_tokens, - } = chunk.message.usage - - yield { - type: "usage", - inputTokens: input_tokens, - outputTokens: output_tokens, - cacheWriteTokens: cache_creation_input_tokens || undefined, - cacheReadTokens: cache_read_input_tokens || undefined, - } + if (metadata?.parallelToolCalls === false) { + anthropicProviderOptions.disableParallelToolUse = true + } - inputTokens += input_tokens - outputTokens += output_tokens - cacheWriteTokens += cache_creation_input_tokens || 0 - cacheReadTokens += cache_read_input_tokens || 0 + const cacheProviderOption = { anthropic: { cacheControl: { type: "ephemeral" as const } } } + const userMsgIndices = mergedMessages.reduce( + (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), + [] as number[], + ) - break - } - case "message_delta": - // Tells us stop_reason, stop_sequence, and output tokens - yield { - type: "usage", - inputTokens: 0, - outputTokens: chunk.usage.output_tokens || 0, - } + const targetIndices = new Set() + const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 + const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - break - case "message_stop": - // No usage data, just an indicator that the message is done. - break - case "content_block_start": - switch (chunk.content_block.type) { - case "thinking": - // Yield thinking/reasoning content - if (chunk.index > 0) { - yield { type: "reasoning", text: "\n" } - } + if (lastUserMsgIndex >= 0) targetIndices.add(lastUserMsgIndex) + if (secondLastUserMsgIndex >= 0) targetIndices.add(secondLastUserMsgIndex) - yield { type: "reasoning", text: chunk.content_block.thinking } - break - case "text": - // We may receive multiple text blocks - if (chunk.index > 0) { - yield { type: "text", text: "\n" } - } + if (targetIndices.size > 0) { + this.applyCacheControlToAiSdkMessages(mergedMessages, aiSdkMessages, targetIndices, cacheProviderOption) + } - yield { type: "text", text: chunk.content_block.text } - break - case "tool_use": { - // Emit initial tool call partial with id and name - yield { - type: "tool_call_partial", - index: chunk.index, - id: chunk.content_block.id, - name: chunk.content_block.name, - arguments: undefined, - } - break - } - } - break - case "content_block_delta": - switch (chunk.delta.type) { - case "thinking_delta": - yield { type: "reasoning", text: chunk.delta.thinking } - break - case "text_delta": - yield { type: "text", text: chunk.delta.text } - break - case "input_json_delta": { - // Emit tool call partial chunks as arguments stream in - yield { - type: "tool_call_partial", - index: chunk.index, - id: undefined, - name: undefined, - arguments: chunk.delta.partial_json, + const requestOptions = { + model: this.client(modelConfig.id), + system: systemPrompt, + ...({ + systemProviderOptions: { anthropic: { cacheControl: { type: "ephemeral" } } }, + } as Record), + messages: aiSdkMessages, + temperature: modelParams.temperature, + maxOutputTokens: modelParams.maxTokens ?? modelConfig.info.maxTokens, + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), + ...(Object.keys(anthropicProviderOptions).length > 0 && { + providerOptions: { anthropic: anthropicProviderOptions } as Record>, + }), + } + + try { + const result = streamText(requestOptions as Parameters[0]) + + for await (const part of result.fullStream) { + const anthropicMetadata = ( + part as { + providerMetadata?: { + anthropic?: { + signature?: string + redactedData?: string } - break } } + ).providerMetadata?.anthropic + + if (anthropicMetadata?.signature) { + this.lastThoughtSignature = anthropicMetadata.signature + } + + if (anthropicMetadata?.redactedData) { + this.lastRedactedThinkingBlocks.push({ + type: "redacted_thinking", + data: anthropicMetadata.redactedData, + }) + } - break - case "content_block_stop": - // Block is complete - no action needed, NativeToolCallParser handles completion - break + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Calculate and yield final cost - if (inputTokens > 0 || outputTokens > 0 || cacheWriteTokens > 0 || cacheReadTokens > 0) { - const { totalCost } = calculateApiCostAnthropic( - this.getModel().info, - inputTokens, - outputTokens, - cacheWriteTokens, - cacheReadTokens, - ) - - yield { - type: "usage", - inputTokens: 0, - outputTokens: 0, - totalCost, + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, modelConfig.info, providerMetadata) } + } catch (error) { + throw handleAiSdkError(error, this.providerName) } } - /** - * Add cache control to the last two user messages for prompt caching - */ - private addCacheControl( - messages: Anthropic.Messages.MessageParam[], - cacheControl: CacheControlEphemeral, - ): Anthropic.Messages.MessageParam[] { - const userMsgIndices = messages.reduce( - (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), - [] as number[], + private processUsageMetrics( + usage: { inputTokens?: number; outputTokens?: number }, + info: ModelInfo, + providerMetadata?: Record>, + ): ApiStreamUsageChunk { + const inputTokens = usage.inputTokens ?? 0 + const outputTokens = usage.outputTokens ?? 0 + + const anthropicMeta = providerMetadata?.anthropic as + | { cacheCreationInputTokens?: number; cacheReadInputTokens?: number } + | undefined + const cacheWriteTokens = anthropicMeta?.cacheCreationInputTokens ?? 0 + const cacheReadTokens = anthropicMeta?.cacheReadInputTokens ?? 0 + + const { totalCost } = calculateApiCostAnthropic( + info, + inputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, ) - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 - const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - - return messages.map((message, index) => { - if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) { - return { - ...message, - content: - typeof message.content === "string" - ? [{ type: "text", text: message.content, cache_control: cacheControl }] - : message.content.map((content, contentIndex) => - contentIndex === message.content.length - 1 - ? { ...content, cache_control: cacheControl } - : content, - ), + return { + type: "usage", + inputTokens, + outputTokens, + cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined, + cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined, + totalCost, + } + } + + private applyCacheControlToAiSdkMessages( + originalMessages: Anthropic.Messages.MessageParam[], + aiSdkMessages: { role: string; providerOptions?: Record> }[], + targetOriginalIndices: Set, + cacheProviderOption: Record>, + ): void { + let aiSdkIdx = 0 + for (let origIdx = 0; origIdx < originalMessages.length; origIdx++) { + const origMsg = originalMessages[origIdx] + + if (typeof origMsg.content === "string") { + if (targetOriginalIndices.has(origIdx) && aiSdkIdx < aiSdkMessages.length) { + aiSdkMessages[aiSdkIdx].providerOptions = { + ...aiSdkMessages[aiSdkIdx].providerOptions, + ...cacheProviderOption, + } + } + aiSdkIdx++ + } else if (origMsg.role === "user") { + const hasToolResults = origMsg.content.some((part) => (part as { type: string }).type === "tool_result") + const hasNonToolContent = origMsg.content.some( + (part) => (part as { type: string }).type === "text" || (part as { type: string }).type === "image", + ) + + if (hasToolResults && hasNonToolContent) { + const userMsgIdx = aiSdkIdx + 1 + if (targetOriginalIndices.has(origIdx) && userMsgIdx < aiSdkMessages.length) { + aiSdkMessages[userMsgIdx].providerOptions = { + ...aiSdkMessages[userMsgIdx].providerOptions, + ...cacheProviderOption, + } + } + aiSdkIdx += 2 + } else if (hasToolResults) { + if (targetOriginalIndices.has(origIdx) && aiSdkIdx < aiSdkMessages.length) { + aiSdkMessages[aiSdkIdx].providerOptions = { + ...aiSdkMessages[aiSdkIdx].providerOptions, + ...cacheProviderOption, + } + } + aiSdkIdx++ + } else { + if (targetOriginalIndices.has(origIdx) && aiSdkIdx < aiSdkMessages.length) { + aiSdkMessages[aiSdkIdx].providerOptions = { + ...aiSdkMessages[aiSdkIdx].providerOptions, + ...cacheProviderOption, + } + } + aiSdkIdx++ } + } else { + aiSdkIdx++ } - return message - }) + } } getModel() { const modelId = this.options.apiModelId - const id = modelId && modelId in minimaxModels ? (modelId as MinimaxModelId) : minimaxDefaultModelId + + const id = modelId && modelId in minimaxModels ? (modelId as keyof typeof minimaxModels) : minimaxDefaultModelId const info = minimaxModels[id] const params = getModelParams({ @@ -289,18 +276,32 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand } } - async completePrompt(prompt: string) { - const { id: model, temperature } = this.getModel() + async completePrompt(prompt: string): Promise { + const { id, maxTokens, temperature } = this.getModel() - const message = await this.client.messages.create({ - model, - max_tokens: 16_384, - temperature: temperature ?? 1.0, - messages: [{ role: "user", content: prompt }], - stream: false, - }) + try { + const { text } = await generateText({ + model: this.client(id), + prompt, + maxOutputTokens: maxTokens ?? minimaxModels[minimaxDefaultModelId].maxTokens, + temperature, + }) + + return text + } catch (error) { + throw handleAiSdkError(error, this.providerName) + } + } + + getThoughtSignature(): string | undefined { + return this.lastThoughtSignature + } + + getRedactedThinkingBlocks(): Array<{ type: "redacted_thinking"; data: string }> | undefined { + return this.lastRedactedThinkingBlocks.length > 0 ? this.lastRedactedThinkingBlocks : undefined + } - const content = message.content.find(({ type }) => type === "text") - return content?.type === "text" ? content.text : "" + override isAiSdkProvider(): boolean { + return true } }