diff --git a/.ade/ade.yaml b/.ade/ade.yaml index 8679ba36a..9f155cd7e 100644 --- a/.ade/ade.yaml +++ b/.ade/ade.yaml @@ -23,3 +23,9 @@ ai: autoTitleEnabled: true autoTitleModelId: openai/gpt-5.3-codex-spark autoTitleRefreshOnComplete: true + localProviders: + lmstudio: + enabled: true + endpoint: http://127.0.0.1:1234 + autoDetect: true + preferredModelId: null diff --git a/.ade/cto/identity.yaml b/.ade/cto/identity.yaml index ad9417718..e7a73393e 100644 --- a/.ade/cto/identity.yaml +++ b/.ade/cto/identity.yaml @@ -1,7 +1,14 @@ name: CTO -version: 3 -persona: Persistent project CTO with collaborative personality. -personality: casual +version: 1 +persona: >- + You are the CTO for this project inside ADE. + + You are the persistent technical lead who owns architecture, execution + quality, engineering continuity, and team direction. + + Use ADE's tools and project context to help the team move forward with clear, + concrete decisions. +personality: strategic modelPreferences: provider: claude model: sonnet @@ -21,8 +28,4 @@ openclawContextPolicy: - secret - token - system_prompt -onboardingState: - completedSteps: - - identity - completedAt: 2026-04-02T14:20:19.124Z -updatedAt: 2026-04-02T14:20:19.127Z +updatedAt: 1970-01-01T00:00:00.000Z diff --git a/apps/desktop/src/main/main.ts b/apps/desktop/src/main/main.ts index 162752d69..fd24d7e6a 100644 --- a/apps/desktop/src/main/main.ts +++ b/apps/desktop/src/main/main.ts @@ -108,6 +108,7 @@ import { createExternalConnectionAuthService } from "./services/externalMcp/exte import { createComputerUseArtifactBrokerService } from "./services/computerUse/computerUseArtifactBrokerService"; import { createSyncService } from "./services/sync/syncService"; import { createAutoUpdateService } from "./services/updates/autoUpdateService"; +import { cleanupStaleTempArtifacts } from "./services/runtime/tempCleanupService"; import type { Logger } from "./services/logging/logger"; /** @@ -1454,6 +1455,15 @@ app.whenReady().then(async () => { }, }); agentChatServiceRef = agentChatService; + setImmediate(() => { + void Promise.resolve() + .then(() => agentChatService.cleanupStaleAttachments()) + .catch((err) => { + logger.warn("agent_chat.cleanup_stale_attachments_failed", { + error: err instanceof Error ? err.message : String(err), + }); + }); + }); // Wire agentChatService into prService for integration resolution prService.setAgentChatService(agentChatService); @@ -2772,6 +2782,10 @@ app.whenReady().then(async () => { // --- Auto-update service (global, not per-project) --- const updateLogger = createFileLogger(path.join(app.getPath("userData"), "ade-update.jsonl")); + cleanupStaleTempArtifacts({ + tempRoot: app.getPath("temp"), + logger: updateLogger, + }); const autoUpdateService = createAutoUpdateService(updateLogger); autoUpdateService.onUpdateAvailable((info) => { BrowserWindow.getAllWindows().forEach((win) => { diff --git a/apps/desktop/src/main/services/ai/aiIntegrationService.ts b/apps/desktop/src/main/services/ai/aiIntegrationService.ts index e512bee52..f02fb9801 100644 --- a/apps/desktop/src/main/services/ai/aiIntegrationService.ts +++ b/apps/desktop/src/main/services/ai/aiIntegrationService.ts @@ -3,18 +3,34 @@ import type { Logger } from "../logging/logger"; import type { AdeDb } from "../state/kvDb"; import type { createProjectConfigService } from "../config/projectConfigService"; import type { AgentModelDescriptor, AgentProvider, ExecutorOpts } from "./agentExecutor"; -import type { AiApiKeyVerificationResult, AiProviderConnections } from "../../../shared/types"; +import type { + AiApiKeyVerificationResult, + AiLocalProviderConfigs, + AiProviderConnections, + AiRuntimeConnections, + AiRuntimeConnectionStatus, +} from "../../../shared/types"; import { createDynamicLocalModelDescriptor, getDefaultModelDescriptor, getModelById, getAvailableModels, + getLocalProviderDefaultEndpoint, listModelDescriptorsForProvider, - MODEL_REGISTRY, + LOCAL_PROVIDER_LABELS, + replaceDynamicLocalModelDescriptors, resolveModelAlias, enrichModelRegistry, + type LocalProviderFamily, } from "../../../shared/modelRegistry"; -import { detectAllAuth, getCachedCliAuthStatuses, verifyProviderApiKey, type DetectedAuth, type CliAuthStatus } from "./authDetector"; +import { + detectAllAuth, + getCachedCliAuthStatuses, + resetLocalProviderDetectionCache, + verifyProviderApiKey, + type DetectedAuth, + type CliAuthStatus, +} from "./authDetector"; import { executeUnified, resumeUnified } from "./unifiedExecutor"; import { initialize as initModelsDevService } from "./modelsDevService"; import { updateModelPricing } from "../../../shared/modelProfiles"; @@ -22,7 +38,7 @@ import { isRecord } from "../shared/utils"; import { getApiKeyStoreStatus } from "./apiKeyStore"; import type { createMemoryService } from "../memory/memoryService"; import type { CompactionFlushService } from "../memory/compactionFlushService"; -import { discoverLocalModels } from "./localModelDiscovery"; +import { discoverLocalModels, inspectLocalProvider } from "./localModelDiscovery"; import { discoverCursorCliModelDescriptors, clearCursorCliModelsCache } from "../chat/cursorModelsDiscovery"; import { resolveCursorAgentExecutable } from "./cursorAgentExecutable"; import { buildProviderConnections } from "./providerConnectionStatus"; @@ -72,12 +88,15 @@ export type AiIntegrationStatus = { cli?: "claude" | "codex" | "cursor"; provider?: string; source?: "config" | "env" | "store"; + endpointSource?: "auto" | "config"; path?: string; endpoint?: string; + preferredModelId?: string | null; authenticated?: boolean; verified?: boolean; }>; providerConnections?: AiProviderConnections; + runtimeConnections?: AiRuntimeConnections; availableModelIds?: string[]; apiKeyStore?: { secureStorageAvailable: boolean; @@ -259,6 +278,33 @@ function extractConfiguredApiKeys(snapshot: ReturnType["get"]>, +): AiLocalProviderConfigs { + const aiConfig = extractAiConfig(snapshot); + const localProvidersRaw = isRecord(aiConfig.localProviders) ? aiConfig.localProviders : {}; + const out: AiLocalProviderConfigs = {}; + + for (const provider of ["ollama", "lmstudio", "vllm"] as const) { + const raw = isRecord(localProvidersRaw[provider]) ? localProvidersRaw[provider] : null; + if (!raw) continue; + const entry: NonNullable = {}; + if (typeof raw.enabled === "boolean") entry.enabled = raw.enabled; + if (typeof raw.autoDetect === "boolean") entry.autoDetect = raw.autoDetect; + if (typeof raw.endpoint === "string" && raw.endpoint.trim().length > 0) { + entry.endpoint = raw.endpoint.trim(); + } + if (raw.preferredModelId === null) { + entry.preferredModelId = null; + } else if (typeof raw.preferredModelId === "string" && raw.preferredModelId.trim().length > 0) { + entry.preferredModelId = raw.preferredModelId.trim(); + } + if (Object.keys(entry).length) out[provider] = entry; + } + + return out; +} + function toCliAvailability(auth: DetectedAuth[]): { claude: boolean; codex: boolean; cursor: boolean } { return { claude: auth.some((entry) => entry.type === "cli-subscription" && entry.cli === "claude"), @@ -308,6 +354,8 @@ function redactDetectedAuth( type: entry.type, provider: entry.provider, endpoint: entry.endpoint, + endpointSource: entry.endpointSource, + preferredModelId: entry.preferredModelId ?? null, }; }); @@ -333,6 +381,253 @@ function redactDetectedAuth( return redacted; } +function apiProviderLabel(provider: string): string { + const labels: Record = { + anthropic: "Anthropic", + openai: "OpenAI", + google: "Google AI", + mistral: "Mistral", + deepseek: "DeepSeek", + xai: "xAI", + groq: "Groq", + together: "Together AI", + openrouter: "OpenRouter", + ollama: "Ollama", + lmstudio: "LM Studio", + vllm: "vLLM", + }; + return labels[provider] ?? provider; +} + +function toCliRuntimeConnection(status: NonNullable[keyof AiProviderConnections]): AiRuntimeConnectionStatus { + const source = status.sources.find((entry) => entry.detected && entry.kind === "local-credentials")?.source; + return { + provider: status.provider, + label: apiProviderLabel(status.provider), + kind: "cli", + configured: status.authAvailable || status.runtimeDetected, + authAvailable: status.authAvailable, + runtimeDetected: status.runtimeDetected, + runtimeAvailable: status.runtimeAvailable, + health: status.runtimeAvailable ? "ready" : status.runtimeDetected ? "reachable" : "not_configured", + ...(source ? { source: source === "cursor-env" ? "env" : "store" as const } : {}), + path: status.path, + blocker: status.blocker, + lastCheckedAt: status.lastCheckedAt, + }; +} + +function normalizeConfiguredLocalProvider( + configs: AiLocalProviderConfigs, + provider: LocalProviderFamily, +): { + enabled: boolean; + endpoint?: string; + autoDetect: boolean; + preferredModelId?: string | null; +} { + const entry = configs[provider]; + return { + enabled: entry?.enabled ?? true, + ...(typeof entry?.endpoint === "string" && entry.endpoint.trim().length + ? { endpoint: entry.endpoint.trim() } + : {}), + autoDetect: entry?.autoDetect ?? true, + preferredModelId: entry?.preferredModelId ?? null, + }; +} + +function createLocalRuntimeConnectionFromInspection(args: { + provider: LocalProviderFamily; + endpoint: string; + source: "config" | "auto"; + inspection: Awaited>; + checkedAt: string; +}): AiRuntimeConnectionStatus { + const label = LOCAL_PROVIDER_LABELS[args.provider]; + const loadedModelIds = args.inspection.loadedModels.map((model) => `${args.provider}/${model.modelId}`); + let blocker: string | null = null; + if (args.inspection.health === "reachable_no_models") { + blocker = `${label} is reachable, but no models are currently loaded.`; + } else if (args.inspection.health === "unreachable") { + blocker = `${label} did not respond at ${args.endpoint}.`; + } + return { + provider: args.provider, + label, + kind: "local", + configured: true, + authAvailable: args.inspection.health === "ready", + runtimeDetected: args.inspection.reachable, + runtimeAvailable: args.inspection.health === "ready", + health: args.inspection.health, + source: args.source, + endpoint: args.endpoint, + blocker, + ...(loadedModelIds.length ? { loadedModelIds } : {}), + lastCheckedAt: args.checkedAt, + }; +} + +async function buildLocalRuntimeConnection(args: { + provider: LocalProviderFamily; + configuredLocalProviders: AiLocalProviderConfigs; + auth: DetectedAuth[]; + checkedAt: string; +}): Promise { + const providerConfig = normalizeConfiguredLocalProvider(args.configuredLocalProviders, args.provider); + const label = LOCAL_PROVIDER_LABELS[args.provider]; + if (!providerConfig.enabled) { + return { + provider: args.provider, + label, + kind: "local", + configured: false, + authAvailable: false, + runtimeDetected: false, + runtimeAvailable: false, + health: "not_configured", + blocker: `${label} is disabled in project AI settings.`, + lastCheckedAt: args.checkedAt, + }; + } + + const detected = args.auth.find( + (entry): entry is Extract => + entry.type === "local" && entry.provider === args.provider, + ); + if (detected) { + const inspection = await inspectLocalProvider(args.provider, detected.endpoint); + return createLocalRuntimeConnectionFromInspection({ + provider: args.provider, + endpoint: detected.endpoint, + source: detected.endpointSource === "config" ? "config" : "auto", + inspection, + checkedAt: args.checkedAt, + }); + } + + const configuredEndpoint = providerConfig.endpoint; + if (configuredEndpoint) { + const manualInspection = await inspectLocalProvider(args.provider, configuredEndpoint); + if (manualInspection.reachable || !providerConfig.autoDetect) { + const status = createLocalRuntimeConnectionFromInspection({ + provider: args.provider, + endpoint: configuredEndpoint, + source: "config", + inspection: manualInspection, + checkedAt: args.checkedAt, + }); + if (!manualInspection.reachable && !providerConfig.autoDetect) { + status.health = "unreachable"; + } + return status; + } + } + + if (providerConfig.autoDetect) { + const autoEndpoint = getLocalProviderDefaultEndpoint(args.provider); + if (!configuredEndpoint || autoEndpoint.replace(/\/+$/, "") !== configuredEndpoint.replace(/\/+$/, "")) { + const autoInspection = await inspectLocalProvider(args.provider, autoEndpoint); + if (autoInspection.reachable) { + return createLocalRuntimeConnectionFromInspection({ + provider: args.provider, + endpoint: autoEndpoint, + source: "auto", + inspection: autoInspection, + checkedAt: args.checkedAt, + }); + } + } + + return { + provider: args.provider, + label, + kind: "local", + configured: true, + authAvailable: false, + runtimeDetected: false, + runtimeAvailable: false, + health: "unreachable", + source: "config", + endpoint: configuredEndpoint, + blocker: `${label} is configured for ${configuredEndpoint}, but the runtime did not respond.`, + lastCheckedAt: args.checkedAt, + }; + } + + return { + provider: args.provider, + label, + kind: "local", + configured: false, + authAvailable: false, + runtimeDetected: false, + runtimeAvailable: false, + health: "not_configured", + blocker: `No ${label} runtime with loaded models was detected.`, + lastCheckedAt: args.checkedAt, + }; +} + +async function buildRuntimeConnections(args: { + configuredLocalProviders: AiLocalProviderConfigs; + auth: DetectedAuth[]; + providerConnections: AiProviderConnections; +}): Promise { + const checkedAt = new Date().toISOString(); + const runtimeConnections: AiRuntimeConnections = { + claude: toCliRuntimeConnection(args.providerConnections.claude), + codex: toCliRuntimeConnection(args.providerConnections.codex), + cursor: toCliRuntimeConnection(args.providerConnections.cursor), + }; + + for (const authEntry of args.auth) { + if (authEntry.type === "api-key") { + runtimeConnections[authEntry.provider] = { + provider: authEntry.provider, + label: apiProviderLabel(authEntry.provider), + kind: "api-key", + configured: true, + authAvailable: true, + runtimeDetected: true, + runtimeAvailable: true, + health: "ready", + source: authEntry.source, + blocker: null, + lastCheckedAt: checkedAt, + }; + continue; + } + if (authEntry.type === "openrouter") { + runtimeConnections.openrouter = { + provider: "openrouter", + label: "OpenRouter", + kind: "openrouter", + configured: true, + authAvailable: true, + runtimeDetected: true, + runtimeAvailable: true, + health: "ready", + source: authEntry.source, + blocker: null, + lastCheckedAt: checkedAt, + }; + } + } + + for (const provider of ["ollama", "lmstudio", "vllm"] as const) { + runtimeConnections[provider] = await buildLocalRuntimeConnection({ + provider, + configuredLocalProviders: args.configuredLocalProviders, + auth: args.auth, + checkedAt, + }); + } + + return runtimeConnections; +} + export function createAiIntegrationService(args: { db: AdeDb; logger: Logger; @@ -371,7 +666,10 @@ export function createAiIntegrationService(args: { const detectAuth = async (options?: { force?: boolean }): Promise => { const snapshot = projectConfigService.get(); - return await detectAllAuth(extractConfiguredApiKeys(snapshot), options); + return await detectAllAuth(extractConfiguredApiKeys(snapshot), { + ...options, + localProviders: extractConfiguredLocalProviders(snapshot), + }); }; const deriveMode = (args: { @@ -426,6 +724,21 @@ export function createAiIntegrationService(args: { }; const getResolvedAvailableModels = async (auth: DetectedAuth[]) => { + const discoveredLocalModels = await discoverLocalModels(auth); + replaceDynamicLocalModelDescriptors( + discoveredLocalModels.map((model) => + createDynamicLocalModelDescriptor(model.provider, model.modelId, { + ...(model.displayName ? { displayName: model.displayName } : {}), + ...(model.contextWindow ? { contextWindow: model.contextWindow } : {}), + ...(model.maxOutputTokens ? { maxOutputTokens: model.maxOutputTokens } : {}), + ...(model.capabilities ? { capabilities: model.capabilities } : {}), + ...(model.reasoningTiers?.length ? { reasoningTiers: model.reasoningTiers } : {}), + ...(model.harnessProfile ? { harnessProfile: model.harnessProfile } : {}), + ...(model.discoverySource ? { discoverySource: model.discoverySource } : {}), + }), + ), + ); + let available = getAvailableModels(auth); const hasCursorCliAuth = auth.some( @@ -447,23 +760,7 @@ export function createAiIntegrationService(args: { } } - const discoveredLocalModels = await discoverLocalModels(auth); - if (discoveredLocalModels.length === 0) { - return available; - } - - const providersWithDynamicModels = new Set( - discoveredLocalModels.map((model) => model.provider), - ); - const filteredStatic = available.filter((descriptor) => !( - descriptor.authTypes.includes("local") - && providersWithDynamicModels.has(descriptor.family as "ollama" | "lmstudio" | "vllm") - )); - - return [ - ...filteredStatic, - ...discoveredLocalModels.map((model) => createDynamicLocalModelDescriptor(model.provider, model.modelId)), - ]; + return available; }; const verifyApiKeyConnection = async (provider: string): Promise => { @@ -902,6 +1199,7 @@ export function createAiIntegrationService(args: { if (options?.force) { resetProviderRuntimeHealth(); resetClaudeRuntimeProbeCache(); + resetLocalProviderDetectionCache(); clearCursorCliModelsCache(); modelListCache.clear(); runtimeHealthVersion = getProviderRuntimeHealthVersion(); @@ -921,6 +1219,12 @@ export function createAiIntegrationService(args: { runtimeHealthVersion = getProviderRuntimeHealthVersion(); } const providerConnections = await buildProviderConnections(cliStatuses); + const configuredLocalProviders = extractConfiguredLocalProviders(projectConfigService.get()); + const runtimeConnections = await buildRuntimeConnections({ + configuredLocalProviders, + auth, + providerConnections, + }); const availability = { claude: providerConnections.claude.runtimeAvailable, codex: providerConnections.codex.runtimeAvailable, @@ -943,6 +1247,7 @@ export function createAiIntegrationService(args: { }, detectedAuth: redactDetectedAuth(auth, cliStatuses), providerConnections, + runtimeConnections, availableModelIds: runtimeFilteredAvailable.map((descriptor) => descriptor.id), apiKeyStore: getApiKeyStoreStatus(), }; diff --git a/apps/desktop/src/main/services/ai/authDetector.test.ts b/apps/desktop/src/main/services/ai/authDetector.test.ts index d18ecdc4a..d8437ab96 100644 --- a/apps/desktop/src/main/services/ai/authDetector.test.ts +++ b/apps/desktop/src/main/services/ai/authDetector.test.ts @@ -235,6 +235,47 @@ describe("authDetector", () => { expect(auth.some((entry) => entry.type === "local" && entry.provider === "lmstudio")).toBe(false); }); + it("falls back from an empty configured LM Studio endpoint to the auto-detected endpoint and preserves the preferred model", async () => { + vi.stubGlobal( + "fetch", + vi.fn(async (url: string) => { + if (url === "http://lmstudio.example:1234/api/v1/models") { + return new Response("{}", { status: 404 }); + } + if (url === "http://lmstudio.example:1234/v1/models") { + return new Response(JSON.stringify({ data: [] }), { status: 200 }); + } + if (url === "http://localhost:1234/api/v1/models") { + return new Response("{}", { status: 404 }); + } + if (url === "http://localhost:1234/v1/models") { + return new Response(JSON.stringify({ + data: [{ id: "gemma-4" }], + }), { status: 200 }); + } + return new Response("{}", { status: 503 }); + }), + ); + + const auth = await detectAllAuth({}, { + localProviders: { + lmstudio: { + endpoint: "http://lmstudio.example:1234", + autoDetect: true, + preferredModelId: "lmstudio/gemma-4", + }, + }, + }); + + expect(auth).toContainEqual(expect.objectContaining({ + type: "local", + provider: "lmstudio", + endpoint: "http://localhost:1234", + endpointSource: "auto", + preferredModelId: "lmstudio/gemma-4", + })); + }); + it("marks unsupported CLI auth checks as unverified", async () => { spawnMock.mockImplementation((command: string, args: string[] = []) => { if (args[0] === "--version") { diff --git a/apps/desktop/src/main/services/ai/authDetector.ts b/apps/desktop/src/main/services/ai/authDetector.ts index 177e4eaf6..cb1c35ed4 100644 --- a/apps/desktop/src/main/services/ai/authDetector.ts +++ b/apps/desktop/src/main/services/ai/authDetector.ts @@ -7,6 +7,9 @@ import { augmentProcessPathWithShellAndKnownCliDirs, resolveExecutableFromKnownLocations, } from "./cliExecutableResolver"; +import { getLocalProviderDefaultEndpoint, type LocalProviderFamily } from "../../../shared/modelRegistry"; +import type { AiLocalProviderConfigs } from "../../../shared/types"; +import { inspectLocalProvider, clearLocalProviderInspectionCache } from "./localModelDiscovery"; type CliName = "claude" | "codex" | "cursor"; @@ -43,7 +46,13 @@ export type DetectedAuth = } | { type: "api-key"; provider: string; key: string; source: ApiKeySource } | { type: "openrouter"; key: string; source: ApiKeySource } - | { type: "local"; provider: "ollama" | "lmstudio" | "vllm"; endpoint: string }; + | { + type: "local"; + provider: "ollama" | "lmstudio" | "vllm"; + endpoint: string; + endpointSource?: "auto" | "config"; + preferredModelId?: string | null; + }; // --------------------------------------------------------------------------- // Internals @@ -98,8 +107,6 @@ const WEAK_UNAUTH_INDICATORS = [ /run .*login/i, ]; -const UNAUTH_INDICATORS = [...STRONG_UNAUTH_INDICATORS, ...WEAK_UNAUTH_INDICATORS]; - const UNSUPPORTED_INDICATORS = [ /unknown command/i, /unrecognized/i, @@ -373,8 +380,14 @@ const API_KEY_VERIFY_TIMEOUT_MS = 8_000; let cachedLocalProviders: | { + key: string; checkedAtMs: number; - entries: Array<{ provider: "ollama" | "lmstudio" | "vllm"; endpoint: string }>; + entries: Array<{ + provider: LocalProviderFamily; + endpoint: string; + endpointSource: "auto" | "config"; + preferredModelId?: string | null; + }>; } | null = null; @@ -412,67 +425,124 @@ async function readStoredApiKeys(): Promise> { } } -async function checkLocalEndpointHasModels( - provider: "ollama" | "lmstudio" | "vllm", - url: string, - timeoutMs = 2_000, -): Promise { - try { - const controller = new AbortController(); - const timer = setTimeout(() => controller.abort(), timeoutMs); - const res = await fetch(url, { method: "GET", signal: controller.signal }); - clearTimeout(timer); - if (!res.ok) return false; - - const payload = await res.json() as unknown; - if (provider === "ollama") { - const models: Array<{ name?: unknown }> = Array.isArray((payload as { models?: unknown[] })?.models) - ? ((payload as { models?: Array<{ name?: unknown }> }).models ?? []) - : []; - return models.some((entry) => typeof entry?.name === "string" && entry.name.trim().length > 0); - } +type NormalizedLocalProviderConfig = { + enabled: boolean; + endpoint?: string; + autoDetect: boolean; + preferredModelId?: string | null; +}; - const models: Array<{ id?: unknown }> = Array.isArray((payload as { data?: unknown[] })?.data) - ? ((payload as { data?: Array<{ id?: unknown }> }).data ?? []) - : []; - return models.some((entry) => typeof entry?.id === "string" && entry.id.trim().length > 0); - } catch { - return false; - } +function normalizeLocalProviderConfig( + config?: AiLocalProviderConfigs, +): Record { + return { + ollama: { + enabled: config?.ollama?.enabled ?? true, + ...(typeof config?.ollama?.endpoint === "string" && config.ollama.endpoint.trim().length + ? { endpoint: config.ollama.endpoint.trim() } + : {}), + autoDetect: config?.ollama?.autoDetect ?? true, + preferredModelId: config?.ollama?.preferredModelId ?? null, + }, + lmstudio: { + enabled: config?.lmstudio?.enabled ?? true, + ...(typeof config?.lmstudio?.endpoint === "string" && config.lmstudio.endpoint.trim().length + ? { endpoint: config.lmstudio.endpoint.trim() } + : {}), + autoDetect: config?.lmstudio?.autoDetect ?? true, + preferredModelId: config?.lmstudio?.preferredModelId ?? null, + }, + vllm: { + enabled: config?.vllm?.enabled ?? true, + ...(typeof config?.vllm?.endpoint === "string" && config.vllm.endpoint.trim().length + ? { endpoint: config.vllm.endpoint.trim() } + : {}), + autoDetect: config?.vllm?.autoDetect ?? true, + preferredModelId: config?.vllm?.preferredModelId ?? null, + }, + }; +} + +function localProvidersCacheKey(config?: AiLocalProviderConfigs): string { + const normalized = normalizeLocalProviderConfig(config); + return (["ollama", "lmstudio", "vllm"] as const) + .map((provider) => { + const entry = normalized[provider]; + return [ + provider, + entry.enabled ? "1" : "0", + entry.autoDetect ? "1" : "0", + entry.endpoint ?? "", + entry.preferredModelId ?? "", + ].join(":"); + }) + .join("|"); } -async function detectLocalProviders(): Promise> { +async function detectLocalProviders( + config?: AiLocalProviderConfigs, +): Promise> { const now = Date.now(); - if (cachedLocalProviders && now - cachedLocalProviders.checkedAtMs < LOCAL_ENDPOINT_CACHE_TTL_MS) { + const cacheKey = localProvidersCacheKey(config); + if ( + cachedLocalProviders + && cachedLocalProviders.key === cacheKey + && now - cachedLocalProviders.checkedAtMs < LOCAL_ENDPOINT_CACHE_TTL_MS + ) { return cachedLocalProviders.entries; } - const localEndpoints: Array<{ - provider: "ollama" | "lmstudio" | "vllm"; - url: string; - }> = [ - { provider: "ollama", url: "http://localhost:11434/api/tags" }, - { provider: "lmstudio", url: "http://localhost:1234/v1/models" }, - { provider: "vllm", url: "http://localhost:8000/v1/models" }, - ]; - - const localChecks = await Promise.allSettled( - localEndpoints.map(async ({ provider, url }) => { - const alive = await checkLocalEndpointHasModels(provider, url, LOCAL_ENDPOINT_CHECK_TIMEOUT_MS); - if (!alive) return null; - const endpoint = url.replace(/\/api\/tags$|\/v1\/models$/, ""); - return { provider, endpoint } as const; - }), - ); + const normalized = normalizeLocalProviderConfig(config); + const entries: Array<{ + provider: LocalProviderFamily; + endpoint: string; + endpointSource: "auto" | "config"; + preferredModelId?: string | null; + }> = []; + + for (const provider of ["ollama", "lmstudio", "vllm"] as const) { + const providerConfig = normalized[provider]; + if (!providerConfig.enabled) continue; + + const configuredEndpoint = providerConfig.endpoint; + if (configuredEndpoint) { + const inspection = await inspectLocalProvider(provider, configuredEndpoint, LOCAL_ENDPOINT_CHECK_TIMEOUT_MS); + if (inspection.health === "ready") { + entries.push({ + provider, + endpoint: configuredEndpoint, + endpointSource: "config", + preferredModelId: providerConfig.preferredModelId ?? null, + }); + continue; + } + if (!providerConfig.autoDetect) { + continue; + } + } - const entries: Array<{ provider: "ollama" | "lmstudio" | "vllm"; endpoint: string }> = []; - for (const check of localChecks) { - if (check.status === "fulfilled" && check.value) { - entries.push(check.value); + if (!providerConfig.autoDetect) continue; + const autoEndpoint = getLocalProviderDefaultEndpoint(provider); + if (configuredEndpoint && autoEndpoint.replace(/\/+$/, "") === configuredEndpoint.replace(/\/+$/, "")) { + continue; + } + const autoInspection = await inspectLocalProvider(provider, autoEndpoint, LOCAL_ENDPOINT_CHECK_TIMEOUT_MS); + if (autoInspection.health === "ready") { + entries.push({ + provider, + endpoint: autoEndpoint, + endpointSource: "auto", + preferredModelId: providerConfig.preferredModelId ?? null, + }); } } - cachedLocalProviders = { checkedAtMs: now, entries }; + cachedLocalProviders = { key: cacheKey, checkedAtMs: now, entries }; return entries; } @@ -747,7 +817,7 @@ export async function detectCliAuthStatuses(options?: { force?: boolean }): Prom export async function detectAllAuth( configApiKeys?: Record, - options?: { force?: boolean }, + options?: { force?: boolean; localProviders?: AiLocalProviderConfigs }, ): Promise { const results: DetectedAuth[] = []; @@ -823,10 +893,15 @@ export async function detectAllAuth( } // 4. Local providers - const localProviders = await detectLocalProviders(); + const localProviders = await detectLocalProviders(options?.localProviders); for (const localProvider of localProviders) { results.push({ type: "local", ...localProvider }); } return results; } + +export function resetLocalProviderDetectionCache(): void { + cachedLocalProviders = null; + clearLocalProviderInspectionCache(); +} diff --git a/apps/desktop/src/main/services/ai/localModelDiscovery.ts b/apps/desktop/src/main/services/ai/localModelDiscovery.ts index 6686e1ef0..6f93ba5f7 100644 --- a/apps/desktop/src/main/services/ai/localModelDiscovery.ts +++ b/apps/desktop/src/main/services/ai/localModelDiscovery.ts @@ -1,19 +1,43 @@ import type { DetectedAuth } from "./authDetector"; -import type { LocalProviderFamily } from "../../../shared/modelRegistry"; +import type { LocalModelHarnessProfile, LocalProviderFamily, ModelCapabilities, ModelDescriptor } from "../../../shared/modelRegistry"; export type DiscoveredLocalModel = { provider: LocalProviderFamily; modelId: string; + displayName?: string; + contextWindow?: number; + maxOutputTokens?: number; + capabilities?: Partial; + reasoningTiers?: string[]; + harnessProfile?: LocalModelHarnessProfile; + discoverySource?: ModelDescriptor["discoverySource"]; +}; + +export type LocalProviderConnectionHealth = + | "ready" + | "reachable_no_models" + | "unreachable"; + +export type LocalProviderInspection = { + provider: LocalProviderFamily; + endpoint: string; + reachable: boolean; + health: LocalProviderConnectionHealth; + loadedModels: DiscoveredLocalModel[]; }; const CACHE_TTL_MS = 30_000; +let inspectionCacheGeneration = 0; -let cache: { +let discoverCache: { key: string; + generation: number; cachedAt: number; models: DiscoveredLocalModel[]; } | null = null; +let inspectionCache = new Map(); + function buildCacheKey(auth: DetectedAuth[]): string { return auth .filter((entry): entry is Extract => entry.type === "local") @@ -22,52 +46,337 @@ function buildCacheKey(auth: DetectedAuth[]): string { .join("|"); } -async function fetchLocalModelIds( - provider: LocalProviderFamily, - endpoint: string, -): Promise { +function buildInspectionKey(provider: LocalProviderFamily, endpoint: string): string { + return `${provider}:${endpoint.replace(/\/+$/, "")}`; +} + +function normalizeString(value: unknown): string | null { + return typeof value === "string" && value.trim().length > 0 ? value.trim() : null; +} + +function normalizeBoolean(value: unknown): boolean | undefined { + return typeof value === "boolean" ? value : undefined; +} + +function normalizeNumber(value: unknown): number | undefined { + return Number.isFinite(Number(value)) ? Number(value) : undefined; +} + +function normalizeReasoningTier(value: unknown): string | null { + if (typeof value !== "string") return null; + const normalized = value.trim().toLowerCase(); + if (!normalized.length) return null; + if (normalized === "max") return "xhigh"; + if (normalized === "none" || normalized === "low" || normalized === "medium" || normalized === "high" || normalized === "xhigh") { + return normalized; + } + return null; +} + +function dedupeReasoningTiers(values: Array): string[] | undefined { + const seen = new Set(); + const tiers: string[] = []; + for (const value of values) { + if (!value || seen.has(value)) continue; + seen.add(value); + tiers.push(value); + } + return tiers.length ? tiers : undefined; +} + +function normalizeReasoningConfig(value: unknown): { tiers?: string[]; supportsReasoning: boolean } { + if (typeof value === "boolean") { + return value ? { tiers: ["low", "medium", "high"], supportsReasoning: true } : { supportsReasoning: false }; + } + if (typeof value === "string") { + const tier = normalizeReasoningTier(value); + return { ...(tier ? { tiers: [tier] } : {}), supportsReasoning: Boolean(tier) }; + } + if (Array.isArray(value)) { + const tiers = dedupeReasoningTiers(value.map((entry) => normalizeReasoningTier(entry))); + return { ...(tiers ? { tiers } : {}), supportsReasoning: Boolean(tiers?.length) }; + } + if (value && typeof value === "object") { + const record = value as Record; + const enabled = normalizeBoolean(record.enabled) ?? normalizeBoolean(record.supported) ?? normalizeBoolean(record.available); + const tiers = dedupeReasoningTiers([ + ...((Array.isArray(record.supported_efforts) ? record.supported_efforts : []) as unknown[]).map((entry) => normalizeReasoningTier(entry)), + ...((Array.isArray(record.supportedEfforts) ? record.supportedEfforts : []) as unknown[]).map((entry) => normalizeReasoningTier(entry)), + ...((Array.isArray(record.efforts) ? record.efforts : []) as unknown[]).map((entry) => normalizeReasoningTier(entry)), + ...((Array.isArray(record.levels) ? record.levels : []) as unknown[]).map((entry) => normalizeReasoningTier(entry)), + normalizeReasoningTier(record.default_effort), + normalizeReasoningTier(record.defaultEffort), + ]); + return { + ...(tiers ? { tiers } : {}), + supportsReasoning: enabled ?? Boolean(tiers?.length), + }; + } + return { supportsReasoning: false }; +} + +function inferVisionFromModelId(modelId: string): boolean { + const lower = modelId.toLowerCase(); + return /(\bvl\b|vision|llava|gemma[-_ ]?(3|4)|qwen2\.?5[-_ ]?vl|llama[-_ ]?3\.2.*vision)/i.test(lower); +} + +function inferNativeToolSupport(modelId: string): boolean { + return /(qwen|llama[-_ ]?3\.(1|2)|ministral|mistral|gpt-oss)/i.test(modelId); +} + +function inferHarnessProfile(args: { + modelId: string; + type?: string | null; + trainedForToolUse?: boolean; +}): LocalModelHarnessProfile { + if (args.type === "embedding") return "read_only"; + if (args.trainedForToolUse) return "verified"; + return inferNativeToolSupport(args.modelId) ? "verified" : "guarded"; +} + +function inferFallbackCapabilities(modelId: string): { + capabilities: Partial; + harnessProfile: LocalModelHarnessProfile; +} { + const lower = modelId.toLowerCase(); + if (/embedding|embed|bge-|nomic-embed|gte-|e5-|rerank|reranker/.test(lower)) { + return { + capabilities: { tools: false, vision: false, reasoning: false, streaming: true }, + harnessProfile: "read_only", + }; + } + const harnessProfile = inferHarnessProfile({ modelId }); + const reasoning = /\breason(ing)?\b|qwq|r1|deepseek-r1|phi-4.*reasoning|nemotron/i.test(lower); + return { + capabilities: { + tools: true, + vision: inferVisionFromModelId(modelId), + reasoning, + streaming: true, + }, + harnessProfile, + }; +} + +function toLoadedInstanceDisplayName(baseDisplayName: string | null, instanceId: string, multiInstance: boolean): string | undefined { + const display = baseDisplayName?.trim() || instanceId; + return multiInstance ? `${display} (${instanceId})` : display; +} + +async function fetchJson(url: string, timeoutMs: number): Promise { const controller = new AbortController(); - const timeout = setTimeout(() => controller.abort(), 2_000); + const timer = setTimeout(() => controller.abort(), timeoutMs); try { - const url = provider === "ollama" - ? `${endpoint.replace(/\/+$/, "")}/api/tags` - : `${endpoint.replace(/\/+$/, "")}/v1/models`; const response = await fetch(url, { method: "GET", signal: controller.signal }); - if (!response.ok) return []; - const payload = await response.json() as unknown; - if (provider === "ollama") { - const models = Array.isArray((payload as { models?: unknown[] })?.models) - ? ((payload as { models?: Array<{ name?: unknown }> }).models ?? []) - : []; - return models - .map((entry) => (typeof entry?.name === "string" ? entry.name.trim() : "")) - .filter(Boolean); - } - const models = Array.isArray((payload as { data?: unknown[] })?.data) - ? ((payload as { data?: Array<{ id?: unknown }> }).data ?? []) - : []; - return models - .map((entry) => (typeof entry?.id === "string" ? entry.id.trim() : "")) - .filter(Boolean); + if (!response.ok) return null; + return await response.json() as unknown; } catch { - return []; + return null; } finally { - clearTimeout(timeout); + clearTimeout(timer); + } +} + +async function inspectLmStudioProvider(endpoint: string, timeoutMs: number): Promise { + const base = endpoint.replace(/\/+$/, ""); + const restPayload = await fetchJson(`${base}/api/v1/models`, timeoutMs); + + if (restPayload && typeof restPayload === "object") { + const models = Array.isArray((restPayload as { models?: unknown[] }).models) + ? (restPayload as { models: Array> }).models + : []; + + const discovered: DiscoveredLocalModel[] = []; + + for (const model of models) { + const loadedInstances = Array.isArray(model.loaded_instances) ? model.loaded_instances as Array> : []; + if (!loadedInstances.length) continue; + + const type = normalizeString(model.type); + const displayName = normalizeString(model.display_name) ?? normalizeString(model.key); + const maxContextLength = normalizeNumber(model.max_context_length); + const capabilitiesRecord = model.capabilities && typeof model.capabilities === "object" + ? model.capabilities as Record + : null; + const trainedForToolUse = normalizeBoolean(capabilitiesRecord?.trained_for_tool_use) ?? false; + const reasoning = normalizeReasoningConfig((model as Record).reasoning); + const multiInstance = loadedInstances.length > 1; + + for (const instance of loadedInstances) { + const instanceId = normalizeString(instance.id); + if (!instanceId) continue; + const config = instance.config && typeof instance.config === "object" + ? instance.config as Record + : null; + const contextWindow = normalizeNumber(config?.context_length) ?? maxContextLength; + const harnessProfile = inferHarnessProfile({ modelId: instanceId, type, trainedForToolUse }); + discovered.push({ + provider: "lmstudio", + modelId: instanceId, + displayName: toLoadedInstanceDisplayName(displayName, instanceId, multiInstance), + ...(contextWindow ? { contextWindow } : {}), + maxOutputTokens: 8_192, + capabilities: { + tools: type !== "embedding", + vision: normalizeBoolean(capabilitiesRecord?.vision) ?? inferVisionFromModelId(instanceId), + reasoning: reasoning.supportsReasoning, + streaming: true, + }, + ...(reasoning.tiers?.length ? { reasoningTiers: reasoning.tiers } : {}), + harnessProfile, + discoverySource: "lmstudio-rest", + }); + } + } + + return { + provider: "lmstudio", + endpoint, + reachable: true, + health: discovered.length ? "ready" : "reachable_no_models", + loadedModels: discovered, + }; + } + + const openAiPayload = await fetchJson(`${base}/v1/models`, timeoutMs); + const models = Array.isArray((openAiPayload as { data?: unknown[] } | null)?.data) + ? ((openAiPayload as { data: Array> }).data) + : []; + const discovered = models + .map((entry) => normalizeString(entry.id)) + .filter((modelId): modelId is string => Boolean(modelId)) + .map((modelId) => { + const fallback = inferFallbackCapabilities(modelId); + return { + provider: "lmstudio" as const, + modelId, + displayName: modelId, + maxOutputTokens: 8_192, + capabilities: fallback.capabilities, + harnessProfile: fallback.harnessProfile, + discoverySource: "lmstudio-openai" as const, + } satisfies DiscoveredLocalModel; + }); + + return { + provider: "lmstudio", + endpoint, + reachable: openAiPayload != null, + health: !openAiPayload ? "unreachable" : discovered.length ? "ready" : "reachable_no_models", + loadedModels: discovered, + }; +} + +async function inspectOpenAiCompatibleProvider( + provider: Exclude, + endpoint: string, + timeoutMs: number, +): Promise { + const base = endpoint.replace(/\/+$/, ""); + const path = provider === "ollama" ? "/api/tags" : "/v1/models"; + const payload = await fetchJson(`${base}${path}`, timeoutMs); + if (payload == null) { + return { + provider, + endpoint, + reachable: false, + health: "unreachable", + loadedModels: [], + }; + } + + const discovered = provider === "ollama" + ? (Array.isArray((payload as { models?: unknown[] }).models) + ? (payload as { models: Array> }).models + : []) + .map((entry) => normalizeString(entry.name)) + .filter((modelId): modelId is string => Boolean(modelId)) + .map((modelId) => { + const fallback = inferFallbackCapabilities(modelId); + return { + provider, + modelId, + displayName: modelId, + maxOutputTokens: 8_192, + capabilities: fallback.capabilities, + harnessProfile: fallback.harnessProfile, + discoverySource: provider, + } satisfies DiscoveredLocalModel; + }) + : (Array.isArray((payload as { data?: unknown[] }).data) + ? (payload as { data: Array> }).data + : []) + .map((entry) => normalizeString(entry.id)) + .filter((modelId): modelId is string => Boolean(modelId)) + .map((modelId) => { + const fallback = inferFallbackCapabilities(modelId); + return { + provider, + modelId, + displayName: modelId, + maxOutputTokens: 8_192, + capabilities: fallback.capabilities, + harnessProfile: fallback.harnessProfile, + discoverySource: provider, + } satisfies DiscoveredLocalModel; + }); + + return { + provider, + endpoint, + reachable: true, + health: discovered.length ? "ready" : "reachable_no_models", + loadedModels: discovered, + }; +} + +export async function inspectLocalProvider( + provider: LocalProviderFamily, + endpoint: string, + timeoutMs = 2_000, +): Promise { + const key = buildInspectionKey(provider, endpoint); + const generation = inspectionCacheGeneration; + const cached = inspectionCache.get(key); + const now = Date.now(); + if (cached && cached.generation === generation && now - cached.cachedAt < CACHE_TTL_MS) { + return cached.inspection; } + + const inspection = provider === "lmstudio" + ? await inspectLmStudioProvider(endpoint, timeoutMs) + : await inspectOpenAiCompatibleProvider(provider, endpoint, timeoutMs); + + if (generation === inspectionCacheGeneration) { + inspectionCache.set(key, { generation, cachedAt: now, inspection }); + } + return inspection; +} + +export function clearLocalProviderInspectionCache(): void { + inspectionCacheGeneration += 1; + inspectionCache = new Map(); + discoverCache = null; } export async function discoverLocalModels(auth: DetectedAuth[]): Promise { const key = buildCacheKey(auth); + const generation = inspectionCacheGeneration; const now = Date.now(); - if (cache && cache.key === key && now - cache.cachedAt < CACHE_TTL_MS) { - return cache.models; + if ( + discoverCache + && discoverCache.generation === generation + && discoverCache.key === key + && now - discoverCache.cachedAt < CACHE_TTL_MS + ) { + return discoverCache.models; } const providers = auth.filter((entry): entry is Extract => entry.type === "local"); const discovered = await Promise.all( providers.map(async (entry) => { - const modelIds = await fetchLocalModelIds(entry.provider, entry.endpoint); - return modelIds.map((modelId) => ({ provider: entry.provider, modelId })); + const inspection = await inspectLocalProvider(entry.provider, entry.endpoint); + return inspection.loadedModels; }), ); @@ -83,6 +392,8 @@ export async function discoverLocalModels(auth: DetectedAuth[]): Promise ({ createCodexCliMock: vi.fn(), @@ -36,6 +36,7 @@ describe("providerResolver codex CLI", () => { beforeEach(() => { createCodexCliMock.mockReset(); createClaudeCodeMock.mockReset(); + vi.unstubAllGlobals(); }); it("resolves Codex CLI models through the community provider with MCP settings", async () => { @@ -204,4 +205,33 @@ describe("providerResolver codex CLI", () => { }, }); }); + + it("uses the saved preferred local model without probing /v1/models", async () => { + const fetchSpy = vi.fn(); + vi.stubGlobal("fetch", fetchSpy); + + const resolved = await resolveAutoModelIdFromOpenAiCompatibleEndpoint( + "http://localhost:1234", + "lmstudio", + "lmstudio/qwen2.5-coder:32b", + ); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(resolved).toBe("qwen2.5-coder:32b"); + }); + + it("requires explicit selection when a local runtime reports multiple loaded models", async () => { + vi.stubGlobal("fetch", vi.fn(async () => + new Response(JSON.stringify({ + data: [ + { id: "meta-llama-3.1-70b-instruct" }, + { id: "qwen2.5-coder:32b" }, + ], + }), { status: 200 }), + )); + + await expect( + resolveAutoModelIdFromOpenAiCompatibleEndpoint("http://localhost:1234", "lmstudio"), + ).rejects.toThrow("Choose a specific model or save a preferred local model"); + }); }); diff --git a/apps/desktop/src/main/services/ai/providerResolver.ts b/apps/desktop/src/main/services/ai/providerResolver.ts index c61d79711..b363459e4 100644 --- a/apps/desktop/src/main/services/ai/providerResolver.ts +++ b/apps/desktop/src/main/services/ai/providerResolver.ts @@ -4,6 +4,7 @@ import type { LanguageModel } from "ai"; import { + getLocalProviderDefaultEndpoint, getModelById, resolveModelAlias, type ModelDescriptor, @@ -101,9 +102,12 @@ function findOpenRouterKey(auth: DetectedAuth[]): string | undefined { return undefined; } -function findLocalEndpoint(auth: DetectedAuth[], provider: string): string | undefined { +function findLocalProviderAuth( + auth: DetectedAuth[], + provider: string, +): Extract | undefined { for (const a of auth) { - if (a.type === "local" && a.provider === provider) return a.endpoint; + if (a.type === "local" && a.provider === provider) return a; } return undefined; } @@ -122,20 +126,32 @@ const COMPATIBLE_BASE_URLS: Record = { together: "https://api.together.xyz/v1", }; -const DEFAULT_LOCAL_ENDPOINTS: Record<"ollama" | "lmstudio" | "vllm", string> = { - ollama: "http://localhost:11434", - lmstudio: "http://localhost:1234", - vllm: "http://localhost:8000", -}; - function normalizeBaseUrl(url: string): string { return url.replace(/\/+$/, ""); } -async function resolveAutoModelIdFromOpenAiCompatibleEndpoint( +function normalizePreferredLocalModelId( + preferredModelId: string | null | undefined, + providerName: string, +): string | undefined { + const normalized = preferredModelId?.trim(); + if (!normalized) return undefined; + const prefix = `${providerName}/`; + if (normalized.toLowerCase().startsWith(prefix)) { + const trimmed = normalized.slice(prefix.length).trim(); + return trimmed || undefined; + } + return normalized; +} + +export async function resolveAutoModelIdFromOpenAiCompatibleEndpoint( endpoint: string, providerName: string, + preferredModelId?: string, ): Promise { + const preferred = normalizePreferredLocalModelId(preferredModelId ?? null, providerName); + if (preferred) return preferred; + const controller = new AbortController(); const timeout = setTimeout(() => controller.abort(), 5_000); try { @@ -147,11 +163,18 @@ async function resolveAutoModelIdFromOpenAiCompatibleEndpoint( throw new Error(`Failed to list models from ${providerName} (${response.status}).`); } const payload = await response.json() as { data?: Array<{ id?: unknown }> }; - const firstModelId = payload.data?.find((entry) => typeof entry?.id === "string" && entry.id.trim().length)?.id; - if (!firstModelId || typeof firstModelId !== "string") { + const modelIds = (payload.data ?? []) + .map((entry) => (typeof entry?.id === "string" ? entry.id.trim() : "")) + .filter(Boolean); + if (modelIds.length === 0) { throw new Error(`${providerName} did not return any usable model IDs.`); } - return firstModelId.trim(); + if (modelIds.length > 1) { + throw new Error( + `${providerName} has multiple loaded models (${modelIds.join(", ")}). Choose a specific model or save a preferred local model.`, + ); + } + return modelIds[0]!; } finally { clearTimeout(timeout); } @@ -161,9 +184,10 @@ async function resolveOpenAiCompatibleModelId( sdkModelId: string, endpoint: string, providerName: string, + preferredModelId?: string | null, ): Promise { if (sdkModelId !== "auto") return sdkModelId; - return resolveAutoModelIdFromOpenAiCompatibleEndpoint(endpoint, providerName); + return resolveAutoModelIdFromOpenAiCompatibleEndpoint(endpoint, providerName, preferredModelId ?? undefined); } // --------------------------------------------------------------------------- @@ -376,8 +400,14 @@ async function resolveDirectProvider( // Local providers (ollama, lmstudio, vllm) if (family === "ollama" || family === "lmstudio" || family === "vllm") { const localProvider = family; - const endpoint = findLocalEndpoint(auth, localProvider) ?? DEFAULT_LOCAL_ENDPOINTS[localProvider]; - const resolvedModelId = await resolveOpenAiCompatibleModelId(sdkModelId, endpoint, localProvider); + const localAuth = findLocalProviderAuth(auth, localProvider); + const endpoint = localAuth?.endpoint ?? getLocalProviderDefaultEndpoint(localProvider); + const resolvedModelId = await resolveOpenAiCompatibleModelId( + sdkModelId, + endpoint, + localProvider, + localAuth?.preferredModelId, + ); const createCompatible = await loadOpenAICompatibleProvider(); const provider = createCompatible({ name: localProvider, diff --git a/apps/desktop/src/main/services/ai/tools/universalTools.test.ts b/apps/desktop/src/main/services/ai/tools/universalTools.test.ts index cfe3e192a..3a3d94561 100644 --- a/apps/desktop/src/main/services/ai/tools/universalTools.test.ts +++ b/apps/desktop/src/main/services/ai/tools/universalTools.test.ts @@ -732,6 +732,63 @@ describe("createUniversalToolSet", () => { expect(result.answer).toBe("user answer"); }); + it("accepts structured askUser prompts and returns normalized answers", async () => { + const cwd = makeTmpDir("ade-tools-askuser-structured-"); + const onAskUser = vi.fn().mockResolvedValue({ + answer: "auth-refactor", + answers: { + roadmap: ["auth-refactor", "bug-fixes"], + }, + responseText: null, + decision: "accept", + }); + const tools = createUniversalToolSet(cwd, { + permissionMode: "plan", + onAskUser, + }); + + const result = await (tools.askUser as any).execute({ + title: "Mock plan question", + body: "Choose the most important sprint goals.", + questions: [ + { + id: "roadmap", + header: "Sprint scope", + question: "Which items should we prioritize?", + multiSelect: true, + options: [ + { label: "Auth refactor", value: "auth-refactor", recommended: true }, + { label: "Bug fixes", value: "bug-fixes" }, + ], + }, + ], + }); + + expect(onAskUser).toHaveBeenCalledWith({ + title: "Mock plan question", + body: "Choose the most important sprint goals.", + questions: [ + { + id: "roadmap", + header: "Sprint scope", + question: "Which items should we prioritize?", + multiSelect: true, + options: [ + { label: "Auth refactor", value: "auth-refactor", recommended: true }, + { label: "Bug fixes", value: "bug-fixes" }, + ], + }, + ], + }); + expect(result).toMatchObject({ + answer: "auth-refactor", + answers: { + roadmap: ["auth-refactor", "bug-fixes"], + }, + decision: "accept", + }); + }); + // ── exitPlanMode tool ─────────────────────────────────────────── it("does not expose exitPlanMode in non-plan permission modes", async () => { diff --git a/apps/desktop/src/main/services/ai/tools/universalTools.ts b/apps/desktop/src/main/services/ai/tools/universalTools.ts index cba2026fd..5700892db 100644 --- a/apps/desktop/src/main/services/ai/tools/universalTools.ts +++ b/apps/desktop/src/main/services/ai/tools/universalTools.ts @@ -20,6 +20,42 @@ const execFileAsync = promisify(execFile); export type PermissionMode = "plan" | "edit" | "full-auto"; +export type AskUserToolOption = { + label: string; + value?: string; + description?: string; + recommended?: boolean; + preview?: string; + previewFormat?: "markdown" | "html"; +}; + +export type AskUserToolQuestion = { + id?: string; + header?: string; + question: string; + options?: AskUserToolOption[]; + multiSelect?: boolean; + allowsFreeform?: boolean; + isSecret?: boolean; + defaultAssumption?: string | null; + impact?: string | null; +}; + +export type AskUserToolInput = { + question?: string; + title?: string; + body?: string; + questions?: AskUserToolQuestion[]; +}; + +export type AskUserToolResult = { + answer: string; + answers?: Record; + responseText?: string | null; + decision?: string; + error?: string; +}; + export interface UniversalToolSetOptions { permissionMode: PermissionMode; memoryService?: ReturnType; @@ -35,7 +71,7 @@ export interface UniversalToolSetOptions { updatedAt: string; }; /** Callback invoked when askUser tool is called; must return the user's response */ - onAskUser?: (question: string) => Promise; + onAskUser?: (input: AskUserToolInput) => Promise; /** Optional callback for ADE-managed tool approvals in interactive chat sessions. */ onApprovalRequest?: (request: ToolApprovalRequest) => Promise; /** Sandbox config for API-model workers. CLI models skip this check. */ @@ -907,21 +943,63 @@ function createGitLogTool(cwd: string) { }); } -function createAskUserTool(onAskUser?: (question: string) => Promise) { +const askUserToolOptionSchema = z.object({ + label: z.string().describe("User-facing option label"), + value: z.string().optional().describe("Optional stable value to return for this option"), + description: z.string().optional().describe("Optional short sentence explaining the option"), + recommended: z.boolean().optional().describe("Whether this option should be highlighted as recommended"), + preview: z.string().optional().describe("Optional preview content shown alongside the option"), + previewFormat: z.enum(["markdown", "html"]).optional().describe("Preview rendering format"), +}); + +const askUserToolQuestionSchema = z.object({ + id: z.string().optional().describe("Optional stable identifier for this question"), + header: z.string().optional().describe("Short question header shown in the UI"), + question: z.string().describe("The question to ask the user"), + options: z.array(askUserToolOptionSchema).optional().describe("Optional multiple-choice options"), + multiSelect: z.boolean().optional().describe("Allow selecting more than one option"), + allowsFreeform: z.boolean().optional().describe("Allow typing a custom answer"), + isSecret: z.boolean().optional().describe("Hide typed input while the user answers"), + defaultAssumption: z.string().nullable().optional().describe("Default assumption if the user skips this question"), + impact: z.string().nullable().optional().describe("Why this question matters"), +}); + +function createAskUserTool( + onAskUser?: (input: AskUserToolInput) => Promise, +) { return tool({ description: "Ask the user a clarifying question when you need more information to proceed. " + "Use sparingly — only when truly blocked.", inputSchema: z.object({ - question: z.string().describe("The question to ask the user"), - }), - execute: async ({ question }) => { + question: z.string().optional().describe("Simple text question to ask the user"), + title: z.string().optional().describe("Optional modal title for a richer prompt"), + body: z.string().optional().describe("Optional supporting context shown above the question list"), + questions: z.array(askUserToolQuestionSchema).optional().describe("Optional structured questions with choices"), + }).refine( + (value) => { + const question = typeof value.question === "string" ? value.question.trim() : ""; + const body = typeof value.body === "string" ? value.body.trim() : ""; + return question.length > 0 || body.length > 0 || Boolean(value.questions?.length); + }, + { message: "Provide question, body, or questions." }, + ), + execute: async (input) => { if (!onAskUser) { return { answer: "", error: "askUser callback not configured" }; } try { - const answer = await onAskUser(question); - return { answer }; + const response = await onAskUser(input); + if (typeof response === "string") { + return { answer: response }; + } + return { + answer: response.answer ?? "", + ...(response.answers ? { answers: response.answers } : {}), + ...(response.responseText !== undefined ? { responseText: response.responseText } : {}), + ...(response.decision !== undefined ? { decision: response.decision } : {}), + ...(response.error !== undefined ? { error: response.error } : {}), + }; } catch (err) { return { answer: "", diff --git a/apps/desktop/src/main/services/ai/tools/workflowTools.ts b/apps/desktop/src/main/services/ai/tools/workflowTools.ts index 4f2345eae..7263c1844 100644 --- a/apps/desktop/src/main/services/ai/tools/workflowTools.ts +++ b/apps/desktop/src/main/services/ai/tools/workflowTools.ts @@ -172,13 +172,12 @@ export function createWorkflowTools( error: "Local computer-use fallback is disabled for this chat session.", }; } + let tmpDir: string | null = null; try { - // Use macOS screencapture to grab the screen - const tmpPath = path.join( - fs.mkdtempSync(path.join(require("node:os").tmpdir(), "ade-screenshot-")), - `screenshot-${Date.now()}.png`, - ); + tmpDir = fs.mkdtempSync(path.join(require("node:os").tmpdir(), "ade-screenshot-")); + const tmpPath = path.join(tmpDir, `screenshot-${Date.now()}.png`); + // Use macOS screencapture to grab the screen await execFileAsync("screencapture", ["-x", tmpPath], { timeout: 15_000, }); @@ -211,11 +210,19 @@ export function createWorkflowTools( return { success: true, artifactId: artifact?.id ?? null, - uri: artifact?.uri ?? tmpPath, + uri: artifact?.uri ?? null, title: artifact?.title ?? title, }; } catch (err) { return formatToolError("Screenshot failed", err); + } finally { + try { + if (tmpDir) { + fs.rmSync(tmpDir, { recursive: true, force: true }); + } + } catch { + // Best-effort cleanup only. + } } }, }); diff --git a/apps/desktop/src/main/services/chat/agentChatService.test.ts b/apps/desktop/src/main/services/chat/agentChatService.test.ts index 7f038978f..5cfc53aeb 100644 --- a/apps/desktop/src/main/services/chat/agentChatService.test.ts +++ b/apps/desktop/src/main/services/chat/agentChatService.test.ts @@ -206,6 +206,10 @@ vi.mock("../ai/authDetector", () => ({ detectAllAuth: vi.fn(async () => []), })); +vi.mock("../ai/localModelDiscovery", () => ({ + discoverLocalModels: vi.fn(async () => []), +})); + vi.mock("../git/git", () => ({ runGit: vi.fn(async () => ({ stdout: "", stderr: "", exitCode: 0 })), })); @@ -285,11 +289,14 @@ vi.mock("./cursorAcpPool", () => ({ // Import system under test (after mocks) // --------------------------------------------------------------------------- import { + buildUnifiedStreamMessages, buildComputerUseDirective, createAgentChatService, + shouldAutoContinueUnifiedLocalTurn, } from "./agentChatService"; import { spawn } from "node:child_process"; import { detectAllAuth } from "../ai/authDetector"; +import { discoverLocalModels } from "../ai/localModelDiscovery"; import * as providerResolver from "../ai/providerResolver"; import { createUniversalToolSet } from "../ai/tools/universalTools"; import { createWorkflowTools } from "../ai/tools/workflowTools"; @@ -626,6 +633,8 @@ beforeEach(() => { vi.mocked(generateText).mockReset(); vi.mocked(unstable_v2_createSession).mockReset(); vi.mocked(detectAllAuth).mockResolvedValue([]); + vi.mocked(discoverLocalModels).mockReset(); + vi.mocked(discoverLocalModels).mockResolvedValue([]); vi.mocked(providerResolver.resolveModel).mockResolvedValue({} as any); vi.mocked(parseAgentChatTranscript).mockReturnValue([]); }); @@ -4504,6 +4513,89 @@ describe("createAgentChatService", () => { expect(updated?.permissionMode).toBe("edit"); expect(updated?.unifiedPermissionMode).toBe("edit"); }); + + it("detects narrated next-step early stops for local unified turns", () => { + expect(shouldAutoContinueUnifiedLocalTurn({ + modelDescriptor: { authTypes: ["local"] }, + permissionMode: "edit", + assistantText: "I will explore the src directory to identify where pages and routing are defined in the application.", + toolCallCount: 1, + toolResultCount: 1, + continuationCount: 0, + })).toBe(true); + + expect(shouldAutoContinueUnifiedLocalTurn({ + modelDescriptor: { authTypes: ["local"] }, + permissionMode: "plan", + assistantText: "I will explore the src directory to identify where pages and routing are defined in the application.", + toolCallCount: 1, + toolResultCount: 1, + continuationCount: 0, + })).toBe(false); + + expect(shouldAutoContinueUnifiedLocalTurn({ + modelDescriptor: { authTypes: ["local"] }, + permissionMode: "edit", + assistantText: "I found the routing file and can add the blank about page next.", + toolCallCount: 1, + toolResultCount: 1, + continuationCount: 0, + })).toBe(false); + }); + + it("preserves original attachments across local auto-continuation retries", () => { + const resolvedPath = path.join(tmpRoot, "note.txt"); + fs.writeFileSync(resolvedPath, "remember this", "utf8"); + + const streamMessages = buildUnifiedStreamMessages({ + messages: [ + { + role: "user", + content: "Add an about me page.\n\nAttached context:\n- file: note.txt", + }, + { + role: "assistant", + content: "I will explore the src directory to identify where pages and routing are defined in the application.", + }, + { + role: "user", + content: "Continue from your last step.", + }, + ], + persistedTurnUserMessageIndex: 0, + resolvedAttachments: [{ + path: "note.txt", + type: "file", + _rootPath: tmpRoot, + _resolvedPath: resolvedPath, + }], + modelDescriptor: { + id: "lmstudio/qwen2.5-coder:32b", + displayName: "qwen2.5-coder:32b", + family: "lmstudio", + authTypes: ["local"], + contextWindow: 0, + maxOutputTokens: 0, + capabilities: { tools: true, vision: false, reasoning: false, streaming: true }, + color: "#64748B", + sdkProvider: "@ai-sdk/openai-compatible", + sdkModelId: "qwen2.5-coder:32b", + isCliWrapped: false, + harnessProfile: "verified", + } as any, + logger: createLogger() as any, + }); + + expect(streamMessages).toHaveLength(3); + expect(streamMessages[0]?.content).toEqual(expect.arrayContaining([ + expect.objectContaining({ type: "text" }), + expect.objectContaining({ type: "file", filename: "note.txt" }), + ])); + expect(streamMessages[2]).toEqual({ + role: "user", + content: "Continue from your last step.", + }); + }); }); it("emits immediate startup activity before unified stream output arrives", async () => { diff --git a/apps/desktop/src/main/services/chat/agentChatService.ts b/apps/desktop/src/main/services/chat/agentChatService.ts index c18c3d12d..0e514c948 100644 --- a/apps/desktop/src/main/services/chat/agentChatService.ts +++ b/apps/desktop/src/main/services/chat/agentChatService.ts @@ -107,6 +107,7 @@ import type { AgentChatUnifiedPermissionMode, PendingInputQuestion, PendingInputRequest, + PendingInputSource, AgentChatUpdateSessionArgs, ComputerUseBackendStatus, ComputerUsePolicy, @@ -116,22 +117,26 @@ import type { CtoCapabilityMode, } from "../../../shared/types"; import { + createDynamicLocalModelDescriptor, getDefaultModelDescriptor, getModelById, getAvailableModels as getRegistryModels, listModelDescriptorsForProvider, + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, pickDefaultCursorDescriptorFromCliList, + replaceDynamicLocalModelDescriptors, resolveModelAlias, resolveModelDescriptorForProvider, resolveProviderGroupForModel, + type LocalProviderFamily, type ModelDescriptor, } from "../../../shared/modelRegistry"; import { canSwitchChatSessionModel } from "../../../shared/chatModelSwitching"; import { detectAllAuth } from "../ai/authDetector"; import * as providerResolver from "../ai/providerResolver"; import { buildCodexAppServerMcpConfigOverrides } from "../ai/codexAppServerConfig"; -import { createUniversalToolSet, type PermissionMode } from "../ai/tools/universalTools"; +import { createUniversalToolSet, type AskUserToolInput, type PermissionMode } from "../ai/tools/universalTools"; import { createWorkflowTools } from "../ai/tools/workflowTools"; import { createLinearTools } from "../ai/tools/linearTools"; import { createCtoOperatorTools, type CtoOperatorToolDeps } from "../ai/tools/ctoOperatorTools"; @@ -173,6 +178,7 @@ import { type CursorAcpPooled, } from "./cursorAcpPool"; import { discoverCursorCliModelDescriptors } from "./cursorModelsDiscovery"; +import { discoverLocalModels, type DiscoveredLocalModel } from "../ai/localModelDiscovery"; import { mapAcpSessionNotificationToChatEvents, mapStopReasonToTerminalEvents, @@ -964,6 +970,56 @@ function normalizeUsagePayload( return { inputTokens, outputTokens }; } +function mergeUsagePayloads( + left: { inputTokens?: number | null; outputTokens?: number | null } | undefined, + right: { inputTokens?: number | null; outputTokens?: number | null } | undefined, +): { inputTokens?: number | null; outputTokens?: number | null } | undefined { + if (!left) return right; + if (!right) return left; + const inputTokens = (left.inputTokens ?? 0) + (right.inputTokens ?? 0); + const outputTokens = (left.outputTokens ?? 0) + (right.outputTokens ?? 0); + if (!inputTokens && !outputTokens) return undefined; + return { inputTokens, outputTokens }; +} + +const MAX_UNIFIED_LOCAL_AUTO_CONTINUATIONS = 1; +const UNIFIED_LOCAL_AUTO_CONTINUE_PROMPT = + "Continue from your last step. Do not restate that you will inspect, explore, or review the codebase. " + + "Take the next concrete action now, using tools if needed. Only stop when you have completed the request or are truly blocked."; + +export function shouldAutoContinueUnifiedLocalTurn(args: { + modelDescriptor: Pick; + permissionMode: PermissionMode; + assistantText: string; + toolCallCount: number; + toolResultCount: number; + continuationCount: number; + pendingApprovalCount?: number; +}): boolean { + if (!args.modelDescriptor.authTypes.includes("local")) return false; + if (args.permissionMode === "plan") return false; + if (args.continuationCount >= MAX_UNIFIED_LOCAL_AUTO_CONTINUATIONS) return false; + if ((args.pendingApprovalCount ?? 0) > 0) return false; + if (args.toolCallCount < 1 || args.toolResultCount < 1) return false; + + const normalized = args.assistantText.trim().replace(/\s+/g, " "); + if (!normalized.length || normalized.length > 280) return false; + + const lower = normalized.toLowerCase(); + if (lower.includes("?")) return false; + if ( + /\b(done|completed|complete|implemented|created|updated|finished|fixed|here('| i)?s|i found|i added|i changed|i created|i updated|i implemented|blocked)\b/i.test(lower) + ) { + return false; + } + + const narratedIntentPattern = + /^(?:ok[, ]+|alright[, ]+|next[, ]+|now[, ]+)?(?:i(?:'ll| will| am going to|'m going to)|let me)\b/i; + const nextActionPattern = + /\b(explore|inspect|check|look at|search|review|read|analy[sz]e|identify|open|scan|trace|browse)\b/i; + return narratedIntentPattern.test(normalized) && nextActionPattern.test(normalized); +} + const KNOWN_CODEX_EFFORTS = new Set(CODEX_REASONING_EFFORTS.map((e) => e.effort)); const EFFORT_ALIASES: Record> = { @@ -1402,6 +1458,35 @@ function buildStreamingUserContent( return parts; } +export function buildUnifiedStreamMessages(args: { + messages: Array<{ role: string; content: string }>; + persistedTurnUserMessageIndex: number; + resolvedAttachments: ResolvedAgentChatFileRef[]; + modelDescriptor: ModelDescriptor; + logger?: Logger; +}): ModelMessage[] { + return args.messages.map((message, index): ModelMessage => { + const isPersistedTurnUserMessage = index === args.persistedTurnUserMessageIndex && message.role === "user"; + if (!isPersistedTurnUserMessage) { + return { + role: message.role === "user" ? "user" : "assistant", + content: message.content, + }; + } + + return { + role: "user", + content: buildStreamingUserContent({ + baseText: message.content, + attachments: args.resolvedAttachments, + runtimeKind: "unified", + modelDescriptor: args.modelDescriptor, + logger: args.logger, + }), + }; + }); +} + function buildExecutionModeDirective( mode: AgentChatExecutionMode | null | undefined, provider: AgentChatProvider, @@ -1985,6 +2070,58 @@ function resolveSessionUnifiedPermissionMode( ?? fallback; } +function applyLocalHarnessPermissionMode(args: { + descriptor?: ModelDescriptor; + requestedPermissionMode?: AgentChatSession["permissionMode"]; + requestedUnifiedPermissionMode?: AgentChatUnifiedPermissionMode; +}): { + requestedPermissionMode?: AgentChatSession["permissionMode"]; + requestedUnifiedPermissionMode?: AgentChatUnifiedPermissionMode; +} { + if (!args.descriptor?.authTypes.includes("local")) { + return { + requestedPermissionMode: args.requestedPermissionMode, + requestedUnifiedPermissionMode: args.requestedUnifiedPermissionMode, + }; + } + + if (args.descriptor.harnessProfile === "read_only") { + return { + requestedPermissionMode: "plan", + requestedUnifiedPermissionMode: "plan", + }; + } + + if ( + args.descriptor.harnessProfile === "guarded" + && args.requestedPermissionMode == null + && args.requestedUnifiedPermissionMode == null + ) { + return { + requestedPermissionMode: "plan", + requestedUnifiedPermissionMode: "plan", + }; + } + + return { + requestedPermissionMode: args.requestedPermissionMode, + requestedUnifiedPermissionMode: args.requestedUnifiedPermissionMode, + }; +} + +function enforceManagedLocalHarnessPermissionMode( + managed: ManagedChatSession, + descriptor?: ModelDescriptor | null, +): void { + const harnessPermissions = applyLocalHarnessPermissionMode({ + descriptor: descriptor ?? resolveSessionModelDescriptor(managed.session) ?? undefined, + requestedPermissionMode: managed.session.permissionMode, + requestedUnifiedPermissionMode: managed.session.unifiedPermissionMode, + }); + managed.session.permissionMode = harnessPermissions.requestedPermissionMode ?? managed.session.permissionMode; + managed.session.unifiedPermissionMode = harnessPermissions.requestedUnifiedPermissionMode ?? managed.session.unifiedPermissionMode; +} + function resolveCursorSessionModeId( session: Pick, ): string | null { @@ -3785,7 +3922,7 @@ export function createAgentChatService(args: { persistChatState(managed); const auth = await detectAuth().catch(() => []); - const availableModels = getRegistryModels(auth).filter((descriptor) => !descriptor.deprecated); + const availableModels = await getAvailableRegistryModels(auth); if (!availableModels.length) return; const preferredModelId = @@ -3928,7 +4065,10 @@ export function createAgentChatService(args: { configApiKeys[String(provider).trim().toLowerCase()] = key; } } - return detectAllAuth(configApiKeys); + const localProviders = snapshot.effective.ai?.localProviders; + return detectAllAuth(configApiKeys, { + localProviders, + }); }; const resolveHandoffBlockedReason = (managed: ManagedChatSession): string | null => { @@ -4051,7 +4191,7 @@ export function createAgentChatService(args: { }): Promise<{ brief: string; usedFallbackSummary: boolean }> => { const deterministicBrief = buildDeterministicHandoffBrief(args); const auth = await detectAuth(); - const availableModels = getRegistryModels(auth).filter((descriptor) => !descriptor.deprecated); + const availableModels = await getAvailableRegistryModels(auth); const preferredModelId = [ resolveChatConfig().summaryModelId, "openai/gpt-5.4-mini", @@ -4175,7 +4315,7 @@ export function createAgentChatService(args: { if (!seed) return; const auth = await detectAuth(); - const availableModels = getRegistryModels(auth).filter((descriptor) => !descriptor.deprecated); + const availableModels = await getAvailableRegistryModels(auth); if (!availableModels.length) return; const preferredModelId = @@ -4248,11 +4388,93 @@ export function createAgentChatService(args: { // Unified session support — for API-key / local models using streamText + universal tools. // CLI-wrapped models fall through to the existing Claude/Codex runtimes. + const discoveredLocalModelToDescriptor = (model: DiscoveredLocalModel): ModelDescriptor => + createDynamicLocalModelDescriptor(model.provider, model.modelId, { + ...(model.displayName ? { displayName: model.displayName } : {}), + ...(model.contextWindow ? { contextWindow: model.contextWindow } : {}), + ...(model.maxOutputTokens ? { maxOutputTokens: model.maxOutputTokens } : {}), + ...(model.capabilities ? { capabilities: model.capabilities } : {}), + ...(model.reasoningTiers?.length ? { reasoningTiers: model.reasoningTiers } : {}), + ...(model.harnessProfile ? { harnessProfile: model.harnessProfile } : {}), + ...(model.discoverySource ? { discoverySource: model.discoverySource } : {}), + }); + + const getAvailableRegistryModels = async ( + auth: Awaited>, + ): Promise => { + if (auth.some((entry) => entry.type === "local")) { + try { + const discovered = await discoverLocalModels(auth); + replaceDynamicLocalModelDescriptors(discovered.map(discoveredLocalModelToDescriptor)); + } catch (err) { + replaceDynamicLocalModelDescriptors([]); + logger.warn("agent_chat.local_model_discovery_failed", { + error: err instanceof Error ? err.message : String(err), + }); + } + } else { + replaceDynamicLocalModelDescriptors([]); + } + return getRegistryModels(auth).filter((descriptor) => !descriptor.deprecated); + }; + + const resolveUnifiedLocalDescriptor = async ( + managed: ManagedChatSession, + descriptor: ModelDescriptor, + auth: Awaited>, + ): Promise => { + if (!(descriptor.family === "ollama" || descriptor.family === "lmstudio" || descriptor.family === "vllm")) { + return descriptor; + } + if (descriptor.sdkModelId !== "auto") { + return descriptor; + } + + const localProvider = descriptor.family as LocalProviderFamily; + let discovered: DiscoveredLocalModel[] = []; + try { + discovered = (await discoverLocalModels(auth)).filter((model) => model.provider === localProvider); + replaceDynamicLocalModelDescriptors(discovered.map(discoveredLocalModelToDescriptor)); + } catch (err) { + replaceDynamicLocalModelDescriptors([]); + logger.warn("agent_chat.local_model_resolution_failed", { + sessionId: managed.session.id, + modelId: descriptor.id, + error: err instanceof Error ? err.message : String(err), + }); + return descriptor; + } + + const preferred = auth.find( + (entry): entry is Extract>[number], { type: "local" }> => + entry.type === "local" && entry.provider === localProvider, + )?.preferredModelId; + const preferredDescriptor = preferred ? getModelById(preferred) : undefined; + if (preferredDescriptor && preferredDescriptor.family === localProvider) { + managed.session.modelId = preferredDescriptor.id; + managed.session.model = preferredDescriptor.id; + return preferredDescriptor; + } + + if (discovered.length === 1) { + const onlyDescriptor = getModelById(`${localProvider}/${discovered[0]!.modelId}`) ?? discoveredLocalModelToDescriptor(discovered[0]!); + return onlyDescriptor; + } + + if (discovered.length > 1) { + throw new Error( + `${descriptor.displayName} has multiple loaded models. Choose a specific ${LOCAL_PROVIDER_LABELS[localProvider]} model or save a preferred local model first.`, + ); + } + + throw new Error(`${descriptor.displayName} is reachable, but no models are currently loaded.`); + }; + const startUnifiedSession = async (managed: ManagedChatSession): Promise<"handled" | "fallthrough"> => { const modelId = managed.session.modelId; if (!modelId) return "fallthrough"; - const descriptor = getModelById(modelId); + let descriptor = getModelById(modelId); if (!descriptor) return "fallthrough"; // CLI-wrapped models -> defer to CLI session runtimes. @@ -4265,7 +4487,9 @@ export function createAgentChatService(args: { }); const auth = await detectAuth(); - const resolvedModel = await providerResolver.resolveModel(modelId, auth, { + descriptor = await resolveUnifiedLocalDescriptor(managed, descriptor, auth); + enforceManagedLocalHarnessPermissionMode(managed, descriptor); + const resolvedModel = await providerResolver.resolveModel(descriptor.id, auth, { cwd: managed.laneWorktreePath, }); @@ -4324,6 +4548,7 @@ export function createAgentChatService(args: { managed.session.provider = "unified"; managed.session.unifiedPermissionMode = permMode; managed.session.permissionMode = syncLegacyPermissionMode(managed.session) ?? managed.session.permissionMode; + enforceManagedLocalHarnessPermissionMode(managed, descriptor); managed.session.capabilityMode = mcpClient ? "full_mcp" : "fallback"; return "handled"; }; @@ -5402,7 +5627,7 @@ export function createAgentChatService(args: { // Fire-and-forget AI summary enhancement const auth = await detectAuth(); - const availableModels = getRegistryModels(auth).filter((d) => !d.deprecated); + const availableModels = await getAvailableRegistryModels(auth); if (!availableModels.length) return; const preferredModelId = @@ -7035,6 +7260,8 @@ export function createAgentChatService(args: { let assistantText = ""; let usage: { inputTokens?: number | null; outputTokens?: number | null } | undefined; + let finalAssistantText = ""; + let autoContinuationCount = 0; let streamedStepCount = 0; const turnStartedAt = Date.now(); let firstStreamEventLogged = false; @@ -7085,30 +7312,11 @@ export function createAgentChatService(args: { applyReconstructionContextToStreamingRuntime(managed, runtime); runtime.messages.push({ role: "user", content: userContent }); + const persistedTurnUserMessageIndex = runtime.messages.length - 1; const abortController = new AbortController(); runtime.abortController = abortController; - const streamMessages = runtime.messages.map((message, index): ModelMessage => { - const isCurrentUserMessage = index === runtime.messages.length - 1 && message.role === "user"; - if (!isCurrentUserMessage) { - return { - role: message.role as "user" | "assistant", - content: message.content, - }; - } - - return { - role: "user", - content: buildStreamingUserContent({ - baseText: streamingBaseText, - attachments: resolvedAttachments, - runtimeKind: "unified", - modelDescriptor: runtime.modelDescriptor, - logger, - }), - }; - }); const lightweight = isLightweightSession(managed.session); const executionLaneId = resolveManagedExecutionLaneId(managed); const tools = lightweight @@ -7224,47 +7432,39 @@ export function createAgentChatService(args: { : "User denied the action.", }; }, - onAskUser: async (question) => { - const askItemId = randomUUID(); - const request: PendingInputRequest = { - requestId: askItemId, - itemId: askItemId, + onAskUser: async (input: AskUserToolInput) => { + const normalizedQuestion = typeof input.question === "string" ? input.question.trim() : ""; + const normalizedBody = typeof input.body === "string" ? input.body.trim() : ""; + const response = await requestChatInput({ + chatSessionId: managed.session.id, + title: typeof input.title === "string" && input.title.trim().length + ? input.title.trim() + : "Question from agent", + body: normalizedBody || normalizedQuestion, + questions: input.questions, source: "unified", - kind: "question", - description: question, - questions: [ - { - id: "response", - header: "Question", - question, - allowsFreeform: true, - }, - ], - allowsFreeform: true, - blocking: true, - canProceedWithoutAnswer: false, providerMetadata: { tool: "askUser", - inputType: "text", + inputType: input.questions?.length ? "structured" : "text", + }, + eventDescription: normalizedBody || normalizedQuestion, + eventDetail: { + tool: "askUser", + ...(normalizedQuestion ? { question: normalizedQuestion } : {}), + inputType: input.questions?.length ? "structured" : "text", }, - turnId, - }; - emitPendingInputRequest(managed, request, { - kind: "tool_call", - description: question, - detail: { tool: "askUser", question, inputType: "text" }, - }); - - const response = await new Promise<{ decision?: AgentChatApprovalDecision; responseText?: string | null; answers?: Record }>((resolve) => { - runtime.pendingApprovals.set(askItemId, { category: "askUser", request, resolve }); }); - runtime.pendingApprovals.delete(askItemId); - const normalizedAnswers = normalizePendingInputAnswers(request, response.answers, response.responseText); - const answer = normalizedAnswers.response?.[0] ?? ""; - if (answer.length) return answer; - if (response.decision === "accept") return "yes"; - if (response.decision === "decline") return "no"; - return String(response.decision); + const primaryQuestionId = input.questions?.[0]?.id ?? "response"; + const answer = response.answers[primaryQuestionId]?.[0] + ?? Object.values(response.answers)[0]?.[0] + ?? response.responseText + ?? ""; + return { + answer, + answers: response.answers, + responseText: response.responseText, + decision: response.decision, + }; }, }); @@ -7448,186 +7648,230 @@ export function createAgentChatService(args: { return baseHarnessPrompt; })(); - const stream = streamText({ - model: runtime.resolvedModel, - ...(harnessPrompt ? { system: harnessPrompt } : {}), - messages: streamMessages, - ...(Object.keys(tools).length ? { tools } : {}), - providerOptions: providerOptions as any, - ...(!lightweight ? { stopWhen: stepCountIs(20) } : {}), - abortSignal: abortController.signal, - onError({ error }) { - logger.warn("agent_chat.unified_stream_error", { - sessionId: managed.session.id, - error: error instanceof Error ? error.message : String(error), - }); - }, - }); + let shouldAutoContinue = false; + const baseTurnMessages = runtime.messages.slice(); + do { + assistantText = ""; + let iterationUsage: { inputTokens?: number | null; outputTokens?: number | null } | undefined; + let iterationToolCallCount = 0; + let iterationToolResultCount = 0; + shouldAutoContinue = false; + + const currentIterationMessages = baseTurnMessages; + const streamMessages = buildUnifiedStreamMessages({ + messages: currentIterationMessages, + persistedTurnUserMessageIndex, + resolvedAttachments, + modelDescriptor: runtime.modelDescriptor, + logger, + }); - // ── Stream processing loop ── - const streamSupportsReasoning = runtime.modelDescriptor.capabilities.reasoning; - for await (const part of stream.fullStream as AsyncIterable) { - if (runtime.interrupted) break; - if (!part || typeof part !== "object") continue; - markFirstStreamEvent(String(part.type ?? "unknown")); + const stream = streamText({ + model: runtime.resolvedModel, + ...(harnessPrompt ? { system: harnessPrompt } : {}), + messages: streamMessages, + ...(Object.keys(tools).length ? { tools } : {}), + providerOptions: providerOptions as any, + ...(!lightweight ? { stopWhen: stepCountIs(20) } : {}), + abortSignal: abortController.signal, + onError({ error }) { + logger.warn("agent_chat.unified_stream_error", { + sessionId: managed.session.id, + error: error instanceof Error ? error.message : String(error), + }); + }, + }); - if (part.type === "start-step") { - streamedStepCount += 1; - emitChatEvent(managed, { - type: "step_boundary", - stepNumber: typeof part.stepNumber === "number" ? part.stepNumber + 1 : streamedStepCount, - turnId, - }); - if (!streamSupportsReasoning && streamedStepCount === 1) { + // ── Stream processing loop ── + const streamSupportsReasoning = runtime.modelDescriptor.capabilities.reasoning; + for await (const part of stream.fullStream as AsyncIterable) { + if (runtime.interrupted) break; + if (!part || typeof part !== "object") continue; + markFirstStreamEvent(String(part.type ?? "unknown")); + + if (part.type === "start-step") { + streamedStepCount += 1; emitChatEvent(managed, { - type: "activity", - activity: "working", - detail: WORKING_ACTIVITY_DETAIL, + type: "step_boundary", + stepNumber: typeof part.stepNumber === "number" ? part.stepNumber + 1 : streamedStepCount, turnId, }); + if (!streamSupportsReasoning && streamedStepCount === 1) { + emitChatEvent(managed, { + type: "activity", + activity: "working", + detail: WORKING_ACTIVITY_DETAIL, + turnId, + }); + } + continue; } - continue; - } - if (part.type === "source") { - emitChatEvent(managed, { - type: "activity", - activity: "searching", - detail: - typeof part.title === "string" && part.title.trim().length - ? part.title - : typeof part.url === "string" && part.url.trim().length - ? part.url - : "Gathering sources", - turnId, - }); - continue; - } + if (part.type === "source") { + emitChatEvent(managed, { + type: "activity", + activity: "searching", + detail: + typeof part.title === "string" && part.title.trim().length + ? part.title + : typeof part.url === "string" && part.url.trim().length + ? part.url + : "Gathering sources", + turnId, + }); + continue; + } - if (part.type === "text-delta") { - const delta = String(part.text ?? part.textDelta ?? ""); - if (!delta.length) continue; - assistantText += delta; - emitChatEvent(managed, { - type: "text", - text: delta, - turnId, - itemId: typeof part.id === "string" ? part.id : undefined - }); - continue; - } + if (part.type === "text-delta") { + const delta = String(part.text ?? part.textDelta ?? ""); + if (!delta.length) continue; + assistantText += delta; + emitChatEvent(managed, { + type: "text", + text: delta, + turnId, + itemId: typeof part.id === "string" ? part.id : undefined + }); + continue; + } - if (part.type === "reasoning-start") { - emitChatEvent(managed, { - type: "activity", - activity: streamSupportsReasoning ? "thinking" : "working", - detail: streamSupportsReasoning ? REASONING_ACTIVITY_DETAIL : WORKING_ACTIVITY_DETAIL, - turnId, - }); - continue; - } + if (part.type === "reasoning-start") { + emitChatEvent(managed, { + type: "activity", + activity: streamSupportsReasoning ? "thinking" : "working", + detail: streamSupportsReasoning ? REASONING_ACTIVITY_DETAIL : WORKING_ACTIVITY_DETAIL, + turnId, + }); + continue; + } - if (part.type === "reasoning" || part.type === "reasoning-delta") { - const delta = String(part.text ?? part.textDelta ?? part.delta ?? ""); - if (!delta.length) continue; - if (!streamSupportsReasoning) { + if (part.type === "reasoning" || part.type === "reasoning-delta") { + const delta = String(part.text ?? part.textDelta ?? part.delta ?? ""); + if (!delta.length) continue; + if (!streamSupportsReasoning) { + emitChatEvent(managed, { + type: "activity", + activity: "working", + detail: WORKING_ACTIVITY_DETAIL, + turnId, + }); + continue; + } emitChatEvent(managed, { type: "activity", - activity: "working", - detail: WORKING_ACTIVITY_DETAIL, + activity: "thinking", + detail: REASONING_ACTIVITY_DETAIL, + turnId, + }); + emitChatEvent(managed, { + type: "reasoning", + text: delta, turnId, + itemId: typeof part.id === "string" ? part.id : undefined }); continue; } - emitChatEvent(managed, { - type: "activity", - activity: "thinking", - detail: REASONING_ACTIVITY_DETAIL, - turnId, - }); - emitChatEvent(managed, { - type: "reasoning", - text: delta, - turnId, - itemId: typeof part.id === "string" ? part.id : undefined - }); - continue; - } - if (part.type === "reasoning-end") { - flushBufferedReasoning(managed); - continue; - } + if (part.type === "reasoning-end") { + flushBufferedReasoning(managed); + continue; + } - if (part.type === "tool-call") { - const nextActivity = activityForToolName(String(part.toolName ?? "tool")); - const parentItemId = readProviderParentItemId((part as { providerMetadata?: unknown }).providerMetadata); - emitChatEvent(managed, { - type: "activity", - activity: nextActivity.activity, - detail: nextActivity.detail, - turnId, - }); - emitChatEvent(managed, { - type: "tool_call", - tool: String(part.toolName ?? "tool"), - args: part.input ?? part.args ?? part.arguments, - itemId: String(part.toolCallId ?? randomUUID()), - ...(parentItemId ? { parentItemId } : {}), - turnId - }); - continue; - } + if (part.type === "tool-call") { + iterationToolCallCount += 1; + const nextActivity = activityForToolName(String(part.toolName ?? "tool")); + const parentItemId = readProviderParentItemId((part as { providerMetadata?: unknown }).providerMetadata); + emitChatEvent(managed, { + type: "activity", + activity: nextActivity.activity, + detail: nextActivity.detail, + turnId, + }); + emitChatEvent(managed, { + type: "tool_call", + tool: String(part.toolName ?? "tool"), + args: part.input ?? part.args ?? part.arguments, + itemId: String(part.toolCallId ?? randomUUID()), + ...(parentItemId ? { parentItemId } : {}), + turnId + }); + continue; + } - if (part.type === "tool-result") { - const parentItemId = readProviderParentItemId((part as { providerMetadata?: unknown }).providerMetadata); - emitChatEvent(managed, { - type: "tool_result", - tool: String(part.toolName ?? "tool"), - result: part.output ?? part.result, - itemId: String(part.toolCallId ?? randomUUID()), - ...(parentItemId ? { parentItemId } : {}), - turnId, - status: part.preliminary ? "running" : "completed" - }); - continue; - } + if (part.type === "tool-result") { + iterationToolResultCount += 1; + const parentItemId = readProviderParentItemId((part as { providerMetadata?: unknown }).providerMetadata); + emitChatEvent(managed, { + type: "tool_result", + tool: String(part.toolName ?? "tool"), + result: part.output ?? part.result, + itemId: String(part.toolCallId ?? randomUUID()), + ...(parentItemId ? { parentItemId } : {}), + turnId, + status: part.preliminary ? "running" : "completed" + }); + continue; + } - if (part.type === "tool-error") { - emitChatEvent(managed, { - type: "error", - message: `Tool '${String(part.toolName ?? "tool")}' failed: ${String(part.error ?? "unknown error")}`, - turnId, - itemId: String(part.toolCallId ?? randomUUID()) - }); - continue; - } + if (part.type === "tool-error") { + emitChatEvent(managed, { + type: "error", + message: `Tool '${String(part.toolName ?? "tool")}' failed: ${String(part.error ?? "unknown error")}`, + turnId, + itemId: String(part.toolCallId ?? randomUUID()) + }); + continue; + } - if (part.type === "tool-approval-request") { - const toolName = String(part.toolCall?.toolName ?? "tool"); - emitChatEvent(managed, { - type: "error", - message: isPlanningApprovalGuarded(managed) - ? buildPlanningApprovalViolation(toolName) - : `Unexpected SDK approval request for '${toolName}'. This tool should use ADE-managed approvals instead.`, - turnId - }); - continue; - } + if (part.type === "tool-approval-request") { + const toolName = String(part.toolCall?.toolName ?? "tool"); + emitChatEvent(managed, { + type: "error", + message: isPlanningApprovalGuarded(managed) + ? buildPlanningApprovalViolation(toolName) + : `Unexpected SDK approval request for '${toolName}'. This tool should use ADE-managed approvals instead.`, + turnId + }); + continue; + } - if (part.type === "finish") { - usage = normalizeUsagePayload(part.totalUsage ?? part.usage); - continue; + if (part.type === "finish") { + iterationUsage = normalizeUsagePayload(part.totalUsage ?? part.usage); + continue; + } + + if (part.type === "error") { + emitChatEvent(managed, { + type: "error", + message: String(part.error ?? "Stream error."), + turnId + }); + } } - if (part.type === "error") { - emitChatEvent(managed, { - type: "error", - message: String(part.error ?? "Stream error."), - turnId + usage = mergeUsagePayloads(usage, iterationUsage); + finalAssistantText += assistantText; + shouldAutoContinue = shouldAutoContinueUnifiedLocalTurn({ + modelDescriptor: runtime.modelDescriptor, + permissionMode: runtime.permissionMode, + assistantText, + toolCallCount: iterationToolCallCount, + toolResultCount: iterationToolResultCount, + continuationCount: autoContinuationCount, + pendingApprovalCount: runtime.pendingApprovals.size, + }); + if (shouldAutoContinue && !runtime.interrupted) { + logger.info("agent_chat.unified_local_auto_continue", { + sessionId: managed.session.id, + turnId, + modelId: runtime.modelDescriptor.id, + continuationCount: autoContinuationCount + 1, }); + baseTurnMessages.push({ role: "assistant", content: assistantText }); + baseTurnMessages.push({ role: "user", content: UNIFIED_LOCAL_AUTO_CONTINUE_PROMPT }); + autoContinuationCount += 1; } - } + } while (shouldAutoContinue && !runtime.interrupted); // ── Shared turn completion ── persistDeliveredLaneDirectiveKey(managed, args.laneDirectiveKey); @@ -7646,8 +7890,8 @@ export function createAgentChatService(args: { }); persistChatState(managed); } else { - if (assistantText.trim().length) { - runtime.messages.push({ role: "assistant", content: assistantText }); + if (finalAssistantText.trim().length) { + runtime.messages.push({ role: "assistant", content: finalAssistantText }); } runtime.busy = false; @@ -7665,10 +7909,10 @@ export function createAgentChatService(args: { ...(usage ? { usage } : {}) }); - if (assistantText.trim().length > 0) { + if (finalAssistantText.trim().length > 0) { appendWorkerActivityToCto(managed, { activityType: "chat_turn", - summary: assistantText, + summary: finalAssistantText, }); } @@ -10135,7 +10379,7 @@ export function createAgentChatService(args: { codexApprovalPolicy: requestedCodexApprovalPolicy, codexSandbox: requestedCodexSandbox, codexConfigSource: requestedCodexConfigSource, - unifiedPermissionMode: requestedUnifiedPermissionMode, + unifiedPermissionMode: requestedUnifiedPermissionModeArg, cursorModeId: requestedCursorModeId, cursorConfigValues: requestedCursorConfigValues, permissionMode: requestedPermMode, @@ -10215,10 +10459,18 @@ export function createAgentChatService(args: { const normalizedCursorConfigValues = normalizeCursorConfigValueRecord(requestedCursorConfigValues); const capabilityMode = inferCapabilityMode(effectiveProvider); const computerUsePolicy = normalizeComputerUsePolicy(computerUse, createDefaultComputerUsePolicy()); - const effectivePermissionMode = identityKey + let effectivePermissionMode = identityKey ? normalizeIdentityPermissionMode(requestedPermMode, effectiveProvider) : requestedPermMode; const chatConfig = resolveChatConfig(); + let requestedUnifiedPermissionMode = requestedUnifiedPermissionModeArg; + const localHarnessPermissions = applyLocalHarnessPermissionMode({ + descriptor: resolvedDescriptor, + requestedPermissionMode: effectivePermissionMode, + requestedUnifiedPermissionMode, + }); + effectivePermissionMode = localHarnessPermissions.requestedPermissionMode; + requestedUnifiedPermissionMode = localHarnessPermissions.requestedUnifiedPermissionMode; const nativePermissionFields = (() => { if (effectiveProvider === "claude") { @@ -12097,6 +12349,7 @@ export function createAgentChatService(args: { await ensureCursorRuntime(managed); managed.session.unifiedPermissionMode = persisted?.unifiedPermissionMode ?? managed.session.unifiedPermissionMode; managed.session.permissionMode = syncLegacyPermissionMode(managed.session) ?? managed.session.permissionMode; + enforceManagedLocalHarnessPermissionMode(managed); sessionService.setResumeCommand(sessionId, `chat:cursor:${sessionId}`); } else if (managed.runtime?.kind === "unified" || (managed.session.modelId && !providerResolver.isModelCliWrapped(managed.session.modelId))) { // Unified runtime resume — re-resolve the model @@ -12109,6 +12362,7 @@ export function createAgentChatService(args: { } managed.session.unifiedPermissionMode = persisted?.unifiedPermissionMode ?? managed.session.unifiedPermissionMode; managed.session.permissionMode = syncLegacyPermissionMode(managed.session) ?? managed.session.permissionMode; + enforceManagedLocalHarnessPermissionMode(managed, managed.runtime.modelDescriptor); managed.runtime.permissionMode = resolveSessionUnifiedPermissionMode( managed.session, resolveChatConfig().unifiedPermissionMode, @@ -12303,6 +12557,7 @@ export function createAgentChatService(args: { managed.session.provider, ); applyLegacyPermissionModeToNativeControls(managed.session, managed.session.permissionMode); + enforceManagedLocalHarnessPermissionMode(managed); normalizeSessionNativePermissionControls(managed.session, resolveChatConfig()); managed.selectedExecutionLaneId = selectedExecutionLaneId ?? managed.selectedExecutionLaneId; refreshReconstructionContext(managed, { includeConversationTail: usesIdentityContinuity(managed) }); @@ -12574,6 +12829,11 @@ export function createAgentChatService(args: { managed.session.unifiedPermissionMode = "edit"; } managed.session.permissionMode = syncLegacyPermissionMode(managed.session) ?? managed.session.permissionMode; + enforceManagedLocalHarnessPermissionMode(managed, managed.runtime.modelDescriptor); + managed.runtime.permissionMode = resolveSessionUnifiedPermissionMode( + managed.session, + resolveChatConfig().unifiedPermissionMode, + ); managed.runtime.pendingApprovals.delete(itemId); pending.resolve({ decision: resolvedDecision, answers, responseText }); emitPendingInputResolved(managed, { @@ -12673,7 +12933,7 @@ export function createAgentChatService(args: { // For unified/non-CLI providers: return all models with valid auth. try { const auth = await detectAuth(); - const available = getRegistryModels(auth); + const available = await getAvailableRegistryModels(auth); const targetModels = provider === "unified" ? available : available.filter(m => m.family === provider); @@ -12845,6 +13105,7 @@ export function createAgentChatService(args: { ); applyLegacyPermissionModeToNativeControls(managed.session, managed.session.permissionMode); } + enforceManagedLocalHarnessPermissionMode(managed, descriptor); normalizeSessionNativePermissionControls(managed.session, chatConfig); // Apply reasoningEffort BEFORE pre-warming so the V2 session is created @@ -12951,6 +13212,7 @@ export function createAgentChatService(args: { || cursorModeId !== undefined || cursorConfigValues !== undefined ) { + enforceManagedLocalHarnessPermissionMode(managed); normalizeSessionNativePermissionControls(managed.session, chatConfig); if (managed.runtime?.kind === "unified") { managed.runtime.permissionMode = resolveSessionUnifiedPermissionMode( @@ -13303,6 +13565,10 @@ export function createAgentChatService(args: { chatSessionId: string; title: string; body: string; + source?: PendingInputSource; + providerMetadata?: Record; + eventDescription?: string; + eventDetail?: Record; questions?: Array<{ id?: string; header?: string; @@ -13425,7 +13691,7 @@ export function createAgentChatService(args: { const request: PendingInputRequest = { requestId: itemId, itemId, - source: "ade", + source: args.source ?? "ade", kind: questions.some((q) => q.options?.length) ? "structured_question" : "question", title: args.title, description: questions[0]?.question ?? args.body, @@ -13433,6 +13699,7 @@ export function createAgentChatService(args: { allowsFreeform: true, blocking: true, canProceedWithoutAnswer: false, + ...(args.providerMetadata ? { providerMetadata: args.providerMetadata } : {}), turnId: managed.runtime?.activeTurnId ?? null, }; @@ -13444,7 +13711,8 @@ export function createAgentChatService(args: { managed.localPendingInputs.set(itemId, { request, resolve }); emitPendingInputRequest(managed, request, { kind: "tool_call", - description: request.description ?? args.body, + description: args.eventDescription ?? request.description ?? args.body, + ...(args.eventDetail !== undefined ? { detail: args.eventDetail } : {}), }); }); @@ -13488,18 +13756,24 @@ export function createAgentChatService(args: { try { const projectRoot = args.projectRoot; if (!projectRoot) return; - const attachDir = path.join(projectRoot, ".ade", "attachments"); - if (!fs.existsSync(attachDir)) return; - const cutoff = Date.now() - 7 * 24 * 60 * 60 * 1000; // 7 days - for (const entry of fs.readdirSync(attachDir)) { - try { - const filePath = path.join(attachDir, entry); - const stat = fs.statSync(filePath); - if (stat.isFile() && stat.mtimeMs < cutoff) { - fs.unlinkSync(filePath); + const cleanupDir = (dirPath: string) => { + if (!fs.existsSync(dirPath)) return; + const cutoff = Date.now() - 7 * 24 * 60 * 60 * 1000; + for (const entry of fs.readdirSync(dirPath)) { + try { + const filePath = path.join(dirPath, entry); + const stat = fs.statSync(filePath); + if (stat.mtimeMs < cutoff) { + fs.rmSync(filePath, { recursive: true, force: true }); + } + } catch { + // Best-effort cleanup only. } - } catch { /* skip */ } - } + } + }; + + cleanupDir(path.join(projectRoot, ".ade", "attachments")); + cleanupDir(path.join(resolveAdeLayout(projectRoot).tmpDir, "agent-chat-attachments")); } catch { /* ignore */ } }, setComputerUseArtifactBrokerService(svc: ComputerUseArtifactBrokerService) { diff --git a/apps/desktop/src/main/services/config/projectConfigService.ts b/apps/desktop/src/main/services/config/projectConfigService.ts index 1e4a91cc5..2893ec8c7 100644 --- a/apps/desktop/src/main/services/config/projectConfigService.ts +++ b/apps/desktop/src/main/services/config/projectConfigService.ts @@ -1023,6 +1023,34 @@ function coerceAiTaskRoutingRule(value: unknown): AiTaskRoutingRule | null { return Object.keys(out).length ? out : null; } +function coerceAiLocalProviders(value: unknown): AiConfig["localProviders"] { + if (!isRecord(value)) return undefined; + + const providers: NonNullable = {}; + for (const provider of ["ollama", "lmstudio", "vllm"] as const) { + const raw = isRecord(value[provider]) ? value[provider] : null; + if (!raw) continue; + + const entry: NonNullable[typeof provider]> = {}; + const enabled = asBool(raw.enabled); + if (enabled != null) entry.enabled = enabled; + const endpoint = asString(raw.endpoint)?.trim(); + if (endpoint) entry.endpoint = endpoint; + const autoDetect = asBool(raw.autoDetect); + if (autoDetect != null) entry.autoDetect = autoDetect; + if (raw.preferredModelId === null) { + entry.preferredModelId = null; + } else { + const preferredModelId = asString(raw.preferredModelId)?.trim(); + if (preferredModelId) entry.preferredModelId = preferredModelId; + } + + if (Object.keys(entry).length) providers[provider] = entry; + } + + return Object.keys(providers).length ? providers : undefined; +} + function coerceAiConfig(value: unknown): AiConfig | undefined { if (!isRecord(value)) return undefined; @@ -1238,6 +1266,9 @@ function coerceAiConfig(value: unknown): AiConfig | undefined { const apiKeys = asStringMap(value.apiKeys); if (apiKeys && Object.keys(apiKeys).length) out.apiKeys = apiKeys; + const localProviders = coerceAiLocalProviders(value.localProviders); + if (localProviders) out.localProviders = localProviders; + const workerSafety = coerceWorkerSafetyPolicy(value.workerSafety); if (workerSafety) out.workerSafety = workerSafety; @@ -1554,6 +1585,16 @@ export function mergeAiConfig(sharedAi?: AiConfig, localAi?: Partial): ...(sharedAi?.apiKeys ?? {}), ...(localAi?.apiKeys ?? {}) }; + const localProvidersEntries = (["ollama", "lmstudio", "vllm"] as const) + .map((provider) => { + const mergedProvider = { + ...(sharedAi?.localProviders?.[provider] ?? {}), + ...(localAi?.localProviders?.[provider] ?? {}), + }; + return Object.keys(mergedProvider).length ? [provider, mergedProvider] as const : null; + }) + .filter((entry): entry is readonly ["ollama" | "lmstudio" | "vllm", Record] => entry != null); + const localProviders = Object.fromEntries(localProvidersEntries) as AiConfig["localProviders"]; const workerSafety = mergeWorkerSafetyPolicy(sharedAi?.workerSafety, localAi?.workerSafety); const mcpServers = { ...(sharedAi?.mcpServers ?? {}), @@ -1572,6 +1613,7 @@ export function mergeAiConfig(sharedAi?: AiConfig, localAi?: Partial): ...(Object.keys(chat).length ? { chat } : {}), ...(Object.keys(featureModelOverrides).length ? { featureModelOverrides } : {}), ...(Object.keys(apiKeys).length ? { apiKeys } : {}), + ...(localProvidersEntries.length ? { localProviders } : {}), ...(workerSafety ? { workerSafety } : {}), ...(Object.keys(mcpServers).length ? { mcpServers } : {}) }; diff --git a/apps/desktop/src/main/services/ipc/registerIpc.ts b/apps/desktop/src/main/services/ipc/registerIpc.ts index 29c99d9ae..9dcf7903f 100644 --- a/apps/desktop/src/main/services/ipc/registerIpc.ts +++ b/apps/desktop/src/main/services/ipc/registerIpc.ts @@ -762,6 +762,7 @@ function getUnavailableAiStatus(): AiSettingsStatus { dailyUsage: 0, dailyLimit: null, })), + runtimeConnections: {}, availableModelIds: [], }; } @@ -1978,6 +1979,7 @@ export function registerIpc({ models: status.models, detectedAuth: status.detectedAuth, providerConnections: status.providerConnections, + runtimeConnections: status.runtimeConnections, availableModelIds: status.availableModelIds, apiKeyStore: status.apiKeyStore, features: AI_USAGE_FEATURE_KEYS.map((feature) => ({ diff --git a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.test.ts b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.test.ts index 57841c4b0..4e1816aa8 100644 --- a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.test.ts +++ b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.test.ts @@ -264,6 +264,125 @@ describe("createUnifiedOrchestratorAdapter", () => { }); }); + it("routes local unified models through the managed chat path instead of the shell fallback", async () => { + const workspaceRoot = fs.mkdtempSync(path.join(os.tmpdir(), "ade-unified-local-managed-")); + const createSession = vi.fn(async () => ({ id: "session-managed-local-1" })); + const adapter = createUnifiedOrchestratorAdapter({ + workspaceRoot, + runtimeRoot: path.join(workspaceRoot, "runtime"), + agentChatService: { + createSession, + } as any, + }); + + const result = await adapter.start({ + run: { + id: "run-1", + missionId: "mission-1", + metadata: { + missionGoal: "Implement the worker step", + }, + } as any, + step: { + id: "step-1", + title: "Local implementation worker", + stepKey: "local-implementation-worker", + laneId: "lane-1", + metadata: { + modelId: "lmstudio/meta-llama-3.1-70b-instruct", + stepType: "implementation", + }, + dependencyStepIds: [], + joinPolicy: "all_success", + } as any, + attempt: { id: "attempt-1" } as any, + allSteps: [], + contextProfile: {} as any, + laneExport: null, + projectExport: { content: "" } as any, + docsRefs: [], + fullDocs: [], + permissionConfig: { + _providers: { + unified: "full-auto", + }, + } as any, + createTrackedSession: vi.fn(), + } as any); + + expect(createSession).toHaveBeenCalledWith(expect.objectContaining({ + laneId: "lane-1", + provider: "unified", + model: "lmstudio/meta-llama-3.1-70b-instruct", + modelId: "lmstudio/meta-llama-3.1-70b-instruct", + })); + expect(result).toMatchObject({ + status: "accepted", + sessionId: "session-managed-local-1", + metadata: expect.objectContaining({ + workerSessionKind: "managed_chat", + workerStreamSource: "agent_chat", + startupCommandPreview: "[managed chat session]", + }), + }); + }); + + it("starts guarded local worker sessions in plan mode by default", async () => { + const workspaceRoot = fs.mkdtempSync(path.join(os.tmpdir(), "ade-unified-local-guarded-")); + const createSession = vi.fn(async () => ({ id: "session-managed-local-plan" })); + const adapter = createUnifiedOrchestratorAdapter({ + workspaceRoot, + runtimeRoot: path.join(workspaceRoot, "runtime"), + agentChatService: { + createSession, + } as any, + }); + + const result = await adapter.start({ + run: { + id: "run-1", + missionId: "mission-1", + metadata: { missionGoal: "Investigate the local worker path" }, + } as any, + step: { + id: "step-1", + title: "Guarded local worker", + stepKey: "guarded-local-worker", + laneId: "lane-1", + metadata: { + modelId: "lmstudio/auto", + stepType: "implementation", + }, + dependencyStepIds: [], + joinPolicy: "all_success", + } as any, + attempt: { id: "attempt-1" } as any, + allSteps: [], + contextProfile: {} as any, + laneExport: null, + projectExport: { content: "" } as any, + docsRefs: [], + fullDocs: [], + permissionConfig: {} as any, + createTrackedSession: vi.fn(), + } as any); + + expect(createSession).toHaveBeenCalledWith(expect.objectContaining({ + provider: "unified", + modelId: "lmstudio/auto", + unifiedPermissionMode: "plan", + })); + expect(result).toMatchObject({ + status: "accepted", + launch: expect.objectContaining({ + permissionMode: "plan", + }), + metadata: expect.objectContaining({ + permissionMode: "plan", + }), + }); + }); + it("forces Codex planning steps into a read-only sandbox", async () => { const workspaceRoot = fs.mkdtempSync(path.join(os.tmpdir(), "ade-unified-codex-")); const adapter = createUnifiedOrchestratorAdapter({ @@ -528,6 +647,13 @@ describe("getUnifiedUnsupportedModelReason", () => { expect(getUnifiedUnsupportedModelReason("anthropic/claude-sonnet-4-6")).toBeNull(); }); + it("describes the shell fallback for API/local models instead of rejecting them globally", () => { + const reason = getUnifiedUnsupportedModelReason("lmstudio/meta-llama-3.1-70b-instruct"); + expect(reason).toContain("shell-startup fallback"); + expect(reason).toContain("in-process unified execution"); + expect(reason).toContain("managed unified chat path"); + }); + it("returns a not-registered message for unknown model refs", () => { const reason = getUnifiedUnsupportedModelReason("nonexistent/fantasy-model-99"); expect(reason).toBe("Model 'nonexistent/fantasy-model-99' is not registered."); diff --git a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts index cf6b072ef..8b79f7659 100644 --- a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts +++ b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts @@ -9,6 +9,7 @@ import { resolveModelAlias, resolveModelDescriptor, resolveProviderGroupForModel, + type ModelDescriptor, } from "../../../shared/modelRegistry"; import type { AgentChatExecutionMode, @@ -81,7 +82,7 @@ export function getUnifiedUnsupportedModelReason(modelRef: string): string | nul const cliProvider = resolveCliProviderForModel(descriptor); if (cliProvider) return null; const executionPath = classifyWorkerExecutionPath(descriptor); - return `Model '${descriptor.id}' requires ${executionPath} execution (${descriptor.family}), but the unified worker adapter currently supports only Claude/Codex CLI models.`; + return `Model '${descriptor.id}' requires ${executionPath} in-process unified execution (${descriptor.family}), but the shell-startup fallback only supports Claude/Codex CLI models. Use the managed unified chat path for API and local models.`; } /** @@ -305,6 +306,7 @@ const VALID_PERMISSION_MODES = new Set(["default", "plan", "edit", "full function resolveManagedPermissionMode(args: { provider: "claude" | "codex" | "unified" | "cursor"; + descriptor?: ModelDescriptor; permissionConfig: LegacyPermissionConfig | undefined; readOnlyExecution: boolean; }): AgentChatPermissionMode | undefined { @@ -314,9 +316,14 @@ function resolveManagedPermissionMode(args: { args.provider === "cursor" ? ((providers?.cursor ?? providers?.unified) as string | undefined) : (providers?.[args.provider] as string | undefined); - return typeof candidate === "string" && VALID_PERMISSION_MODES.has(candidate) + const normalizedCandidate = typeof candidate === "string" && VALID_PERMISSION_MODES.has(candidate) ? candidate as AgentChatPermissionMode : undefined; + if (args.descriptor?.authTypes.includes("local")) { + if (args.descriptor.harnessProfile === "read_only") return "plan"; + if (args.descriptor.harnessProfile === "guarded") return "plan"; + } + return normalizedCandidate; } function mapPermissionModeToNativeFields( @@ -570,9 +577,10 @@ export function createUnifiedOrchestratorAdapter(options?: { return startup; } - // Non-CLI or unknown models cannot run via this shell-based adapter. + // Non-CLI or unknown models can still run via the managed chat path. + // This shell fallback only exists for CLI-wrapped workers. const unsupportedReason = getUnifiedUnsupportedModelReason(model) ?? `Model '${model}' is not supported by unified adapter.`; - const failureMessage = `[ADE] Unified orchestrator adapter currently supports CLI-wrapped Anthropic/OpenAI models only. ${unsupportedReason} Select a CLI model for this worker.`; + const failureMessage = `[ADE] Shell-startup fallback for the unified orchestrator adapter only supports CLI-wrapped Anthropic/OpenAI models. ${unsupportedReason}`; return `printf '%s\\n' ${shellEscapeArg(failureMessage)} >&2; exit 64`; }, @@ -657,6 +665,7 @@ export function createUnifiedOrchestratorAdapter(options?: { : undefined; const permissionMode = resolveManagedPermissionMode({ provider, + descriptor, permissionConfig: effectivePermissionConfig, readOnlyExecution, }); diff --git a/apps/desktop/src/main/services/runtime/tempCleanupService.test.ts b/apps/desktop/src/main/services/runtime/tempCleanupService.test.ts new file mode 100644 index 000000000..7c5ac9461 --- /dev/null +++ b/apps/desktop/src/main/services/runtime/tempCleanupService.test.ts @@ -0,0 +1,105 @@ +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { cleanupStaleTempArtifacts } from "./tempCleanupService"; + +const tempRoots: string[] = []; +const NOW_MS = Date.parse("2026-04-02T12:00:00.000Z"); +const DAY_MS = 24 * 60 * 60 * 1000; + +function createTempRoot(): string { + const root = fs.mkdtempSync(path.join(os.tmpdir(), "ade-temp-cleanup-test-")); + tempRoots.push(root); + return root; +} + +function touchMtime(targetPath: string, mtimeMs: number): void { + const stamp = new Date(mtimeMs); + fs.utimesSync(targetPath, stamp, stamp); +} + +afterEach(() => { + while (tempRoots.length > 0) { + const root = tempRoots.pop(); + if (root) { + fs.rmSync(root, { recursive: true, force: true }); + } + } +}); + +describe("cleanupStaleTempArtifacts", () => { + it("removes stale ADE ShipIt caches but keeps recent ones", () => { + const tempRoot = createTempRoot(); + const stalePath = path.join(tempRoot, "com.ade.desktop.ShipIt.stale"); + const freshPath = path.join(tempRoot, "com.ade.desktop.ShipIt.fresh"); + fs.mkdirSync(stalePath, { recursive: true }); + fs.mkdirSync(freshPath, { recursive: true }); + touchMtime(stalePath, NOW_MS - (8 * DAY_MS)); + touchMtime(freshPath, NOW_MS - (2 * DAY_MS)); + + cleanupStaleTempArtifacts({ + tempRoot, + nowMs: NOW_MS, + logger: { info: vi.fn(), warn: vi.fn() }, + }); + + expect(fs.existsSync(stalePath)).toBe(false); + expect(fs.existsSync(freshPath)).toBe(true); + }); + + it("removes stale screenshot temp directories", () => { + const tempRoot = createTempRoot(); + const stalePath = path.join(tempRoot, "ade-screenshot-old"); + const freshPath = path.join(tempRoot, "ade-screenshot-new"); + fs.mkdirSync(stalePath, { recursive: true }); + fs.mkdirSync(freshPath, { recursive: true }); + touchMtime(stalePath, NOW_MS - (2 * DAY_MS)); + touchMtime(freshPath, NOW_MS - (6 * 60 * 60 * 1000)); + + cleanupStaleTempArtifacts({ + tempRoot, + nowMs: NOW_MS, + logger: { info: vi.fn(), warn: vi.fn() }, + }); + + expect(fs.existsSync(stalePath)).toBe(false); + expect(fs.existsSync(freshPath)).toBe(true); + }); + + it("prunes stale fallback attachments and removes the directory when it becomes empty", () => { + const tempRoot = createTempRoot(); + const attachmentsDir = path.join(tempRoot, "ade-attachments"); + const staleFile = path.join(attachmentsDir, "stale.png"); + fs.mkdirSync(attachmentsDir, { recursive: true }); + fs.writeFileSync(staleFile, "stale", "utf8"); + touchMtime(staleFile, NOW_MS - (8 * DAY_MS)); + + cleanupStaleTempArtifacts({ + tempRoot, + nowMs: NOW_MS, + logger: { info: vi.fn(), warn: vi.fn() }, + }); + + expect(fs.existsSync(staleFile)).toBe(false); + expect(fs.existsSync(attachmentsDir)).toBe(false); + }); + + it("keeps recent fallback attachments", () => { + const tempRoot = createTempRoot(); + const attachmentsDir = path.join(tempRoot, "ade-attachments"); + const freshFile = path.join(attachmentsDir, "fresh.png"); + fs.mkdirSync(attachmentsDir, { recursive: true }); + fs.writeFileSync(freshFile, "fresh", "utf8"); + touchMtime(freshFile, NOW_MS - (2 * DAY_MS)); + + cleanupStaleTempArtifacts({ + tempRoot, + nowMs: NOW_MS, + logger: { info: vi.fn(), warn: vi.fn() }, + }); + + expect(fs.existsSync(freshFile)).toBe(true); + expect(fs.existsSync(attachmentsDir)).toBe(true); + }); +}); diff --git a/apps/desktop/src/main/services/runtime/tempCleanupService.ts b/apps/desktop/src/main/services/runtime/tempCleanupService.ts new file mode 100644 index 000000000..5923b623f --- /dev/null +++ b/apps/desktop/src/main/services/runtime/tempCleanupService.ts @@ -0,0 +1,129 @@ +import fs from "node:fs"; +import path from "node:path"; +import type { Logger } from "../logging/logger"; + +const DAY_MS = 24 * 60 * 60 * 1000; +const SHIPIT_RETENTION_MS = 7 * DAY_MS; +const SCREENSHOT_RETENTION_MS = DAY_MS; +const ATTACHMENTS_RETENTION_MS = 7 * DAY_MS; + +type CleanupSummary = { + shipItEntriesRemoved: number; + screenshotEntriesRemoved: number; + attachmentEntriesRemoved: number; + attachmentDirRemoved: boolean; +}; + +function isOlderThan(targetPath: string, cutoffMs: number): boolean { + try { + return fs.statSync(targetPath).mtimeMs < cutoffMs; + } catch { + return false; + } +} + +function removePath(targetPath: string): boolean { + try { + fs.rmSync(targetPath, { recursive: true, force: true }); + return true; + } catch { + return false; + } +} + +function cleanupStaleAttachmentEntries(attachmentsDir: string, cutoffMs: number): { removedEntries: number; removedDir: boolean } { + let removedEntries = 0; + + let entries: fs.Dirent[] = []; + try { + entries = fs.readdirSync(attachmentsDir, { withFileTypes: true }); + } catch { + return { removedEntries, removedDir: false }; + } + + for (const entry of entries) { + const entryPath = path.join(attachmentsDir, entry.name); + if (!isOlderThan(entryPath, cutoffMs)) continue; + if (removePath(entryPath)) { + removedEntries += 1; + } + } + + let removedDir = false; + try { + if (fs.readdirSync(attachmentsDir).length === 0) { + removedDir = removePath(attachmentsDir); + } + } catch { + // Directory may already be gone. + } + + return { removedEntries, removedDir }; +} + +export function cleanupStaleTempArtifacts(args: { + tempRoot: string; + logger: Pick; + nowMs?: number; +}): void { + const nowMs = args.nowMs ?? Date.now(); + const summary: CleanupSummary = { + shipItEntriesRemoved: 0, + screenshotEntriesRemoved: 0, + attachmentEntriesRemoved: 0, + attachmentDirRemoved: false, + }; + + let entries: fs.Dirent[] = []; + try { + entries = fs.readdirSync(args.tempRoot, { withFileTypes: true }); + } catch (error) { + const code = (error as NodeJS.ErrnoException)?.code; + if (code !== "ENOENT") { + args.logger.warn("tempCleanup.scan_failed", { + tempRoot: args.tempRoot, + message: error instanceof Error ? error.message : String(error), + }); + } + return; + } + + for (const entry of entries) { + const entryPath = path.join(args.tempRoot, entry.name); + + if (entry.name.startsWith("com.ade.desktop.ShipIt.")) { + if (isOlderThan(entryPath, nowMs - SHIPIT_RETENTION_MS) && removePath(entryPath)) { + summary.shipItEntriesRemoved += 1; + } + continue; + } + + if (entry.name.startsWith("ade-screenshot-")) { + if (isOlderThan(entryPath, nowMs - SCREENSHOT_RETENTION_MS) && removePath(entryPath)) { + summary.screenshotEntriesRemoved += 1; + } + continue; + } + + if (entry.name === "ade-attachments" && entry.isDirectory()) { + const attachmentCleanup = cleanupStaleAttachmentEntries(entryPath, nowMs - ATTACHMENTS_RETENTION_MS); + summary.attachmentEntriesRemoved += attachmentCleanup.removedEntries; + summary.attachmentDirRemoved = summary.attachmentDirRemoved || attachmentCleanup.removedDir; + } + } + + if ( + summary.shipItEntriesRemoved > 0 + || summary.screenshotEntriesRemoved > 0 + || summary.attachmentEntriesRemoved > 0 + || summary.attachmentDirRemoved + ) { + args.logger.info("tempCleanup.removed_stale_entries", { + tempRoot: args.tempRoot, + shipItEntriesRemoved: summary.shipItEntriesRemoved, + screenshotEntriesRemoved: summary.screenshotEntriesRemoved, + attachmentEntriesRemoved: summary.attachmentEntriesRemoved, + attachmentDirRemoved: summary.attachmentDirRemoved, + }); + } +} diff --git a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx index 8340d840f..b16c0daaa 100644 --- a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx +++ b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx @@ -15,6 +15,7 @@ import { type AgentChatFileRef, type AgentChatInteractionMode, type AiProviderConnectionStatus, + type AiRuntimeConnectionStatus, type AgentChatSession, type AgentChatUnifiedPermissionMode, type AgentChatSessionProfile, @@ -24,13 +25,20 @@ import { type AgentChatSessionSummary, type ComputerUseOwnerSnapshot, type ComputerUsePolicy, + type AiSettingsStatus, type TerminalToolType, } from "../../../shared/types"; import { parseAgentChatTranscript } from "../../../shared/chatTranscript"; import { + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, + getLocalModelIdTail, + getLocalProviderDefaultEndpoint, getModelById, + getModelDescriptorForPermissionMode, + parseLocalProviderFromModelId, resolveModelDescriptorForProvider, + type LocalProviderFamily, type ModelDescriptor, } from "../../../shared/modelRegistry"; import { filterChatModelIdsForSession } from "../../../shared/chatModelSwitching"; @@ -64,6 +72,82 @@ const LEGACY_MODEL_KEY_PREFIX = "ade.chat.lastModel"; const COMPUTER_USE_SNAPSHOT_COOLDOWN_MS = 750; +type AiStatusSnapshot = AiSettingsStatus & { + runtimeConnections?: Record; +}; + +function formatLocalModelLabel(modelId: string): string { + const provider = parseLocalProviderFromModelId(modelId); + if (!provider) { + return getModelById(modelId)?.displayName ?? modelId; + } + const tail = getLocalModelIdTail(modelId, provider); + return tail.length ? tail : modelId; +} + +function recommendedUnifiedPermissionModeForModel( + descriptor: ModelDescriptor | null | undefined, +): AgentChatUnifiedPermissionMode | null { + if (!descriptor?.authTypes.includes("local")) return null; + return descriptor.harnessProfile === "guarded" || descriptor.harnessProfile === "read_only" + ? "plan" + : null; +} + +function shouldResetUnifiedPermissionForModelSwitch( + previous: ModelDescriptor | null | undefined, + next: ModelDescriptor | null | undefined, +): boolean { + const prevRec = recommendedUnifiedPermissionModeForModel(previous); + const nextRec = recommendedUnifiedPermissionModeForModel(next); + if (prevRec == null && nextRec == null) return false; + return prevRec !== nextRec; +} + +type LocalRuntimeNoticeShape = { + tone: "success" | "warning"; + title: string; + message: string; +}; + +function LocalRuntimeNoticeBlock(props: { + notice: LocalRuntimeNoticeShape; + endpoint?: string | null; + /** `inline` = text only (inside a parent runtime card). */ + variant?: "card" | "inline"; +}) { + const { notice, endpoint, variant = "card" } = props; + const isCard = variant === "card"; + return ( +
+
+ {notice.title} +
+
+ {notice.message} +
+ {endpoint ? ( + + {endpoint} + + ) : null} +
+ ); +} + export function resolveChatSessionProfile(_computerUsePolicy: ComputerUsePolicy): AgentChatSessionProfile { return "workflow"; } @@ -590,9 +674,11 @@ export function AgentChatPane({ const [codexSandbox, setCodexSandbox] = useState(initialNativeControls.codexSandbox); const [codexConfigSource, setCodexConfigSource] = useState(initialNativeControls.codexConfigSource); const [unifiedPermissionMode, setUnifiedPermissionMode] = useState(initialNativeControls.unifiedPermissionMode); + const prevModelDescRef = useRef(undefined); const [cursorModeId, setCursorModeId] = useState(initialNativeControls.cursorModeId); const [cursorConfigValues, setCursorConfigValues] = useState>(initialNativeControls.cursorConfigValues); const [computerUsePolicy, setComputerUsePolicy] = useState(createDefaultComputerUsePolicy()); + const [aiStatus, setAiStatus] = useState(null); const [providerConnections, setProviderConnections] = useState<{ claude: AiProviderConnectionStatus | null; codex: AiProviderConnectionStatus | null; @@ -685,6 +771,126 @@ export function AgentChatPane({ const pendingSteers = selectedSessionId ? (pendingSteersBySession[selectedSessionId] ?? []) : []; const selectedModelDesc = getModelById(modelId); const reasoningTiers = selectedModelDesc?.reasoningTiers ?? []; + const localRuntimeState = useMemo(() => { + const provider = selectedModelDesc?.authTypes.includes("local") + ? (selectedModelDesc.family as LocalProviderFamily) + : parseLocalProviderFromModelId(modelId); + if (!provider) return null; + const runtimeConnection = aiStatus?.runtimeConnections?.[provider] ?? null; + const detectedEntry = aiStatus?.detectedAuth?.find( + (entry): entry is { type: "local"; provider: LocalProviderFamily; endpoint: string } => + entry.type === "local" && entry.provider === provider, + ) ?? null; + const modelIds = runtimeConnection?.loadedModelIds !== undefined && runtimeConnection.loadedModelIds !== null + ? runtimeConnection.loadedModelIds.filter((id): id is string => String(id ?? "").startsWith(`${provider}/`)) + : availableModelIds.filter((id) => id.startsWith(`${provider}/`)); + return { + provider, + label: LOCAL_PROVIDER_LABELS[provider], + endpoint: runtimeConnection?.endpoint ?? detectedEntry?.endpoint ?? getLocalProviderDefaultEndpoint(provider), + detected: Boolean(runtimeConnection?.runtimeDetected ?? detectedEntry), + runtimeAvailable: runtimeConnection?.runtimeAvailable ?? false, + health: runtimeConnection?.health ?? null, + blocker: runtimeConnection?.blocker ?? null, + modelIds, + statusKnown: Boolean(aiStatus), + }; + }, [aiStatus, availableModelIds, modelId, selectedModelDesc]); + const localRuntimeNotice = useMemo(() => { + if (!localRuntimeState) return null; + if (!localRuntimeState.statusKnown) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: `ADE could not read ${localRuntimeState.label} status right now. It will still try the unified local-model path, but refresh settings if the runtime changed.`, + }; + } + if (localRuntimeState.blocker) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: localRuntimeState.blocker, + }; + } + if (!localRuntimeState.detected) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: `${localRuntimeState.label} is not detected at ${localRuntimeState.endpoint}. Start it, load a model, then refresh so ADE can use the local runtime.`, + }; + } + if (!localRuntimeState.modelIds.length) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: `${localRuntimeState.label} responded, but no loaded models were reported yet. Load a model in ${localRuntimeState.label} and refresh.`, + }; + } + if (!localRuntimeState.modelIds.includes(modelId)) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: `${localRuntimeState.label} is running, but ${selectedModelDesc?.displayName ?? formatLocalModelLabel(modelId)} is not in the loaded model list. Choose one of the loaded models or load this model in ${localRuntimeState.label}.`, + }; + } + return { + tone: "success" as const, + title: `${localRuntimeState.label} runtime`, + message: `${localRuntimeState.label} is connected with ${localRuntimeState.modelIds.length} loaded model${localRuntimeState.modelIds.length === 1 ? "" : "s"}${localRuntimeState.health ? ` (${localRuntimeState.health})` : ""}.`, + }; + }, [localRuntimeState, modelId, selectedModelDesc?.displayName]); + + const cliRuntimeBlocked = Boolean( + selectedSessionId + && activeProviderConnection + && !activeProviderConnection.runtimeAvailable + && (activeProviderConnection.blocker || activeProviderConnection.provider === "cursor"), + ); + const cliRuntimeTitle = activeProviderConnection?.provider === "claude" + ? "Claude runtime" + : activeProviderConnection?.provider === "cursor" + ? "Cursor runtime" + : "Codex runtime"; + const cliRuntimeBody = activeProviderConnection?.blocker + ?? (activeProviderConnection?.provider === "cursor" + ? "Cursor agent is not available. Ensure Cursor is installed and the agent is enabled." + : null); + + const mergedRuntimeBanner = useMemo(() => { + if (!cliRuntimeBlocked && !localRuntimeNotice) return null; + if (cliRuntimeBlocked && localRuntimeNotice) { + return { + kind: "merged" as const, + cliTitle: cliRuntimeTitle, + cliBody: cliRuntimeBody ?? "", + localNotice: localRuntimeNotice, + localEndpoint: localRuntimeState?.endpoint, + }; + } + if (cliRuntimeBlocked) { + return { + kind: "cli-only" as const, + cliTitle: cliRuntimeTitle, + cliBody: cliRuntimeBody ?? "", + }; + } + return { + kind: "local-only" as const, + localNotice: localRuntimeNotice!, + localEndpoint: localRuntimeState?.endpoint, + }; + }, [ + cliRuntimeBlocked, + cliRuntimeBody, + cliRuntimeTitle, + localRuntimeNotice, + localRuntimeState?.endpoint, + ]); + + useEffect(() => { + prevModelDescRef.current = getModelDescriptorForPermissionMode(modelId); + }, [modelId]); + const surfaceMode = presentation?.mode ?? "standard"; const identitySessionSettingsBusy = isPersistentIdentitySurface && sessionMutationKind !== null; @@ -812,10 +1018,18 @@ export function AgentChatPane({ const refreshAvailableModels = useCallback(async () => { try { const status = await window.ade.ai.getStatus(); + setAiStatus(status); + setProviderConnections({ + claude: status.providerConnections?.claude ?? null, + codex: status.providerConnections?.codex ?? null, + cursor: status.providerConnections?.cursor ?? null, + }); const available = deriveConfiguredModelIds(status, { includeCursor: true }); setAvailableModelIds(available); return available; } catch { + setAiStatus(null); + setProviderConnections(null); // Fall back to direct model discovery probes below. } @@ -842,7 +1056,11 @@ export function AgentChatPane({ } for (const model of unifiedModels) { const resolved = resolveRegistryModelId(model.id); - if (resolved) available.add(resolved); + if (resolved) { + available.add(resolved); + } else { + available.add(model.id); + } } const ordered = MODEL_REGISTRY @@ -863,19 +1081,6 @@ export function AgentChatPane({ } }, []); - const refreshProviderConnections = useCallback(async () => { - try { - const status = await window.ade.ai.getStatus(); - setProviderConnections({ - claude: status.providerConnections?.claude ?? null, - codex: status.providerConnections?.codex ?? null, - cursor: status.providerConnections?.cursor ?? null, - }); - } catch { - setProviderConnections(null); - } - }, []); - const touchSession = useCallback((sessionId: string | null | undefined, touchedAt = new Date().toISOString()) => { if (!sessionId) return; const previousTouch = localTouchBySessionRef.current.get(sessionId); @@ -963,16 +1168,16 @@ export function AgentChatPane({ }, [selectedSessionId]); useEffect(() => { - void refreshProviderConnections(); - }, [refreshProviderConnections, selectedSession?.provider]); + void refreshAvailableModels(); + }, [refreshAvailableModels, selectedSession?.provider]); useEffect(() => { if (!turnActive || !selectedSession?.provider) return; const timer = window.setInterval(() => { - void refreshProviderConnections(); + void refreshAvailableModels(); }, 5000); return () => window.clearInterval(timer); - }, [refreshProviderConnections, selectedSession?.provider, turnActive]); + }, [refreshAvailableModels, selectedSession?.provider, turnActive]); const refreshComputerUseSnapshot = useCallback(async ( sessionId: string | null, @@ -1595,27 +1800,41 @@ export function AgentChatPane({ }; }, [currentNativeControls]); const buildModelSelectionSnapshot = useCallback((nextModelId: string) => { + const previousDesc = prevModelDescRef.current; const nextDesc = getModelById(nextModelId); + const nextPermissionDesc = getModelDescriptorForPermissionMode(nextModelId); const nextProvider = resolveChatRuntimeProvider(nextDesc); const nextModel = nextProvider === "unified" ? nextModelId : runtimeFacingModelId(nextDesc, nextModelId); const tiers = nextDesc?.reasoningTiers ?? []; const preferred = readLastUsedReasoningEffort({ laneId, modelId: nextModelId }); const nextReasoningEffort = selectReasoningEffort({ tiers, preferred }); + const nextRec = recommendedUnifiedPermissionModeForModel(nextPermissionDesc); return { nextDesc, nextModelId, nextModel, nextProvider, nextReasoningEffort, + nextUnifiedPermissionMode: nextRec, + resetUnifiedPermissionToDefault: shouldResetUnifiedPermissionForModelSwitch(previousDesc, nextPermissionDesc), }; }, [laneId]); const applyModelSelectionSnapshot = useCallback((snapshot: { nextModelId: string; nextReasoningEffort: string | null; + nextUnifiedPermissionMode?: AgentChatUnifiedPermissionMode | null; + resetUnifiedPermissionToDefault?: boolean; }) => { setModelId(snapshot.nextModelId); setReasoningEffort(snapshot.nextReasoningEffort); - }, []); + const nextUnified = snapshot.nextUnifiedPermissionMode ?? null; + const targetUnified = snapshot.resetUnifiedPermissionToDefault + ? (nextUnified ?? initialNativeControls.unifiedPermissionMode) + : nextUnified; + if (targetUnified != null) { + setUnifiedPermissionMode(targetUnified); + } + }, [initialNativeControls.unifiedPermissionMode]); const notifySessionCreated = useCallback((session: AgentChatSession) => { if (!onSessionCreated) return; void Promise.resolve(onSessionCreated(session)).catch((err) => { console.error("notifySessionCreated failed:", err); }); @@ -1628,9 +1847,22 @@ export function AgentChatPane({ if (!laneId) return null; const createPromise = (async () => { const desc = getModelById(modelId); + const permissionDesc = getModelDescriptorForPermissionMode(modelId); const provider = resolveChatRuntimeProvider(desc); const model = provider === "unified" ? modelId : runtimeFacingModelId(desc, modelId); const sessionProfile = resolveChatSessionProfile(computerUsePolicy); + const harnessPermissionMode = provider === "unified" + ? recommendedUnifiedPermissionModeForModel(permissionDesc) + : null; + const nativeControlPayload = harnessPermissionMode + ? { + ...summarizeNativeControls(provider, { + ...currentNativeControls, + unifiedPermissionMode: harnessPermissionMode, + }), + ...(provider === "cursor" ? { cursorConfigValues: currentNativeControls.cursorConfigValues } : {}), + } + : buildNativeControlPayload(provider); const created = await window.ade.agentChat.create({ laneId, provider, @@ -1638,7 +1870,7 @@ export function AgentChatPane({ modelId, sessionProfile, reasoningEffort, - ...buildNativeControlPayload(provider), + ...nativeControlPayload, computerUse: computerUsePolicy, }); loadedHistoryRef.current.delete(created.id); @@ -1665,7 +1897,7 @@ export function AgentChatPane({ createSessionPromiseRef.current = null; } } - }, [buildNativeControlPayload, computerUsePolicy, laneId, modelId, notifySessionCreated, reasoningEffort, refreshSessions, touchSession]); + }, [buildNativeControlPayload, computerUsePolicy, currentNativeControls, laneId, modelId, notifySessionCreated, reasoningEffort, refreshSessions, touchSession]); const handoffSession = useCallback(async () => { if (!canShowHandoff || !selectedSessionId || !handoffModelId || handoffBlocked) return; @@ -2313,11 +2545,22 @@ export function AgentChatPane({ } setSessionMutationKind("model"); + const nextUnifiedForPayload = snapshot.resetUnifiedPermissionToDefault + ? (snapshot.nextUnifiedPermissionMode ?? initialNativeControls.unifiedPermissionMode) + : snapshot.nextUnifiedPermissionMode; + const nextNativeControlPayload = snapshot.nextProvider === "unified" && nextUnifiedForPayload != null + ? { + ...summarizeNativeControls("unified", { + ...currentNativeControls, + unifiedPermissionMode: nextUnifiedForPayload, + }), + } + : buildNativeControlPayload(snapshot.nextProvider); void window.ade.agentChat.updateSession({ sessionId: selectedSessionId, modelId: nextModelId, reasoningEffort: snapshot.nextReasoningEffort, - ...buildNativeControlPayload(snapshot.nextProvider), + ...nextNativeControlPayload, computerUse: computerUsePolicy, }).then((updatedSession) => { applyModelSelectionSnapshot(snapshot); @@ -2419,17 +2662,43 @@ export function AgentChatPane({ {error} ) : null} - {selectedSessionId && !activeProviderConnection?.runtimeAvailable && (activeProviderConnection?.blocker || activeProviderConnection?.provider === "cursor") ? ( + {mergedRuntimeBanner?.kind === "cli-only" ? (
- {activeProviderConnection.provider === "claude" - ? "Claude runtime" - : activeProviderConnection.provider === "cursor" - ? "Cursor runtime" - : "Codex runtime"} + {mergedRuntimeBanner.cliTitle}
- {activeProviderConnection.blocker || "Cursor agent is not available. Ensure Cursor is installed and the agent is enabled."} + {mergedRuntimeBanner.cliBody} +
+
+ ) : null} + {mergedRuntimeBanner?.kind === "local-only" ? ( + + ) : null} + {mergedRuntimeBanner?.kind === "merged" ? ( +
+
+ Runtime status +
+
+
+
+ {mergedRuntimeBanner.cliTitle} +
+
+ {mergedRuntimeBanner.cliBody} +
+
+
+ +
) : null} diff --git a/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx b/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx index 128bfd28d..7ba8ffea7 100644 --- a/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx +++ b/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx @@ -27,7 +27,7 @@ const STEP_META: Record = { }, ai: { title: "AI connections", - subtitle: "Connect your AI providers", + subtitle: "Connect your AI providers and local runtimes", }, helpers: { title: "Background helpers", @@ -56,7 +56,7 @@ const STEP_HEADERS: Record = { tools: { heading: "Developer Tools", sub: "ADE needs git for version control. GitHub CLI unlocks PR creation, review requests, and CI checks." }, ai: { heading: "Connect AI providers", - sub: "Link API keys and CLIs (Claude Code, Codex, Cursor agent) so ADE can power chat, codegen, and background automations. After the CLI is installed and signed in, Cursor models appear in work chat automatically.", + sub: "Link API keys, CLIs, and local runtimes (LM Studio, Ollama, vLLM) so ADE can power chat, codegen, and background helpers. After the CLI is installed and signed in, Cursor models appear in work chat automatically.", }, helpers: { heading: "Background Helpers", sub: "These lightweight AI automations run in the background while you work. All are optional and can be changed anytime in Settings." }, github: { heading: "GitHub Integration", sub: "A personal access token lets ADE create PRs, request reviews, and monitor CI on your behalf." }, diff --git a/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx b/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx index cd6e4b458..3788bb112 100644 --- a/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx +++ b/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx @@ -6,7 +6,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { ProvidersSection } from "./ProvidersSection"; import type { AgentChatEventEnvelope, AiSettingsStatus } from "../../../shared/types"; -function buildStatus(claudeRuntimeAvailable: boolean): AiSettingsStatus { +function buildStatus(claudeRuntimeAvailable: boolean, localModels: string[] = []): AiSettingsStatus { return { mode: "subscription", availableProviders: { @@ -20,7 +20,32 @@ function buildStatus(claudeRuntimeAvailable: boolean): AiSettingsStatus { cursor: [], }, features: [], - detectedAuth: [], + detectedAuth: localModels.length > 0 + ? [ + { + type: "local", + provider: "lmstudio", + endpoint: "http://localhost:1234", + }, + ] + : [], + availableModelIds: localModels, + runtimeConnections: { + lmstudio: { + provider: "lmstudio", + label: "LM Studio", + kind: "local", + endpoint: "http://localhost:1234", + configured: true, + authAvailable: false, + runtimeDetected: localModels.length > 0, + runtimeAvailable: localModels.length > 0, + health: localModels.length > 0 ? "ready" : "unreachable", + blocker: localModels.length > 0 ? null : "No lmstudio runtime with loaded models was detected.", + loadedModelIds: localModels, + lastCheckedAt: "2026-03-17T19:00:00.000Z", + }, + }, providerConnections: { claude: { provider: "claude", @@ -81,9 +106,17 @@ describe("ProvidersSection", () => { globalThis.window.ade = { ai: { getStatus: vi.fn() - .mockResolvedValueOnce(buildStatus(true)) - .mockResolvedValueOnce(buildStatus(false)), + .mockResolvedValueOnce(buildStatus(true, ["lmstudio/meta-llama-3.1-70b-instruct", "lmstudio/qwen2.5-coder:32b"])) + .mockResolvedValueOnce(buildStatus(false, ["lmstudio/meta-llama-3.1-70b-instruct", "lmstudio/qwen2.5-coder:32b"])), listApiKeys: vi.fn().mockResolvedValue([]), + updateConfig: vi.fn().mockResolvedValue(undefined), + }, + projectConfig: { + get: vi.fn().mockResolvedValue({ + effective: { + ai: {}, + }, + }), }, agentChat: { onEvent: vi.fn((listener: (envelope: AgentChatEventEnvelope) => void) => { @@ -144,4 +177,19 @@ describe("ProvidersSection", () => { expect((await screen.findAllByText("Connected")).length).toBeGreaterThan(0); expect(screen.getAllByText("/Users/arul/.local/bin/claude").length).toBeGreaterThan(0); }); + + it("renders local runtime details and loaded local models", async () => { + render(); + + await waitFor(() => { + expect(window.ade.ai.getStatus).toHaveBeenCalledTimes(1); + expect(window.ade.ai.listApiKeys).toHaveBeenCalledTimes(1); + }); + + expect((await screen.findAllByText("LM Studio")).length).toBeGreaterThan(0); + expect(screen.getAllByText("Ready").length).toBeGreaterThan(0); + expect(screen.getAllByText("LM Studio is reachable at http://localhost:1234. ADE can use 2 loaded models from this runtime (ready).").length).toBeGreaterThan(0); + expect(screen.getAllByText("meta-llama-3.1-70b-instruct (LM Studio)").length).toBeGreaterThan(0); + expect(screen.getAllByText("qwen2.5-coder:32b (LM Studio)").length).toBeGreaterThan(0); + }); }); diff --git a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx index 5471dca23..d9c6efc8a 100644 --- a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx +++ b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx @@ -1,10 +1,21 @@ import React, { useCallback, useEffect, useMemo, useRef, useState } from "react"; import type { AgentChatEventEnvelope, + AiConfig, AiApiKeyVerificationResult, AiProviderConnectionStatus, + AiRuntimeConnectionStatus, AiSettingsStatus, + ProjectConfigSnapshot, } from "../../../shared/types"; +import { + getLocalModelIdTail, + getLocalProviderDefaultEndpoint, + getModelById, + LOCAL_PROVIDER_LABELS, + parseLocalProviderFromModelId, + type LocalProviderFamily, +} from "../../../shared/modelRegistry"; import { ArrowsClockwise, CheckCircle, @@ -58,6 +69,16 @@ const CLI_TOOLS: Array<{ }, ]; +const LOCAL_PROVIDER_SPECS: Array<{ + provider: LocalProviderFamily; + label: string; + description: string; +}> = [ + { provider: "lmstudio", label: "LM Studio", description: "OpenAI-compatible local server" }, + { provider: "ollama", label: "Ollama", description: "OpenAI-compatible local server" }, + { provider: "vllm", label: "vLLM", description: "OpenAI-compatible local server" }, +]; + const API_KEY_PROVIDERS: Array<{ provider: string; label: string; @@ -76,6 +97,13 @@ const API_KEY_PROVIDERS: Array<{ { provider: "openrouter", label: "OpenRouter", envVar: "OPENROUTER_API_KEY", placeholder: "sk-or-...", accent: "#A78BFA" }, ]; +type LocalProviderDraft = { + enabled: boolean; + endpoint: string; + autoDetect: boolean; + preferredModelId: string; +}; + const groupLabelStyle: React.CSSProperties = { ...LABEL_STYLE, fontSize: 11, @@ -156,6 +184,18 @@ function buildCliMessage(tool: (typeof CLI_TOOLS)[number], connection: AiProvide return `CLI not found in PATH. Install: ${tool.installHint}. If already installed, ensure it is on your shell PATH and use Refresh.`; } +function formatLocalModelLabel(modelId: string): string { + const descriptor = getModelById(modelId); + if (descriptor) return descriptor.displayName; + const provider = parseLocalProviderFromModelId(modelId); + if (provider) { + const tail = getLocalModelIdTail(modelId, provider); + const brand = LOCAL_PROVIDER_LABELS[provider]; + return tail.length ? `${tail} (${brand})` : String(modelId ?? "").trim(); + } + return String(modelId ?? "").trim(); +} + const AUTH_ERROR_SIGNALS = [ "invalid authentication credentials", "authentication error", @@ -190,11 +230,40 @@ function shouldRefreshProvidersForChatEvent(envelope: AgentChatEventEnvelope): b return false; } +function buildLocalProviderDrafts( + snapshot: ProjectConfigSnapshot | null | undefined, + status: (AiSettingsStatus & { runtimeConnections?: Record }) | null | undefined, +): Record { + const configured = snapshot?.effective.ai?.localProviders ?? {}; + return Object.fromEntries( + LOCAL_PROVIDER_SPECS.map((spec) => { + const runtimeConnection = status?.runtimeConnections?.[spec.provider]; + const providerConfig = configured[spec.provider]; + return [spec.provider, { + enabled: providerConfig?.enabled ?? true, + endpoint: + (typeof providerConfig?.endpoint === "string" && providerConfig.endpoint.trim().length + ? providerConfig.endpoint.trim() + : runtimeConnection?.endpoint?.trim()) + ?? getLocalProviderDefaultEndpoint(spec.provider), + autoDetect: providerConfig?.autoDetect ?? true, + preferredModelId: typeof providerConfig?.preferredModelId === "string" ? providerConfig.preferredModelId : "", + }]; + }), + ) as Record; +} + export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefreshOnMount?: boolean }) { - const [status, setStatus] = useState(null); + const [status, setStatus] = useState<(AiSettingsStatus & { runtimeConnections?: Record }) | null>(null); + const [projectConfigSnapshot, setProjectConfigSnapshot] = useState(null); const [storedProviders, setStoredProviders] = useState([]); const [loading, setLoading] = useState(true); const [editingProvider, setEditingProvider] = useState(null); + const [editingLocalProvider, setEditingLocalProvider] = useState(null); + const [savingLocalProvider, setSavingLocalProvider] = useState(null); + const [localProviderDrafts, setLocalProviderDrafts] = useState>(() => + buildLocalProviderDrafts(null, null), + ); const [editValue, setEditValue] = useState(""); const [notice, setNotice] = useState(null); const [error, setError] = useState(null); @@ -208,11 +277,16 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh } setError(null); try { - const [nextStatus, nextStoredProviders] = await Promise.all([ + const [nextStatus, nextStoredProviders, nextProjectConfig] = await Promise.all([ window.ade.ai.getStatus(options?.force ? { force: true } : undefined), window.ade.ai.listApiKeys(), + window.ade.projectConfig.get(), ]); setStatus(nextStatus); + setProjectConfigSnapshot(nextProjectConfig); + if (editingLocalProvider == null && savingLocalProvider == null) { + setLocalProviderDrafts(buildLocalProviderDrafts(nextProjectConfig, nextStatus)); + } setStoredProviders(nextStoredProviders.map((entry) => entry.trim().toLowerCase()).filter(Boolean)); } catch (err) { setError(err instanceof Error ? err.message : String(err)); @@ -221,7 +295,7 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh setLoading(false); } } - }, []); + }, [editingLocalProvider, savingLocalProvider]); useEffect(() => { void refreshStatus(forceRefreshOnMount ? { force: true } : undefined); @@ -245,7 +319,7 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh }; }, [refreshStatus]); - const detectedAuth = status?.detectedAuth ?? []; + const detectedAuth = useMemo(() => status?.detectedAuth ?? [], [status?.detectedAuth]); const providerConnections = status?.providerConnections; const isInitialCheckInFlight = loading && status == null; @@ -261,14 +335,30 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh return map; }, [detectedAuth]); - const localEndpoints = useMemo(() => { - const entries: Array<{ provider: string; endpoint: string }> = []; - for (const entry of detectedAuth) { - if (entry.type !== "local" || !entry.provider || !entry.endpoint) continue; - entries.push({ provider: entry.provider, endpoint: entry.endpoint }); - } - return entries; - }, [detectedAuth]); + const localRuntimes = useMemo(() => { + const availableModelIds = status?.availableModelIds ?? []; + const runtimeConnections = status?.runtimeConnections ?? {}; + return LOCAL_PROVIDER_SPECS.map((spec) => { + const runtimeConnection = runtimeConnections[spec.provider] ?? null; + const detected = detectedAuth.find( + (entry): entry is { type: "local"; provider: LocalProviderFamily; endpoint: string } => + entry.type === "local" && entry.provider === spec.provider, + ) ?? null; + const modelIds = runtimeConnection?.loadedModelIds?.length + ? runtimeConnection.loadedModelIds.filter((rawId) => String(rawId ?? "").trim().startsWith(`${spec.provider}/`)) + : availableModelIds.filter((rawId) => String(rawId ?? "").trim().startsWith(`${spec.provider}/`)); + return { + ...spec, + endpoint: runtimeConnection?.endpoint ?? detected?.endpoint ?? getLocalProviderDefaultEndpoint(spec.provider), + health: runtimeConnection?.health ?? null, + blocker: runtimeConnection?.blocker ?? null, + runtimeAvailable: runtimeConnection?.runtimeAvailable ?? false, + detected, + modelIds, + hasModels: modelIds.length > 0, + }; + }); + }, [detectedAuth, status?.availableModelIds, status?.runtimeConnections]); const apiKeyStoreWarning = useMemo(() => { if (status?.apiKeyStore?.legacyPlaintextDetected) { @@ -345,6 +435,57 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh } }; + const updateLocalProviderDraft = useCallback(( + provider: LocalProviderFamily, + patch: Partial, + ) => { + setLocalProviderDrafts((prev) => ({ + ...prev, + [provider]: { + ...prev[provider], + ...patch, + }, + })); + }, []); + + const beginEditingLocalRuntime = useCallback((provider: LocalProviderFamily) => { + setEditingLocalProvider(provider); + setError(null); + setNotice(null); + }, []); + + const cancelEditingLocalRuntime = useCallback(() => { + setEditingLocalProvider(null); + setLocalProviderDrafts(buildLocalProviderDrafts(projectConfigSnapshot, status)); + }, [projectConfigSnapshot, status]); + + const saveLocalProvider = useCallback(async (provider: LocalProviderFamily) => { + const draft = localProviderDrafts[provider]; + if (!draft) return; + setSavingLocalProvider(provider); + setError(null); + setNotice(null); + try { + await window.ade.ai.updateConfig({ + localProviders: { + [provider]: { + enabled: draft.enabled, + endpoint: draft.endpoint.trim(), + autoDetect: draft.autoDetect, + preferredModelId: draft.preferredModelId.trim() || null, + }, + } as AiConfig["localProviders"], + }); + setNotice(`${LOCAL_PROVIDER_LABELS[provider]} settings saved.`); + setEditingLocalProvider(null); + await refreshStatus({ force: true }); + } catch (err) { + setError(err instanceof Error ? err.message : String(err)); + } finally { + setSavingLocalProvider(null); + } + }, [localProviderDrafts, refreshStatus]); + return (
{notice && ( @@ -645,56 +786,303 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh
-
Local runtimes
- {localEndpoints.length > 0 ? ( -
- {localEndpoints.map((entry) => ( +
+
+
Local runtimes
+
+ LM Studio, Ollama, and vLLM become ready once at least one model is loaded and the server exposes its OpenAI-compatible `/v1/models` list. +
+
+ +
+ +
+ {localRuntimes.map((entry) => { + const isEditing = editingLocalProvider === entry.provider; + const isSaving = savingLocalProvider === entry.provider; + const draft = localProviderDrafts[entry.provider]; + const tone = entry.blocker + ? { color: COLORS.warning, label: "Blocked" } + : entry.runtimeAvailable || (entry.detected && entry.hasModels) + ? { color: COLORS.success, label: entry.hasModels ? "Ready" : "Connected" } + : { color: COLORS.warning, label: "Not detected" }; + const loadedModels = entry.modelIds.slice(0, 4); + const extraModelCount = Math.max(0, entry.modelIds.length - loadedModels.length); + const message = entry.blocker + ? entry.blocker + : entry.detected + ? entry.hasModels + ? `${entry.label} is reachable at ${entry.endpoint}. ADE can use ${entry.modelIds.length} loaded model${entry.modelIds.length === 1 ? "" : "s"} from this runtime${entry.health ? ` (${entry.health})` : ""}.` + : `${entry.label} responded, but no loaded models were reported yet. Load a model in ${entry.label} and refresh.` + : `${entry.label} was not detected. Start it, load at least one model, then refresh so ADE can discover its OpenAI-compatible server.`; + + return (
-
-
- {entry.provider} +
+
+ +
+
+ {entry.label} +
+
+ {entry.description} +
+
+
+
+ {entry.detected ? : } + + {tone.label} +
- - {entry.endpoint} -
-
- - - Reachable + +
+ {message} +
+ +
+ + {draft?.enabled === false ? "Disabled" : "Enabled"} + + + {draft?.autoDetect === false ? "Manual only" : "Auto-detect fallback"}
+ + + {draft?.endpoint?.trim() || entry.endpoint} + + +
+ {loadedModels.length > 0 ? ( + <> + {loadedModels.map((modelId) => ( + + + {formatLocalModelLabel(modelId)} + + ))} + {extraModelCount > 0 ? ( + + +{extraModelCount} more + + ) : null} + + ) : ( + + No loaded models reported yet. + + )} +
+ + {isEditing && draft ? ( +
+ + + + +
+ ) : null} + +
+ {isEditing ? ( + <> + + + + ) : ( + <> + + + + )} +
- ))} -
- ) : ( -
- - No local model endpoints detected (Ollama, LM Studio, vLLM). -
- )} + ); + })} +
+
+ + If LM Studio is running but ADE does not show it, load at least one model in LM Studio, then use Refresh. ADE only marks a local runtime as ready after `/v1/models` returns loaded models. +
); diff --git a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx index 40a943aa7..a2c65fd6a 100644 --- a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx +++ b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx @@ -2,7 +2,10 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { createPortal } from "react-dom"; import { AnimatePresence, motion } from "motion/react"; import { + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, + getLocalModelIdTail, + parseLocalProviderFromModelId, resolveModelDescriptor, type ModelDescriptor, } from "../../../shared/modelRegistry"; @@ -16,6 +19,7 @@ import { createModelOrderMap, matchesQuery, PROVIDER_BADGE_COLORS, + providerLabel, subsectionKeyForModel, sourceSectionLabel, type ModelProviderBlock, @@ -69,7 +73,7 @@ function modelAvailabilityLabel(model: ModelDescriptor, isAvailable: boolean): s if (isAvailable) { if (model.family === "cursor" && model.isCliWrapped) return "Cursor CLI ready"; if (model.isCliWrapped) return "Subscription ready"; - if (model.authTypes.includes("local")) return "Local ready"; + if (model.authTypes.includes("local")) return `${providerLabel(model.family)} ready`; if (model.authTypes.includes("api-key")) return "API ready"; if (model.authTypes.includes("oauth")) return "OAuth ready"; if (model.authTypes.includes("openrouter")) return "OpenRouter ready"; @@ -79,7 +83,7 @@ function modelAvailabilityLabel(model: ModelDescriptor, isAvailable: boolean): s return "Cursor CLI · run `agent login` or set CURSOR_API_KEY / CURSOR_AUTH_TOKEN"; } if (model.isCliWrapped) return "Subscription · not configured"; - if (model.authTypes.includes("local")) return "Local · not configured"; + if (model.authTypes.includes("local")) return `${providerLabel(model.family)} · not configured`; if (model.authTypes.includes("api-key")) return "API · not configured"; if (model.authTypes.includes("oauth")) return "OAuth · not configured"; if (model.authTypes.includes("openrouter")) return "OpenRouter · not configured"; @@ -112,6 +116,28 @@ function createUnknownModelPlaceholder(modelId: string): ModelDescriptor { isCliWrapped: true, }; } + const localProvider = parseLocalProviderFromModelId(modelId); + if (localProvider) { + const shortId = getLocalModelIdTail(modelId, localProvider) || modelId; + const brand = LOCAL_PROVIDER_LABELS[localProvider]; + return { + id: modelId, + shortId, + displayName: shortId, + family: localProvider, + authTypes: ["local"], + contextWindow: 0, + maxOutputTokens: 0, + capabilities: { tools: false, vision: false, reasoning: false, streaming: true }, + color: PROVIDER_BADGE_COLORS[localProvider] ?? "#64748B", + sdkProvider: "@ai-sdk/openai-compatible", + sdkModelId: shortId, + isCliWrapped: false, + discoverySource: localProvider === "lmstudio" ? "lmstudio-openai" : localProvider, + harnessProfile: "guarded", + aliases: brand ? [brand] : [], + }; + } return { id: modelId, shortId: modelId, diff --git a/apps/desktop/src/renderer/lib/modelOptions.test.ts b/apps/desktop/src/renderer/lib/modelOptions.test.ts index 6662b3128..1dd915efa 100644 --- a/apps/desktop/src/renderer/lib/modelOptions.test.ts +++ b/apps/desktop/src/renderer/lib/modelOptions.test.ts @@ -12,7 +12,7 @@ import { // Helpers // --------------------------------------------------------------------------- -function makeStatus(overrides: Partial = {}): AiSettingsStatus { +function makeStatus(overrides: Partial & Record = {}): AiSettingsStatus { return { mode: "guest", availableProviders: { claude: false, codex: false, cursor: false }, @@ -206,13 +206,48 @@ describe("deriveConfiguredModelIds", () => { it("includes local models for local auth", () => { const status = makeStatus({ - detectedAuth: [{ type: "local", provider: "ollama" }], + detectedAuth: [{ type: "local", provider: "lmstudio" }], + availableModelIds: ["lmstudio/meta-llama-3.1-70b-instruct", "lmstudio/qwen2.5-coder:32b", "ollama/llama3.2"], }); const ids = deriveConfiguredModelIds(status); - const ollamaModels = MODEL_REGISTRY.filter( - (m) => m.family === "ollama" && !m.deprecated && !m.isCliWrapped, - ); - expect(ids.length).toBe(ollamaModels.length); + expect(ids).toContain("lmstudio/meta-llama-3.1-70b-instruct"); + expect(ids).toContain("lmstudio/qwen2.5-coder:32b"); + expect(ids).not.toContain("ollama/llama3.2"); + }); + + it("includes local models from runtimeConnections when availableModelIds is empty", () => { + const status = makeStatus({ + runtimeConnections: { + lmstudio: { + provider: "lmstudio", + label: "LM Studio", + kind: "local", + configured: true, + authAvailable: true, + runtimeDetected: true, + runtimeAvailable: true, + health: "ready", + endpoint: "http://localhost:1234", + blocker: null, + loadedModelIds: ["lmstudio/meta-llama-3.1-70b-instruct"], + lastCheckedAt: "2026-03-17T19:00:00.000Z", + }, + }, + }); + const ids = deriveConfiguredModelIds(status); + expect(ids).toContain("lmstudio/meta-llama-3.1-70b-instruct"); + }); + + it("prefers discovered local model ids over static local placeholders", () => { + const status = makeStatus({ + detectedAuth: [{ type: "local", provider: "lmstudio" }], + availableModelIds: ["lmstudio/meta-llama-3.1-70b-instruct"], + }); + + const ids = deriveConfiguredModelIds(status); + + expect(ids).toContain("lmstudio/meta-llama-3.1-70b-instruct"); + expect(ids).not.toContain("lmstudio/auto"); }); it("merges models from multiple auth sources without duplicates", () => { @@ -335,6 +370,18 @@ describe("deriveConfiguredModelOptions", () => { expect(sonnet).toBeTruthy(); expect(sonnet!.label).toBe("Claude Sonnet 4.6"); }); + + it("preserves dynamic local model ids as options", () => { + const status = makeStatus({ + detectedAuth: [{ type: "local", provider: "lmstudio" }], + availableModelIds: ["lmstudio/local-test-model-123"], + }); + const options = deriveConfiguredModelOptions(status); + const local = options.find((o) => o.id === "lmstudio/local-test-model-123"); + expect(local).toBeTruthy(); + expect(local!.label).toBe("local-test-model-123 (LM Studio)"); + expect(local!.description).toBe("lmstudio (local)"); + }); }); // --------------------------------------------------------------------------- diff --git a/apps/desktop/src/renderer/lib/modelOptions.ts b/apps/desktop/src/renderer/lib/modelOptions.ts index 49b3d4fa5..b059a6fdb 100644 --- a/apps/desktop/src/renderer/lib/modelOptions.ts +++ b/apps/desktop/src/renderer/lib/modelOptions.ts @@ -1,10 +1,42 @@ -import type { AiModelDescriptor, AiSettingsStatus, ModelId } from "../../shared/types"; -import { MODEL_REGISTRY, getModelById, type ModelDescriptor } from "../../shared/modelRegistry"; +import type { AiModelDescriptor, AiRuntimeConnectionStatus, AiSettingsStatus, ModelId } from "../../shared/types"; +import { + LOCAL_PROVIDER_LABELS, + MODEL_REGISTRY, + getLocalModelIdTail, + getModelById, + isLocalProviderFamily, + parseLocalProviderFromModelId, + type ModelDescriptor, +} from "../../shared/modelRegistry"; function normalizeAuthProvider(provider: string | undefined): string { return String(provider ?? "").trim().toLowerCase(); } +function getLocalModelLabel(modelId: string): string { + const provider = parseLocalProviderFromModelId(modelId); + if (!provider) return modelId; + const tail = getLocalModelIdTail(modelId, provider); + return tail.length ? tail : modelId; +} + +function buildFallbackModelOption(modelId: string): AiModelDescriptor { + const provider = parseLocalProviderFromModelId(modelId); + if (provider) { + const pLabel = LOCAL_PROVIDER_LABELS[provider]; + return { + id: modelId, + label: getLocalModelLabel(modelId), + description: `${pLabel} local model`, + }; + } + return { + id: modelId, + label: modelId, + description: "Unknown model", + }; +} + export function describeModelSource(descriptor: ModelDescriptor): string { if (descriptor.authTypes.includes("local")) { return "local"; @@ -41,6 +73,39 @@ function addKnownModelIds(ids: Set, family: string, includeCliWrapped: } } +function addAvailableModelIdsByPrefix( + ids: Set, + availableModelIds: readonly ModelId[] | undefined, + prefix: string, +) { + if (!availableModelIds?.length) return; + const normalizedPrefix = prefix.trim(); + if (!normalizedPrefix.length) return; + for (const rawId of availableModelIds) { + const id = String(rawId ?? "").trim(); + if (id.startsWith(normalizedPrefix)) { + ids.add(id as ModelId); + } + } +} + +function hasDynamicLocalModelIdsForProvider( + provider: string, + availableModelIds: readonly ModelId[] | undefined, + runtimeConnections: Record | undefined, +): boolean { + const normalizedProvider = normalizeAuthProvider(provider); + if (!isLocalProviderFamily(normalizedProvider)) { + return false; + } + const prefix = `${normalizedProvider}/`; + const loadedModelIds = runtimeConnections?.[normalizedProvider]?.loadedModelIds; + return Boolean( + availableModelIds?.some((rawId) => String(rawId ?? "").trim().startsWith(prefix)) + || loadedModelIds?.some((rawId) => String(rawId ?? "").trim().startsWith(prefix)), + ); +} + export interface DeriveModelOptions { /** Include cursor/* models in the result. Defaults to `false`. */ includeCursor?: boolean; @@ -53,10 +118,11 @@ export function deriveConfiguredModelIds( if (!status) return []; const { includeCursor = false } = options ?? {}; + const runtimeConnections = (status as { runtimeConnections?: Record } | null | undefined)?.runtimeConnections; // Derive available models from detectedAuth. For Cursor CLI, merge in // `status.availableModelIds` entries under `cursor/*` (main lists them after - // `agent models`); other providers still use registry + auth only. + // `agent models`); local runtimes also merge in discovered loaded models. const ids = new Set(); for (const auth of status.detectedAuth ?? []) { @@ -81,8 +147,29 @@ export function deriveConfiguredModelIds( if (auth.type === "local") { const provider = normalizeAuthProvider(auth.provider); - if (provider.length) addKnownModelIds(ids, provider, false); + if (provider.length) { + if (!hasDynamicLocalModelIdsForProvider(provider, status.availableModelIds, runtimeConnections)) { + addKnownModelIds(ids, provider, false); + } + addAvailableModelIdsByPrefix(ids, status.availableModelIds, `${provider}/`); + addAvailableModelIdsByPrefix(ids, runtimeConnections?.[provider]?.loadedModelIds, `${provider}/`); + } + } + } + + for (const [provider, connection] of Object.entries(runtimeConnections ?? {})) { + const normalizedProvider = normalizeAuthProvider(provider); + if (!isLocalProviderFamily(normalizedProvider)) { + continue; + } + if (connection == null || connection.runtimeAvailable !== true) { + continue; + } + if (!hasDynamicLocalModelIdsForProvider(normalizedProvider, status.availableModelIds, runtimeConnections)) { + addKnownModelIds(ids, normalizedProvider, false); } + addAvailableModelIdsByPrefix(ids, connection.loadedModelIds, `${normalizedProvider}/`); + addAvailableModelIdsByPrefix(ids, status.availableModelIds, `${normalizedProvider}/`); } if (includeCursor) { @@ -115,7 +202,11 @@ export function deriveConfiguredModelOptions( ): AiModelDescriptor[] { return deriveConfiguredModelIds(status, options).flatMap((modelId) => { const descriptor = getModelById(modelId); - return descriptor ? [descriptorToModelOption(descriptor)] : []; + return descriptor + ? [descriptorToModelOption(descriptor)] + : parseLocalProviderFromModelId(modelId) + ? [buildFallbackModelOption(modelId)] + : []; }); } @@ -126,6 +217,7 @@ export function includeSelectedModelOption( const modelId = String(selectedModelId ?? "").trim(); if (!modelId.length || options.some((option) => option.id === modelId)) return options; const descriptor = getModelById(modelId); - if (!descriptor) return options; - return [descriptorToModelOption(descriptor), ...options]; + if (descriptor) return [descriptorToModelOption(descriptor), ...options]; + if (parseLocalProviderFromModelId(modelId)) return [buildFallbackModelOption(modelId), ...options]; + return options; } diff --git a/apps/desktop/src/shared/modelRegistry.test.ts b/apps/desktop/src/shared/modelRegistry.test.ts index 4fdbf920a..fa446ba97 100644 --- a/apps/desktop/src/shared/modelRegistry.test.ts +++ b/apps/desktop/src/shared/modelRegistry.test.ts @@ -4,6 +4,7 @@ import { getAvailableModels, getDefaultModelDescriptor, getModelById, + getModelDescriptorForPermissionMode, getRuntimeModelRefForDescriptor, listModelDescriptorsForProvider, MODEL_REGISTRY, @@ -76,6 +77,19 @@ describe("modelRegistry", () => { expect(resolveModelDescriptor("nonexistent/model-id")).toBeUndefined(); }); + it("getModelDescriptorForPermissionMode matches getModelById for known locals", () => { + const id = "ollama/qwen2.5-coder:32b"; + expect(getModelDescriptorForPermissionMode(id)).toEqual(getModelById(id)); + }); + + it("getModelDescriptorForPermissionMode yields guarded local for ollama/auto when getModelById is undefined", () => { + expect(getModelById("ollama/auto")).toBeUndefined(); + const perm = getModelDescriptorForPermissionMode("ollama/auto"); + expect(perm?.family).toBe("ollama"); + expect(perm?.harnessProfile).toBe("guarded"); + expect(perm?.authTypes).toContain("local"); + }); + it("resolves gpt-5.4 shortId to the API-key variant, not the codex variant", () => { const resolved = resolveModelAlias("gpt-5.4"); expect(resolved).toBeTruthy(); diff --git a/apps/desktop/src/shared/modelRegistry.ts b/apps/desktop/src/shared/modelRegistry.ts index 18f5666c7..508d4bf44 100644 --- a/apps/desktop/src/shared/modelRegistry.ts +++ b/apps/desktop/src/shared/modelRegistry.ts @@ -26,6 +26,8 @@ export type ModelCapabilities = { streaming: boolean; }; +export type LocalModelHarnessProfile = "verified" | "guarded" | "read_only"; + export type ModelDescriptor = { id: string; shortId: string; @@ -49,6 +51,21 @@ export type ModelDescriptor = { outputPricePer1M?: number; /** Curated cost tier for UI display (missions model selector) */ costTier?: "low" | "medium" | "high" | "very_high"; + /** ADE-owned safety/tooling profile for local and experimental models. */ + harnessProfile?: LocalModelHarnessProfile; + /** Source of runtime-discovered descriptors for debugging and UI hints. */ + discoverySource?: "lmstudio-rest" | "lmstudio-openai" | "ollama" | "vllm"; +}; + +export type DynamicLocalModelDescriptorOptions = { + displayName?: string; + contextWindow?: number; + maxOutputTokens?: number; + capabilities?: Partial; + reasoningTiers?: string[]; + aliases?: string[]; + harnessProfile?: LocalModelHarnessProfile; + discoverySource?: ModelDescriptor["discoverySource"]; }; export type WorkerExecutionPath = "cli" | "api" | "local"; @@ -65,7 +82,8 @@ export function isModelProviderGroup(value: string | null | undefined): value is const ALL_CAPS: ModelCapabilities = { tools: true, vision: true, reasoning: true, streaming: true }; const NO_REASONING: ModelCapabilities = { tools: true, vision: true, reasoning: false, streaming: true }; const BASIC_CAPS: ModelCapabilities = { tools: true, vision: false, reasoning: false, streaming: true }; -const LOCAL_PROVIDER_LABELS: Record = { +/** Human-readable names for Ollama / LM Studio / vLLM (shared across main, renderer, and MCP). */ +export const LOCAL_PROVIDER_LABELS: Record = { ollama: "Ollama", lmstudio: "LM Studio", vllm: "vLLM", @@ -638,6 +656,7 @@ export const MODEL_REGISTRY: ModelDescriptor[] = [ color: "#71717A", sdkProvider: "@ai-sdk/openai-compatible", sdkModelId: "auto", + harnessProfile: "guarded", isCliWrapped: false, }, { @@ -652,6 +671,7 @@ export const MODEL_REGISTRY: ModelDescriptor[] = [ color: "#64748B", sdkProvider: "@ai-sdk/openai-compatible", sdkModelId: "auto", + harnessProfile: "guarded", isCliWrapped: false, }, { @@ -666,6 +686,7 @@ export const MODEL_REGISTRY: ModelDescriptor[] = [ color: "#475569", sdkProvider: "@ai-sdk/openai-compatible", sdkModelId: "auto", + harnessProfile: "guarded", isCliWrapped: false, }, ]; @@ -678,6 +699,8 @@ let byId = new Map(); let byShortId = new Map(); let byAlias = new Map(); let bySdkModelId = new Map(); +let dynamicLocalById = new Map(); +let dynamicLocalByAlias = new Map(); function rebuildIndexes() { byId = new Map(); @@ -731,10 +754,38 @@ export function validateModelRegistry(models: ModelDescriptor[] = MODEL_REGISTRY validateModelRegistry(); rebuildIndexes(); -function isLocalProviderFamily(value: string): value is LocalProviderFamily { +export function isLocalProviderFamily(value: string): value is LocalProviderFamily { return value === "ollama" || value === "lmstudio" || value === "vllm"; } +/** First path segment of `provider/modelId` when it is a known local provider. */ +export function parseLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { + const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase() ?? ""; + return isLocalProviderFamily(provider) ? provider : null; +} + +/** Model name segment after `provider/` for local refs; empty string if missing. */ +export function getLocalModelIdTail(modelId: string, provider: LocalProviderFamily): string { + return String(modelId ?? "").trim().slice(provider.length + 1).trim(); +} + +/** + * Descriptor for unified permission / harness decisions when the registry has no row yet. + * `getModelById` returns undefined for refs such as `ollama/auto`; this still returns a + * guarded local descriptor so the UI matches main-process harness behavior. + */ +export function getModelDescriptorForPermissionMode(modelId: string): ModelDescriptor | undefined { + const resolved = getModelById(modelId); + if (resolved) return resolved; + const provider = parseLocalProviderFromModelId(modelId); + if (!provider) return undefined; + const tail = getLocalModelIdTail(modelId, provider); + if (!tail.length || tail === "auto") { + return createDynamicLocalModelDescriptor(provider, "auto", { harnessProfile: "guarded" }); + } + return createDynamicLocalModelDescriptor(provider, tail); +} + function parseDynamicLocalModelRef(modelRef: string): { provider: LocalProviderFamily; modelId: string } | null { const normalized = modelRef.trim(); if (!normalized.length) return null; @@ -754,25 +805,63 @@ function toDynamicLocalDisplayName(provider: LocalProviderFamily, modelId: strin export function createDynamicLocalModelDescriptor( provider: LocalProviderFamily, modelId: string, + options?: DynamicLocalModelDescriptorOptions, ): ModelDescriptor { const normalizedModelId = modelId.trim(); + const displayName = options?.displayName?.trim() || toDynamicLocalDisplayName(provider, normalizedModelId); + const capabilities: ModelCapabilities = { + ...BASIC_CAPS, + ...(options?.capabilities ?? {}), + }; + const aliases = [ + `${provider}:${normalizedModelId}`, + ...(options?.aliases ?? []), + ].filter((value, index, list) => { + const normalized = value.trim(); + return normalized.length > 0 && list.findIndex((entry) => entry.trim().toLowerCase() === normalized.toLowerCase()) === index; + }); return { id: `${provider}/${normalizedModelId}`, shortId: normalizedModelId, - displayName: toDynamicLocalDisplayName(provider, normalizedModelId), + displayName, family: provider, authTypes: ["local"], - contextWindow: 128_000, - maxOutputTokens: 8_192, - capabilities: { ...BASIC_CAPS }, + contextWindow: options?.contextWindow ?? 128_000, + maxOutputTokens: options?.maxOutputTokens ?? 8_192, + capabilities, color: LOCAL_PROVIDER_COLORS[provider], sdkProvider: "@ai-sdk/openai-compatible", sdkModelId: normalizedModelId, - aliases: [`${provider}:${normalizedModelId}`], + ...(options?.reasoningTiers?.length ? { reasoningTiers: [...options.reasoningTiers] } : {}), + aliases, isCliWrapped: false, + harnessProfile: options?.harnessProfile ?? "guarded", + ...(options?.discoverySource ? { discoverySource: options.discoverySource } : {}), }; } +function isDynamicLocalDescriptor(descriptor: ModelDescriptor): boolean { + return descriptor.authTypes.includes("local") && !byId.has(descriptor.id); +} + +export function replaceDynamicLocalModelDescriptors(descriptors: ModelDescriptor[]): void { + dynamicLocalById = new Map(); + dynamicLocalByAlias = new Map(); + + for (const descriptor of descriptors) { + if (!isDynamicLocalDescriptor(descriptor)) continue; + dynamicLocalById.set(descriptor.id, descriptor); + for (const alias of descriptor.aliases ?? []) { + const normalized = alias.trim().toLowerCase(); + if (normalized.length) dynamicLocalByAlias.set(normalized, descriptor); + } + } +} + +export function getDynamicLocalModelDescriptors(): ModelDescriptor[] { + return [...dynamicLocalById.values()]; +} + export function getLocalProviderDefaultEndpoint(provider: LocalProviderFamily): string { return LOCAL_PROVIDER_ENDPOINTS[provider]; } @@ -896,6 +985,8 @@ export function sortCursorCliDescriptorsForPicker(descriptors: ModelDescriptor[] export function getModelById(id: string): ModelDescriptor | undefined { const cached = byId.get(id); if (cached) return cached; + const dynamic = dynamicLocalById.get(id); + if (dynamic) return dynamic; const local = parseDynamicLocalModelRef(id); if (local) return createDynamicLocalModelDescriptor(local.provider, local.modelId); const cursor = parseDynamicCursorModelRef(id); @@ -947,12 +1038,20 @@ export function getAvailableModels( return false; }); - return MODEL_REGISTRY.filter((model) => !model.deprecated && hasAuthForModel(model)); + const staticModels = MODEL_REGISTRY.filter((model) => !model.deprecated && hasAuthForModel(model)); + const dynamicLocalModels = getDynamicLocalModelDescriptors().filter((model) => hasAuthForModel(model)); + if (!dynamicLocalModels.length) return staticModels; + + const providersWithDynamicLocals = new Set(dynamicLocalModels.map((model) => model.family)); + const filteredStatic = staticModels.filter( + (model) => !(model.authTypes.includes("local") && providersWithDynamicLocals.has(model.family)), + ); + return [...filteredStatic, ...dynamicLocalModels]; } export function resolveModelAlias(alias: string): ModelDescriptor | undefined { const normalized = alias.trim().toLowerCase(); - return byId.get(normalized) ?? byShortId.get(normalized) ?? byAlias.get(normalized) ?? undefined; + return byId.get(normalized) ?? byShortId.get(normalized) ?? byAlias.get(normalized) ?? dynamicLocalByAlias.get(normalized) ?? undefined; } export function resolveModelDescriptor(modelRef: string): ModelDescriptor | undefined { diff --git a/apps/desktop/src/shared/types/config.ts b/apps/desktop/src/shared/types/config.ts index c7e9667b2..3f00e461a 100644 --- a/apps/desktop/src/shared/types/config.ts +++ b/apps/desktop/src/shared/types/config.ts @@ -9,6 +9,7 @@ import type { MissionExecutionPolicy, MissionPermissionConfig, MissionProviderPe import type { ExternalMcpMissionSelection } from "./externalMcp"; import type { MissionModelConfig, ModelConfig } from "./models"; import type { LinearSyncConfig } from "./linearSync"; +import type { LocalProviderFamily } from "../modelRegistry"; // Backward compatible with earlier configs that used `on_crash`. export type ProcessRestartPolicy = "never" | "on-failure" | "always" | "on_crash"; @@ -827,8 +828,10 @@ export type AiDetectedAuth = { cli?: "claude" | "codex" | "cursor"; provider?: string; source?: "config" | "env" | "store"; + endpointSource?: "auto" | "config"; path?: string; endpoint?: string; + preferredModelId?: ModelId | null; authenticated?: boolean; verified?: boolean; }; @@ -873,6 +876,43 @@ export type AiApiKeyVerificationResult = { verifiedAt: string; }; +export type AiLocalProviderConfig = { + enabled?: boolean; + endpoint?: string; + autoDetect?: boolean; + preferredModelId?: ModelId | null; +}; + +export type AiLocalProviderConfigs = Partial>; + +export type AiRuntimeConnectionHealth = + | "ready" + | "reachable" + | "reachable_no_models" + | "not_configured" + | "unreachable"; + +export type AiRuntimeConnectionKind = "cli" | "api-key" | "openrouter" | "local"; + +export type AiRuntimeConnectionStatus = { + provider: string; + label: string; + kind: AiRuntimeConnectionKind; + configured: boolean; + authAvailable: boolean; + runtimeDetected: boolean; + runtimeAvailable: boolean; + health: AiRuntimeConnectionHealth; + source?: "config" | "env" | "store" | "auto"; + path?: string | null; + endpoint?: string | null; + blocker: string | null; + loadedModelIds?: ModelId[]; + lastCheckedAt: string; +}; + +export type AiRuntimeConnections = Record; + export type AiSettingsStatus = { mode: "guest" | "subscription"; availableProviders: { @@ -888,6 +928,7 @@ export type AiSettingsStatus = { features: AiFeatureUsageRow[]; detectedAuth?: AiDetectedAuth[]; providerConnections?: AiProviderConnections; + runtimeConnections?: AiRuntimeConnections; availableModelIds?: ModelId[]; apiKeyStore?: { secureStorageAvailable: boolean; @@ -1034,6 +1075,7 @@ export type AiConfig = { // New unified fields defaultModel?: ModelId; apiKeys?: Record; + localProviders?: AiLocalProviderConfigs; workerSafety?: WorkerSafetyPolicy; mcpServers?: Record; /** Per-feature model overrides, e.g. { mission_planning: "claude-sonnet-4-6" } */ @@ -1059,6 +1101,7 @@ export type AiIntegrationStatus = { // New unified fields detectedAuth?: AiDetectedAuth[]; providerConnections?: AiProviderConnections; + runtimeConnections?: AiRuntimeConnections; availableModelIds?: ModelId[]; }; diff --git a/apps/desktop/src/test/setup.ts b/apps/desktop/src/test/setup.ts index beb1957e7..5f3b472be 100644 --- a/apps/desktop/src/test/setup.ts +++ b/apps/desktop/src/test/setup.ts @@ -1,6 +1,78 @@ import fs from "node:fs"; import os from "node:os"; import path from "node:path"; +import { afterAll } from "vitest"; + +type TestTempTrackerState = { + installed: boolean; + trackedDirs: Set; + originalMkdtempSync?: typeof fs.mkdtempSync; + originalPromisesMkdtemp?: typeof fs.promises.mkdtemp; +}; + +const TEST_TEMP_TRACKER_KEY = Symbol.for("ade.desktop.testTempTracker"); +const testTempRoot = path.resolve(os.tmpdir()); + +function getTestTempTrackerState(): TestTempTrackerState { + const existing = (globalThis as Record)[TEST_TEMP_TRACKER_KEY]; + if (existing) return existing as TestTempTrackerState; + const created: TestTempTrackerState = { + installed: false, + trackedDirs: new Set(), + }; + (globalThis as Record)[TEST_TEMP_TRACKER_KEY] = created; + return created; +} + +function shouldTrackTempDir(dirPath: string): boolean { + const resolved = path.resolve(dirPath); + const baseName = path.basename(resolved); + return resolved.startsWith(`${testTempRoot}${path.sep}`) && baseName.startsWith("ade-"); +} + +function cleanupTrackedTempDirs(): void { + const state = getTestTempTrackerState(); + const targets = [...state.trackedDirs].sort((left, right) => right.length - left.length); + for (const target of targets) { + try { + fs.rmSync(target, { recursive: true, force: true }); + state.trackedDirs.delete(target); + } catch { + // Best-effort cleanup only for test temp roots. + } + } +} + +function installTrackedTempCleanup(): void { + const state = getTestTempTrackerState(); + if (state.installed) return; + state.installed = true; + state.originalMkdtempSync = fs.mkdtempSync.bind(fs); + state.originalPromisesMkdtemp = fs.promises.mkdtemp.bind(fs.promises); + + fs.mkdtempSync = ((prefix: string, options?: Parameters[1]) => { + const created = state.originalMkdtempSync!(prefix, options); + if (shouldTrackTempDir(created)) { + state.trackedDirs.add(path.resolve(created)); + } + return created; + }) as typeof fs.mkdtempSync; + + fs.promises.mkdtemp = (async (prefix: string, options?: Parameters[1]) => { + const created = await state.originalPromisesMkdtemp!(prefix, options); + if (shouldTrackTempDir(created)) { + state.trackedDirs.add(path.resolve(created)); + } + return created; + }) as typeof fs.promises.mkdtemp; + + process.once("exit", cleanupTrackedTempDirs); +} + +installTrackedTempCleanup(); +afterAll(() => { + cleanupTrackedTempDirs(); +}); const claudeConfigDir = path.join(os.tmpdir(), "ade-vitest-claude-config"); diff --git a/apps/mcp-server/src/test/setup.ts b/apps/mcp-server/src/test/setup.ts new file mode 100644 index 000000000..2df0e44b9 --- /dev/null +++ b/apps/mcp-server/src/test/setup.ts @@ -0,0 +1,75 @@ +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterAll } from "vitest"; + +type TestTempTrackerState = { + installed: boolean; + trackedDirs: Set; + originalMkdtempSync?: typeof fs.mkdtempSync; + originalPromisesMkdtemp?: typeof fs.promises.mkdtemp; +}; + +const TEST_TEMP_TRACKER_KEY = Symbol.for("ade.mcp.testTempTracker"); +const testTempRoot = path.resolve(os.tmpdir()); + +function getTestTempTrackerState(): TestTempTrackerState { + const existing = (globalThis as Record)[TEST_TEMP_TRACKER_KEY]; + if (existing) return existing as TestTempTrackerState; + const created: TestTempTrackerState = { + installed: false, + trackedDirs: new Set(), + }; + (globalThis as Record)[TEST_TEMP_TRACKER_KEY] = created; + return created; +} + +function shouldTrackTempDir(dirPath: string): boolean { + const resolved = path.resolve(dirPath); + const baseName = path.basename(resolved); + return resolved.startsWith(`${testTempRoot}${path.sep}`) && baseName.startsWith("ade-"); +} + +function cleanupTrackedTempDirs(): void { + const state = getTestTempTrackerState(); + const targets = [...state.trackedDirs].sort((left, right) => right.length - left.length); + for (const target of targets) { + try { + fs.rmSync(target, { recursive: true, force: true }); + state.trackedDirs.delete(target); + } catch { + // Best-effort cleanup only for test temp roots. + } + } +} + +function installTrackedTempCleanup(): void { + const state = getTestTempTrackerState(); + if (state.installed) return; + state.installed = true; + state.originalMkdtempSync = fs.mkdtempSync.bind(fs); + state.originalPromisesMkdtemp = fs.promises.mkdtemp.bind(fs.promises); + + fs.mkdtempSync = ((prefix: string, options?: Parameters[1]) => { + const created = state.originalMkdtempSync!(prefix, options); + if (shouldTrackTempDir(created)) { + state.trackedDirs.add(path.resolve(created)); + } + return created; + }) as typeof fs.mkdtempSync; + + fs.promises.mkdtemp = (async (prefix: string, options?: Parameters[1]) => { + const created = await state.originalPromisesMkdtemp!(prefix, options); + if (shouldTrackTempDir(created)) { + state.trackedDirs.add(path.resolve(created)); + } + return created; + }) as typeof fs.promises.mkdtemp; + + process.once("exit", cleanupTrackedTempDirs); +} + +installTrackedTempCleanup(); +afterAll(() => { + cleanupTrackedTempDirs(); +}); diff --git a/apps/mcp-server/vitest.config.ts b/apps/mcp-server/vitest.config.ts index 7ad107513..8242fa108 100644 --- a/apps/mcp-server/vitest.config.ts +++ b/apps/mcp-server/vitest.config.ts @@ -4,6 +4,7 @@ export default defineConfig({ test: { environment: "node", include: ["src/**/*.test.ts"], + setupFiles: ["src/test/setup.ts"], coverage: { provider: "v8", reporter: ["text", "lcov"],