diff --git a/src/api/providers/__tests__/qwen-code-native-tools.spec.ts b/src/api/providers/__tests__/qwen-code-native-tools.spec.ts index 3b470ce461..852df0c140 100644 --- a/src/api/providers/__tests__/qwen-code-native-tools.spec.ts +++ b/src/api/providers/__tests__/qwen-code-native-tools.spec.ts @@ -1,373 +1,379 @@ // npx vitest run api/providers/__tests__/qwen-code-native-tools.spec.ts -// Mock filesystem - must come before other imports +const { + mockStreamText, + mockGenerateText, + mockWrapLanguageModel, + mockExtractReasoningMiddleware, + mockCreateOpenAICompatible, + mockSafeWriteJson, +} = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), + mockWrapLanguageModel: vi.fn((opts: { model: unknown }) => opts.model), + mockExtractReasoningMiddleware: vi.fn(() => ({})), + mockCreateOpenAICompatible: vi.fn(), + mockSafeWriteJson: vi.fn(), +})) + +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + wrapLanguageModel: mockWrapLanguageModel, + extractReasoningMiddleware: mockExtractReasoningMiddleware, + } +}) + +vi.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: mockCreateOpenAICompatible, +})) + vi.mock("node:fs", () => ({ promises: { readFile: vi.fn(), - writeFile: vi.fn(), }, })) -const mockCreate = vi.fn() -vi.mock("openai", () => { +vi.mock("../../../utils/safeWriteJson", () => ({ + safeWriteJson: mockSafeWriteJson, +})) + +import type { Anthropic } from "@anthropic-ai/sdk" +import { qwenCodeDefaultModelId, qwenCodeModels, type QwenCodeModelId } from "@roo-code/types" + +import { promises as fs } from "node:fs" +import * as path from "node:path" + +import { QwenCodeHandler, type QwenCodeHandlerOptions } from "../qwen-code" +import { safeWriteJson } from "../../../utils/safeWriteJson" + +type QwenCredentials = { + access_token: string + refresh_token: string + token_type: string + expiry_date: number + resource_url?: string +} + +type MutableQwenHandler = { + credentials: QwenCredentials | null +} + +class TestableQwenCodeHandler extends QwenCodeHandler { + public getLanguageModelForTest() { + return this.getLanguageModel() + } +} + +function buildCredentials(overrides: Partial = {}): QwenCredentials { return { - __esModule: true, - default: vi.fn().mockImplementation(() => ({ - apiKey: "test-key", - baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", - chat: { - completions: { - create: mockCreate, - }, - }, - })), + access_token: "test-access-token", + refresh_token: "test-refresh-token", + token_type: "Bearer", + expiry_date: Date.now() + 60 * 60 * 1000, + resource_url: "https://dashscope.aliyuncs.com/compatible-mode/v1", + ...overrides, } -}) +} + +function createFetchResponse(params: { + ok: boolean + status?: number + statusText?: string + jsonBody?: unknown + textBody?: string +}): Response { + return { + ok: params.ok, + status: params.status ?? (params.ok ? 200 : 500), + statusText: params.statusText ?? "", + json: async () => params.jsonBody, + text: async () => params.textBody ?? "", + } as unknown as Response +} + +function createFullStream(parts: unknown[]): AsyncGenerator { + return (async function* () { + for (const part of parts) { + yield part + } + })() +} -import { promises as fs } from "node:fs" -import { QwenCodeHandler } from "../qwen-code" -import { NativeToolCallParser } from "../../../core/assistant-message/NativeToolCallParser" -import type { ApiHandlerOptions } from "../../../shared/api" - -describe("QwenCodeHandler Native Tools", () => { - let handler: QwenCodeHandler - let mockOptions: ApiHandlerOptions & { qwenCodeOauthPath?: string } - - const testTools = [ - { - type: "function" as const, - function: { - name: "test_tool", - description: "A test tool", - parameters: { - type: "object", - properties: { - arg1: { type: "string", description: "First argument" }, - }, - required: ["arg1"], - }, - }, - }, - ] +async function collectStreamChunks(stream: AsyncGenerator): Promise { + const chunks: unknown[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + return chunks +} + +describe("QwenCodeHandler (AI SDK)", () => { + let fetchMock: ReturnType beforeEach(() => { vi.clearAllMocks() - // Mock credentials file - const mockCredentials = { - access_token: "test-access-token", - refresh_token: "test-refresh-token", - token_type: "Bearer", - expiry_date: Date.now() + 3600000, // 1 hour from now - resource_url: "https://dashscope.aliyuncs.com/compatible-mode/v1", - } - ;(fs.readFile as any).mockResolvedValue(JSON.stringify(mockCredentials)) - ;(fs.writeFile as any).mockResolvedValue(undefined) + fetchMock = vi.fn() + vi.stubGlobal("fetch", fetchMock) - mockOptions = { - apiModelId: "qwen3-coder-plus", - } - handler = new QwenCodeHandler(mockOptions) + mockCreateOpenAICompatible.mockImplementation((config: { name: string; baseURL: string; apiKey: string }) => { + return vi.fn((modelId: string) => ({ + modelId, + provider: config.name, + baseURL: config.baseURL, + apiKey: config.apiKey, + })) + }) - // Clear NativeToolCallParser state before each test - NativeToolCallParser.clearRawChunkState() + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(buildCredentials())) + vi.mocked(safeWriteJson).mockResolvedValue(undefined) + mockGenerateText.mockResolvedValue({ text: "ok" }) }) - describe("Native Tool Calling Support", () => { - it("should include tools in request when model supports native tools and tools are provided", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" } }], - } - }, - })) + afterEach(() => { + vi.unstubAllGlobals() + }) - const stream = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) - await stream.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - type: "function", - function: expect.objectContaining({ - name: "test_tool", - }), - }), - ]), - parallel_tool_calls: true, - }), - ) - }) + it("constructs successfully with valid options", () => { + const handler = new QwenCodeHandler({ apiModelId: "qwen3-coder-plus" }) + expect(handler).toBeInstanceOf(QwenCodeHandler) + }) - it("should include tool_choice when provided", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" } }], - } - }, - })) + it("getModel returns default model when apiModelId is not provided", () => { + const handler = new QwenCodeHandler({}) + const model = handler.getModel() - const stream = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - tool_choice: "auto", - }) - await stream.next() + expect(model.id).toBe(qwenCodeDefaultModelId) + expect(model.info).toEqual(qwenCodeModels[qwenCodeDefaultModelId as QwenCodeModelId]) + }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tool_choice: "auto", - }), - ) - }) + it("getModel returns custom model ID and info", () => { + const customModelId: QwenCodeModelId = "qwen3-coder-plus" + const handler = new QwenCodeHandler({ apiModelId: customModelId }) + const model = handler.getModel() - it("should always include tools and tool_choice (tools are guaranteed to be present after ALWAYS_AVAILABLE_TOOLS)", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" } }], - } - }, - })) + expect(model.id).toBe(customModelId) + expect(model.info).toEqual(qwenCodeModels[customModelId]) + }) - const stream = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - }) - await stream.next() + it("isAiSdkProvider returns true", () => { + const handler = new QwenCodeHandler({ apiModelId: "qwen3-coder-plus" }) + expect(handler.isAiSdkProvider()).toBe(true) + }) - // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0] - expect(callArgs).toHaveProperty("tools") - expect(callArgs).toHaveProperty("tool_choice") - expect(callArgs).toHaveProperty("parallel_tool_calls", true) + it("loads OAuth credentials from file before completing prompt", async () => { + const oauthPath = "/tmp/qwen/oauth_creds.json" + const handler = new QwenCodeHandler({ + apiModelId: "qwen3-coder-plus", + qwenCodeOauthPath: oauthPath, }) - it("should yield tool_call_partial chunks during streaming", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_qwen_123", - function: { - name: "test_tool", - arguments: '{"arg1":', - }, - }, - ], - }, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { - arguments: '"value"}', - }, - }, - ], - }, - }, - ], - } + const result = await handler.completePrompt("Hello") + + expect(result).toBe("ok") + expect(fs.readFile).toHaveBeenCalledWith(path.resolve(oauthPath), "utf-8") + expect(fetchMock).not.toHaveBeenCalled() + }) + + it("refreshes access token when credentials are expired", async () => { + const oauthPath = "/tmp/qwen/expired_creds.json" + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify( + buildCredentials({ + expiry_date: Date.now() - 60_000, + resource_url: "dashscope.aliyuncs.com/compatible-mode", + }), + ), + ) + + fetchMock.mockResolvedValue( + createFetchResponse({ + ok: true, + jsonBody: { + access_token: "refreshed-access-token", + refresh_token: "refreshed-refresh-token", + token_type: "Bearer", + expires_in: 3600, }, - })) + }), + ) + + const handler = new QwenCodeHandler({ apiModelId: "qwen3-coder-plus", qwenCodeOauthPath: oauthPath }) + + const result = await handler.completePrompt("Refresh now") + expect(result).toBe("ok") + + expect(fetchMock).toHaveBeenCalledTimes(1) + const [refreshUrl, refreshInit] = fetchMock.mock.calls[0] as [string, RequestInit] + expect(refreshUrl).toBe("https://chat.qwen.ai/api/v1/oauth2/token") + expect(refreshInit.method).toBe("POST") + expect(String(refreshInit.body)).toContain("grant_type=refresh_token") + expect(String(refreshInit.body)).toContain("refresh_token=test-refresh-token") + expect(String(refreshInit.body)).toContain("client_id=f0304373b74a44d2b584a3fb70ca9e56") + + expect(safeWriteJson).toHaveBeenCalledWith( + path.resolve(oauthPath), + expect.objectContaining({ + access_token: "refreshed-access-token", + refresh_token: "refreshed-refresh-token", + token_type: "Bearer", + }), + ) + + expect(mockCreateOpenAICompatible).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", + apiKey: "refreshed-access-token", + }), + ) + }) - const stream = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) + it("retries createMessage once after 401 by refreshing token", async () => { + const oauthPath = "/tmp/qwen/retry_creds.json" + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(buildCredentials({ access_token: "stale-token" }))) + + fetchMock.mockResolvedValue( + createFetchResponse({ + ok: true, + jsonBody: { + access_token: "fresh-token", + refresh_token: "fresh-refresh-token", + token_type: "Bearer", + expires_in: 3600, + }, + }), + ) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: "call_qwen_123", - name: "test_tool", - arguments: '{"arg1":', - }) + const unauthorizedError = new Error("Unauthorized") + Object.assign(unauthorizedError, { status: 401 }) - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '"value"}', + mockStreamText + .mockImplementationOnce(() => { + throw unauthorizedError + }) + .mockReturnValueOnce({ + fullStream: createFullStream([{ type: "text-delta", text: "Recovered response" }]), + usage: Promise.resolve({ inputTokens: 11, outputTokens: 7 }), }) - }) - it("should set parallel_tool_calls based on metadata", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" } }], - } - }, - })) + const handler = new QwenCodeHandler({ apiModelId: "qwen3-coder-plus", qwenCodeOauthPath: oauthPath }) + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] - const stream = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - parallelToolCalls: true, - }) - await stream.next() + const chunks = await collectStreamChunks(handler.createMessage("System", messages)) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - parallel_tool_calls: true, - }), - ) + expect(mockStreamText).toHaveBeenCalledTimes(2) + expect(fetchMock).toHaveBeenCalledTimes(1) + expect(chunks).toEqual( + expect.arrayContaining([ + expect.objectContaining({ type: "text", text: "Recovered response" }), + expect.objectContaining({ type: "usage", inputTokens: 11, outputTokens: 7 }), + ]), + ) + }) + + it("createMessage yields expected stream chunk types from AI SDK streamText", async () => { + mockStreamText.mockReturnValue({ + fullStream: createFullStream([ + { type: "reasoning-delta", text: "Thinking..." }, + { type: "text-delta", text: "Answer" }, + { type: "tool-input-start", id: "tool-1", toolName: "read_file" }, + { type: "tool-input-delta", id: "tool-1", delta: '{"path":"a.ts"}' }, + { type: "tool-input-end", id: "tool-1" }, + { type: "finish", finishReason: "stop" }, + ]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), }) - it("should yield tool_call_end events when finish_reason is tool_calls", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_qwen_test", - function: { - name: "test_tool", - arguments: '{"arg1":"value"}', - }, - }, - ], - }, - }, - ], - } - yield { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - }, - })) + const handler = new QwenCodeHandler({ apiModelId: "qwen3-coder-plus" }) + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + const chunks = await collectStreamChunks(handler.createMessage("System", messages)) + + expect(chunks).toEqual( + expect.arrayContaining([ + expect.objectContaining({ type: "reasoning", text: "Thinking..." }), + expect.objectContaining({ type: "text", text: "Answer" }), + expect.objectContaining({ type: "tool_call_start", id: "tool-1", name: "read_file" }), + expect.objectContaining({ type: "tool_call_delta", id: "tool-1", delta: '{"path":"a.ts"}' }), + expect.objectContaining({ type: "tool_call_end", id: "tool-1" }), + expect.objectContaining({ type: "usage", inputTokens: 10, outputTokens: 5 }), + ]), + ) + }) - const stream = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) + it("completePrompt returns generated text", async () => { + mockGenerateText.mockResolvedValue({ text: "Completion text" }) + const handler = new QwenCodeHandler({ apiModelId: "qwen3-coder-plus" }) - const chunks = [] - for await (const chunk of stream) { - // Simulate what Task.ts does: when we receive tool_call_partial, - // process it through NativeToolCallParser to populate rawChunkTracker - if (chunk.type === "tool_call_partial") { - NativeToolCallParser.processRawChunk({ - index: chunk.index, - id: chunk.id, - name: chunk.name, - arguments: chunk.arguments, - }) - } - chunks.push(chunk) - } - - // Should have tool_call_partial and tool_call_end - const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - - expect(partialChunks).toHaveLength(1) - expect(endChunks).toHaveLength(1) - expect(endChunks[0].id).toBe("call_qwen_test") - }) + const result = await handler.completePrompt("Complete this") - it("should preserve thinking block handling alongside tool calls", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - reasoning_content: "Thinking about this...", - }, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_after_think", - function: { - name: "test_tool", - arguments: '{"arg1":"result"}', - }, - }, - ], - }, - }, - ], - } - yield { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - }, - ], - } - }, - })) + expect(result).toBe("Completion text") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Complete this", + }), + ) + }) - const stream = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) + it("getLanguageModel creates a fresh provider using current OAuth credentials", () => { + const handler = new TestableQwenCodeHandler({ apiModelId: "qwen3-coder-plus" }) + const mutable = handler as unknown as MutableQwenHandler + + mutable.credentials = buildCredentials({ + access_token: "token-1", + resource_url: "https://dashscope.aliyuncs.com/compatible-mode/v1", + }) + handler.getLanguageModelForTest() + + mutable.credentials = buildCredentials({ + access_token: "token-2", + resource_url: "dashscope.aliyuncs.com/compatible-mode", + }) + handler.getLanguageModelForTest() + + expect(mockCreateOpenAICompatible).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + apiKey: "pending-oauth", + baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", + }), + ) + expect(mockCreateOpenAICompatible).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + apiKey: "token-1", + baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", + }), + ) + expect(mockCreateOpenAICompatible).toHaveBeenNthCalledWith( + 3, + expect.objectContaining({ + apiKey: "token-2", + baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", + }), + ) + }) - const chunks = [] - for await (const chunk of stream) { - if (chunk.type === "tool_call_partial") { - NativeToolCallParser.processRawChunk({ - index: chunk.index, - id: chunk.id, - name: chunk.name, - arguments: chunk.arguments, - }) - } - chunks.push(chunk) - } - - // Should have reasoning, tool_call_partial, and tool_call_end - const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning") - const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - - expect(reasoningChunks).toHaveLength(1) - expect(reasoningChunks[0].text).toBe("Thinking about this...") - expect(partialChunks).toHaveLength(1) - expect(endChunks).toHaveLength(1) + it("uses wrapLanguageModel with reasoning middleware", async () => { + mockStreamText.mockReturnValue({ + fullStream: createFullStream([{ type: "text-delta", text: "ok" }]), + usage: Promise.resolve({ inputTokens: 1, outputTokens: 1 }), }) + + const handler = new QwenCodeHandler({ apiModelId: "qwen3-coder-plus" }) + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + await collectStreamChunks(handler.createMessage("System", messages)) + + expect(mockExtractReasoningMiddleware).toHaveBeenCalledWith({ tagName: "think" }) + expect(mockWrapLanguageModel).toHaveBeenCalledWith( + expect.objectContaining({ + middleware: expect.any(Object), + }), + ) }) }) diff --git a/src/api/providers/qwen-code.ts b/src/api/providers/qwen-code.ts index 18d09a59f3..23919506c1 100644 --- a/src/api/providers/qwen-code.ts +++ b/src/api/providers/qwen-code.ts @@ -1,19 +1,20 @@ import { promises as fs } from "node:fs" import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" +import { createOpenAICompatible } from "@ai-sdk/openai-compatible" +import { wrapLanguageModel, extractReasoningMiddleware, type LanguageModel } from "ai" import * as os from "os" import * as path from "path" import { type ModelInfo, type QwenCodeModelId, qwenCodeModels, qwenCodeDefaultModelId } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" +import { safeWriteJson } from "../../utils/safeWriteJson" -import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser" - -import { convertToOpenAiMessages } from "../transform/openai-format" +import { getModelParams } from "../transform/model-params" import { ApiStream } from "../transform/stream" -import { BaseProvider } from "./base-provider" +import { DEFAULT_HEADERS } from "./constants" +import { OpenAICompatibleHandler, OpenAICompatibleConfig } from "./openai-compatible" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" const QWEN_OAUTH_BASE_URL = "https://chat.qwen.ai" @@ -30,7 +31,16 @@ interface QwenOAuthCredentials { resource_url?: string } -interface QwenCodeHandlerOptions extends ApiHandlerOptions { +interface QwenTokenRefreshResponse { + access_token?: string + refresh_token?: string + token_type?: string + expires_in?: number + error?: string + error_description?: string +} + +export interface QwenCodeHandlerOptions extends ApiHandlerOptions { qwenCodeOauthPath?: string } @@ -51,37 +61,57 @@ function objectToUrlEncoded(data: Record): string { .join("&") } -export class QwenCodeHandler extends BaseProvider implements SingleCompletionHandler { - protected options: QwenCodeHandlerOptions +export class QwenCodeHandler extends OpenAICompatibleHandler implements SingleCompletionHandler { + private qwenOptions: QwenCodeHandlerOptions private credentials: QwenOAuthCredentials | null = null - private client: OpenAI | undefined private refreshPromise: Promise | null = null constructor(options: QwenCodeHandlerOptions) { - super() - this.options = options + const modelId = options.apiModelId ?? qwenCodeDefaultModelId + const modelInfo = + qwenCodeModels[modelId as QwenCodeModelId] || qwenCodeModels[qwenCodeDefaultModelId as QwenCodeModelId] + + const config: OpenAICompatibleConfig = { + providerName: "qwen-code", + baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", + apiKey: "pending-oauth", + modelId, + modelInfo, + modelMaxTokens: options.modelMaxTokens ?? undefined, + temperature: options.modelTemperature ?? undefined, + } + + super(options, config) + this.qwenOptions = options } - private ensureClient(): OpenAI { - if (!this.client) { - // Create the client instance with dummy key initially - // The API key will be updated dynamically via ensureAuthenticated - this.client = new OpenAI({ - apiKey: "dummy-key-will-be-replaced", - baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", - }) - } - return this.client + protected override getLanguageModel(): LanguageModel { + const apiKey = this.credentials?.access_token ?? "pending-oauth" + const baseURL = this.credentials + ? this.getBaseUrl(this.credentials) + : "https://dashscope.aliyuncs.com/compatible-mode/v1" + const provider = createOpenAICompatible({ + name: "qwen-code", + baseURL, + apiKey, + headers: { ...DEFAULT_HEADERS }, + }) + const model = this.getModel() + const baseModel = provider(model.id) + return wrapLanguageModel({ + model: baseModel, + middleware: extractReasoningMiddleware({ tagName: "think" }), + }) } private async loadCachedQwenCredentials(): Promise { try { - const keyFile = getQwenCachedCredentialPath(this.options.qwenCodeOauthPath) + const keyFile = getQwenCachedCredentialPath(this.qwenOptions.qwenCodeOauthPath) const credsStr = await fs.readFile(keyFile, "utf-8") return JSON.parse(credsStr) } catch (error) { console.error( - `Error reading or parsing credentials file at ${getQwenCachedCredentialPath(this.options.qwenCodeOauthPath)}`, + `Error reading or parsing credentials file at ${getQwenCachedCredentialPath(this.qwenOptions.qwenCodeOauthPath)}`, ) throw new Error(`Failed to load Qwen OAuth credentials: ${error}`) } @@ -130,13 +160,17 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan throw new Error(`Token refresh failed: ${response.status} ${response.statusText}. Response: ${errorText}`) } - const tokenData = await response.json() + const tokenData = (await response.json()) as QwenTokenRefreshResponse if (tokenData.error) { throw new Error(`Token refresh failed: ${tokenData.error} - ${tokenData.error_description}`) } - const newCredentials = { + if (!tokenData.access_token || !tokenData.token_type || typeof tokenData.expires_in !== "number") { + throw new Error("Token refresh failed: invalid token response") + } + + const newCredentials: QwenOAuthCredentials = { ...credentials, access_token: tokenData.access_token, token_type: tokenData.token_type, @@ -144,9 +178,9 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan expiry_date: Date.now() + tokenData.expires_in * 1000, } - const filePath = getQwenCachedCredentialPath(this.options.qwenCodeOauthPath) + const filePath = getQwenCachedCredentialPath(this.qwenOptions.qwenCodeOauthPath) try { - await fs.writeFile(filePath, JSON.stringify(newCredentials, null, 2)) + await safeWriteJson(filePath, newCredentials) } catch (error) { console.error("Failed to save refreshed credentials:", error) // Continue with the refreshed token in memory even if file write fails @@ -171,11 +205,6 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan if (!this.isTokenValid(this.credentials)) { this.credentials = await this.refreshAccessToken(this.credentials) } - - // After authentication, update the apiKey and baseURL on the existing client - const client = this.ensureClient() - client.apiKey = this.credentials.access_token - client.baseURL = this.getBaseUrl(this.credentials) } private getBaseUrl(creds: QwenOAuthCredentials): string { @@ -186,154 +215,106 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan return baseUrl.endsWith("/v1") ? baseUrl : `${baseUrl}/v1` } - private async callApiWithRetry(apiCall: () => Promise): Promise { - try { - return await apiCall() - } catch (error: any) { - if (error.status === 401) { - // Token expired, refresh and retry - this.credentials = await this.refreshAccessToken(this.credentials!) - const client = this.ensureClient() - client.apiKey = this.credentials.access_token - client.baseURL = this.getBaseUrl(this.credentials) - return await apiCall() - } else { - throw error - } + private async forceRefreshAndAuthenticate(): Promise { + if (!this.credentials) { + this.credentials = await this.loadCachedQwenCredentials() } - } - override async *createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - ): ApiStream { - await this.ensureAuthenticated() - const client = this.ensureClient() - const model = this.getModel() + this.credentials = await this.refreshAccessToken(this.credentials) + } - const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { - role: "system", - content: systemPrompt, + private getStatusCode(error: unknown): number | undefined { + if (!error || typeof error !== "object") { + return undefined } - const convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] - - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model: model.id, - temperature: 0, - messages: convertedMessages, - stream: true, - stream_options: { include_usage: true }, - max_completion_tokens: model.info.maxTokens, - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, + const obj = error as Record + + const parseStatusCode = (value: unknown): number | undefined => { + if (typeof value === "number") { + return value + } + if (typeof value === "string") { + const parsed = Number.parseInt(value, 10) + if (!Number.isNaN(parsed)) { + return parsed + } + } + return undefined } - const stream = await this.callApiWithRetry(() => client.chat.completions.create(requestOptions)) + const directStatus = parseStatusCode(obj.status) ?? parseStatusCode(obj.statusCode) + if (directStatus !== undefined) { + return directStatus + } - let fullContent = "" + const nestedError = obj.lastError ?? obj.cause + if (nestedError) { + return this.getStatusCode(nestedError) + } - for await (const apiChunk of stream) { - const delta = apiChunk.choices[0]?.delta ?? {} - const finishReason = apiChunk.choices[0]?.finish_reason + return undefined + } - if (delta.content) { - let newText = delta.content - if (newText.startsWith(fullContent)) { - newText = newText.substring(fullContent.length) - } - fullContent = delta.content - - if (newText) { - // Check for thinking blocks - if (newText.includes("") || newText.includes("")) { - // Simple parsing for thinking blocks - const parts = newText.split(/<\/?think>/g) - for (let i = 0; i < parts.length; i++) { - if (parts[i]) { - if (i % 2 === 0) { - // Outside thinking block - yield { - type: "text", - text: parts[i], - } - } else { - // Inside thinking block - yield { - type: "reasoning", - text: parts[i], - } - } - } - } - } else { - yield { - type: "text", - text: newText, - } - } - } - } + private isAuthError(error: unknown): boolean { + const statusCode = this.getStatusCode(error) + if (statusCode === 401) { + return true + } - if ("reasoning_content" in delta && delta.reasoning_content) { - yield { - type: "reasoning", - text: (delta.reasoning_content as string | undefined) || "", - } + if (error instanceof Error) { + const message = error.message || "" + if (message.includes("401") || message.includes("Unauthorized")) { + return true } + } - // Handle tool calls in stream - emit partial chunks for NativeToolCallParser - if (delta.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } + return false + } - // Process finish_reason to emit tool_call_end events - if (finishReason) { - const endEvents = NativeToolCallParser.processFinishReason(finishReason) - for (const event of endEvents) { - yield event - } - } + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + await this.ensureAuthenticated() - if (apiChunk.usage) { - yield { - type: "usage", - inputTokens: apiChunk.usage.prompt_tokens || 0, - outputTokens: apiChunk.usage.completion_tokens || 0, - } + try { + yield* super.createMessage(systemPrompt, messages, metadata) + } catch (error) { + if (this.isAuthError(error)) { + await this.forceRefreshAndAuthenticate() + yield* super.createMessage(systemPrompt, messages, metadata) + } else { + throw error } } } - override getModel(): { id: string; info: ModelInfo } { + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { const id = this.options.apiModelId ?? qwenCodeDefaultModelId - const info = qwenCodeModels[id as keyof typeof qwenCodeModels] || qwenCodeModels[qwenCodeDefaultModelId] - return { id, info } + const info = qwenCodeModels[id as QwenCodeModelId] || qwenCodeModels[qwenCodeDefaultModelId as QwenCodeModelId] + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: 0, + }) + return { id, info, ...params } } - async completePrompt(prompt: string): Promise { + override async completePrompt(prompt: string): Promise { await this.ensureAuthenticated() - const client = this.ensureClient() - const model = this.getModel() - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: model.id, - messages: [{ role: "user", content: prompt }], - max_completion_tokens: model.info.maxTokens, + try { + return await super.completePrompt(prompt) + } catch (error) { + if (this.isAuthError(error)) { + await this.forceRefreshAndAuthenticate() + return await super.completePrompt(prompt) + } + throw error } - - const response = await this.callApiWithRetry(() => client.chat.completions.create(requestOptions)) - - return response.choices[0]?.message.content || "" } }