From 9f5bcd36723334bd768b59ffa47865a0557c86d6 Mon Sep 17 00:00:00 2001 From: Arul Sharma <31745423+arul28@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:13:07 -0400 Subject: [PATCH 1/3] Add local AI runtime detection & cleanup Implement local provider discovery and inspection for ollama, lmstudio and vllm: normalize project config for local providers, probe endpoints, inspect loaded models, infer capabilities and harness profiles, and cache results with TTL and explicit cache reset. Integrate discovered local models into AI integration: replace dynamic local descriptors, build runtimeConnections (cli/api-key/local/openrouter), expose runtimeConnections in the AiIntegrationStatus, and reset local detection when forcing refresh. Improve provider resolution to honor saved preferred local model IDs and to require explicit selection when multiple models are loaded. Add tests covering fallback endpoint detection and preferred-model behavior. Add temporary artifact cleanup: call cleanupStaleTempArtifacts on startup and ensure screenshot tool removes temp dirs after use. Misc: refactor authDetector and localModelDiscovery APIs, add inspection cache clearing, and wire changes through services and tests. --- apps/desktop/src/main/main.ts | 6 + .../main/services/ai/aiIntegrationService.ts | 339 +++++++++++- .../src/main/services/ai/authDetector.test.ts | 41 ++ .../src/main/services/ai/authDetector.ts | 187 +++++-- .../main/services/ai/localModelDiscovery.ts | 363 +++++++++++-- .../main/services/ai/providerResolver.test.ts | 32 +- .../src/main/services/ai/providerResolver.ts | 60 ++- .../main/services/ai/tools/workflowTools.ts | 18 +- .../main/services/chat/agentChatService.ts | 180 ++++++- .../services/config/projectConfigService.ts | 42 ++ .../src/main/services/ipc/registerIpc.ts | 2 + .../unifiedOrchestratorAdapter.test.ts | 126 +++++ .../unifiedOrchestratorAdapter.ts | 17 +- .../runtime/tempCleanupService.test.ts | 105 ++++ .../services/runtime/tempCleanupService.ts | 129 +++++ .../components/chat/AgentChatPane.tsx | 200 +++++++- .../onboarding/ProjectSetupPage.tsx | 4 +- .../settings/ProvidersSection.test.tsx | 40 +- .../components/settings/ProvidersSection.tsx | 484 ++++++++++++++++-- .../shared/UnifiedModelSelector.tsx | 47 +- .../src/renderer/lib/modelOptions.test.ts | 59 ++- apps/desktop/src/renderer/lib/modelOptions.ts | 112 +++- apps/desktop/src/shared/modelRegistry.ts | 84 ++- apps/desktop/src/shared/types/config.ts | 43 ++ apps/desktop/src/test/setup.ts | 72 +++ apps/mcp-server/src/test/setup.ts | 75 +++ apps/mcp-server/vitest.config.ts | 1 + 27 files changed, 2612 insertions(+), 256 deletions(-) create mode 100644 apps/desktop/src/main/services/runtime/tempCleanupService.test.ts create mode 100644 apps/desktop/src/main/services/runtime/tempCleanupService.ts create mode 100644 apps/mcp-server/src/test/setup.ts diff --git a/apps/desktop/src/main/main.ts b/apps/desktop/src/main/main.ts index 162752d69..784312c21 100644 --- a/apps/desktop/src/main/main.ts +++ b/apps/desktop/src/main/main.ts @@ -108,6 +108,7 @@ import { createExternalConnectionAuthService } from "./services/externalMcp/exte import { createComputerUseArtifactBrokerService } from "./services/computerUse/computerUseArtifactBrokerService"; import { createSyncService } from "./services/sync/syncService"; import { createAutoUpdateService } from "./services/updates/autoUpdateService"; +import { cleanupStaleTempArtifacts } from "./services/runtime/tempCleanupService"; import type { Logger } from "./services/logging/logger"; /** @@ -1454,6 +1455,7 @@ app.whenReady().then(async () => { }, }); agentChatServiceRef = agentChatService; + agentChatService.cleanupStaleAttachments(); // Wire agentChatService into prService for integration resolution prService.setAgentChatService(agentChatService); @@ -2772,6 +2774,10 @@ app.whenReady().then(async () => { // --- Auto-update service (global, not per-project) --- const updateLogger = createFileLogger(path.join(app.getPath("userData"), "ade-update.jsonl")); + cleanupStaleTempArtifacts({ + tempRoot: app.getPath("temp"), + logger: updateLogger, + }); const autoUpdateService = createAutoUpdateService(updateLogger); autoUpdateService.onUpdateAvailable((info) => { BrowserWindow.getAllWindows().forEach((win) => { diff --git a/apps/desktop/src/main/services/ai/aiIntegrationService.ts b/apps/desktop/src/main/services/ai/aiIntegrationService.ts index e512bee52..f1efbfb8e 100644 --- a/apps/desktop/src/main/services/ai/aiIntegrationService.ts +++ b/apps/desktop/src/main/services/ai/aiIntegrationService.ts @@ -3,18 +3,33 @@ import type { Logger } from "../logging/logger"; import type { AdeDb } from "../state/kvDb"; import type { createProjectConfigService } from "../config/projectConfigService"; import type { AgentModelDescriptor, AgentProvider, ExecutorOpts } from "./agentExecutor"; -import type { AiApiKeyVerificationResult, AiProviderConnections } from "../../../shared/types"; +import type { + AiApiKeyVerificationResult, + AiLocalProviderConfigs, + AiProviderConnections, + AiRuntimeConnections, + AiRuntimeConnectionStatus, +} from "../../../shared/types"; import { createDynamicLocalModelDescriptor, getDefaultModelDescriptor, getModelById, getAvailableModels, + getLocalProviderDefaultEndpoint, listModelDescriptorsForProvider, - MODEL_REGISTRY, + replaceDynamicLocalModelDescriptors, resolveModelAlias, enrichModelRegistry, + type LocalProviderFamily, } from "../../../shared/modelRegistry"; -import { detectAllAuth, getCachedCliAuthStatuses, verifyProviderApiKey, type DetectedAuth, type CliAuthStatus } from "./authDetector"; +import { + detectAllAuth, + getCachedCliAuthStatuses, + resetLocalProviderDetectionCache, + verifyProviderApiKey, + type DetectedAuth, + type CliAuthStatus, +} from "./authDetector"; import { executeUnified, resumeUnified } from "./unifiedExecutor"; import { initialize as initModelsDevService } from "./modelsDevService"; import { updateModelPricing } from "../../../shared/modelProfiles"; @@ -22,7 +37,7 @@ import { isRecord } from "../shared/utils"; import { getApiKeyStoreStatus } from "./apiKeyStore"; import type { createMemoryService } from "../memory/memoryService"; import type { CompactionFlushService } from "../memory/compactionFlushService"; -import { discoverLocalModels } from "./localModelDiscovery"; +import { discoverLocalModels, inspectLocalProvider } from "./localModelDiscovery"; import { discoverCursorCliModelDescriptors, clearCursorCliModelsCache } from "../chat/cursorModelsDiscovery"; import { resolveCursorAgentExecutable } from "./cursorAgentExecutable"; import { buildProviderConnections } from "./providerConnectionStatus"; @@ -72,12 +87,15 @@ export type AiIntegrationStatus = { cli?: "claude" | "codex" | "cursor"; provider?: string; source?: "config" | "env" | "store"; + endpointSource?: "auto" | "config"; path?: string; endpoint?: string; + preferredModelId?: string | null; authenticated?: boolean; verified?: boolean; }>; providerConnections?: AiProviderConnections; + runtimeConnections?: AiRuntimeConnections; availableModelIds?: string[]; apiKeyStore?: { secureStorageAvailable: boolean; @@ -259,6 +277,33 @@ function extractConfiguredApiKeys(snapshot: ReturnType["get"]>, +): AiLocalProviderConfigs { + const aiConfig = extractAiConfig(snapshot); + const localProvidersRaw = isRecord(aiConfig.localProviders) ? aiConfig.localProviders : {}; + const out: AiLocalProviderConfigs = {}; + + for (const provider of ["ollama", "lmstudio", "vllm"] as const) { + const raw = isRecord(localProvidersRaw[provider]) ? localProvidersRaw[provider] : null; + if (!raw) continue; + const entry: NonNullable = {}; + if (typeof raw.enabled === "boolean") entry.enabled = raw.enabled; + if (typeof raw.autoDetect === "boolean") entry.autoDetect = raw.autoDetect; + if (typeof raw.endpoint === "string" && raw.endpoint.trim().length > 0) { + entry.endpoint = raw.endpoint.trim(); + } + if (raw.preferredModelId === null) { + entry.preferredModelId = null; + } else if (typeof raw.preferredModelId === "string" && raw.preferredModelId.trim().length > 0) { + entry.preferredModelId = raw.preferredModelId.trim(); + } + if (Object.keys(entry).length) out[provider] = entry; + } + + return out; +} + function toCliAvailability(auth: DetectedAuth[]): { claude: boolean; codex: boolean; cursor: boolean } { return { claude: auth.some((entry) => entry.type === "cli-subscription" && entry.cli === "claude"), @@ -308,6 +353,8 @@ function redactDetectedAuth( type: entry.type, provider: entry.provider, endpoint: entry.endpoint, + endpointSource: entry.endpointSource, + preferredModelId: entry.preferredModelId ?? null, }; }); @@ -333,6 +380,244 @@ function redactDetectedAuth( return redacted; } +const LOCAL_PROVIDER_LABELS: Record = { + ollama: "Ollama", + lmstudio: "LM Studio", + vllm: "vLLM", +}; + +function apiProviderLabel(provider: string): string { + const labels: Record = { + anthropic: "Anthropic", + openai: "OpenAI", + google: "Google AI", + mistral: "Mistral", + deepseek: "DeepSeek", + xai: "xAI", + groq: "Groq", + together: "Together AI", + openrouter: "OpenRouter", + ollama: "Ollama", + lmstudio: "LM Studio", + vllm: "vLLM", + }; + return labels[provider] ?? provider; +} + +function toCliRuntimeConnection(status: NonNullable[keyof AiProviderConnections]): AiRuntimeConnectionStatus { + const source = status.sources.find((entry) => entry.detected && entry.kind === "local-credentials")?.source; + return { + provider: status.provider, + label: apiProviderLabel(status.provider), + kind: "cli", + configured: status.authAvailable || status.runtimeDetected, + authAvailable: status.authAvailable, + runtimeDetected: status.runtimeDetected, + runtimeAvailable: status.runtimeAvailable, + health: status.runtimeAvailable ? "ready" : status.runtimeDetected ? "reachable" : "not_configured", + ...(source ? { source: source === "cursor-env" ? "env" : "store" as const } : {}), + path: status.path, + blocker: status.blocker, + lastCheckedAt: status.lastCheckedAt, + }; +} + +function normalizeConfiguredLocalProvider( + configs: AiLocalProviderConfigs, + provider: LocalProviderFamily, +): { + enabled: boolean; + endpoint?: string; + autoDetect: boolean; + preferredModelId?: string | null; +} { + const entry = configs[provider]; + return { + enabled: entry?.enabled ?? true, + ...(typeof entry?.endpoint === "string" && entry.endpoint.trim().length + ? { endpoint: entry.endpoint.trim() } + : {}), + autoDetect: entry?.autoDetect ?? true, + preferredModelId: entry?.preferredModelId ?? null, + }; +} + +function createLocalRuntimeConnectionFromInspection(args: { + provider: LocalProviderFamily; + endpoint: string; + source: "config" | "auto"; + inspection: Awaited>; + checkedAt: string; +}): AiRuntimeConnectionStatus { + const label = LOCAL_PROVIDER_LABELS[args.provider]; + const loadedModelIds = args.inspection.loadedModels.map((model) => `${args.provider}/${model.modelId}`); + let blocker: string | null = null; + if (args.inspection.health === "reachable_no_models") { + blocker = `${label} is reachable, but no models are currently loaded.`; + } else if (args.inspection.health === "unreachable") { + blocker = `${label} did not respond at ${args.endpoint}.`; + } + return { + provider: args.provider, + label, + kind: "local", + configured: true, + authAvailable: args.inspection.health === "ready", + runtimeDetected: args.inspection.reachable, + runtimeAvailable: args.inspection.health === "ready", + health: args.inspection.health, + source: args.source, + endpoint: args.endpoint, + blocker, + ...(loadedModelIds.length ? { loadedModelIds } : {}), + lastCheckedAt: args.checkedAt, + }; +} + +async function buildLocalRuntimeConnection(args: { + provider: LocalProviderFamily; + configuredLocalProviders: AiLocalProviderConfigs; + auth: DetectedAuth[]; + checkedAt: string; +}): Promise { + const providerConfig = normalizeConfiguredLocalProvider(args.configuredLocalProviders, args.provider); + const label = LOCAL_PROVIDER_LABELS[args.provider]; + if (!providerConfig.enabled) { + return { + provider: args.provider, + label, + kind: "local", + configured: false, + authAvailable: false, + runtimeDetected: false, + runtimeAvailable: false, + health: "not_configured", + blocker: `${label} is disabled in project AI settings.`, + lastCheckedAt: args.checkedAt, + }; + } + + const detected = args.auth.find( + (entry): entry is Extract => + entry.type === "local" && entry.provider === args.provider, + ); + if (detected) { + const inspection = await inspectLocalProvider(args.provider, detected.endpoint); + return createLocalRuntimeConnectionFromInspection({ + provider: args.provider, + endpoint: detected.endpoint, + source: detected.endpointSource === "config" ? "config" : "auto", + inspection, + checkedAt: args.checkedAt, + }); + } + + const configuredEndpoint = providerConfig.endpoint; + if (configuredEndpoint) { + const manualInspection = await inspectLocalProvider(args.provider, configuredEndpoint); + if (manualInspection.reachable || !providerConfig.autoDetect) { + const status = createLocalRuntimeConnectionFromInspection({ + provider: args.provider, + endpoint: configuredEndpoint, + source: "config", + inspection: manualInspection, + checkedAt: args.checkedAt, + }); + if (!manualInspection.reachable && !providerConfig.autoDetect) { + status.health = "unreachable"; + } + return status; + } + } + + if (providerConfig.autoDetect) { + const autoEndpoint = getLocalProviderDefaultEndpoint(args.provider); + if (!configuredEndpoint || autoEndpoint.replace(/\/+$/, "") !== configuredEndpoint.replace(/\/+$/, "")) { + const autoInspection = await inspectLocalProvider(args.provider, autoEndpoint); + if (autoInspection.reachable) { + return createLocalRuntimeConnectionFromInspection({ + provider: args.provider, + endpoint: autoEndpoint, + source: "auto", + inspection: autoInspection, + checkedAt: args.checkedAt, + }); + } + } + } + + return { + provider: args.provider, + label, + kind: "local", + configured: false, + authAvailable: false, + runtimeDetected: false, + runtimeAvailable: false, + health: "not_configured", + blocker: `No ${label} runtime with loaded models was detected.`, + lastCheckedAt: args.checkedAt, + }; +} + +async function buildRuntimeConnections(args: { + configuredLocalProviders: AiLocalProviderConfigs; + auth: DetectedAuth[]; + providerConnections: AiProviderConnections; +}): Promise { + const checkedAt = new Date().toISOString(); + const runtimeConnections: AiRuntimeConnections = { + claude: toCliRuntimeConnection(args.providerConnections.claude), + codex: toCliRuntimeConnection(args.providerConnections.codex), + cursor: toCliRuntimeConnection(args.providerConnections.cursor), + }; + + for (const authEntry of args.auth) { + if (authEntry.type === "api-key") { + runtimeConnections[authEntry.provider] = { + provider: authEntry.provider, + label: apiProviderLabel(authEntry.provider), + kind: "api-key", + configured: true, + authAvailable: true, + runtimeDetected: true, + runtimeAvailable: true, + health: "ready", + source: authEntry.source, + blocker: null, + lastCheckedAt: checkedAt, + }; + continue; + } + if (authEntry.type === "openrouter") { + runtimeConnections.openrouter = { + provider: "openrouter", + label: "OpenRouter", + kind: "openrouter", + configured: true, + authAvailable: true, + runtimeDetected: true, + runtimeAvailable: true, + health: "ready", + source: authEntry.source, + blocker: null, + lastCheckedAt: checkedAt, + }; + } + } + + for (const provider of ["ollama", "lmstudio", "vllm"] as const) { + runtimeConnections[provider] = await buildLocalRuntimeConnection({ + provider, + configuredLocalProviders: args.configuredLocalProviders, + auth: args.auth, + checkedAt, + }); + } + + return runtimeConnections; +} + export function createAiIntegrationService(args: { db: AdeDb; logger: Logger; @@ -371,7 +656,10 @@ export function createAiIntegrationService(args: { const detectAuth = async (options?: { force?: boolean }): Promise => { const snapshot = projectConfigService.get(); - return await detectAllAuth(extractConfiguredApiKeys(snapshot), options); + return await detectAllAuth(extractConfiguredApiKeys(snapshot), { + ...options, + localProviders: extractConfiguredLocalProviders(snapshot), + }); }; const deriveMode = (args: { @@ -426,6 +714,21 @@ export function createAiIntegrationService(args: { }; const getResolvedAvailableModels = async (auth: DetectedAuth[]) => { + const discoveredLocalModels = await discoverLocalModels(auth); + replaceDynamicLocalModelDescriptors( + discoveredLocalModels.map((model) => + createDynamicLocalModelDescriptor(model.provider, model.modelId, { + ...(model.displayName ? { displayName: model.displayName } : {}), + ...(model.contextWindow ? { contextWindow: model.contextWindow } : {}), + ...(model.maxOutputTokens ? { maxOutputTokens: model.maxOutputTokens } : {}), + ...(model.capabilities ? { capabilities: model.capabilities } : {}), + ...(model.reasoningTiers?.length ? { reasoningTiers: model.reasoningTiers } : {}), + ...(model.harnessProfile ? { harnessProfile: model.harnessProfile } : {}), + ...(model.discoverySource ? { discoverySource: model.discoverySource } : {}), + }), + ), + ); + let available = getAvailableModels(auth); const hasCursorCliAuth = auth.some( @@ -447,23 +750,7 @@ export function createAiIntegrationService(args: { } } - const discoveredLocalModels = await discoverLocalModels(auth); - if (discoveredLocalModels.length === 0) { - return available; - } - - const providersWithDynamicModels = new Set( - discoveredLocalModels.map((model) => model.provider), - ); - const filteredStatic = available.filter((descriptor) => !( - descriptor.authTypes.includes("local") - && providersWithDynamicModels.has(descriptor.family as "ollama" | "lmstudio" | "vllm") - )); - - return [ - ...filteredStatic, - ...discoveredLocalModels.map((model) => createDynamicLocalModelDescriptor(model.provider, model.modelId)), - ]; + return available; }; const verifyApiKeyConnection = async (provider: string): Promise => { @@ -902,6 +1189,7 @@ export function createAiIntegrationService(args: { if (options?.force) { resetProviderRuntimeHealth(); resetClaudeRuntimeProbeCache(); + resetLocalProviderDetectionCache(); clearCursorCliModelsCache(); modelListCache.clear(); runtimeHealthVersion = getProviderRuntimeHealthVersion(); @@ -921,6 +1209,12 @@ export function createAiIntegrationService(args: { runtimeHealthVersion = getProviderRuntimeHealthVersion(); } const providerConnections = await buildProviderConnections(cliStatuses); + const configuredLocalProviders = extractConfiguredLocalProviders(projectConfigService.get()); + const runtimeConnections = await buildRuntimeConnections({ + configuredLocalProviders, + auth, + providerConnections, + }); const availability = { claude: providerConnections.claude.runtimeAvailable, codex: providerConnections.codex.runtimeAvailable, @@ -943,6 +1237,7 @@ export function createAiIntegrationService(args: { }, detectedAuth: redactDetectedAuth(auth, cliStatuses), providerConnections, + runtimeConnections, availableModelIds: runtimeFilteredAvailable.map((descriptor) => descriptor.id), apiKeyStore: getApiKeyStoreStatus(), }; diff --git a/apps/desktop/src/main/services/ai/authDetector.test.ts b/apps/desktop/src/main/services/ai/authDetector.test.ts index d18ecdc4a..d8437ab96 100644 --- a/apps/desktop/src/main/services/ai/authDetector.test.ts +++ b/apps/desktop/src/main/services/ai/authDetector.test.ts @@ -235,6 +235,47 @@ describe("authDetector", () => { expect(auth.some((entry) => entry.type === "local" && entry.provider === "lmstudio")).toBe(false); }); + it("falls back from an empty configured LM Studio endpoint to the auto-detected endpoint and preserves the preferred model", async () => { + vi.stubGlobal( + "fetch", + vi.fn(async (url: string) => { + if (url === "http://lmstudio.example:1234/api/v1/models") { + return new Response("{}", { status: 404 }); + } + if (url === "http://lmstudio.example:1234/v1/models") { + return new Response(JSON.stringify({ data: [] }), { status: 200 }); + } + if (url === "http://localhost:1234/api/v1/models") { + return new Response("{}", { status: 404 }); + } + if (url === "http://localhost:1234/v1/models") { + return new Response(JSON.stringify({ + data: [{ id: "gemma-4" }], + }), { status: 200 }); + } + return new Response("{}", { status: 503 }); + }), + ); + + const auth = await detectAllAuth({}, { + localProviders: { + lmstudio: { + endpoint: "http://lmstudio.example:1234", + autoDetect: true, + preferredModelId: "lmstudio/gemma-4", + }, + }, + }); + + expect(auth).toContainEqual(expect.objectContaining({ + type: "local", + provider: "lmstudio", + endpoint: "http://localhost:1234", + endpointSource: "auto", + preferredModelId: "lmstudio/gemma-4", + })); + }); + it("marks unsupported CLI auth checks as unverified", async () => { spawnMock.mockImplementation((command: string, args: string[] = []) => { if (args[0] === "--version") { diff --git a/apps/desktop/src/main/services/ai/authDetector.ts b/apps/desktop/src/main/services/ai/authDetector.ts index 177e4eaf6..cb1c35ed4 100644 --- a/apps/desktop/src/main/services/ai/authDetector.ts +++ b/apps/desktop/src/main/services/ai/authDetector.ts @@ -7,6 +7,9 @@ import { augmentProcessPathWithShellAndKnownCliDirs, resolveExecutableFromKnownLocations, } from "./cliExecutableResolver"; +import { getLocalProviderDefaultEndpoint, type LocalProviderFamily } from "../../../shared/modelRegistry"; +import type { AiLocalProviderConfigs } from "../../../shared/types"; +import { inspectLocalProvider, clearLocalProviderInspectionCache } from "./localModelDiscovery"; type CliName = "claude" | "codex" | "cursor"; @@ -43,7 +46,13 @@ export type DetectedAuth = } | { type: "api-key"; provider: string; key: string; source: ApiKeySource } | { type: "openrouter"; key: string; source: ApiKeySource } - | { type: "local"; provider: "ollama" | "lmstudio" | "vllm"; endpoint: string }; + | { + type: "local"; + provider: "ollama" | "lmstudio" | "vllm"; + endpoint: string; + endpointSource?: "auto" | "config"; + preferredModelId?: string | null; + }; // --------------------------------------------------------------------------- // Internals @@ -98,8 +107,6 @@ const WEAK_UNAUTH_INDICATORS = [ /run .*login/i, ]; -const UNAUTH_INDICATORS = [...STRONG_UNAUTH_INDICATORS, ...WEAK_UNAUTH_INDICATORS]; - const UNSUPPORTED_INDICATORS = [ /unknown command/i, /unrecognized/i, @@ -373,8 +380,14 @@ const API_KEY_VERIFY_TIMEOUT_MS = 8_000; let cachedLocalProviders: | { + key: string; checkedAtMs: number; - entries: Array<{ provider: "ollama" | "lmstudio" | "vllm"; endpoint: string }>; + entries: Array<{ + provider: LocalProviderFamily; + endpoint: string; + endpointSource: "auto" | "config"; + preferredModelId?: string | null; + }>; } | null = null; @@ -412,67 +425,124 @@ async function readStoredApiKeys(): Promise> { } } -async function checkLocalEndpointHasModels( - provider: "ollama" | "lmstudio" | "vllm", - url: string, - timeoutMs = 2_000, -): Promise { - try { - const controller = new AbortController(); - const timer = setTimeout(() => controller.abort(), timeoutMs); - const res = await fetch(url, { method: "GET", signal: controller.signal }); - clearTimeout(timer); - if (!res.ok) return false; - - const payload = await res.json() as unknown; - if (provider === "ollama") { - const models: Array<{ name?: unknown }> = Array.isArray((payload as { models?: unknown[] })?.models) - ? ((payload as { models?: Array<{ name?: unknown }> }).models ?? []) - : []; - return models.some((entry) => typeof entry?.name === "string" && entry.name.trim().length > 0); - } +type NormalizedLocalProviderConfig = { + enabled: boolean; + endpoint?: string; + autoDetect: boolean; + preferredModelId?: string | null; +}; - const models: Array<{ id?: unknown }> = Array.isArray((payload as { data?: unknown[] })?.data) - ? ((payload as { data?: Array<{ id?: unknown }> }).data ?? []) - : []; - return models.some((entry) => typeof entry?.id === "string" && entry.id.trim().length > 0); - } catch { - return false; - } +function normalizeLocalProviderConfig( + config?: AiLocalProviderConfigs, +): Record { + return { + ollama: { + enabled: config?.ollama?.enabled ?? true, + ...(typeof config?.ollama?.endpoint === "string" && config.ollama.endpoint.trim().length + ? { endpoint: config.ollama.endpoint.trim() } + : {}), + autoDetect: config?.ollama?.autoDetect ?? true, + preferredModelId: config?.ollama?.preferredModelId ?? null, + }, + lmstudio: { + enabled: config?.lmstudio?.enabled ?? true, + ...(typeof config?.lmstudio?.endpoint === "string" && config.lmstudio.endpoint.trim().length + ? { endpoint: config.lmstudio.endpoint.trim() } + : {}), + autoDetect: config?.lmstudio?.autoDetect ?? true, + preferredModelId: config?.lmstudio?.preferredModelId ?? null, + }, + vllm: { + enabled: config?.vllm?.enabled ?? true, + ...(typeof config?.vllm?.endpoint === "string" && config.vllm.endpoint.trim().length + ? { endpoint: config.vllm.endpoint.trim() } + : {}), + autoDetect: config?.vllm?.autoDetect ?? true, + preferredModelId: config?.vllm?.preferredModelId ?? null, + }, + }; +} + +function localProvidersCacheKey(config?: AiLocalProviderConfigs): string { + const normalized = normalizeLocalProviderConfig(config); + return (["ollama", "lmstudio", "vllm"] as const) + .map((provider) => { + const entry = normalized[provider]; + return [ + provider, + entry.enabled ? "1" : "0", + entry.autoDetect ? "1" : "0", + entry.endpoint ?? "", + entry.preferredModelId ?? "", + ].join(":"); + }) + .join("|"); } -async function detectLocalProviders(): Promise> { +async function detectLocalProviders( + config?: AiLocalProviderConfigs, +): Promise> { const now = Date.now(); - if (cachedLocalProviders && now - cachedLocalProviders.checkedAtMs < LOCAL_ENDPOINT_CACHE_TTL_MS) { + const cacheKey = localProvidersCacheKey(config); + if ( + cachedLocalProviders + && cachedLocalProviders.key === cacheKey + && now - cachedLocalProviders.checkedAtMs < LOCAL_ENDPOINT_CACHE_TTL_MS + ) { return cachedLocalProviders.entries; } - const localEndpoints: Array<{ - provider: "ollama" | "lmstudio" | "vllm"; - url: string; - }> = [ - { provider: "ollama", url: "http://localhost:11434/api/tags" }, - { provider: "lmstudio", url: "http://localhost:1234/v1/models" }, - { provider: "vllm", url: "http://localhost:8000/v1/models" }, - ]; - - const localChecks = await Promise.allSettled( - localEndpoints.map(async ({ provider, url }) => { - const alive = await checkLocalEndpointHasModels(provider, url, LOCAL_ENDPOINT_CHECK_TIMEOUT_MS); - if (!alive) return null; - const endpoint = url.replace(/\/api\/tags$|\/v1\/models$/, ""); - return { provider, endpoint } as const; - }), - ); + const normalized = normalizeLocalProviderConfig(config); + const entries: Array<{ + provider: LocalProviderFamily; + endpoint: string; + endpointSource: "auto" | "config"; + preferredModelId?: string | null; + }> = []; + + for (const provider of ["ollama", "lmstudio", "vllm"] as const) { + const providerConfig = normalized[provider]; + if (!providerConfig.enabled) continue; + + const configuredEndpoint = providerConfig.endpoint; + if (configuredEndpoint) { + const inspection = await inspectLocalProvider(provider, configuredEndpoint, LOCAL_ENDPOINT_CHECK_TIMEOUT_MS); + if (inspection.health === "ready") { + entries.push({ + provider, + endpoint: configuredEndpoint, + endpointSource: "config", + preferredModelId: providerConfig.preferredModelId ?? null, + }); + continue; + } + if (!providerConfig.autoDetect) { + continue; + } + } - const entries: Array<{ provider: "ollama" | "lmstudio" | "vllm"; endpoint: string }> = []; - for (const check of localChecks) { - if (check.status === "fulfilled" && check.value) { - entries.push(check.value); + if (!providerConfig.autoDetect) continue; + const autoEndpoint = getLocalProviderDefaultEndpoint(provider); + if (configuredEndpoint && autoEndpoint.replace(/\/+$/, "") === configuredEndpoint.replace(/\/+$/, "")) { + continue; + } + const autoInspection = await inspectLocalProvider(provider, autoEndpoint, LOCAL_ENDPOINT_CHECK_TIMEOUT_MS); + if (autoInspection.health === "ready") { + entries.push({ + provider, + endpoint: autoEndpoint, + endpointSource: "auto", + preferredModelId: providerConfig.preferredModelId ?? null, + }); } } - cachedLocalProviders = { checkedAtMs: now, entries }; + cachedLocalProviders = { key: cacheKey, checkedAtMs: now, entries }; return entries; } @@ -747,7 +817,7 @@ export async function detectCliAuthStatuses(options?: { force?: boolean }): Prom export async function detectAllAuth( configApiKeys?: Record, - options?: { force?: boolean }, + options?: { force?: boolean; localProviders?: AiLocalProviderConfigs }, ): Promise { const results: DetectedAuth[] = []; @@ -823,10 +893,15 @@ export async function detectAllAuth( } // 4. Local providers - const localProviders = await detectLocalProviders(); + const localProviders = await detectLocalProviders(options?.localProviders); for (const localProvider of localProviders) { results.push({ type: "local", ...localProvider }); } return results; } + +export function resetLocalProviderDetectionCache(): void { + cachedLocalProviders = null; + clearLocalProviderInspectionCache(); +} diff --git a/apps/desktop/src/main/services/ai/localModelDiscovery.ts b/apps/desktop/src/main/services/ai/localModelDiscovery.ts index 6686e1ef0..68551b0e5 100644 --- a/apps/desktop/src/main/services/ai/localModelDiscovery.ts +++ b/apps/desktop/src/main/services/ai/localModelDiscovery.ts @@ -1,19 +1,41 @@ import type { DetectedAuth } from "./authDetector"; -import type { LocalProviderFamily } from "../../../shared/modelRegistry"; +import type { LocalModelHarnessProfile, LocalProviderFamily, ModelCapabilities, ModelDescriptor } from "../../../shared/modelRegistry"; export type DiscoveredLocalModel = { provider: LocalProviderFamily; modelId: string; + displayName?: string; + contextWindow?: number; + maxOutputTokens?: number; + capabilities?: Partial; + reasoningTiers?: string[]; + harnessProfile?: LocalModelHarnessProfile; + discoverySource?: ModelDescriptor["discoverySource"]; +}; + +export type LocalProviderConnectionHealth = + | "ready" + | "reachable_no_models" + | "unreachable"; + +export type LocalProviderInspection = { + provider: LocalProviderFamily; + endpoint: string; + reachable: boolean; + health: LocalProviderConnectionHealth; + loadedModels: DiscoveredLocalModel[]; }; const CACHE_TTL_MS = 30_000; -let cache: { +let discoverCache: { key: string; cachedAt: number; models: DiscoveredLocalModel[]; } | null = null; +let inspectionCache = new Map(); + function buildCacheKey(auth: DetectedAuth[]): string { return auth .filter((entry): entry is Extract => entry.type === "local") @@ -22,52 +44,327 @@ function buildCacheKey(auth: DetectedAuth[]): string { .join("|"); } -async function fetchLocalModelIds( - provider: LocalProviderFamily, - endpoint: string, -): Promise { +function buildInspectionKey(provider: LocalProviderFamily, endpoint: string): string { + return `${provider}:${endpoint.replace(/\/+$/, "")}`; +} + +function normalizeString(value: unknown): string | null { + return typeof value === "string" && value.trim().length > 0 ? value.trim() : null; +} + +function normalizeBoolean(value: unknown): boolean | undefined { + return typeof value === "boolean" ? value : undefined; +} + +function normalizeNumber(value: unknown): number | undefined { + return Number.isFinite(Number(value)) ? Number(value) : undefined; +} + +function normalizeReasoningTier(value: unknown): string | null { + if (typeof value !== "string") return null; + const normalized = value.trim().toLowerCase(); + if (!normalized.length) return null; + if (normalized === "max") return "xhigh"; + if (normalized === "none" || normalized === "low" || normalized === "medium" || normalized === "high" || normalized === "xhigh") { + return normalized; + } + return null; +} + +function dedupeReasoningTiers(values: Array): string[] | undefined { + const seen = new Set(); + const tiers: string[] = []; + for (const value of values) { + if (!value || seen.has(value)) continue; + seen.add(value); + tiers.push(value); + } + return tiers.length ? tiers : undefined; +} + +function normalizeReasoningConfig(value: unknown): { tiers?: string[]; supportsReasoning: boolean } { + if (typeof value === "boolean") { + return value ? { tiers: ["low", "medium", "high"], supportsReasoning: true } : { supportsReasoning: false }; + } + if (typeof value === "string") { + const tier = normalizeReasoningTier(value); + return { ...(tier ? { tiers: [tier] } : {}), supportsReasoning: Boolean(tier) }; + } + if (Array.isArray(value)) { + const tiers = dedupeReasoningTiers(value.map((entry) => normalizeReasoningTier(entry))); + return { ...(tiers ? { tiers } : {}), supportsReasoning: Boolean(tiers?.length) }; + } + if (value && typeof value === "object") { + const record = value as Record; + const enabled = normalizeBoolean(record.enabled) ?? normalizeBoolean(record.supported) ?? normalizeBoolean(record.available); + const tiers = dedupeReasoningTiers([ + ...((Array.isArray(record.supported_efforts) ? record.supported_efforts : []) as unknown[]).map((entry) => normalizeReasoningTier(entry)), + ...((Array.isArray(record.supportedEfforts) ? record.supportedEfforts : []) as unknown[]).map((entry) => normalizeReasoningTier(entry)), + ...((Array.isArray(record.efforts) ? record.efforts : []) as unknown[]).map((entry) => normalizeReasoningTier(entry)), + ...((Array.isArray(record.levels) ? record.levels : []) as unknown[]).map((entry) => normalizeReasoningTier(entry)), + normalizeReasoningTier(record.default_effort), + normalizeReasoningTier(record.defaultEffort), + ]); + return { + ...(tiers ? { tiers } : {}), + supportsReasoning: enabled ?? Boolean(tiers?.length), + }; + } + return { supportsReasoning: false }; +} + +function inferVisionFromModelId(modelId: string): boolean { + const lower = modelId.toLowerCase(); + return /(\bvl\b|vision|llava|gemma[-_ ]?(3|4)|qwen2\.?5[-_ ]?vl|llama[-_ ]?3\.2.*vision)/i.test(lower); +} + +function inferNativeToolSupport(modelId: string): boolean { + return /(qwen|llama[-_ ]?3\.(1|2)|ministral|mistral|gpt-oss)/i.test(modelId); +} + +function inferHarnessProfile(args: { + modelId: string; + type?: string | null; + trainedForToolUse?: boolean; +}): LocalModelHarnessProfile { + if (args.type === "embedding") return "read_only"; + if (args.trainedForToolUse) return "verified"; + return inferNativeToolSupport(args.modelId) ? "verified" : "guarded"; +} + +function inferFallbackCapabilities(modelId: string): { + capabilities: Partial; + harnessProfile: LocalModelHarnessProfile; +} { + const lower = modelId.toLowerCase(); + if (/embedding|embed|bge-|nomic-embed|gte-|e5-|rerank|reranker/.test(lower)) { + return { + capabilities: { tools: false, vision: false, reasoning: false, streaming: true }, + harnessProfile: "read_only", + }; + } + const harnessProfile = inferHarnessProfile({ modelId }); + const reasoning = /\breason(ing)?\b|qwq|r1|deepseek-r1|phi-4.*reasoning|nemotron/i.test(lower); + return { + capabilities: { + tools: true, + vision: inferVisionFromModelId(modelId), + reasoning, + streaming: true, + }, + harnessProfile, + }; +} + +function toLoadedInstanceDisplayName(baseDisplayName: string | null, instanceId: string, multiInstance: boolean): string | undefined { + const display = baseDisplayName?.trim() || instanceId; + return multiInstance ? `${display} (${instanceId})` : display; +} + +async function fetchJson(url: string, timeoutMs: number): Promise { const controller = new AbortController(); - const timeout = setTimeout(() => controller.abort(), 2_000); + const timer = setTimeout(() => controller.abort(), timeoutMs); try { - const url = provider === "ollama" - ? `${endpoint.replace(/\/+$/, "")}/api/tags` - : `${endpoint.replace(/\/+$/, "")}/v1/models`; const response = await fetch(url, { method: "GET", signal: controller.signal }); - if (!response.ok) return []; - const payload = await response.json() as unknown; - if (provider === "ollama") { - const models = Array.isArray((payload as { models?: unknown[] })?.models) - ? ((payload as { models?: Array<{ name?: unknown }> }).models ?? []) - : []; - return models - .map((entry) => (typeof entry?.name === "string" ? entry.name.trim() : "")) - .filter(Boolean); - } - const models = Array.isArray((payload as { data?: unknown[] })?.data) - ? ((payload as { data?: Array<{ id?: unknown }> }).data ?? []) - : []; - return models - .map((entry) => (typeof entry?.id === "string" ? entry.id.trim() : "")) - .filter(Boolean); + if (!response.ok) return null; + return await response.json() as unknown; } catch { - return []; + return null; } finally { - clearTimeout(timeout); + clearTimeout(timer); + } +} + +async function inspectLmStudioProvider(endpoint: string, timeoutMs: number): Promise { + const base = endpoint.replace(/\/+$/, ""); + const restPayload = await fetchJson(`${base}/api/v1/models`, timeoutMs); + + if (restPayload && typeof restPayload === "object") { + const models = Array.isArray((restPayload as { models?: unknown[] }).models) + ? (restPayload as { models: Array> }).models + : []; + + const discovered: DiscoveredLocalModel[] = []; + + for (const model of models) { + const loadedInstances = Array.isArray(model.loaded_instances) ? model.loaded_instances as Array> : []; + if (!loadedInstances.length) continue; + + const type = normalizeString(model.type); + const displayName = normalizeString(model.display_name) ?? normalizeString(model.key); + const maxContextLength = normalizeNumber(model.max_context_length); + const capabilitiesRecord = model.capabilities && typeof model.capabilities === "object" + ? model.capabilities as Record + : null; + const trainedForToolUse = normalizeBoolean(capabilitiesRecord?.trained_for_tool_use) ?? false; + const reasoning = normalizeReasoningConfig((model as Record).reasoning); + const multiInstance = loadedInstances.length > 1; + + for (const instance of loadedInstances) { + const instanceId = normalizeString(instance.id); + if (!instanceId) continue; + const config = instance.config && typeof instance.config === "object" + ? instance.config as Record + : null; + const contextWindow = normalizeNumber(config?.context_length) ?? maxContextLength; + const harnessProfile = inferHarnessProfile({ modelId: instanceId, type, trainedForToolUse }); + discovered.push({ + provider: "lmstudio", + modelId: instanceId, + displayName: toLoadedInstanceDisplayName(displayName, instanceId, multiInstance), + ...(contextWindow ? { contextWindow } : {}), + maxOutputTokens: 8_192, + capabilities: { + tools: type !== "embedding", + vision: normalizeBoolean(capabilitiesRecord?.vision) ?? inferVisionFromModelId(instanceId), + reasoning: reasoning.supportsReasoning, + streaming: true, + }, + ...(reasoning.tiers?.length ? { reasoningTiers: reasoning.tiers } : {}), + harnessProfile, + discoverySource: "lmstudio-rest", + }); + } + } + + return { + provider: "lmstudio", + endpoint, + reachable: true, + health: discovered.length ? "ready" : "reachable_no_models", + loadedModels: discovered, + }; } + + const openAiPayload = await fetchJson(`${base}/v1/models`, timeoutMs); + const models = Array.isArray((openAiPayload as { data?: unknown[] } | null)?.data) + ? ((openAiPayload as { data: Array> }).data) + : []; + const discovered = models + .map((entry) => normalizeString(entry.id)) + .filter((modelId): modelId is string => Boolean(modelId)) + .map((modelId) => { + const fallback = inferFallbackCapabilities(modelId); + return { + provider: "lmstudio" as const, + modelId, + displayName: modelId, + maxOutputTokens: 8_192, + capabilities: fallback.capabilities, + harnessProfile: fallback.harnessProfile, + discoverySource: "lmstudio-openai" as const, + } satisfies DiscoveredLocalModel; + }); + + return { + provider: "lmstudio", + endpoint, + reachable: openAiPayload != null, + health: !openAiPayload ? "unreachable" : discovered.length ? "ready" : "reachable_no_models", + loadedModels: discovered, + }; +} + +async function inspectOpenAiCompatibleProvider( + provider: Exclude, + endpoint: string, + timeoutMs: number, +): Promise { + const base = endpoint.replace(/\/+$/, ""); + const path = provider === "ollama" ? "/api/tags" : "/v1/models"; + const payload = await fetchJson(`${base}${path}`, timeoutMs); + if (payload == null) { + return { + provider, + endpoint, + reachable: false, + health: "unreachable", + loadedModels: [], + }; + } + + const discovered = provider === "ollama" + ? (Array.isArray((payload as { models?: unknown[] }).models) + ? (payload as { models: Array> }).models + : []) + .map((entry) => normalizeString(entry.name)) + .filter((modelId): modelId is string => Boolean(modelId)) + .map((modelId) => { + const fallback = inferFallbackCapabilities(modelId); + return { + provider, + modelId, + displayName: modelId, + maxOutputTokens: 8_192, + capabilities: fallback.capabilities, + harnessProfile: fallback.harnessProfile, + discoverySource: provider, + } satisfies DiscoveredLocalModel; + }) + : (Array.isArray((payload as { data?: unknown[] }).data) + ? (payload as { data: Array> }).data + : []) + .map((entry) => normalizeString(entry.id)) + .filter((modelId): modelId is string => Boolean(modelId)) + .map((modelId) => { + const fallback = inferFallbackCapabilities(modelId); + return { + provider, + modelId, + displayName: modelId, + maxOutputTokens: 8_192, + capabilities: fallback.capabilities, + harnessProfile: fallback.harnessProfile, + discoverySource: provider, + } satisfies DiscoveredLocalModel; + }); + + return { + provider, + endpoint, + reachable: true, + health: discovered.length ? "ready" : "reachable_no_models", + loadedModels: discovered, + }; +} + +export async function inspectLocalProvider( + provider: LocalProviderFamily, + endpoint: string, + timeoutMs = 2_000, +): Promise { + const key = buildInspectionKey(provider, endpoint); + const cached = inspectionCache.get(key); + const now = Date.now(); + if (cached && now - cached.cachedAt < CACHE_TTL_MS) { + return cached.inspection; + } + + const inspection = provider === "lmstudio" + ? await inspectLmStudioProvider(endpoint, timeoutMs) + : await inspectOpenAiCompatibleProvider(provider, endpoint, timeoutMs); + + inspectionCache.set(key, { cachedAt: now, inspection }); + return inspection; +} + +export function clearLocalProviderInspectionCache(): void { + inspectionCache = new Map(); + discoverCache = null; } export async function discoverLocalModels(auth: DetectedAuth[]): Promise { const key = buildCacheKey(auth); const now = Date.now(); - if (cache && cache.key === key && now - cache.cachedAt < CACHE_TTL_MS) { - return cache.models; + if (discoverCache && discoverCache.key === key && now - discoverCache.cachedAt < CACHE_TTL_MS) { + return discoverCache.models; } const providers = auth.filter((entry): entry is Extract => entry.type === "local"); const discovered = await Promise.all( providers.map(async (entry) => { - const modelIds = await fetchLocalModelIds(entry.provider, entry.endpoint); - return modelIds.map((modelId) => ({ provider: entry.provider, modelId })); + const inspection = await inspectLocalProvider(entry.provider, entry.endpoint); + return inspection.loadedModels; }), ); @@ -83,6 +380,6 @@ export async function discoverLocalModels(auth: DetectedAuth[]): Promise ({ createCodexCliMock: vi.fn(), @@ -36,6 +36,7 @@ describe("providerResolver codex CLI", () => { beforeEach(() => { createCodexCliMock.mockReset(); createClaudeCodeMock.mockReset(); + vi.unstubAllGlobals(); }); it("resolves Codex CLI models through the community provider with MCP settings", async () => { @@ -204,4 +205,33 @@ describe("providerResolver codex CLI", () => { }, }); }); + + it("uses the saved preferred local model without probing /v1/models", async () => { + const fetchSpy = vi.fn(); + vi.stubGlobal("fetch", fetchSpy); + + const resolved = await resolveAutoModelIdFromOpenAiCompatibleEndpoint( + "http://localhost:1234", + "lmstudio", + "lmstudio/qwen2.5-coder:32b", + ); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(resolved).toBe("qwen2.5-coder:32b"); + }); + + it("requires explicit selection when a local runtime reports multiple loaded models", async () => { + vi.stubGlobal("fetch", vi.fn(async () => + new Response(JSON.stringify({ + data: [ + { id: "meta-llama-3.1-70b-instruct" }, + { id: "qwen2.5-coder:32b" }, + ], + }), { status: 200 }), + )); + + await expect( + resolveAutoModelIdFromOpenAiCompatibleEndpoint("http://localhost:1234", "lmstudio"), + ).rejects.toThrow("Choose a specific model or save a preferred local model"); + }); }); diff --git a/apps/desktop/src/main/services/ai/providerResolver.ts b/apps/desktop/src/main/services/ai/providerResolver.ts index c61d79711..b363459e4 100644 --- a/apps/desktop/src/main/services/ai/providerResolver.ts +++ b/apps/desktop/src/main/services/ai/providerResolver.ts @@ -4,6 +4,7 @@ import type { LanguageModel } from "ai"; import { + getLocalProviderDefaultEndpoint, getModelById, resolveModelAlias, type ModelDescriptor, @@ -101,9 +102,12 @@ function findOpenRouterKey(auth: DetectedAuth[]): string | undefined { return undefined; } -function findLocalEndpoint(auth: DetectedAuth[], provider: string): string | undefined { +function findLocalProviderAuth( + auth: DetectedAuth[], + provider: string, +): Extract | undefined { for (const a of auth) { - if (a.type === "local" && a.provider === provider) return a.endpoint; + if (a.type === "local" && a.provider === provider) return a; } return undefined; } @@ -122,20 +126,32 @@ const COMPATIBLE_BASE_URLS: Record = { together: "https://api.together.xyz/v1", }; -const DEFAULT_LOCAL_ENDPOINTS: Record<"ollama" | "lmstudio" | "vllm", string> = { - ollama: "http://localhost:11434", - lmstudio: "http://localhost:1234", - vllm: "http://localhost:8000", -}; - function normalizeBaseUrl(url: string): string { return url.replace(/\/+$/, ""); } -async function resolveAutoModelIdFromOpenAiCompatibleEndpoint( +function normalizePreferredLocalModelId( + preferredModelId: string | null | undefined, + providerName: string, +): string | undefined { + const normalized = preferredModelId?.trim(); + if (!normalized) return undefined; + const prefix = `${providerName}/`; + if (normalized.toLowerCase().startsWith(prefix)) { + const trimmed = normalized.slice(prefix.length).trim(); + return trimmed || undefined; + } + return normalized; +} + +export async function resolveAutoModelIdFromOpenAiCompatibleEndpoint( endpoint: string, providerName: string, + preferredModelId?: string, ): Promise { + const preferred = normalizePreferredLocalModelId(preferredModelId ?? null, providerName); + if (preferred) return preferred; + const controller = new AbortController(); const timeout = setTimeout(() => controller.abort(), 5_000); try { @@ -147,11 +163,18 @@ async function resolveAutoModelIdFromOpenAiCompatibleEndpoint( throw new Error(`Failed to list models from ${providerName} (${response.status}).`); } const payload = await response.json() as { data?: Array<{ id?: unknown }> }; - const firstModelId = payload.data?.find((entry) => typeof entry?.id === "string" && entry.id.trim().length)?.id; - if (!firstModelId || typeof firstModelId !== "string") { + const modelIds = (payload.data ?? []) + .map((entry) => (typeof entry?.id === "string" ? entry.id.trim() : "")) + .filter(Boolean); + if (modelIds.length === 0) { throw new Error(`${providerName} did not return any usable model IDs.`); } - return firstModelId.trim(); + if (modelIds.length > 1) { + throw new Error( + `${providerName} has multiple loaded models (${modelIds.join(", ")}). Choose a specific model or save a preferred local model.`, + ); + } + return modelIds[0]!; } finally { clearTimeout(timeout); } @@ -161,9 +184,10 @@ async function resolveOpenAiCompatibleModelId( sdkModelId: string, endpoint: string, providerName: string, + preferredModelId?: string | null, ): Promise { if (sdkModelId !== "auto") return sdkModelId; - return resolveAutoModelIdFromOpenAiCompatibleEndpoint(endpoint, providerName); + return resolveAutoModelIdFromOpenAiCompatibleEndpoint(endpoint, providerName, preferredModelId ?? undefined); } // --------------------------------------------------------------------------- @@ -376,8 +400,14 @@ async function resolveDirectProvider( // Local providers (ollama, lmstudio, vllm) if (family === "ollama" || family === "lmstudio" || family === "vllm") { const localProvider = family; - const endpoint = findLocalEndpoint(auth, localProvider) ?? DEFAULT_LOCAL_ENDPOINTS[localProvider]; - const resolvedModelId = await resolveOpenAiCompatibleModelId(sdkModelId, endpoint, localProvider); + const localAuth = findLocalProviderAuth(auth, localProvider); + const endpoint = localAuth?.endpoint ?? getLocalProviderDefaultEndpoint(localProvider); + const resolvedModelId = await resolveOpenAiCompatibleModelId( + sdkModelId, + endpoint, + localProvider, + localAuth?.preferredModelId, + ); const createCompatible = await loadOpenAICompatibleProvider(); const provider = createCompatible({ name: localProvider, diff --git a/apps/desktop/src/main/services/ai/tools/workflowTools.ts b/apps/desktop/src/main/services/ai/tools/workflowTools.ts index 4f2345eae..fffeb0ac2 100644 --- a/apps/desktop/src/main/services/ai/tools/workflowTools.ts +++ b/apps/desktop/src/main/services/ai/tools/workflowTools.ts @@ -172,13 +172,13 @@ export function createWorkflowTools( error: "Local computer-use fallback is disabled for this chat session.", }; } + const tmpDir = fs.mkdtempSync(path.join(require("node:os").tmpdir(), "ade-screenshot-")); + const tmpPath = path.join( + tmpDir, + `screenshot-${Date.now()}.png`, + ); try { // Use macOS screencapture to grab the screen - const tmpPath = path.join( - fs.mkdtempSync(path.join(require("node:os").tmpdir(), "ade-screenshot-")), - `screenshot-${Date.now()}.png`, - ); - await execFileAsync("screencapture", ["-x", tmpPath], { timeout: 15_000, }); @@ -211,11 +211,17 @@ export function createWorkflowTools( return { success: true, artifactId: artifact?.id ?? null, - uri: artifact?.uri ?? tmpPath, + uri: artifact?.uri ?? null, title: artifact?.title ?? title, }; } catch (err) { return formatToolError("Screenshot failed", err); + } finally { + try { + fs.rmSync(tmpDir, { recursive: true, force: true }); + } catch { + // Best-effort cleanup only. + } } }, }); diff --git a/apps/desktop/src/main/services/chat/agentChatService.ts b/apps/desktop/src/main/services/chat/agentChatService.ts index c18c3d12d..a37fd2baa 100644 --- a/apps/desktop/src/main/services/chat/agentChatService.ts +++ b/apps/desktop/src/main/services/chat/agentChatService.ts @@ -116,15 +116,18 @@ import type { CtoCapabilityMode, } from "../../../shared/types"; import { + createDynamicLocalModelDescriptor, getDefaultModelDescriptor, getModelById, getAvailableModels as getRegistryModels, listModelDescriptorsForProvider, MODEL_REGISTRY, pickDefaultCursorDescriptorFromCliList, + replaceDynamicLocalModelDescriptors, resolveModelAlias, resolveModelDescriptorForProvider, resolveProviderGroupForModel, + type LocalProviderFamily, type ModelDescriptor, } from "../../../shared/modelRegistry"; import { canSwitchChatSessionModel } from "../../../shared/chatModelSwitching"; @@ -173,6 +176,7 @@ import { type CursorAcpPooled, } from "./cursorAcpPool"; import { discoverCursorCliModelDescriptors } from "./cursorModelsDiscovery"; +import { discoverLocalModels, type DiscoveredLocalModel } from "../ai/localModelDiscovery"; import { mapAcpSessionNotificationToChatEvents, mapStopReasonToTerminalEvents, @@ -1985,6 +1989,45 @@ function resolveSessionUnifiedPermissionMode( ?? fallback; } +function applyLocalHarnessPermissionMode(args: { + descriptor?: ModelDescriptor; + requestedPermissionMode?: AgentChatSession["permissionMode"]; + requestedUnifiedPermissionMode?: AgentChatUnifiedPermissionMode; +}): { + requestedPermissionMode?: AgentChatSession["permissionMode"]; + requestedUnifiedPermissionMode?: AgentChatUnifiedPermissionMode; +} { + if (!args.descriptor?.authTypes.includes("local")) { + return { + requestedPermissionMode: args.requestedPermissionMode, + requestedUnifiedPermissionMode: args.requestedUnifiedPermissionMode, + }; + } + + if (args.descriptor.harnessProfile === "read_only") { + return { + requestedPermissionMode: "plan", + requestedUnifiedPermissionMode: "plan", + }; + } + + if ( + args.descriptor.harnessProfile === "guarded" + && args.requestedPermissionMode == null + && args.requestedUnifiedPermissionMode == null + ) { + return { + requestedPermissionMode: "plan", + requestedUnifiedPermissionMode: "plan", + }; + } + + return { + requestedPermissionMode: args.requestedPermissionMode, + requestedUnifiedPermissionMode: args.requestedUnifiedPermissionMode, + }; +} + function resolveCursorSessionModeId( session: Pick, ): string | null { @@ -3785,7 +3828,7 @@ export function createAgentChatService(args: { persistChatState(managed); const auth = await detectAuth().catch(() => []); - const availableModels = getRegistryModels(auth).filter((descriptor) => !descriptor.deprecated); + const availableModels = await getAvailableRegistryModels(auth); if (!availableModels.length) return; const preferredModelId = @@ -3928,7 +3971,10 @@ export function createAgentChatService(args: { configApiKeys[String(provider).trim().toLowerCase()] = key; } } - return detectAllAuth(configApiKeys); + const localProviders = snapshot.effective.ai?.localProviders; + return detectAllAuth(configApiKeys, { + localProviders, + }); }; const resolveHandoffBlockedReason = (managed: ManagedChatSession): string | null => { @@ -4051,7 +4097,7 @@ export function createAgentChatService(args: { }): Promise<{ brief: string; usedFallbackSummary: boolean }> => { const deterministicBrief = buildDeterministicHandoffBrief(args); const auth = await detectAuth(); - const availableModels = getRegistryModels(auth).filter((descriptor) => !descriptor.deprecated); + const availableModels = await getAvailableRegistryModels(auth); const preferredModelId = [ resolveChatConfig().summaryModelId, "openai/gpt-5.4-mini", @@ -4175,7 +4221,7 @@ export function createAgentChatService(args: { if (!seed) return; const auth = await detectAuth(); - const availableModels = getRegistryModels(auth).filter((descriptor) => !descriptor.deprecated); + const availableModels = await getAvailableRegistryModels(auth); if (!availableModels.length) return; const preferredModelId = @@ -4248,11 +4294,81 @@ export function createAgentChatService(args: { // Unified session support — for API-key / local models using streamText + universal tools. // CLI-wrapped models fall through to the existing Claude/Codex runtimes. + const LOCAL_PROVIDER_LABELS: Record = { + ollama: "Ollama", + lmstudio: "LM Studio", + vllm: "vLLM", + }; + + const discoveredLocalModelToDescriptor = (model: DiscoveredLocalModel): ModelDescriptor => + createDynamicLocalModelDescriptor(model.provider, model.modelId, { + ...(model.displayName ? { displayName: model.displayName } : {}), + ...(model.contextWindow ? { contextWindow: model.contextWindow } : {}), + ...(model.maxOutputTokens ? { maxOutputTokens: model.maxOutputTokens } : {}), + ...(model.capabilities ? { capabilities: model.capabilities } : {}), + ...(model.reasoningTiers?.length ? { reasoningTiers: model.reasoningTiers } : {}), + ...(model.harnessProfile ? { harnessProfile: model.harnessProfile } : {}), + ...(model.discoverySource ? { discoverySource: model.discoverySource } : {}), + }); + + const getAvailableRegistryModels = async ( + auth: Awaited>, + ): Promise => { + if (auth.some((entry) => entry.type === "local")) { + const discovered = await discoverLocalModels(auth); + replaceDynamicLocalModelDescriptors(discovered.map(discoveredLocalModelToDescriptor)); + } + return getRegistryModels(auth).filter((descriptor) => !descriptor.deprecated); + }; + + const resolveUnifiedLocalDescriptor = async ( + managed: ManagedChatSession, + descriptor: ModelDescriptor, + auth: Awaited>, + ): Promise => { + if (!(descriptor.family === "ollama" || descriptor.family === "lmstudio" || descriptor.family === "vllm")) { + return descriptor; + } + if (descriptor.sdkModelId !== "auto") { + return descriptor; + } + + const localProvider = descriptor.family as LocalProviderFamily; + const discovered = (await discoverLocalModels(auth)).filter((model) => model.provider === localProvider); + replaceDynamicLocalModelDescriptors(discovered.map(discoveredLocalModelToDescriptor)); + + const preferred = auth.find( + (entry): entry is Extract>[number], { type: "local" }> => + entry.type === "local" && entry.provider === localProvider, + )?.preferredModelId; + const preferredDescriptor = preferred ? getModelById(preferred) : undefined; + if (preferredDescriptor && preferredDescriptor.family === localProvider) { + managed.session.modelId = preferredDescriptor.id; + managed.session.model = preferredDescriptor.id; + return preferredDescriptor; + } + + if (discovered.length === 1) { + const onlyDescriptor = getModelById(`${localProvider}/${discovered[0]!.modelId}`) ?? discoveredLocalModelToDescriptor(discovered[0]!); + managed.session.modelId = onlyDescriptor.id; + managed.session.model = onlyDescriptor.id; + return onlyDescriptor; + } + + if (discovered.length > 1) { + throw new Error( + `${descriptor.displayName} has multiple loaded models. Choose a specific ${LOCAL_PROVIDER_LABELS[localProvider]} model or save a preferred local model first.`, + ); + } + + throw new Error(`${descriptor.displayName} is reachable, but no models are currently loaded.`); + }; + const startUnifiedSession = async (managed: ManagedChatSession): Promise<"handled" | "fallthrough"> => { const modelId = managed.session.modelId; if (!modelId) return "fallthrough"; - const descriptor = getModelById(modelId); + let descriptor = getModelById(modelId); if (!descriptor) return "fallthrough"; // CLI-wrapped models -> defer to CLI session runtimes. @@ -4265,7 +4381,15 @@ export function createAgentChatService(args: { }); const auth = await detectAuth(); - const resolvedModel = await providerResolver.resolveModel(modelId, auth, { + descriptor = await resolveUnifiedLocalDescriptor(managed, descriptor, auth); + const harnessPermissions = applyLocalHarnessPermissionMode({ + descriptor, + requestedPermissionMode: managed.session.permissionMode, + requestedUnifiedPermissionMode: managed.session.unifiedPermissionMode, + }); + managed.session.permissionMode = harnessPermissions.requestedPermissionMode ?? managed.session.permissionMode; + managed.session.unifiedPermissionMode = harnessPermissions.requestedUnifiedPermissionMode ?? managed.session.unifiedPermissionMode; + const resolvedModel = await providerResolver.resolveModel(descriptor.id, auth, { cwd: managed.laneWorktreePath, }); @@ -5402,7 +5526,7 @@ export function createAgentChatService(args: { // Fire-and-forget AI summary enhancement const auth = await detectAuth(); - const availableModels = getRegistryModels(auth).filter((d) => !d.deprecated); + const availableModels = await getAvailableRegistryModels(auth); if (!availableModels.length) return; const preferredModelId = @@ -10135,7 +10259,7 @@ export function createAgentChatService(args: { codexApprovalPolicy: requestedCodexApprovalPolicy, codexSandbox: requestedCodexSandbox, codexConfigSource: requestedCodexConfigSource, - unifiedPermissionMode: requestedUnifiedPermissionMode, + unifiedPermissionMode: requestedUnifiedPermissionModeArg, cursorModeId: requestedCursorModeId, cursorConfigValues: requestedCursorConfigValues, permissionMode: requestedPermMode, @@ -10215,10 +10339,18 @@ export function createAgentChatService(args: { const normalizedCursorConfigValues = normalizeCursorConfigValueRecord(requestedCursorConfigValues); const capabilityMode = inferCapabilityMode(effectiveProvider); const computerUsePolicy = normalizeComputerUsePolicy(computerUse, createDefaultComputerUsePolicy()); - const effectivePermissionMode = identityKey + let effectivePermissionMode = identityKey ? normalizeIdentityPermissionMode(requestedPermMode, effectiveProvider) : requestedPermMode; const chatConfig = resolveChatConfig(); + let requestedUnifiedPermissionMode = requestedUnifiedPermissionModeArg; + const localHarnessPermissions = applyLocalHarnessPermissionMode({ + descriptor: resolvedDescriptor, + requestedPermissionMode: effectivePermissionMode, + requestedUnifiedPermissionMode, + }); + effectivePermissionMode = localHarnessPermissions.requestedPermissionMode; + requestedUnifiedPermissionMode = localHarnessPermissions.requestedUnifiedPermissionMode; const nativePermissionFields = (() => { if (effectiveProvider === "claude") { @@ -12673,7 +12805,7 @@ export function createAgentChatService(args: { // For unified/non-CLI providers: return all models with valid auth. try { const auth = await detectAuth(); - const available = getRegistryModels(auth); + const available = await getAvailableRegistryModels(auth); const targetModels = provider === "unified" ? available : available.filter(m => m.family === provider); @@ -13488,18 +13620,24 @@ export function createAgentChatService(args: { try { const projectRoot = args.projectRoot; if (!projectRoot) return; - const attachDir = path.join(projectRoot, ".ade", "attachments"); - if (!fs.existsSync(attachDir)) return; - const cutoff = Date.now() - 7 * 24 * 60 * 60 * 1000; // 7 days - for (const entry of fs.readdirSync(attachDir)) { - try { - const filePath = path.join(attachDir, entry); - const stat = fs.statSync(filePath); - if (stat.isFile() && stat.mtimeMs < cutoff) { - fs.unlinkSync(filePath); + const cleanupDir = (dirPath: string) => { + if (!fs.existsSync(dirPath)) return; + const cutoff = Date.now() - 7 * 24 * 60 * 60 * 1000; + for (const entry of fs.readdirSync(dirPath)) { + try { + const filePath = path.join(dirPath, entry); + const stat = fs.statSync(filePath); + if (stat.mtimeMs < cutoff) { + fs.rmSync(filePath, { recursive: true, force: true }); + } + } catch { + // Best-effort cleanup only. } - } catch { /* skip */ } - } + } + }; + + cleanupDir(path.join(projectRoot, ".ade", "attachments")); + cleanupDir(path.join(resolveAdeLayout(projectRoot).tmpDir, "agent-chat-attachments")); } catch { /* ignore */ } }, setComputerUseArtifactBrokerService(svc: ComputerUseArtifactBrokerService) { diff --git a/apps/desktop/src/main/services/config/projectConfigService.ts b/apps/desktop/src/main/services/config/projectConfigService.ts index 1e4a91cc5..2893ec8c7 100644 --- a/apps/desktop/src/main/services/config/projectConfigService.ts +++ b/apps/desktop/src/main/services/config/projectConfigService.ts @@ -1023,6 +1023,34 @@ function coerceAiTaskRoutingRule(value: unknown): AiTaskRoutingRule | null { return Object.keys(out).length ? out : null; } +function coerceAiLocalProviders(value: unknown): AiConfig["localProviders"] { + if (!isRecord(value)) return undefined; + + const providers: NonNullable = {}; + for (const provider of ["ollama", "lmstudio", "vllm"] as const) { + const raw = isRecord(value[provider]) ? value[provider] : null; + if (!raw) continue; + + const entry: NonNullable[typeof provider]> = {}; + const enabled = asBool(raw.enabled); + if (enabled != null) entry.enabled = enabled; + const endpoint = asString(raw.endpoint)?.trim(); + if (endpoint) entry.endpoint = endpoint; + const autoDetect = asBool(raw.autoDetect); + if (autoDetect != null) entry.autoDetect = autoDetect; + if (raw.preferredModelId === null) { + entry.preferredModelId = null; + } else { + const preferredModelId = asString(raw.preferredModelId)?.trim(); + if (preferredModelId) entry.preferredModelId = preferredModelId; + } + + if (Object.keys(entry).length) providers[provider] = entry; + } + + return Object.keys(providers).length ? providers : undefined; +} + function coerceAiConfig(value: unknown): AiConfig | undefined { if (!isRecord(value)) return undefined; @@ -1238,6 +1266,9 @@ function coerceAiConfig(value: unknown): AiConfig | undefined { const apiKeys = asStringMap(value.apiKeys); if (apiKeys && Object.keys(apiKeys).length) out.apiKeys = apiKeys; + const localProviders = coerceAiLocalProviders(value.localProviders); + if (localProviders) out.localProviders = localProviders; + const workerSafety = coerceWorkerSafetyPolicy(value.workerSafety); if (workerSafety) out.workerSafety = workerSafety; @@ -1554,6 +1585,16 @@ export function mergeAiConfig(sharedAi?: AiConfig, localAi?: Partial): ...(sharedAi?.apiKeys ?? {}), ...(localAi?.apiKeys ?? {}) }; + const localProvidersEntries = (["ollama", "lmstudio", "vllm"] as const) + .map((provider) => { + const mergedProvider = { + ...(sharedAi?.localProviders?.[provider] ?? {}), + ...(localAi?.localProviders?.[provider] ?? {}), + }; + return Object.keys(mergedProvider).length ? [provider, mergedProvider] as const : null; + }) + .filter((entry): entry is readonly ["ollama" | "lmstudio" | "vllm", Record] => entry != null); + const localProviders = Object.fromEntries(localProvidersEntries) as AiConfig["localProviders"]; const workerSafety = mergeWorkerSafetyPolicy(sharedAi?.workerSafety, localAi?.workerSafety); const mcpServers = { ...(sharedAi?.mcpServers ?? {}), @@ -1572,6 +1613,7 @@ export function mergeAiConfig(sharedAi?: AiConfig, localAi?: Partial): ...(Object.keys(chat).length ? { chat } : {}), ...(Object.keys(featureModelOverrides).length ? { featureModelOverrides } : {}), ...(Object.keys(apiKeys).length ? { apiKeys } : {}), + ...(localProvidersEntries.length ? { localProviders } : {}), ...(workerSafety ? { workerSafety } : {}), ...(Object.keys(mcpServers).length ? { mcpServers } : {}) }; diff --git a/apps/desktop/src/main/services/ipc/registerIpc.ts b/apps/desktop/src/main/services/ipc/registerIpc.ts index 29c99d9ae..9dcf7903f 100644 --- a/apps/desktop/src/main/services/ipc/registerIpc.ts +++ b/apps/desktop/src/main/services/ipc/registerIpc.ts @@ -762,6 +762,7 @@ function getUnavailableAiStatus(): AiSettingsStatus { dailyUsage: 0, dailyLimit: null, })), + runtimeConnections: {}, availableModelIds: [], }; } @@ -1978,6 +1979,7 @@ export function registerIpc({ models: status.models, detectedAuth: status.detectedAuth, providerConnections: status.providerConnections, + runtimeConnections: status.runtimeConnections, availableModelIds: status.availableModelIds, apiKeyStore: status.apiKeyStore, features: AI_USAGE_FEATURE_KEYS.map((feature) => ({ diff --git a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.test.ts b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.test.ts index 57841c4b0..4e1816aa8 100644 --- a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.test.ts +++ b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.test.ts @@ -264,6 +264,125 @@ describe("createUnifiedOrchestratorAdapter", () => { }); }); + it("routes local unified models through the managed chat path instead of the shell fallback", async () => { + const workspaceRoot = fs.mkdtempSync(path.join(os.tmpdir(), "ade-unified-local-managed-")); + const createSession = vi.fn(async () => ({ id: "session-managed-local-1" })); + const adapter = createUnifiedOrchestratorAdapter({ + workspaceRoot, + runtimeRoot: path.join(workspaceRoot, "runtime"), + agentChatService: { + createSession, + } as any, + }); + + const result = await adapter.start({ + run: { + id: "run-1", + missionId: "mission-1", + metadata: { + missionGoal: "Implement the worker step", + }, + } as any, + step: { + id: "step-1", + title: "Local implementation worker", + stepKey: "local-implementation-worker", + laneId: "lane-1", + metadata: { + modelId: "lmstudio/meta-llama-3.1-70b-instruct", + stepType: "implementation", + }, + dependencyStepIds: [], + joinPolicy: "all_success", + } as any, + attempt: { id: "attempt-1" } as any, + allSteps: [], + contextProfile: {} as any, + laneExport: null, + projectExport: { content: "" } as any, + docsRefs: [], + fullDocs: [], + permissionConfig: { + _providers: { + unified: "full-auto", + }, + } as any, + createTrackedSession: vi.fn(), + } as any); + + expect(createSession).toHaveBeenCalledWith(expect.objectContaining({ + laneId: "lane-1", + provider: "unified", + model: "lmstudio/meta-llama-3.1-70b-instruct", + modelId: "lmstudio/meta-llama-3.1-70b-instruct", + })); + expect(result).toMatchObject({ + status: "accepted", + sessionId: "session-managed-local-1", + metadata: expect.objectContaining({ + workerSessionKind: "managed_chat", + workerStreamSource: "agent_chat", + startupCommandPreview: "[managed chat session]", + }), + }); + }); + + it("starts guarded local worker sessions in plan mode by default", async () => { + const workspaceRoot = fs.mkdtempSync(path.join(os.tmpdir(), "ade-unified-local-guarded-")); + const createSession = vi.fn(async () => ({ id: "session-managed-local-plan" })); + const adapter = createUnifiedOrchestratorAdapter({ + workspaceRoot, + runtimeRoot: path.join(workspaceRoot, "runtime"), + agentChatService: { + createSession, + } as any, + }); + + const result = await adapter.start({ + run: { + id: "run-1", + missionId: "mission-1", + metadata: { missionGoal: "Investigate the local worker path" }, + } as any, + step: { + id: "step-1", + title: "Guarded local worker", + stepKey: "guarded-local-worker", + laneId: "lane-1", + metadata: { + modelId: "lmstudio/auto", + stepType: "implementation", + }, + dependencyStepIds: [], + joinPolicy: "all_success", + } as any, + attempt: { id: "attempt-1" } as any, + allSteps: [], + contextProfile: {} as any, + laneExport: null, + projectExport: { content: "" } as any, + docsRefs: [], + fullDocs: [], + permissionConfig: {} as any, + createTrackedSession: vi.fn(), + } as any); + + expect(createSession).toHaveBeenCalledWith(expect.objectContaining({ + provider: "unified", + modelId: "lmstudio/auto", + unifiedPermissionMode: "plan", + })); + expect(result).toMatchObject({ + status: "accepted", + launch: expect.objectContaining({ + permissionMode: "plan", + }), + metadata: expect.objectContaining({ + permissionMode: "plan", + }), + }); + }); + it("forces Codex planning steps into a read-only sandbox", async () => { const workspaceRoot = fs.mkdtempSync(path.join(os.tmpdir(), "ade-unified-codex-")); const adapter = createUnifiedOrchestratorAdapter({ @@ -528,6 +647,13 @@ describe("getUnifiedUnsupportedModelReason", () => { expect(getUnifiedUnsupportedModelReason("anthropic/claude-sonnet-4-6")).toBeNull(); }); + it("describes the shell fallback for API/local models instead of rejecting them globally", () => { + const reason = getUnifiedUnsupportedModelReason("lmstudio/meta-llama-3.1-70b-instruct"); + expect(reason).toContain("shell-startup fallback"); + expect(reason).toContain("in-process unified execution"); + expect(reason).toContain("managed unified chat path"); + }); + it("returns a not-registered message for unknown model refs", () => { const reason = getUnifiedUnsupportedModelReason("nonexistent/fantasy-model-99"); expect(reason).toBe("Model 'nonexistent/fantasy-model-99' is not registered."); diff --git a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts index cf6b072ef..dd7b8457f 100644 --- a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts +++ b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts @@ -9,6 +9,7 @@ import { resolveModelAlias, resolveModelDescriptor, resolveProviderGroupForModel, + type ModelDescriptor, } from "../../../shared/modelRegistry"; import type { AgentChatExecutionMode, @@ -81,7 +82,7 @@ export function getUnifiedUnsupportedModelReason(modelRef: string): string | nul const cliProvider = resolveCliProviderForModel(descriptor); if (cliProvider) return null; const executionPath = classifyWorkerExecutionPath(descriptor); - return `Model '${descriptor.id}' requires ${executionPath} execution (${descriptor.family}), but the unified worker adapter currently supports only Claude/Codex CLI models.`; + return `Model '${descriptor.id}' requires ${executionPath} in-process unified execution (${descriptor.family}), but the shell-startup fallback only supports Claude/Codex CLI models. Use the managed unified chat path for API and local models.`; } /** @@ -305,6 +306,7 @@ const VALID_PERMISSION_MODES = new Set(["default", "plan", "edit", "full function resolveManagedPermissionMode(args: { provider: "claude" | "codex" | "unified" | "cursor"; + descriptor?: ModelDescriptor; permissionConfig: LegacyPermissionConfig | undefined; readOnlyExecution: boolean; }): AgentChatPermissionMode | undefined { @@ -314,9 +316,14 @@ function resolveManagedPermissionMode(args: { args.provider === "cursor" ? ((providers?.cursor ?? providers?.unified) as string | undefined) : (providers?.[args.provider] as string | undefined); - return typeof candidate === "string" && VALID_PERMISSION_MODES.has(candidate) + const normalizedCandidate = typeof candidate === "string" && VALID_PERMISSION_MODES.has(candidate) ? candidate as AgentChatPermissionMode : undefined; + if (args.descriptor?.authTypes.includes("local")) { + if (args.descriptor.harnessProfile === "read_only") return "plan"; + if (args.descriptor.harnessProfile === "guarded" && normalizedCandidate == null) return "plan"; + } + return normalizedCandidate; } function mapPermissionModeToNativeFields( @@ -570,9 +577,10 @@ export function createUnifiedOrchestratorAdapter(options?: { return startup; } - // Non-CLI or unknown models cannot run via this shell-based adapter. + // Non-CLI or unknown models can still run via the managed chat path. + // This shell fallback only exists for CLI-wrapped workers. const unsupportedReason = getUnifiedUnsupportedModelReason(model) ?? `Model '${model}' is not supported by unified adapter.`; - const failureMessage = `[ADE] Unified orchestrator adapter currently supports CLI-wrapped Anthropic/OpenAI models only. ${unsupportedReason} Select a CLI model for this worker.`; + const failureMessage = `[ADE] Shell-startup fallback for the unified orchestrator adapter only supports CLI-wrapped Anthropic/OpenAI models. ${unsupportedReason}`; return `printf '%s\\n' ${shellEscapeArg(failureMessage)} >&2; exit 64`; }, @@ -657,6 +665,7 @@ export function createUnifiedOrchestratorAdapter(options?: { : undefined; const permissionMode = resolveManagedPermissionMode({ provider, + descriptor, permissionConfig: effectivePermissionConfig, readOnlyExecution, }); diff --git a/apps/desktop/src/main/services/runtime/tempCleanupService.test.ts b/apps/desktop/src/main/services/runtime/tempCleanupService.test.ts new file mode 100644 index 000000000..7c5ac9461 --- /dev/null +++ b/apps/desktop/src/main/services/runtime/tempCleanupService.test.ts @@ -0,0 +1,105 @@ +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { cleanupStaleTempArtifacts } from "./tempCleanupService"; + +const tempRoots: string[] = []; +const NOW_MS = Date.parse("2026-04-02T12:00:00.000Z"); +const DAY_MS = 24 * 60 * 60 * 1000; + +function createTempRoot(): string { + const root = fs.mkdtempSync(path.join(os.tmpdir(), "ade-temp-cleanup-test-")); + tempRoots.push(root); + return root; +} + +function touchMtime(targetPath: string, mtimeMs: number): void { + const stamp = new Date(mtimeMs); + fs.utimesSync(targetPath, stamp, stamp); +} + +afterEach(() => { + while (tempRoots.length > 0) { + const root = tempRoots.pop(); + if (root) { + fs.rmSync(root, { recursive: true, force: true }); + } + } +}); + +describe("cleanupStaleTempArtifacts", () => { + it("removes stale ADE ShipIt caches but keeps recent ones", () => { + const tempRoot = createTempRoot(); + const stalePath = path.join(tempRoot, "com.ade.desktop.ShipIt.stale"); + const freshPath = path.join(tempRoot, "com.ade.desktop.ShipIt.fresh"); + fs.mkdirSync(stalePath, { recursive: true }); + fs.mkdirSync(freshPath, { recursive: true }); + touchMtime(stalePath, NOW_MS - (8 * DAY_MS)); + touchMtime(freshPath, NOW_MS - (2 * DAY_MS)); + + cleanupStaleTempArtifacts({ + tempRoot, + nowMs: NOW_MS, + logger: { info: vi.fn(), warn: vi.fn() }, + }); + + expect(fs.existsSync(stalePath)).toBe(false); + expect(fs.existsSync(freshPath)).toBe(true); + }); + + it("removes stale screenshot temp directories", () => { + const tempRoot = createTempRoot(); + const stalePath = path.join(tempRoot, "ade-screenshot-old"); + const freshPath = path.join(tempRoot, "ade-screenshot-new"); + fs.mkdirSync(stalePath, { recursive: true }); + fs.mkdirSync(freshPath, { recursive: true }); + touchMtime(stalePath, NOW_MS - (2 * DAY_MS)); + touchMtime(freshPath, NOW_MS - (6 * 60 * 60 * 1000)); + + cleanupStaleTempArtifacts({ + tempRoot, + nowMs: NOW_MS, + logger: { info: vi.fn(), warn: vi.fn() }, + }); + + expect(fs.existsSync(stalePath)).toBe(false); + expect(fs.existsSync(freshPath)).toBe(true); + }); + + it("prunes stale fallback attachments and removes the directory when it becomes empty", () => { + const tempRoot = createTempRoot(); + const attachmentsDir = path.join(tempRoot, "ade-attachments"); + const staleFile = path.join(attachmentsDir, "stale.png"); + fs.mkdirSync(attachmentsDir, { recursive: true }); + fs.writeFileSync(staleFile, "stale", "utf8"); + touchMtime(staleFile, NOW_MS - (8 * DAY_MS)); + + cleanupStaleTempArtifacts({ + tempRoot, + nowMs: NOW_MS, + logger: { info: vi.fn(), warn: vi.fn() }, + }); + + expect(fs.existsSync(staleFile)).toBe(false); + expect(fs.existsSync(attachmentsDir)).toBe(false); + }); + + it("keeps recent fallback attachments", () => { + const tempRoot = createTempRoot(); + const attachmentsDir = path.join(tempRoot, "ade-attachments"); + const freshFile = path.join(attachmentsDir, "fresh.png"); + fs.mkdirSync(attachmentsDir, { recursive: true }); + fs.writeFileSync(freshFile, "fresh", "utf8"); + touchMtime(freshFile, NOW_MS - (2 * DAY_MS)); + + cleanupStaleTempArtifacts({ + tempRoot, + nowMs: NOW_MS, + logger: { info: vi.fn(), warn: vi.fn() }, + }); + + expect(fs.existsSync(freshFile)).toBe(true); + expect(fs.existsSync(attachmentsDir)).toBe(true); + }); +}); diff --git a/apps/desktop/src/main/services/runtime/tempCleanupService.ts b/apps/desktop/src/main/services/runtime/tempCleanupService.ts new file mode 100644 index 000000000..5923b623f --- /dev/null +++ b/apps/desktop/src/main/services/runtime/tempCleanupService.ts @@ -0,0 +1,129 @@ +import fs from "node:fs"; +import path from "node:path"; +import type { Logger } from "../logging/logger"; + +const DAY_MS = 24 * 60 * 60 * 1000; +const SHIPIT_RETENTION_MS = 7 * DAY_MS; +const SCREENSHOT_RETENTION_MS = DAY_MS; +const ATTACHMENTS_RETENTION_MS = 7 * DAY_MS; + +type CleanupSummary = { + shipItEntriesRemoved: number; + screenshotEntriesRemoved: number; + attachmentEntriesRemoved: number; + attachmentDirRemoved: boolean; +}; + +function isOlderThan(targetPath: string, cutoffMs: number): boolean { + try { + return fs.statSync(targetPath).mtimeMs < cutoffMs; + } catch { + return false; + } +} + +function removePath(targetPath: string): boolean { + try { + fs.rmSync(targetPath, { recursive: true, force: true }); + return true; + } catch { + return false; + } +} + +function cleanupStaleAttachmentEntries(attachmentsDir: string, cutoffMs: number): { removedEntries: number; removedDir: boolean } { + let removedEntries = 0; + + let entries: fs.Dirent[] = []; + try { + entries = fs.readdirSync(attachmentsDir, { withFileTypes: true }); + } catch { + return { removedEntries, removedDir: false }; + } + + for (const entry of entries) { + const entryPath = path.join(attachmentsDir, entry.name); + if (!isOlderThan(entryPath, cutoffMs)) continue; + if (removePath(entryPath)) { + removedEntries += 1; + } + } + + let removedDir = false; + try { + if (fs.readdirSync(attachmentsDir).length === 0) { + removedDir = removePath(attachmentsDir); + } + } catch { + // Directory may already be gone. + } + + return { removedEntries, removedDir }; +} + +export function cleanupStaleTempArtifacts(args: { + tempRoot: string; + logger: Pick; + nowMs?: number; +}): void { + const nowMs = args.nowMs ?? Date.now(); + const summary: CleanupSummary = { + shipItEntriesRemoved: 0, + screenshotEntriesRemoved: 0, + attachmentEntriesRemoved: 0, + attachmentDirRemoved: false, + }; + + let entries: fs.Dirent[] = []; + try { + entries = fs.readdirSync(args.tempRoot, { withFileTypes: true }); + } catch (error) { + const code = (error as NodeJS.ErrnoException)?.code; + if (code !== "ENOENT") { + args.logger.warn("tempCleanup.scan_failed", { + tempRoot: args.tempRoot, + message: error instanceof Error ? error.message : String(error), + }); + } + return; + } + + for (const entry of entries) { + const entryPath = path.join(args.tempRoot, entry.name); + + if (entry.name.startsWith("com.ade.desktop.ShipIt.")) { + if (isOlderThan(entryPath, nowMs - SHIPIT_RETENTION_MS) && removePath(entryPath)) { + summary.shipItEntriesRemoved += 1; + } + continue; + } + + if (entry.name.startsWith("ade-screenshot-")) { + if (isOlderThan(entryPath, nowMs - SCREENSHOT_RETENTION_MS) && removePath(entryPath)) { + summary.screenshotEntriesRemoved += 1; + } + continue; + } + + if (entry.name === "ade-attachments" && entry.isDirectory()) { + const attachmentCleanup = cleanupStaleAttachmentEntries(entryPath, nowMs - ATTACHMENTS_RETENTION_MS); + summary.attachmentEntriesRemoved += attachmentCleanup.removedEntries; + summary.attachmentDirRemoved = summary.attachmentDirRemoved || attachmentCleanup.removedDir; + } + } + + if ( + summary.shipItEntriesRemoved > 0 + || summary.screenshotEntriesRemoved > 0 + || summary.attachmentEntriesRemoved > 0 + || summary.attachmentDirRemoved + ) { + args.logger.info("tempCleanup.removed_stale_entries", { + tempRoot: args.tempRoot, + shipItEntriesRemoved: summary.shipItEntriesRemoved, + screenshotEntriesRemoved: summary.screenshotEntriesRemoved, + attachmentEntriesRemoved: summary.attachmentEntriesRemoved, + attachmentDirRemoved: summary.attachmentDirRemoved, + }); + } +} diff --git a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx index 8340d840f..1ff5bb6ad 100644 --- a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx +++ b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx @@ -15,6 +15,7 @@ import { type AgentChatFileRef, type AgentChatInteractionMode, type AiProviderConnectionStatus, + type AiRuntimeConnectionStatus, type AgentChatSession, type AgentChatUnifiedPermissionMode, type AgentChatSessionProfile, @@ -24,13 +25,16 @@ import { type AgentChatSessionSummary, type ComputerUseOwnerSnapshot, type ComputerUsePolicy, + type AiSettingsStatus, type TerminalToolType, } from "../../../shared/types"; import { parseAgentChatTranscript } from "../../../shared/chatTranscript"; import { MODEL_REGISTRY, + getLocalProviderDefaultEndpoint, getModelById, resolveModelDescriptorForProvider, + type LocalProviderFamily, type ModelDescriptor, } from "../../../shared/modelRegistry"; import { filterChatModelIdsForSession } from "../../../shared/chatModelSwitching"; @@ -63,6 +67,41 @@ const LEGACY_PROVIDER_KEY = "ade.chat.lastProvider"; const LEGACY_MODEL_KEY_PREFIX = "ade.chat.lastModel"; const COMPUTER_USE_SNAPSHOT_COOLDOWN_MS = 750; +const LOCAL_PROVIDER_LABELS: Record = { + ollama: "Ollama", + lmstudio: "LM Studio", + vllm: "vLLM", +}; + +type AiStatusSnapshot = AiSettingsStatus & { + runtimeConnections?: Record; +}; + +function getLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { + const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase(); + if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { + return provider; + } + return null; +} + +function formatLocalModelLabel(modelId: string): string { + const provider = getLocalProviderFromModelId(modelId); + if (!provider) { + return getModelById(modelId)?.displayName ?? modelId; + } + const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); + return tail.length ? tail : modelId; +} + +function recommendedUnifiedPermissionModeForModel( + descriptor: ModelDescriptor | null | undefined, +): AgentChatUnifiedPermissionMode | null { + if (!descriptor?.authTypes.includes("local")) return null; + return descriptor.harnessProfile === "guarded" || descriptor.harnessProfile === "read_only" + ? "plan" + : null; +} export function resolveChatSessionProfile(_computerUsePolicy: ComputerUsePolicy): AgentChatSessionProfile { return "workflow"; @@ -593,6 +632,7 @@ export function AgentChatPane({ const [cursorModeId, setCursorModeId] = useState(initialNativeControls.cursorModeId); const [cursorConfigValues, setCursorConfigValues] = useState>(initialNativeControls.cursorConfigValues); const [computerUsePolicy, setComputerUsePolicy] = useState(createDefaultComputerUsePolicy()); + const [aiStatus, setAiStatus] = useState(null); const [providerConnections, setProviderConnections] = useState<{ claude: AiProviderConnectionStatus | null; codex: AiProviderConnectionStatus | null; @@ -685,6 +725,74 @@ export function AgentChatPane({ const pendingSteers = selectedSessionId ? (pendingSteersBySession[selectedSessionId] ?? []) : []; const selectedModelDesc = getModelById(modelId); const reasoningTiers = selectedModelDesc?.reasoningTiers ?? []; + const localRuntimeState = useMemo(() => { + const provider = selectedModelDesc?.authTypes.includes("local") + ? (selectedModelDesc.family as LocalProviderFamily) + : getLocalProviderFromModelId(modelId); + if (!provider) return null; + const runtimeConnection = aiStatus?.runtimeConnections?.[provider] ?? null; + const detectedEntry = aiStatus?.detectedAuth?.find( + (entry): entry is { type: "local"; provider: LocalProviderFamily; endpoint: string } => + entry.type === "local" && entry.provider === provider, + ) ?? null; + const modelIds = runtimeConnection?.loadedModelIds?.length + ? runtimeConnection.loadedModelIds.filter((id): id is string => String(id ?? "").startsWith(`${provider}/`)) + : availableModelIds.filter((id) => id.startsWith(`${provider}/`)); + return { + provider, + label: LOCAL_PROVIDER_LABELS[provider], + endpoint: runtimeConnection?.endpoint ?? detectedEntry?.endpoint ?? getLocalProviderDefaultEndpoint(provider), + detected: Boolean(runtimeConnection?.runtimeDetected ?? detectedEntry), + runtimeAvailable: runtimeConnection?.runtimeAvailable ?? false, + health: runtimeConnection?.health ?? null, + blocker: runtimeConnection?.blocker ?? null, + modelIds, + statusKnown: Boolean(aiStatus), + }; + }, [aiStatus, availableModelIds, modelId, selectedModelDesc]); + const localRuntimeNotice = useMemo(() => { + if (!localRuntimeState) return null; + if (!localRuntimeState.statusKnown) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: `ADE could not read ${localRuntimeState.label} status right now. It will still try the unified local-model path, but refresh settings if the runtime changed.`, + }; + } + if (localRuntimeState.blocker) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: localRuntimeState.blocker, + }; + } + if (!localRuntimeState.detected) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: `${localRuntimeState.label} is not detected at ${localRuntimeState.endpoint}. Start it, load a model, then refresh so ADE can use the local runtime.`, + }; + } + if (!localRuntimeState.modelIds.length) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: `${localRuntimeState.label} responded, but no loaded models were reported yet. Load a model in ${localRuntimeState.label} and refresh.`, + }; + } + if (!localRuntimeState.modelIds.includes(modelId)) { + return { + tone: "warning" as const, + title: `${localRuntimeState.label} runtime`, + message: `${localRuntimeState.label} is running, but ${selectedModelDesc?.displayName ?? formatLocalModelLabel(modelId)} is not in the loaded model list. Choose one of the loaded models or load this model in ${localRuntimeState.label}.`, + }; + } + return { + tone: "success" as const, + title: `${localRuntimeState.label} runtime`, + message: `${localRuntimeState.label} is connected at ${localRuntimeState.endpoint} with ${localRuntimeState.modelIds.length} loaded model${localRuntimeState.modelIds.length === 1 ? "" : "s"}${localRuntimeState.health ? ` (${localRuntimeState.health})` : ""}.`, + }; + }, [localRuntimeState, modelId, selectedModelDesc?.displayName]); const surfaceMode = presentation?.mode ?? "standard"; const identitySessionSettingsBusy = isPersistentIdentitySurface && sessionMutationKind !== null; @@ -812,10 +920,18 @@ export function AgentChatPane({ const refreshAvailableModels = useCallback(async () => { try { const status = await window.ade.ai.getStatus(); + setAiStatus(status); + setProviderConnections({ + claude: status.providerConnections?.claude ?? null, + codex: status.providerConnections?.codex ?? null, + cursor: status.providerConnections?.cursor ?? null, + }); const available = deriveConfiguredModelIds(status, { includeCursor: true }); setAvailableModelIds(available); return available; } catch { + setAiStatus(null); + setProviderConnections(null); // Fall back to direct model discovery probes below. } @@ -863,19 +979,6 @@ export function AgentChatPane({ } }, []); - const refreshProviderConnections = useCallback(async () => { - try { - const status = await window.ade.ai.getStatus(); - setProviderConnections({ - claude: status.providerConnections?.claude ?? null, - codex: status.providerConnections?.codex ?? null, - cursor: status.providerConnections?.cursor ?? null, - }); - } catch { - setProviderConnections(null); - } - }, []); - const touchSession = useCallback((sessionId: string | null | undefined, touchedAt = new Date().toISOString()) => { if (!sessionId) return; const previousTouch = localTouchBySessionRef.current.get(sessionId); @@ -963,16 +1066,16 @@ export function AgentChatPane({ }, [selectedSessionId]); useEffect(() => { - void refreshProviderConnections(); - }, [refreshProviderConnections, selectedSession?.provider]); + void refreshAvailableModels(); + }, [refreshAvailableModels, selectedSession?.provider]); useEffect(() => { if (!turnActive || !selectedSession?.provider) return; const timer = window.setInterval(() => { - void refreshProviderConnections(); + void refreshAvailableModels(); }, 5000); return () => window.clearInterval(timer); - }, [refreshProviderConnections, selectedSession?.provider, turnActive]); + }, [refreshAvailableModels, selectedSession?.provider, turnActive]); const refreshComputerUseSnapshot = useCallback(async ( sessionId: string | null, @@ -1607,14 +1710,19 @@ export function AgentChatPane({ nextModel, nextProvider, nextReasoningEffort, + nextUnifiedPermissionMode: recommendedUnifiedPermissionModeForModel(nextDesc), }; }, [laneId]); const applyModelSelectionSnapshot = useCallback((snapshot: { nextModelId: string; nextReasoningEffort: string | null; + nextUnifiedPermissionMode?: AgentChatUnifiedPermissionMode | null; }) => { setModelId(snapshot.nextModelId); setReasoningEffort(snapshot.nextReasoningEffort); + if (snapshot.nextUnifiedPermissionMode) { + setUnifiedPermissionMode(snapshot.nextUnifiedPermissionMode); + } }, []); const notifySessionCreated = useCallback((session: AgentChatSession) => { if (!onSessionCreated) return; @@ -1631,6 +1739,18 @@ export function AgentChatPane({ const provider = resolveChatRuntimeProvider(desc); const model = provider === "unified" ? modelId : runtimeFacingModelId(desc, modelId); const sessionProfile = resolveChatSessionProfile(computerUsePolicy); + const harnessPermissionMode = provider === "unified" + ? recommendedUnifiedPermissionModeForModel(desc) + : null; + const nativeControlPayload = harnessPermissionMode + ? { + ...summarizeNativeControls(provider, { + ...currentNativeControls, + unifiedPermissionMode: harnessPermissionMode, + }), + ...(provider === "cursor" ? { cursorConfigValues: currentNativeControls.cursorConfigValues } : {}), + } + : buildNativeControlPayload(provider); const created = await window.ade.agentChat.create({ laneId, provider, @@ -1638,7 +1758,7 @@ export function AgentChatPane({ modelId, sessionProfile, reasoningEffort, - ...buildNativeControlPayload(provider), + ...nativeControlPayload, computerUse: computerUsePolicy, }); loadedHistoryRef.current.delete(created.id); @@ -1665,7 +1785,7 @@ export function AgentChatPane({ createSessionPromiseRef.current = null; } } - }, [buildNativeControlPayload, computerUsePolicy, laneId, modelId, notifySessionCreated, reasoningEffort, refreshSessions, touchSession]); + }, [buildNativeControlPayload, computerUsePolicy, currentNativeControls, laneId, modelId, notifySessionCreated, reasoningEffort, refreshSessions, touchSession]); const handoffSession = useCallback(async () => { if (!canShowHandoff || !selectedSessionId || !handoffModelId || handoffBlocked) return; @@ -2313,11 +2433,22 @@ export function AgentChatPane({ } setSessionMutationKind("model"); + const nextNativeControlPayload = snapshot.nextUnifiedPermissionMode + ? { + ...summarizeNativeControls(snapshot.nextProvider, { + ...currentNativeControls, + unifiedPermissionMode: snapshot.nextUnifiedPermissionMode, + }), + ...(snapshot.nextProvider === "cursor" + ? { cursorConfigValues: currentNativeControls.cursorConfigValues } + : {}), + } + : buildNativeControlPayload(snapshot.nextProvider); void window.ade.agentChat.updateSession({ sessionId: selectedSessionId, modelId: nextModelId, reasoningEffort: snapshot.nextReasoningEffort, - ...buildNativeControlPayload(snapshot.nextProvider), + ...nextNativeControlPayload, computerUse: computerUsePolicy, }).then((updatedSession) => { applyModelSelectionSnapshot(snapshot); @@ -2434,6 +2565,35 @@ export function AgentChatPane({ ) : null} + {localRuntimeNotice ? ( +
+
+ {localRuntimeNotice.title} +
+
+ {localRuntimeNotice.message} +
+ {localRuntimeState ? ( + + {localRuntimeState.endpoint} + + ) : null} +
+ ) : null} +
{loading && !embedDraft && !selectedSessionId ? (
diff --git a/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx b/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx index 128bfd28d..2c7c607fe 100644 --- a/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx +++ b/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx @@ -27,7 +27,7 @@ const STEP_META: Record = { }, ai: { title: "AI connections", - subtitle: "Connect your AI providers", + subtitle: "Connect your AI providers and local runtimes", }, helpers: { title: "Background helpers", @@ -56,7 +56,7 @@ const STEP_HEADERS: Record = { tools: { heading: "Developer Tools", sub: "ADE needs git for version control. GitHub CLI unlocks PR creation, review requests, and CI checks." }, ai: { heading: "Connect AI providers", - sub: "Link API keys and CLIs (Claude Code, Codex, Cursor agent) so ADE can power chat, codegen, and background automations. After the CLI is installed and signed in, Cursor models appear in work chat automatically.", + sub: "Link API keys, CLIs, and local runtimes (LM Studio, Ollama, vLLM) so ADE can power chat, codegen, and background automations. After the CLI is installed and signed in, Cursor models appear in work chat automatically.", }, helpers: { heading: "Background Helpers", sub: "These lightweight AI automations run in the background while you work. All are optional and can be changed anytime in Settings." }, github: { heading: "GitHub Integration", sub: "A personal access token lets ADE create PRs, request reviews, and monitor CI on your behalf." }, diff --git a/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx b/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx index cd6e4b458..f3340bc7f 100644 --- a/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx +++ b/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx @@ -6,7 +6,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { ProvidersSection } from "./ProvidersSection"; import type { AgentChatEventEnvelope, AiSettingsStatus } from "../../../shared/types"; -function buildStatus(claudeRuntimeAvailable: boolean): AiSettingsStatus { +function buildStatus(claudeRuntimeAvailable: boolean, localModels: string[] = []): AiSettingsStatus { return { mode: "subscription", availableProviders: { @@ -20,7 +20,16 @@ function buildStatus(claudeRuntimeAvailable: boolean): AiSettingsStatus { cursor: [], }, features: [], - detectedAuth: [], + detectedAuth: localModels.length > 0 + ? [ + { + type: "local", + provider: "lmstudio", + endpoint: "http://localhost:1234", + }, + ] + : [], + availableModelIds: localModels, providerConnections: { claude: { provider: "claude", @@ -81,9 +90,17 @@ describe("ProvidersSection", () => { globalThis.window.ade = { ai: { getStatus: vi.fn() - .mockResolvedValueOnce(buildStatus(true)) - .mockResolvedValueOnce(buildStatus(false)), + .mockResolvedValueOnce(buildStatus(true, ["lmstudio/meta-llama-3.1-70b-instruct", "lmstudio/qwen2.5-coder:32b"])) + .mockResolvedValueOnce(buildStatus(false, ["lmstudio/meta-llama-3.1-70b-instruct", "lmstudio/qwen2.5-coder:32b"])), listApiKeys: vi.fn().mockResolvedValue([]), + updateConfig: vi.fn().mockResolvedValue(undefined), + }, + projectConfig: { + get: vi.fn().mockResolvedValue({ + effective: { + ai: {}, + }, + }), }, agentChat: { onEvent: vi.fn((listener: (envelope: AgentChatEventEnvelope) => void) => { @@ -144,4 +161,19 @@ describe("ProvidersSection", () => { expect((await screen.findAllByText("Connected")).length).toBeGreaterThan(0); expect(screen.getAllByText("/Users/arul/.local/bin/claude").length).toBeGreaterThan(0); }); + + it("renders local runtime details and loaded local models", async () => { + render(); + + await waitFor(() => { + expect(window.ade.ai.getStatus).toHaveBeenCalledTimes(1); + expect(window.ade.ai.listApiKeys).toHaveBeenCalledTimes(1); + }); + + expect((await screen.findAllByText("LM Studio")).length).toBeGreaterThan(0); + expect(screen.getAllByText("Ready").length).toBeGreaterThan(0); + expect(screen.getAllByText("LM Studio is reachable at http://localhost:1234. ADE can use 2 loaded models from this runtime.").length).toBeGreaterThan(0); + expect(screen.getAllByText("meta-llama-3.1-70b-instruct (LM Studio)").length).toBeGreaterThan(0); + expect(screen.getAllByText("qwen2.5-coder:32b (LM Studio)").length).toBeGreaterThan(0); + }); }); diff --git a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx index 5471dca23..41734d982 100644 --- a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx +++ b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx @@ -1,10 +1,14 @@ import React, { useCallback, useEffect, useMemo, useRef, useState } from "react"; import type { AgentChatEventEnvelope, + AiConfig, AiApiKeyVerificationResult, AiProviderConnectionStatus, + AiRuntimeConnectionStatus, AiSettingsStatus, + ProjectConfigSnapshot, } from "../../../shared/types"; +import { getLocalProviderDefaultEndpoint, getModelById, type LocalProviderFamily } from "../../../shared/modelRegistry"; import { ArrowsClockwise, CheckCircle, @@ -58,6 +62,20 @@ const CLI_TOOLS: Array<{ }, ]; +const LOCAL_PROVIDER_SPECS: Array<{ + provider: LocalProviderFamily; + label: string; + description: string; +}> = [ + { provider: "lmstudio", label: "LM Studio", description: "OpenAI-compatible local server" }, + { provider: "ollama", label: "Ollama", description: "OpenAI-compatible local server" }, + { provider: "vllm", label: "vLLM", description: "OpenAI-compatible local server" }, +]; + +const LOCAL_PROVIDER_LABELS: Record = Object.fromEntries( + LOCAL_PROVIDER_SPECS.map((entry) => [entry.provider, entry.label]), +) as Record; + const API_KEY_PROVIDERS: Array<{ provider: string; label: string; @@ -76,6 +94,13 @@ const API_KEY_PROVIDERS: Array<{ { provider: "openrouter", label: "OpenRouter", envVar: "OPENROUTER_API_KEY", placeholder: "sk-or-...", accent: "#A78BFA" }, ]; +type LocalProviderDraft = { + enabled: boolean; + endpoint: string; + autoDetect: boolean; + preferredModelId: string; +}; + const groupLabelStyle: React.CSSProperties = { ...LABEL_STYLE, fontSize: 11, @@ -156,6 +181,19 @@ function buildCliMessage(tool: (typeof CLI_TOOLS)[number], connection: AiProvide return `CLI not found in PATH. Install: ${tool.installHint}. If already installed, ensure it is on your shell PATH and use Refresh.`; } +function formatLocalModelLabel(modelId: string): string { + const descriptor = getModelById(modelId); + if (descriptor) return descriptor.displayName; + const raw = String(modelId ?? "").trim(); + const prefix = raw.split("/", 1)[0]?.toLowerCase(); + if (prefix === "ollama" || prefix === "lmstudio" || prefix === "vllm") { + const tail = raw.slice(prefix.length + 1).trim(); + const providerLabel = LOCAL_PROVIDER_SPECS.find((entry) => entry.provider === prefix)?.label ?? prefix; + return tail.length ? `${tail} (${providerLabel})` : raw; + } + return raw; +} + const AUTH_ERROR_SIGNALS = [ "invalid authentication credentials", "authentication error", @@ -190,11 +228,40 @@ function shouldRefreshProvidersForChatEvent(envelope: AgentChatEventEnvelope): b return false; } +function buildLocalProviderDrafts( + snapshot: ProjectConfigSnapshot | null | undefined, + status: (AiSettingsStatus & { runtimeConnections?: Record }) | null | undefined, +): Record { + const configured = snapshot?.effective.ai?.localProviders ?? {}; + return Object.fromEntries( + LOCAL_PROVIDER_SPECS.map((spec) => { + const runtimeConnection = status?.runtimeConnections?.[spec.provider]; + const providerConfig = configured[spec.provider]; + return [spec.provider, { + enabled: providerConfig?.enabled ?? true, + endpoint: + (typeof providerConfig?.endpoint === "string" && providerConfig.endpoint.trim().length + ? providerConfig.endpoint.trim() + : runtimeConnection?.endpoint?.trim()) + ?? getLocalProviderDefaultEndpoint(spec.provider), + autoDetect: providerConfig?.autoDetect ?? true, + preferredModelId: typeof providerConfig?.preferredModelId === "string" ? providerConfig.preferredModelId : "", + }]; + }), + ) as Record; +} + export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefreshOnMount?: boolean }) { - const [status, setStatus] = useState(null); + const [status, setStatus] = useState<(AiSettingsStatus & { runtimeConnections?: Record }) | null>(null); + const [projectConfigSnapshot, setProjectConfigSnapshot] = useState(null); const [storedProviders, setStoredProviders] = useState([]); const [loading, setLoading] = useState(true); const [editingProvider, setEditingProvider] = useState(null); + const [editingLocalProvider, setEditingLocalProvider] = useState(null); + const [savingLocalProvider, setSavingLocalProvider] = useState(null); + const [localProviderDrafts, setLocalProviderDrafts] = useState>(() => + buildLocalProviderDrafts(null, null), + ); const [editValue, setEditValue] = useState(""); const [notice, setNotice] = useState(null); const [error, setError] = useState(null); @@ -208,11 +275,14 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh } setError(null); try { - const [nextStatus, nextStoredProviders] = await Promise.all([ + const [nextStatus, nextStoredProviders, nextProjectConfig] = await Promise.all([ window.ade.ai.getStatus(options?.force ? { force: true } : undefined), window.ade.ai.listApiKeys(), + window.ade.projectConfig.get(), ]); setStatus(nextStatus); + setProjectConfigSnapshot(nextProjectConfig); + setLocalProviderDrafts(buildLocalProviderDrafts(nextProjectConfig, nextStatus)); setStoredProviders(nextStoredProviders.map((entry) => entry.trim().toLowerCase()).filter(Boolean)); } catch (err) { setError(err instanceof Error ? err.message : String(err)); @@ -245,7 +315,7 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh }; }, [refreshStatus]); - const detectedAuth = status?.detectedAuth ?? []; + const detectedAuth = useMemo(() => status?.detectedAuth ?? [], [status?.detectedAuth]); const providerConnections = status?.providerConnections; const isInitialCheckInFlight = loading && status == null; @@ -261,14 +331,30 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh return map; }, [detectedAuth]); - const localEndpoints = useMemo(() => { - const entries: Array<{ provider: string; endpoint: string }> = []; - for (const entry of detectedAuth) { - if (entry.type !== "local" || !entry.provider || !entry.endpoint) continue; - entries.push({ provider: entry.provider, endpoint: entry.endpoint }); - } - return entries; - }, [detectedAuth]); + const localRuntimes = useMemo(() => { + const availableModelIds = status?.availableModelIds ?? []; + const runtimeConnections = status?.runtimeConnections ?? {}; + return LOCAL_PROVIDER_SPECS.map((spec) => { + const runtimeConnection = runtimeConnections[spec.provider] ?? null; + const detected = detectedAuth.find( + (entry): entry is { type: "local"; provider: LocalProviderFamily; endpoint: string } => + entry.type === "local" && entry.provider === spec.provider, + ) ?? null; + const modelIds = runtimeConnection?.loadedModelIds?.length + ? runtimeConnection.loadedModelIds.filter((rawId) => String(rawId ?? "").trim().startsWith(`${spec.provider}/`)) + : availableModelIds.filter((rawId) => String(rawId ?? "").trim().startsWith(`${spec.provider}/`)); + return { + ...spec, + endpoint: runtimeConnection?.endpoint ?? detected?.endpoint ?? getLocalProviderDefaultEndpoint(spec.provider), + health: runtimeConnection?.health ?? null, + blocker: runtimeConnection?.blocker ?? null, + runtimeAvailable: runtimeConnection?.runtimeAvailable ?? false, + detected, + modelIds, + hasModels: modelIds.length > 0, + }; + }); + }, [detectedAuth, status?.availableModelIds, status?.runtimeConnections]); const apiKeyStoreWarning = useMemo(() => { if (status?.apiKeyStore?.legacyPlaintextDetected) { @@ -345,6 +431,57 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh } }; + const updateLocalProviderDraft = useCallback(( + provider: LocalProviderFamily, + patch: Partial, + ) => { + setLocalProviderDrafts((prev) => ({ + ...prev, + [provider]: { + ...prev[provider], + ...patch, + }, + })); + }, []); + + const beginEditingLocalRuntime = useCallback((provider: LocalProviderFamily) => { + setEditingLocalProvider(provider); + setError(null); + setNotice(null); + }, []); + + const cancelEditingLocalRuntime = useCallback(() => { + setEditingLocalProvider(null); + setLocalProviderDrafts(buildLocalProviderDrafts(projectConfigSnapshot, status)); + }, [projectConfigSnapshot, status]); + + const saveLocalProvider = useCallback(async (provider: LocalProviderFamily) => { + const draft = localProviderDrafts[provider]; + if (!draft) return; + setSavingLocalProvider(provider); + setError(null); + setNotice(null); + try { + await window.ade.ai.updateConfig({ + localProviders: { + [provider]: { + enabled: draft.enabled, + endpoint: draft.endpoint.trim(), + autoDetect: draft.autoDetect, + preferredModelId: draft.preferredModelId.trim() || null, + }, + } as AiConfig["localProviders"], + }); + setNotice(`${LOCAL_PROVIDER_LABELS[provider]} settings saved.`); + setEditingLocalProvider(null); + await refreshStatus({ force: true }); + } catch (err) { + setError(err instanceof Error ? err.message : String(err)); + } finally { + setSavingLocalProvider(null); + } + }, [localProviderDrafts, refreshStatus]); + return (
{notice && ( @@ -645,56 +782,303 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh
-
Local runtimes
- {localEndpoints.length > 0 ? ( -
- {localEndpoints.map((entry) => ( +
+
+
Local runtimes
+
+ LM Studio, Ollama, and vLLM become ready once at least one model is loaded and the server exposes its OpenAI-compatible `/v1/models` list. +
+
+ +
+ +
+ {localRuntimes.map((entry) => { + const isEditing = editingLocalProvider === entry.provider; + const isSaving = savingLocalProvider === entry.provider; + const draft = localProviderDrafts[entry.provider]; + const tone = entry.blocker + ? { color: COLORS.warning, label: "Blocked" } + : entry.runtimeAvailable || (entry.detected && entry.hasModels) + ? { color: COLORS.success, label: entry.hasModels ? "Ready" : "Connected" } + : { color: COLORS.warning, label: "Not detected" }; + const loadedModels = entry.modelIds.slice(0, 4); + const extraModelCount = Math.max(0, entry.modelIds.length - loadedModels.length); + const message = entry.blocker + ? entry.blocker + : entry.detected + ? entry.hasModels + ? `${entry.label} is reachable at ${entry.endpoint}. ADE can use ${entry.modelIds.length} loaded model${entry.modelIds.length === 1 ? "" : "s"} from this runtime${entry.health ? ` (${entry.health})` : ""}.` + : `${entry.label} responded, but no loaded models were reported yet. Load a model in ${entry.label} and refresh.` + : `${entry.label} was not detected. Start it, load at least one model, then refresh so ADE can discover its OpenAI-compatible server.`; + + return (
-
-
- {entry.provider} +
+
+ +
+
+ {entry.label} +
+
+ {entry.description} +
+
+
+
+ {entry.detected ? : } + + {tone.label} +
- - {entry.endpoint} -
-
- - - Reachable + +
+ {message} +
+ +
+ + {draft?.enabled === false ? "Disabled" : "Enabled"} + + + {draft?.autoDetect === false ? "Manual only" : "Auto-detect fallback"}
+ + + {draft?.endpoint?.trim() || entry.endpoint} + + +
+ {loadedModels.length > 0 ? ( + <> + {loadedModels.map((modelId) => ( + + + {formatLocalModelLabel(modelId)} + + ))} + {extraModelCount > 0 ? ( + + +{extraModelCount} more + + ) : null} + + ) : ( + + No loaded models reported yet. + + )} +
+ + {isEditing && draft ? ( +
+ + + + +
+ ) : null} + +
+ {isEditing ? ( + <> + + + + ) : ( + <> + + + + )} +
- ))} -
- ) : ( -
- - No local model endpoints detected (Ollama, LM Studio, vLLM). -
- )} + ); + })} +
+
+ + If LM Studio is running but ADE does not show it, load at least one model in LM Studio, then use Refresh. ADE only marks a local runtime as ready after `/v1/models` returns loaded models. +
); diff --git a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx index 40a943aa7..8b289ae13 100644 --- a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx +++ b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx @@ -16,6 +16,7 @@ import { createModelOrderMap, matchesQuery, PROVIDER_BADGE_COLORS, + providerLabel, subsectionKeyForModel, sourceSectionLabel, type ModelProviderBlock, @@ -42,6 +43,11 @@ type UnifiedModelSelectorProps = { }; const SOURCE_KEYS: SourceSectionKey[] = ["subscription", "api", "local"]; +const LOCAL_PROVIDER_LABELS: Record = { + ollama: "Ollama", + lmstudio: "LM Studio", + vllm: "vLLM", +}; const selectCls = cn( "h-8 rounded-lg border border-white/[0.08] bg-white/[0.04] px-2 font-sans text-[11px] text-fg/70", @@ -61,6 +67,21 @@ function providerAccent(family: string, fallback?: string): string { return PROVIDER_BADGE_COLORS[family] ?? fallback ?? "#A78BFA"; } +function getLocalProviderFromModelId(modelId: string): "ollama" | "lmstudio" | "vllm" | null { + const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase(); + if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { + return provider; + } + return null; +} + +function getLocalModelShortLabel(modelId: string): string { + const provider = getLocalProviderFromModelId(modelId); + if (!provider) return modelId; + const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); + return tail.length ? tail : modelId; +} + function subsectionTabTitle(sub: ModelSubsection): string { return sub.label.trim() || "Models"; } @@ -69,7 +90,7 @@ function modelAvailabilityLabel(model: ModelDescriptor, isAvailable: boolean): s if (isAvailable) { if (model.family === "cursor" && model.isCliWrapped) return "Cursor CLI ready"; if (model.isCliWrapped) return "Subscription ready"; - if (model.authTypes.includes("local")) return "Local ready"; + if (model.authTypes.includes("local")) return `${providerLabel(model.family)} ready`; if (model.authTypes.includes("api-key")) return "API ready"; if (model.authTypes.includes("oauth")) return "OAuth ready"; if (model.authTypes.includes("openrouter")) return "OpenRouter ready"; @@ -79,7 +100,7 @@ function modelAvailabilityLabel(model: ModelDescriptor, isAvailable: boolean): s return "Cursor CLI · run `agent login` or set CURSOR_API_KEY / CURSOR_AUTH_TOKEN"; } if (model.isCliWrapped) return "Subscription · not configured"; - if (model.authTypes.includes("local")) return "Local · not configured"; + if (model.authTypes.includes("local")) return `${providerLabel(model.family)} · not configured`; if (model.authTypes.includes("api-key")) return "API · not configured"; if (model.authTypes.includes("oauth")) return "OAuth · not configured"; if (model.authTypes.includes("openrouter")) return "OpenRouter · not configured"; @@ -112,6 +133,28 @@ function createUnknownModelPlaceholder(modelId: string): ModelDescriptor { isCliWrapped: true, }; } + const localProvider = getLocalProviderFromModelId(modelId); + if (localProvider) { + const providerLabel = LOCAL_PROVIDER_LABELS[localProvider]; + const shortId = getLocalModelShortLabel(modelId); + return { + id: modelId, + shortId, + displayName: shortId, + family: localProvider, + authTypes: ["local"], + contextWindow: 0, + maxOutputTokens: 0, + capabilities: { tools: false, vision: false, reasoning: false, streaming: true }, + color: PROVIDER_BADGE_COLORS[localProvider] ?? "#64748B", + sdkProvider: "@ai-sdk/openai-compatible", + sdkModelId: modelId, + isCliWrapped: false, + discoverySource: localProvider === "lmstudio" ? "lmstudio-openai" : localProvider, + harnessProfile: "guarded", + aliases: [providerLabel], + }; + } return { id: modelId, shortId: modelId, diff --git a/apps/desktop/src/renderer/lib/modelOptions.test.ts b/apps/desktop/src/renderer/lib/modelOptions.test.ts index 6662b3128..1dd915efa 100644 --- a/apps/desktop/src/renderer/lib/modelOptions.test.ts +++ b/apps/desktop/src/renderer/lib/modelOptions.test.ts @@ -12,7 +12,7 @@ import { // Helpers // --------------------------------------------------------------------------- -function makeStatus(overrides: Partial = {}): AiSettingsStatus { +function makeStatus(overrides: Partial & Record = {}): AiSettingsStatus { return { mode: "guest", availableProviders: { claude: false, codex: false, cursor: false }, @@ -206,13 +206,48 @@ describe("deriveConfiguredModelIds", () => { it("includes local models for local auth", () => { const status = makeStatus({ - detectedAuth: [{ type: "local", provider: "ollama" }], + detectedAuth: [{ type: "local", provider: "lmstudio" }], + availableModelIds: ["lmstudio/meta-llama-3.1-70b-instruct", "lmstudio/qwen2.5-coder:32b", "ollama/llama3.2"], }); const ids = deriveConfiguredModelIds(status); - const ollamaModels = MODEL_REGISTRY.filter( - (m) => m.family === "ollama" && !m.deprecated && !m.isCliWrapped, - ); - expect(ids.length).toBe(ollamaModels.length); + expect(ids).toContain("lmstudio/meta-llama-3.1-70b-instruct"); + expect(ids).toContain("lmstudio/qwen2.5-coder:32b"); + expect(ids).not.toContain("ollama/llama3.2"); + }); + + it("includes local models from runtimeConnections when availableModelIds is empty", () => { + const status = makeStatus({ + runtimeConnections: { + lmstudio: { + provider: "lmstudio", + label: "LM Studio", + kind: "local", + configured: true, + authAvailable: true, + runtimeDetected: true, + runtimeAvailable: true, + health: "ready", + endpoint: "http://localhost:1234", + blocker: null, + loadedModelIds: ["lmstudio/meta-llama-3.1-70b-instruct"], + lastCheckedAt: "2026-03-17T19:00:00.000Z", + }, + }, + }); + const ids = deriveConfiguredModelIds(status); + expect(ids).toContain("lmstudio/meta-llama-3.1-70b-instruct"); + }); + + it("prefers discovered local model ids over static local placeholders", () => { + const status = makeStatus({ + detectedAuth: [{ type: "local", provider: "lmstudio" }], + availableModelIds: ["lmstudio/meta-llama-3.1-70b-instruct"], + }); + + const ids = deriveConfiguredModelIds(status); + + expect(ids).toContain("lmstudio/meta-llama-3.1-70b-instruct"); + expect(ids).not.toContain("lmstudio/auto"); }); it("merges models from multiple auth sources without duplicates", () => { @@ -335,6 +370,18 @@ describe("deriveConfiguredModelOptions", () => { expect(sonnet).toBeTruthy(); expect(sonnet!.label).toBe("Claude Sonnet 4.6"); }); + + it("preserves dynamic local model ids as options", () => { + const status = makeStatus({ + detectedAuth: [{ type: "local", provider: "lmstudio" }], + availableModelIds: ["lmstudio/local-test-model-123"], + }); + const options = deriveConfiguredModelOptions(status); + const local = options.find((o) => o.id === "lmstudio/local-test-model-123"); + expect(local).toBeTruthy(); + expect(local!.label).toBe("local-test-model-123 (LM Studio)"); + expect(local!.description).toBe("lmstudio (local)"); + }); }); // --------------------------------------------------------------------------- diff --git a/apps/desktop/src/renderer/lib/modelOptions.ts b/apps/desktop/src/renderer/lib/modelOptions.ts index 49b3d4fa5..d52e08681 100644 --- a/apps/desktop/src/renderer/lib/modelOptions.ts +++ b/apps/desktop/src/renderer/lib/modelOptions.ts @@ -1,10 +1,51 @@ -import type { AiModelDescriptor, AiSettingsStatus, ModelId } from "../../shared/types"; -import { MODEL_REGISTRY, getModelById, type ModelDescriptor } from "../../shared/modelRegistry"; +import type { AiModelDescriptor, AiRuntimeConnectionStatus, AiSettingsStatus, ModelId } from "../../shared/types"; +import { MODEL_REGISTRY, getModelById, type LocalProviderFamily, type ModelDescriptor } from "../../shared/modelRegistry"; function normalizeAuthProvider(provider: string | undefined): string { return String(provider ?? "").trim().toLowerCase(); } +const LOCAL_PROVIDER_LABELS: Record = { + ollama: "Ollama", + lmstudio: "LM Studio", + vllm: "vLLM", +}; + +function getLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { + const provider = String(modelId ?? "") + .trim() + .split("/", 1)[0] + ?.toLowerCase(); + if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { + return provider; + } + return null; +} + +function getLocalModelLabel(modelId: string): string { + const provider = getLocalProviderFromModelId(modelId); + if (!provider) return modelId; + const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); + return tail.length ? tail : modelId; +} + +function buildFallbackModelOption(modelId: string): AiModelDescriptor { + const provider = getLocalProviderFromModelId(modelId); + if (provider) { + const providerLabel = LOCAL_PROVIDER_LABELS[provider]; + return { + id: modelId, + label: getLocalModelLabel(modelId), + description: `${providerLabel} local model`, + }; + } + return { + id: modelId, + label: modelId, + description: "Unknown model", + }; +} + export function describeModelSource(descriptor: ModelDescriptor): string { if (descriptor.authTypes.includes("local")) { return "local"; @@ -41,6 +82,39 @@ function addKnownModelIds(ids: Set, family: string, includeCliWrapped: } } +function addAvailableModelIdsByPrefix( + ids: Set, + availableModelIds: readonly ModelId[] | undefined, + prefix: string, +) { + if (!availableModelIds?.length) return; + const normalizedPrefix = prefix.trim(); + if (!normalizedPrefix.length) return; + for (const rawId of availableModelIds) { + const id = String(rawId ?? "").trim(); + if (id.startsWith(normalizedPrefix)) { + ids.add(id as ModelId); + } + } +} + +function hasDynamicLocalModelIdsForProvider( + provider: string, + availableModelIds: readonly ModelId[] | undefined, + runtimeConnections: Record | undefined, +): boolean { + const normalizedProvider = normalizeAuthProvider(provider); + if (normalizedProvider !== "ollama" && normalizedProvider !== "lmstudio" && normalizedProvider !== "vllm") { + return false; + } + const prefix = `${normalizedProvider}/`; + const loadedModelIds = runtimeConnections?.[normalizedProvider]?.loadedModelIds; + return Boolean( + availableModelIds?.some((rawId) => String(rawId ?? "").trim().startsWith(prefix)) + || loadedModelIds?.some((rawId) => String(rawId ?? "").trim().startsWith(prefix)), + ); +} + export interface DeriveModelOptions { /** Include cursor/* models in the result. Defaults to `false`. */ includeCursor?: boolean; @@ -53,10 +127,11 @@ export function deriveConfiguredModelIds( if (!status) return []; const { includeCursor = false } = options ?? {}; + const runtimeConnections = (status as { runtimeConnections?: Record } | null | undefined)?.runtimeConnections; // Derive available models from detectedAuth. For Cursor CLI, merge in // `status.availableModelIds` entries under `cursor/*` (main lists them after - // `agent models`); other providers still use registry + auth only. + // `agent models`); local runtimes also merge in discovered loaded models. const ids = new Set(); for (const auth of status.detectedAuth ?? []) { @@ -81,8 +156,26 @@ export function deriveConfiguredModelIds( if (auth.type === "local") { const provider = normalizeAuthProvider(auth.provider); - if (provider.length) addKnownModelIds(ids, provider, false); + if (provider.length) { + if (!hasDynamicLocalModelIdsForProvider(provider, status.availableModelIds, runtimeConnections)) { + addKnownModelIds(ids, provider, false); + } + addAvailableModelIdsByPrefix(ids, status.availableModelIds, `${provider}/`); + addAvailableModelIdsByPrefix(ids, runtimeConnections?.[provider]?.loadedModelIds, `${provider}/`); + } + } + } + + for (const [provider, connection] of Object.entries(runtimeConnections ?? {})) { + const normalizedProvider = normalizeAuthProvider(provider); + if (normalizedProvider !== "ollama" && normalizedProvider !== "lmstudio" && normalizedProvider !== "vllm") { + continue; + } + if (!hasDynamicLocalModelIdsForProvider(normalizedProvider, status.availableModelIds, runtimeConnections)) { + addKnownModelIds(ids, normalizedProvider, false); } + addAvailableModelIdsByPrefix(ids, connection.loadedModelIds, `${normalizedProvider}/`); + addAvailableModelIdsByPrefix(ids, status.availableModelIds, `${normalizedProvider}/`); } if (includeCursor) { @@ -115,7 +208,11 @@ export function deriveConfiguredModelOptions( ): AiModelDescriptor[] { return deriveConfiguredModelIds(status, options).flatMap((modelId) => { const descriptor = getModelById(modelId); - return descriptor ? [descriptorToModelOption(descriptor)] : []; + return descriptor + ? [descriptorToModelOption(descriptor)] + : getLocalProviderFromModelId(modelId) + ? [buildFallbackModelOption(modelId)] + : []; }); } @@ -126,6 +223,7 @@ export function includeSelectedModelOption( const modelId = String(selectedModelId ?? "").trim(); if (!modelId.length || options.some((option) => option.id === modelId)) return options; const descriptor = getModelById(modelId); - if (!descriptor) return options; - return [descriptorToModelOption(descriptor), ...options]; + if (descriptor) return [descriptorToModelOption(descriptor), ...options]; + if (getLocalProviderFromModelId(modelId)) return [buildFallbackModelOption(modelId), ...options]; + return options; } diff --git a/apps/desktop/src/shared/modelRegistry.ts b/apps/desktop/src/shared/modelRegistry.ts index 18f5666c7..b7c407cdd 100644 --- a/apps/desktop/src/shared/modelRegistry.ts +++ b/apps/desktop/src/shared/modelRegistry.ts @@ -26,6 +26,8 @@ export type ModelCapabilities = { streaming: boolean; }; +export type LocalModelHarnessProfile = "verified" | "guarded" | "read_only"; + export type ModelDescriptor = { id: string; shortId: string; @@ -49,6 +51,21 @@ export type ModelDescriptor = { outputPricePer1M?: number; /** Curated cost tier for UI display (missions model selector) */ costTier?: "low" | "medium" | "high" | "very_high"; + /** ADE-owned safety/tooling profile for local and experimental models. */ + harnessProfile?: LocalModelHarnessProfile; + /** Source of runtime-discovered descriptors for debugging and UI hints. */ + discoverySource?: "lmstudio-rest" | "lmstudio-openai" | "ollama" | "vllm"; +}; + +export type DynamicLocalModelDescriptorOptions = { + displayName?: string; + contextWindow?: number; + maxOutputTokens?: number; + capabilities?: Partial; + reasoningTiers?: string[]; + aliases?: string[]; + harnessProfile?: LocalModelHarnessProfile; + discoverySource?: ModelDescriptor["discoverySource"]; }; export type WorkerExecutionPath = "cli" | "api" | "local"; @@ -638,6 +655,7 @@ export const MODEL_REGISTRY: ModelDescriptor[] = [ color: "#71717A", sdkProvider: "@ai-sdk/openai-compatible", sdkModelId: "auto", + harnessProfile: "guarded", isCliWrapped: false, }, { @@ -652,6 +670,7 @@ export const MODEL_REGISTRY: ModelDescriptor[] = [ color: "#64748B", sdkProvider: "@ai-sdk/openai-compatible", sdkModelId: "auto", + harnessProfile: "guarded", isCliWrapped: false, }, { @@ -666,6 +685,7 @@ export const MODEL_REGISTRY: ModelDescriptor[] = [ color: "#475569", sdkProvider: "@ai-sdk/openai-compatible", sdkModelId: "auto", + harnessProfile: "guarded", isCliWrapped: false, }, ]; @@ -678,6 +698,8 @@ let byId = new Map(); let byShortId = new Map(); let byAlias = new Map(); let bySdkModelId = new Map(); +let dynamicLocalById = new Map(); +let dynamicLocalByAlias = new Map(); function rebuildIndexes() { byId = new Map(); @@ -754,25 +776,63 @@ function toDynamicLocalDisplayName(provider: LocalProviderFamily, modelId: strin export function createDynamicLocalModelDescriptor( provider: LocalProviderFamily, modelId: string, + options?: DynamicLocalModelDescriptorOptions, ): ModelDescriptor { const normalizedModelId = modelId.trim(); + const displayName = options?.displayName?.trim() || toDynamicLocalDisplayName(provider, normalizedModelId); + const capabilities: ModelCapabilities = { + ...BASIC_CAPS, + ...(options?.capabilities ?? {}), + }; + const aliases = [ + `${provider}:${normalizedModelId}`, + ...(options?.aliases ?? []), + ].filter((value, index, list) => { + const normalized = value.trim(); + return normalized.length > 0 && list.findIndex((entry) => entry.trim().toLowerCase() === normalized.toLowerCase()) === index; + }); return { id: `${provider}/${normalizedModelId}`, shortId: normalizedModelId, - displayName: toDynamicLocalDisplayName(provider, normalizedModelId), + displayName, family: provider, authTypes: ["local"], - contextWindow: 128_000, - maxOutputTokens: 8_192, - capabilities: { ...BASIC_CAPS }, + contextWindow: options?.contextWindow ?? 128_000, + maxOutputTokens: options?.maxOutputTokens ?? 8_192, + capabilities, color: LOCAL_PROVIDER_COLORS[provider], sdkProvider: "@ai-sdk/openai-compatible", sdkModelId: normalizedModelId, - aliases: [`${provider}:${normalizedModelId}`], + ...(options?.reasoningTiers?.length ? { reasoningTiers: [...options.reasoningTiers] } : {}), + aliases, isCliWrapped: false, + harnessProfile: options?.harnessProfile ?? "guarded", + ...(options?.discoverySource ? { discoverySource: options.discoverySource } : {}), }; } +function isDynamicLocalDescriptor(descriptor: ModelDescriptor): boolean { + return descriptor.authTypes.includes("local") && !byId.has(descriptor.id); +} + +export function replaceDynamicLocalModelDescriptors(descriptors: ModelDescriptor[]): void { + dynamicLocalById = new Map(); + dynamicLocalByAlias = new Map(); + + for (const descriptor of descriptors) { + if (!isDynamicLocalDescriptor(descriptor)) continue; + dynamicLocalById.set(descriptor.id, descriptor); + for (const alias of descriptor.aliases ?? []) { + const normalized = alias.trim().toLowerCase(); + if (normalized.length) dynamicLocalByAlias.set(normalized, descriptor); + } + } +} + +export function getDynamicLocalModelDescriptors(): ModelDescriptor[] { + return [...dynamicLocalById.values()]; +} + export function getLocalProviderDefaultEndpoint(provider: LocalProviderFamily): string { return LOCAL_PROVIDER_ENDPOINTS[provider]; } @@ -896,6 +956,8 @@ export function sortCursorCliDescriptorsForPicker(descriptors: ModelDescriptor[] export function getModelById(id: string): ModelDescriptor | undefined { const cached = byId.get(id); if (cached) return cached; + const dynamic = dynamicLocalById.get(id); + if (dynamic) return dynamic; const local = parseDynamicLocalModelRef(id); if (local) return createDynamicLocalModelDescriptor(local.provider, local.modelId); const cursor = parseDynamicCursorModelRef(id); @@ -947,12 +1009,20 @@ export function getAvailableModels( return false; }); - return MODEL_REGISTRY.filter((model) => !model.deprecated && hasAuthForModel(model)); + const staticModels = MODEL_REGISTRY.filter((model) => !model.deprecated && hasAuthForModel(model)); + const dynamicLocalModels = getDynamicLocalModelDescriptors().filter((model) => hasAuthForModel(model)); + if (!dynamicLocalModels.length) return staticModels; + + const providersWithDynamicLocals = new Set(dynamicLocalModels.map((model) => model.family)); + const filteredStatic = staticModels.filter( + (model) => !(model.authTypes.includes("local") && providersWithDynamicLocals.has(model.family)), + ); + return [...filteredStatic, ...dynamicLocalModels]; } export function resolveModelAlias(alias: string): ModelDescriptor | undefined { const normalized = alias.trim().toLowerCase(); - return byId.get(normalized) ?? byShortId.get(normalized) ?? byAlias.get(normalized) ?? undefined; + return byId.get(normalized) ?? byShortId.get(normalized) ?? byAlias.get(normalized) ?? dynamicLocalByAlias.get(normalized) ?? undefined; } export function resolveModelDescriptor(modelRef: string): ModelDescriptor | undefined { diff --git a/apps/desktop/src/shared/types/config.ts b/apps/desktop/src/shared/types/config.ts index c7e9667b2..3f00e461a 100644 --- a/apps/desktop/src/shared/types/config.ts +++ b/apps/desktop/src/shared/types/config.ts @@ -9,6 +9,7 @@ import type { MissionExecutionPolicy, MissionPermissionConfig, MissionProviderPe import type { ExternalMcpMissionSelection } from "./externalMcp"; import type { MissionModelConfig, ModelConfig } from "./models"; import type { LinearSyncConfig } from "./linearSync"; +import type { LocalProviderFamily } from "../modelRegistry"; // Backward compatible with earlier configs that used `on_crash`. export type ProcessRestartPolicy = "never" | "on-failure" | "always" | "on_crash"; @@ -827,8 +828,10 @@ export type AiDetectedAuth = { cli?: "claude" | "codex" | "cursor"; provider?: string; source?: "config" | "env" | "store"; + endpointSource?: "auto" | "config"; path?: string; endpoint?: string; + preferredModelId?: ModelId | null; authenticated?: boolean; verified?: boolean; }; @@ -873,6 +876,43 @@ export type AiApiKeyVerificationResult = { verifiedAt: string; }; +export type AiLocalProviderConfig = { + enabled?: boolean; + endpoint?: string; + autoDetect?: boolean; + preferredModelId?: ModelId | null; +}; + +export type AiLocalProviderConfigs = Partial>; + +export type AiRuntimeConnectionHealth = + | "ready" + | "reachable" + | "reachable_no_models" + | "not_configured" + | "unreachable"; + +export type AiRuntimeConnectionKind = "cli" | "api-key" | "openrouter" | "local"; + +export type AiRuntimeConnectionStatus = { + provider: string; + label: string; + kind: AiRuntimeConnectionKind; + configured: boolean; + authAvailable: boolean; + runtimeDetected: boolean; + runtimeAvailable: boolean; + health: AiRuntimeConnectionHealth; + source?: "config" | "env" | "store" | "auto"; + path?: string | null; + endpoint?: string | null; + blocker: string | null; + loadedModelIds?: ModelId[]; + lastCheckedAt: string; +}; + +export type AiRuntimeConnections = Record; + export type AiSettingsStatus = { mode: "guest" | "subscription"; availableProviders: { @@ -888,6 +928,7 @@ export type AiSettingsStatus = { features: AiFeatureUsageRow[]; detectedAuth?: AiDetectedAuth[]; providerConnections?: AiProviderConnections; + runtimeConnections?: AiRuntimeConnections; availableModelIds?: ModelId[]; apiKeyStore?: { secureStorageAvailable: boolean; @@ -1034,6 +1075,7 @@ export type AiConfig = { // New unified fields defaultModel?: ModelId; apiKeys?: Record; + localProviders?: AiLocalProviderConfigs; workerSafety?: WorkerSafetyPolicy; mcpServers?: Record; /** Per-feature model overrides, e.g. { mission_planning: "claude-sonnet-4-6" } */ @@ -1059,6 +1101,7 @@ export type AiIntegrationStatus = { // New unified fields detectedAuth?: AiDetectedAuth[]; providerConnections?: AiProviderConnections; + runtimeConnections?: AiRuntimeConnections; availableModelIds?: ModelId[]; }; diff --git a/apps/desktop/src/test/setup.ts b/apps/desktop/src/test/setup.ts index beb1957e7..06654363c 100644 --- a/apps/desktop/src/test/setup.ts +++ b/apps/desktop/src/test/setup.ts @@ -1,6 +1,78 @@ import fs from "node:fs"; import os from "node:os"; import path from "node:path"; +import { afterAll } from "vitest"; + +type TestTempTrackerState = { + installed: boolean; + trackedDirs: Set; + originalMkdtempSync?: typeof fs.mkdtempSync; + originalPromisesMkdtemp?: typeof fs.promises.mkdtemp; +}; + +const TEST_TEMP_TRACKER_KEY = Symbol.for("ade.desktop.testTempTracker"); +const testTempRoot = path.resolve(os.tmpdir()); + +function getTestTempTrackerState(): TestTempTrackerState { + const existing = (globalThis as Record)[TEST_TEMP_TRACKER_KEY]; + if (existing) return existing as TestTempTrackerState; + const created: TestTempTrackerState = { + installed: false, + trackedDirs: new Set(), + }; + (globalThis as Record)[TEST_TEMP_TRACKER_KEY] = created; + return created; +} + +function shouldTrackTempDir(dirPath: string): boolean { + const resolved = path.resolve(dirPath); + const baseName = path.basename(resolved); + return resolved.startsWith(`${testTempRoot}${path.sep}`) && baseName.startsWith("ade-"); +} + +function cleanupTrackedTempDirs(): void { + const state = getTestTempTrackerState(); + const targets = [...state.trackedDirs].sort((left, right) => right.length - left.length); + state.trackedDirs.clear(); + for (const target of targets) { + try { + fs.rmSync(target, { recursive: true, force: true }); + } catch { + // Best-effort cleanup only for test temp roots. + } + } +} + +function installTrackedTempCleanup(): void { + const state = getTestTempTrackerState(); + if (state.installed) return; + state.installed = true; + state.originalMkdtempSync = fs.mkdtempSync.bind(fs); + state.originalPromisesMkdtemp = fs.promises.mkdtemp.bind(fs.promises); + + fs.mkdtempSync = ((prefix: string, options?: Parameters[1]) => { + const created = state.originalMkdtempSync!(prefix, options); + if (shouldTrackTempDir(created)) { + state.trackedDirs.add(path.resolve(created)); + } + return created; + }) as typeof fs.mkdtempSync; + + fs.promises.mkdtemp = (async (prefix: string, options?: Parameters[1]) => { + const created = await state.originalPromisesMkdtemp!(prefix, options); + if (shouldTrackTempDir(created)) { + state.trackedDirs.add(path.resolve(created)); + } + return created; + }) as typeof fs.promises.mkdtemp; + + process.once("exit", cleanupTrackedTempDirs); +} + +installTrackedTempCleanup(); +afterAll(() => { + cleanupTrackedTempDirs(); +}); const claudeConfigDir = path.join(os.tmpdir(), "ade-vitest-claude-config"); diff --git a/apps/mcp-server/src/test/setup.ts b/apps/mcp-server/src/test/setup.ts new file mode 100644 index 000000000..b27c9d28f --- /dev/null +++ b/apps/mcp-server/src/test/setup.ts @@ -0,0 +1,75 @@ +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterAll } from "vitest"; + +type TestTempTrackerState = { + installed: boolean; + trackedDirs: Set; + originalMkdtempSync?: typeof fs.mkdtempSync; + originalPromisesMkdtemp?: typeof fs.promises.mkdtemp; +}; + +const TEST_TEMP_TRACKER_KEY = Symbol.for("ade.mcp.testTempTracker"); +const testTempRoot = path.resolve(os.tmpdir()); + +function getTestTempTrackerState(): TestTempTrackerState { + const existing = (globalThis as Record)[TEST_TEMP_TRACKER_KEY]; + if (existing) return existing as TestTempTrackerState; + const created: TestTempTrackerState = { + installed: false, + trackedDirs: new Set(), + }; + (globalThis as Record)[TEST_TEMP_TRACKER_KEY] = created; + return created; +} + +function shouldTrackTempDir(dirPath: string): boolean { + const resolved = path.resolve(dirPath); + const baseName = path.basename(resolved); + return resolved.startsWith(`${testTempRoot}${path.sep}`) && baseName.startsWith("ade-"); +} + +function cleanupTrackedTempDirs(): void { + const state = getTestTempTrackerState(); + const targets = [...state.trackedDirs].sort((left, right) => right.length - left.length); + state.trackedDirs.clear(); + for (const target of targets) { + try { + fs.rmSync(target, { recursive: true, force: true }); + } catch { + // Best-effort cleanup only for test temp roots. + } + } +} + +function installTrackedTempCleanup(): void { + const state = getTestTempTrackerState(); + if (state.installed) return; + state.installed = true; + state.originalMkdtempSync = fs.mkdtempSync.bind(fs); + state.originalPromisesMkdtemp = fs.promises.mkdtemp.bind(fs.promises); + + fs.mkdtempSync = ((prefix: string, options?: Parameters[1]) => { + const created = state.originalMkdtempSync!(prefix, options); + if (shouldTrackTempDir(created)) { + state.trackedDirs.add(path.resolve(created)); + } + return created; + }) as typeof fs.mkdtempSync; + + fs.promises.mkdtemp = (async (prefix: string, options?: Parameters[1]) => { + const created = await state.originalPromisesMkdtemp!(prefix, options); + if (shouldTrackTempDir(created)) { + state.trackedDirs.add(path.resolve(created)); + } + return created; + }) as typeof fs.promises.mkdtemp; + + process.once("exit", cleanupTrackedTempDirs); +} + +installTrackedTempCleanup(); +afterAll(() => { + cleanupTrackedTempDirs(); +}); diff --git a/apps/mcp-server/vitest.config.ts b/apps/mcp-server/vitest.config.ts index 7ad107513..8242fa108 100644 --- a/apps/mcp-server/vitest.config.ts +++ b/apps/mcp-server/vitest.config.ts @@ -4,6 +4,7 @@ export default defineConfig({ test: { environment: "node", include: ["src/**/*.test.ts"], + setupFiles: ["src/test/setup.ts"], coverage: { provider: "v8", reporter: ["text", "lcov"], From e01c26f1f9c48916a4d87db2178e3b6eae62ef74 Mon Sep 17 00:00:00 2001 From: Arul Sharma <31745423+arul28@users.noreply.github.com> Date: Sat, 4 Apr 2026 22:35:10 -0400 Subject: [PATCH 2/3] Local model chat: shared helpers, banner merge, permission reset (#132) * Polish local model chat UI and share provider helpers Export LOCAL_PROVIDER_LABELS and local model id parsers from modelRegistry; dedupe labels in agentChatService and aiIntegrationService. Merge CLI and local runtime banners in AgentChatPane, reset unified permission mode when harness recommendations change, and fix local placeholder sdkModelId. Settings and model options use the shared helpers. Co-authored-by: Arul Sharma * Address PR review: local permission descriptors and unified mode setter Add getModelDescriptorForPermissionMode for harness decisions when getModelById is undefined (e.g. ollama/auto). Use it in AgentChatPane for prevModelRef, model switch snapshots, and session create. Consolidate applyModelSelectionSnapshot to a single setUnifiedPermissionMode call. Co-authored-by: Arul Sharma --------- Co-authored-by: Cursor Agent Co-authored-by: Arul Sharma --- .../main/services/ai/aiIntegrationService.ts | 7 +- .../main/services/chat/agentChatService.ts | 7 +- .../components/chat/AgentChatPane.tsx | 225 +++++++++++++----- .../components/settings/ProvidersSection.tsx | 26 +- .../shared/UnifiedModelSelector.tsx | 33 +-- apps/desktop/src/renderer/lib/modelOptions.ts | 45 ++-- apps/desktop/src/shared/modelRegistry.test.ts | 14 ++ apps/desktop/src/shared/modelRegistry.ts | 33 ++- 8 files changed, 252 insertions(+), 138 deletions(-) diff --git a/apps/desktop/src/main/services/ai/aiIntegrationService.ts b/apps/desktop/src/main/services/ai/aiIntegrationService.ts index f1efbfb8e..7836c2188 100644 --- a/apps/desktop/src/main/services/ai/aiIntegrationService.ts +++ b/apps/desktop/src/main/services/ai/aiIntegrationService.ts @@ -17,6 +17,7 @@ import { getAvailableModels, getLocalProviderDefaultEndpoint, listModelDescriptorsForProvider, + LOCAL_PROVIDER_LABELS, replaceDynamicLocalModelDescriptors, resolveModelAlias, enrichModelRegistry, @@ -380,12 +381,6 @@ function redactDetectedAuth( return redacted; } -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; - function apiProviderLabel(provider: string): string { const labels: Record = { anthropic: "Anthropic", diff --git a/apps/desktop/src/main/services/chat/agentChatService.ts b/apps/desktop/src/main/services/chat/agentChatService.ts index a37fd2baa..d4e407ffc 100644 --- a/apps/desktop/src/main/services/chat/agentChatService.ts +++ b/apps/desktop/src/main/services/chat/agentChatService.ts @@ -121,6 +121,7 @@ import { getModelById, getAvailableModels as getRegistryModels, listModelDescriptorsForProvider, + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, pickDefaultCursorDescriptorFromCliList, replaceDynamicLocalModelDescriptors, @@ -4294,12 +4295,6 @@ export function createAgentChatService(args: { // Unified session support — for API-key / local models using streamText + universal tools. // CLI-wrapped models fall through to the existing Claude/Codex runtimes. - const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", - }; - const discoveredLocalModelToDescriptor = (model: DiscoveredLocalModel): ModelDescriptor => createDynamicLocalModelDescriptor(model.provider, model.modelId, { ...(model.displayName ? { displayName: model.displayName } : {}), diff --git a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx index 1ff5bb6ad..86eb1d5cc 100644 --- a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx +++ b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx @@ -30,9 +30,13 @@ import { } from "../../../shared/types"; import { parseAgentChatTranscript } from "../../../shared/chatTranscript"; import { + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, + getLocalModelIdTail, getLocalProviderDefaultEndpoint, getModelById, + getModelDescriptorForPermissionMode, + parseLocalProviderFromModelId, resolveModelDescriptorForProvider, type LocalProviderFamily, type ModelDescriptor, @@ -67,30 +71,17 @@ const LEGACY_PROVIDER_KEY = "ade.chat.lastProvider"; const LEGACY_MODEL_KEY_PREFIX = "ade.chat.lastModel"; const COMPUTER_USE_SNAPSHOT_COOLDOWN_MS = 750; -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; type AiStatusSnapshot = AiSettingsStatus & { runtimeConnections?: Record; }; -function getLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { - const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase(); - if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { - return provider; - } - return null; -} - function formatLocalModelLabel(modelId: string): string { - const provider = getLocalProviderFromModelId(modelId); + const provider = parseLocalProviderFromModelId(modelId); if (!provider) { return getModelById(modelId)?.displayName ?? modelId; } - const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); + const tail = getLocalModelIdTail(modelId, provider); return tail.length ? tail : modelId; } @@ -103,6 +94,60 @@ function recommendedUnifiedPermissionModeForModel( : null; } +function shouldResetUnifiedPermissionForModelSwitch( + previous: ModelDescriptor | null | undefined, + next: ModelDescriptor | null | undefined, +): boolean { + const prevRec = recommendedUnifiedPermissionModeForModel(previous); + const nextRec = recommendedUnifiedPermissionModeForModel(next); + if (prevRec == null && nextRec == null) return false; + return prevRec !== nextRec; +} + +type LocalRuntimeNoticeShape = { + tone: "success" | "warning"; + title: string; + message: string; +}; + +function LocalRuntimeNoticeBlock(props: { + notice: LocalRuntimeNoticeShape; + endpoint?: string | null; + /** `inline` = text only (inside a parent runtime card). */ + variant?: "card" | "inline"; +}) { + const { notice, endpoint, variant = "card" } = props; + const isCard = variant === "card"; + return ( +
+
+ {notice.title} +
+
+ {notice.message} +
+ {endpoint ? ( + + {endpoint} + + ) : null} +
+ ); +} + export function resolveChatSessionProfile(_computerUsePolicy: ComputerUsePolicy): AgentChatSessionProfile { return "workflow"; } @@ -629,6 +674,7 @@ export function AgentChatPane({ const [codexSandbox, setCodexSandbox] = useState(initialNativeControls.codexSandbox); const [codexConfigSource, setCodexConfigSource] = useState(initialNativeControls.codexConfigSource); const [unifiedPermissionMode, setUnifiedPermissionMode] = useState(initialNativeControls.unifiedPermissionMode); + const prevModelDescRef = useRef(undefined); const [cursorModeId, setCursorModeId] = useState(initialNativeControls.cursorModeId); const [cursorConfigValues, setCursorConfigValues] = useState>(initialNativeControls.cursorConfigValues); const [computerUsePolicy, setComputerUsePolicy] = useState(createDefaultComputerUsePolicy()); @@ -728,7 +774,7 @@ export function AgentChatPane({ const localRuntimeState = useMemo(() => { const provider = selectedModelDesc?.authTypes.includes("local") ? (selectedModelDesc.family as LocalProviderFamily) - : getLocalProviderFromModelId(modelId); + : parseLocalProviderFromModelId(modelId); if (!provider) return null; const runtimeConnection = aiStatus?.runtimeConnections?.[provider] ?? null; const detectedEntry = aiStatus?.detectedAuth?.find( @@ -790,9 +836,61 @@ export function AgentChatPane({ return { tone: "success" as const, title: `${localRuntimeState.label} runtime`, - message: `${localRuntimeState.label} is connected at ${localRuntimeState.endpoint} with ${localRuntimeState.modelIds.length} loaded model${localRuntimeState.modelIds.length === 1 ? "" : "s"}${localRuntimeState.health ? ` (${localRuntimeState.health})` : ""}.`, + message: `${localRuntimeState.label} is connected with ${localRuntimeState.modelIds.length} loaded model${localRuntimeState.modelIds.length === 1 ? "" : "s"}${localRuntimeState.health ? ` (${localRuntimeState.health})` : ""}.`, }; }, [localRuntimeState, modelId, selectedModelDesc?.displayName]); + + const cliRuntimeBlocked = Boolean( + selectedSessionId + && activeProviderConnection + && !activeProviderConnection.runtimeAvailable + && (activeProviderConnection.blocker || activeProviderConnection.provider === "cursor"), + ); + const cliRuntimeTitle = activeProviderConnection?.provider === "claude" + ? "Claude runtime" + : activeProviderConnection?.provider === "cursor" + ? "Cursor runtime" + : "Codex runtime"; + const cliRuntimeBody = activeProviderConnection?.blocker + ?? (activeProviderConnection?.provider === "cursor" + ? "Cursor agent is not available. Ensure Cursor is installed and the agent is enabled." + : null); + + const mergedRuntimeBanner = useMemo(() => { + if (!cliRuntimeBlocked && !localRuntimeNotice) return null; + if (cliRuntimeBlocked && localRuntimeNotice) { + return { + kind: "merged" as const, + cliTitle: cliRuntimeTitle, + cliBody: cliRuntimeBody ?? "", + localNotice: localRuntimeNotice, + localEndpoint: localRuntimeState?.endpoint, + }; + } + if (cliRuntimeBlocked) { + return { + kind: "cli-only" as const, + cliTitle: cliRuntimeTitle, + cliBody: cliRuntimeBody ?? "", + }; + } + return { + kind: "local-only" as const, + localNotice: localRuntimeNotice!, + localEndpoint: localRuntimeState?.endpoint, + }; + }, [ + cliRuntimeBlocked, + cliRuntimeBody, + cliRuntimeTitle, + localRuntimeNotice, + localRuntimeState?.endpoint, + ]); + + useEffect(() => { + prevModelDescRef.current = getModelDescriptorForPermissionMode(modelId); + }, [modelId]); + const surfaceMode = presentation?.mode ?? "standard"; const identitySessionSettingsBusy = isPersistentIdentitySurface && sessionMutationKind !== null; @@ -1698,32 +1796,41 @@ export function AgentChatPane({ }; }, [currentNativeControls]); const buildModelSelectionSnapshot = useCallback((nextModelId: string) => { + const previousDesc = prevModelDescRef.current; const nextDesc = getModelById(nextModelId); + const nextPermissionDesc = getModelDescriptorForPermissionMode(nextModelId); const nextProvider = resolveChatRuntimeProvider(nextDesc); const nextModel = nextProvider === "unified" ? nextModelId : runtimeFacingModelId(nextDesc, nextModelId); const tiers = nextDesc?.reasoningTiers ?? []; const preferred = readLastUsedReasoningEffort({ laneId, modelId: nextModelId }); const nextReasoningEffort = selectReasoningEffort({ tiers, preferred }); + const nextRec = recommendedUnifiedPermissionModeForModel(nextPermissionDesc); return { nextDesc, nextModelId, nextModel, nextProvider, nextReasoningEffort, - nextUnifiedPermissionMode: recommendedUnifiedPermissionModeForModel(nextDesc), + nextUnifiedPermissionMode: nextRec, + resetUnifiedPermissionToDefault: shouldResetUnifiedPermissionForModelSwitch(previousDesc, nextPermissionDesc), }; }, [laneId]); const applyModelSelectionSnapshot = useCallback((snapshot: { nextModelId: string; nextReasoningEffort: string | null; nextUnifiedPermissionMode?: AgentChatUnifiedPermissionMode | null; + resetUnifiedPermissionToDefault?: boolean; }) => { setModelId(snapshot.nextModelId); setReasoningEffort(snapshot.nextReasoningEffort); - if (snapshot.nextUnifiedPermissionMode) { - setUnifiedPermissionMode(snapshot.nextUnifiedPermissionMode); + const nextUnified = snapshot.nextUnifiedPermissionMode ?? null; + const targetUnified = snapshot.resetUnifiedPermissionToDefault + ? (nextUnified ?? initialNativeControls.unifiedPermissionMode) + : nextUnified; + if (targetUnified != null) { + setUnifiedPermissionMode(targetUnified); } - }, []); + }, [initialNativeControls.unifiedPermissionMode]); const notifySessionCreated = useCallback((session: AgentChatSession) => { if (!onSessionCreated) return; void Promise.resolve(onSessionCreated(session)).catch((err) => { console.error("notifySessionCreated failed:", err); }); @@ -1736,11 +1843,12 @@ export function AgentChatPane({ if (!laneId) return null; const createPromise = (async () => { const desc = getModelById(modelId); + const permissionDesc = getModelDescriptorForPermissionMode(modelId); const provider = resolveChatRuntimeProvider(desc); const model = provider === "unified" ? modelId : runtimeFacingModelId(desc, modelId); const sessionProfile = resolveChatSessionProfile(computerUsePolicy); const harnessPermissionMode = provider === "unified" - ? recommendedUnifiedPermissionModeForModel(desc) + ? recommendedUnifiedPermissionModeForModel(permissionDesc) : null; const nativeControlPayload = harnessPermissionMode ? { @@ -2433,15 +2541,15 @@ export function AgentChatPane({ } setSessionMutationKind("model"); - const nextNativeControlPayload = snapshot.nextUnifiedPermissionMode + const nextUnifiedForPayload = snapshot.resetUnifiedPermissionToDefault + ? (snapshot.nextUnifiedPermissionMode ?? initialNativeControls.unifiedPermissionMode) + : snapshot.nextUnifiedPermissionMode; + const nextNativeControlPayload = snapshot.nextProvider === "unified" && nextUnifiedForPayload != null ? { - ...summarizeNativeControls(snapshot.nextProvider, { + ...summarizeNativeControls("unified", { ...currentNativeControls, - unifiedPermissionMode: snapshot.nextUnifiedPermissionMode, + unifiedPermissionMode: nextUnifiedForPayload, }), - ...(snapshot.nextProvider === "cursor" - ? { cursorConfigValues: currentNativeControls.cursorConfigValues } - : {}), } : buildNativeControlPayload(snapshot.nextProvider); void window.ade.agentChat.updateSession({ @@ -2550,47 +2658,44 @@ export function AgentChatPane({ {error}
) : null} - {selectedSessionId && !activeProviderConnection?.runtimeAvailable && (activeProviderConnection?.blocker || activeProviderConnection?.provider === "cursor") ? ( + {mergedRuntimeBanner?.kind === "cli-only" ? (
- {activeProviderConnection.provider === "claude" - ? "Claude runtime" - : activeProviderConnection.provider === "cursor" - ? "Cursor runtime" - : "Codex runtime"} + {mergedRuntimeBanner.cliTitle}
- {activeProviderConnection.blocker || "Cursor agent is not available. Ensure Cursor is installed and the agent is enabled."} + {mergedRuntimeBanner.cliBody}
) : null} - - {localRuntimeNotice ? ( -
-
- {localRuntimeNotice.title} + {mergedRuntimeBanner?.kind === "local-only" ? ( + + ) : null} + {mergedRuntimeBanner?.kind === "merged" ? ( +
+
+ Runtime status
-
- {localRuntimeNotice.message} +
+
+
+ {mergedRuntimeBanner.cliTitle} +
+
+ {mergedRuntimeBanner.cliBody} +
+
+
+ +
- {localRuntimeState ? ( - - {localRuntimeState.endpoint} - - ) : null}
) : null} diff --git a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx index 41734d982..7b934afec 100644 --- a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx +++ b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx @@ -8,7 +8,14 @@ import type { AiSettingsStatus, ProjectConfigSnapshot, } from "../../../shared/types"; -import { getLocalProviderDefaultEndpoint, getModelById, type LocalProviderFamily } from "../../../shared/modelRegistry"; +import { + getLocalModelIdTail, + getLocalProviderDefaultEndpoint, + getModelById, + LOCAL_PROVIDER_LABELS, + parseLocalProviderFromModelId, + type LocalProviderFamily, +} from "../../../shared/modelRegistry"; import { ArrowsClockwise, CheckCircle, @@ -72,10 +79,6 @@ const LOCAL_PROVIDER_SPECS: Array<{ { provider: "vllm", label: "vLLM", description: "OpenAI-compatible local server" }, ]; -const LOCAL_PROVIDER_LABELS: Record = Object.fromEntries( - LOCAL_PROVIDER_SPECS.map((entry) => [entry.provider, entry.label]), -) as Record; - const API_KEY_PROVIDERS: Array<{ provider: string; label: string; @@ -184,14 +187,13 @@ function buildCliMessage(tool: (typeof CLI_TOOLS)[number], connection: AiProvide function formatLocalModelLabel(modelId: string): string { const descriptor = getModelById(modelId); if (descriptor) return descriptor.displayName; - const raw = String(modelId ?? "").trim(); - const prefix = raw.split("/", 1)[0]?.toLowerCase(); - if (prefix === "ollama" || prefix === "lmstudio" || prefix === "vllm") { - const tail = raw.slice(prefix.length + 1).trim(); - const providerLabel = LOCAL_PROVIDER_SPECS.find((entry) => entry.provider === prefix)?.label ?? prefix; - return tail.length ? `${tail} (${providerLabel})` : raw; + const provider = parseLocalProviderFromModelId(modelId); + if (provider) { + const tail = getLocalModelIdTail(modelId, provider); + const brand = LOCAL_PROVIDER_LABELS[provider]; + return tail.length ? `${tail} (${brand})` : String(modelId ?? "").trim(); } - return raw; + return String(modelId ?? "").trim(); } const AUTH_ERROR_SIGNALS = [ diff --git a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx index 8b289ae13..062ee112a 100644 --- a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx +++ b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx @@ -2,7 +2,10 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { createPortal } from "react-dom"; import { AnimatePresence, motion } from "motion/react"; import { + LOCAL_PROVIDER_LABELS, MODEL_REGISTRY, + getLocalModelIdTail, + parseLocalProviderFromModelId, resolveModelDescriptor, type ModelDescriptor, } from "../../../shared/modelRegistry"; @@ -43,11 +46,6 @@ type UnifiedModelSelectorProps = { }; const SOURCE_KEYS: SourceSectionKey[] = ["subscription", "api", "local"]; -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; const selectCls = cn( "h-8 rounded-lg border border-white/[0.08] bg-white/[0.04] px-2 font-sans text-[11px] text-fg/70", @@ -67,21 +65,6 @@ function providerAccent(family: string, fallback?: string): string { return PROVIDER_BADGE_COLORS[family] ?? fallback ?? "#A78BFA"; } -function getLocalProviderFromModelId(modelId: string): "ollama" | "lmstudio" | "vllm" | null { - const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase(); - if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { - return provider; - } - return null; -} - -function getLocalModelShortLabel(modelId: string): string { - const provider = getLocalProviderFromModelId(modelId); - if (!provider) return modelId; - const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); - return tail.length ? tail : modelId; -} - function subsectionTabTitle(sub: ModelSubsection): string { return sub.label.trim() || "Models"; } @@ -133,10 +116,10 @@ function createUnknownModelPlaceholder(modelId: string): ModelDescriptor { isCliWrapped: true, }; } - const localProvider = getLocalProviderFromModelId(modelId); + const localProvider = parseLocalProviderFromModelId(modelId); if (localProvider) { - const providerLabel = LOCAL_PROVIDER_LABELS[localProvider]; - const shortId = getLocalModelShortLabel(modelId); + const shortId = getLocalModelIdTail(modelId, localProvider) || modelId; + const brand = LOCAL_PROVIDER_LABELS[localProvider]; return { id: modelId, shortId, @@ -148,11 +131,11 @@ function createUnknownModelPlaceholder(modelId: string): ModelDescriptor { capabilities: { tools: false, vision: false, reasoning: false, streaming: true }, color: PROVIDER_BADGE_COLORS[localProvider] ?? "#64748B", sdkProvider: "@ai-sdk/openai-compatible", - sdkModelId: modelId, + sdkModelId: shortId, isCliWrapped: false, discoverySource: localProvider === "lmstudio" ? "lmstudio-openai" : localProvider, harnessProfile: "guarded", - aliases: [providerLabel], + aliases: [brand], }; } return { diff --git a/apps/desktop/src/renderer/lib/modelOptions.ts b/apps/desktop/src/renderer/lib/modelOptions.ts index d52e08681..c6e5fc1f3 100644 --- a/apps/desktop/src/renderer/lib/modelOptions.ts +++ b/apps/desktop/src/renderer/lib/modelOptions.ts @@ -1,42 +1,33 @@ import type { AiModelDescriptor, AiRuntimeConnectionStatus, AiSettingsStatus, ModelId } from "../../shared/types"; -import { MODEL_REGISTRY, getModelById, type LocalProviderFamily, type ModelDescriptor } from "../../shared/modelRegistry"; +import { + LOCAL_PROVIDER_LABELS, + MODEL_REGISTRY, + getLocalModelIdTail, + getModelById, + isLocalProviderFamily, + parseLocalProviderFromModelId, + type ModelDescriptor, +} from "../../shared/modelRegistry"; function normalizeAuthProvider(provider: string | undefined): string { return String(provider ?? "").trim().toLowerCase(); } -const LOCAL_PROVIDER_LABELS: Record = { - ollama: "Ollama", - lmstudio: "LM Studio", - vllm: "vLLM", -}; - -function getLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { - const provider = String(modelId ?? "") - .trim() - .split("/", 1)[0] - ?.toLowerCase(); - if (provider === "ollama" || provider === "lmstudio" || provider === "vllm") { - return provider; - } - return null; -} - function getLocalModelLabel(modelId: string): string { - const provider = getLocalProviderFromModelId(modelId); + const provider = parseLocalProviderFromModelId(modelId); if (!provider) return modelId; - const tail = String(modelId ?? "").trim().slice(provider.length + 1).trim(); + const tail = getLocalModelIdTail(modelId, provider); return tail.length ? tail : modelId; } function buildFallbackModelOption(modelId: string): AiModelDescriptor { - const provider = getLocalProviderFromModelId(modelId); + const provider = parseLocalProviderFromModelId(modelId); if (provider) { - const providerLabel = LOCAL_PROVIDER_LABELS[provider]; + const pLabel = LOCAL_PROVIDER_LABELS[provider]; return { id: modelId, label: getLocalModelLabel(modelId), - description: `${providerLabel} local model`, + description: `${pLabel} local model`, }; } return { @@ -104,7 +95,7 @@ function hasDynamicLocalModelIdsForProvider( runtimeConnections: Record | undefined, ): boolean { const normalizedProvider = normalizeAuthProvider(provider); - if (normalizedProvider !== "ollama" && normalizedProvider !== "lmstudio" && normalizedProvider !== "vllm") { + if (!isLocalProviderFamily(normalizedProvider)) { return false; } const prefix = `${normalizedProvider}/`; @@ -168,7 +159,7 @@ export function deriveConfiguredModelIds( for (const [provider, connection] of Object.entries(runtimeConnections ?? {})) { const normalizedProvider = normalizeAuthProvider(provider); - if (normalizedProvider !== "ollama" && normalizedProvider !== "lmstudio" && normalizedProvider !== "vllm") { + if (!isLocalProviderFamily(normalizedProvider)) { continue; } if (!hasDynamicLocalModelIdsForProvider(normalizedProvider, status.availableModelIds, runtimeConnections)) { @@ -210,7 +201,7 @@ export function deriveConfiguredModelOptions( const descriptor = getModelById(modelId); return descriptor ? [descriptorToModelOption(descriptor)] - : getLocalProviderFromModelId(modelId) + : parseLocalProviderFromModelId(modelId) ? [buildFallbackModelOption(modelId)] : []; }); @@ -224,6 +215,6 @@ export function includeSelectedModelOption( if (!modelId.length || options.some((option) => option.id === modelId)) return options; const descriptor = getModelById(modelId); if (descriptor) return [descriptorToModelOption(descriptor), ...options]; - if (getLocalProviderFromModelId(modelId)) return [buildFallbackModelOption(modelId), ...options]; + if (parseLocalProviderFromModelId(modelId)) return [buildFallbackModelOption(modelId), ...options]; return options; } diff --git a/apps/desktop/src/shared/modelRegistry.test.ts b/apps/desktop/src/shared/modelRegistry.test.ts index 4fdbf920a..fa446ba97 100644 --- a/apps/desktop/src/shared/modelRegistry.test.ts +++ b/apps/desktop/src/shared/modelRegistry.test.ts @@ -4,6 +4,7 @@ import { getAvailableModels, getDefaultModelDescriptor, getModelById, + getModelDescriptorForPermissionMode, getRuntimeModelRefForDescriptor, listModelDescriptorsForProvider, MODEL_REGISTRY, @@ -76,6 +77,19 @@ describe("modelRegistry", () => { expect(resolveModelDescriptor("nonexistent/model-id")).toBeUndefined(); }); + it("getModelDescriptorForPermissionMode matches getModelById for known locals", () => { + const id = "ollama/qwen2.5-coder:32b"; + expect(getModelDescriptorForPermissionMode(id)).toEqual(getModelById(id)); + }); + + it("getModelDescriptorForPermissionMode yields guarded local for ollama/auto when getModelById is undefined", () => { + expect(getModelById("ollama/auto")).toBeUndefined(); + const perm = getModelDescriptorForPermissionMode("ollama/auto"); + expect(perm?.family).toBe("ollama"); + expect(perm?.harnessProfile).toBe("guarded"); + expect(perm?.authTypes).toContain("local"); + }); + it("resolves gpt-5.4 shortId to the API-key variant, not the codex variant", () => { const resolved = resolveModelAlias("gpt-5.4"); expect(resolved).toBeTruthy(); diff --git a/apps/desktop/src/shared/modelRegistry.ts b/apps/desktop/src/shared/modelRegistry.ts index b7c407cdd..508d4bf44 100644 --- a/apps/desktop/src/shared/modelRegistry.ts +++ b/apps/desktop/src/shared/modelRegistry.ts @@ -82,7 +82,8 @@ export function isModelProviderGroup(value: string | null | undefined): value is const ALL_CAPS: ModelCapabilities = { tools: true, vision: true, reasoning: true, streaming: true }; const NO_REASONING: ModelCapabilities = { tools: true, vision: true, reasoning: false, streaming: true }; const BASIC_CAPS: ModelCapabilities = { tools: true, vision: false, reasoning: false, streaming: true }; -const LOCAL_PROVIDER_LABELS: Record = { +/** Human-readable names for Ollama / LM Studio / vLLM (shared across main, renderer, and MCP). */ +export const LOCAL_PROVIDER_LABELS: Record = { ollama: "Ollama", lmstudio: "LM Studio", vllm: "vLLM", @@ -753,10 +754,38 @@ export function validateModelRegistry(models: ModelDescriptor[] = MODEL_REGISTRY validateModelRegistry(); rebuildIndexes(); -function isLocalProviderFamily(value: string): value is LocalProviderFamily { +export function isLocalProviderFamily(value: string): value is LocalProviderFamily { return value === "ollama" || value === "lmstudio" || value === "vllm"; } +/** First path segment of `provider/modelId` when it is a known local provider. */ +export function parseLocalProviderFromModelId(modelId: string): LocalProviderFamily | null { + const provider = String(modelId ?? "").trim().split("/", 1)[0]?.toLowerCase() ?? ""; + return isLocalProviderFamily(provider) ? provider : null; +} + +/** Model name segment after `provider/` for local refs; empty string if missing. */ +export function getLocalModelIdTail(modelId: string, provider: LocalProviderFamily): string { + return String(modelId ?? "").trim().slice(provider.length + 1).trim(); +} + +/** + * Descriptor for unified permission / harness decisions when the registry has no row yet. + * `getModelById` returns undefined for refs such as `ollama/auto`; this still returns a + * guarded local descriptor so the UI matches main-process harness behavior. + */ +export function getModelDescriptorForPermissionMode(modelId: string): ModelDescriptor | undefined { + const resolved = getModelById(modelId); + if (resolved) return resolved; + const provider = parseLocalProviderFromModelId(modelId); + if (!provider) return undefined; + const tail = getLocalModelIdTail(modelId, provider); + if (!tail.length || tail === "auto") { + return createDynamicLocalModelDescriptor(provider, "auto", { harnessProfile: "guarded" }); + } + return createDynamicLocalModelDescriptor(provider, tail); +} + function parseDynamicLocalModelRef(modelRef: string): { provider: LocalProviderFamily; modelId: string } | null { const normalized = modelRef.trim(); if (!normalized.length) return null; From 0e4785b57c8a4c97a9cfede7cad139dc208d34e5 Mon Sep 17 00:00:00 2001 From: Arul Sharma <31745423+arul28@users.noreply.github.com> Date: Mon, 6 Apr 2026 04:02:23 -0400 Subject: [PATCH 3/3] Harden local chat runtime and structured prompts --- .ade/ade.yaml | 6 + .ade/cto/identity.yaml | 19 +- apps/desktop/src/main/main.ts | 10 +- .../main/services/ai/aiIntegrationService.ts | 15 + .../main/services/ai/localModelDiscovery.ts | 26 +- .../services/ai/tools/universalTools.test.ts | 57 ++ .../main/services/ai/tools/universalTools.ts | 92 ++- .../main/services/ai/tools/workflowTools.ts | 13 +- .../services/chat/agentChatService.test.ts | 92 +++ .../main/services/chat/agentChatService.ts | 601 +++++++++++------- .../unifiedOrchestratorAdapter.ts | 2 +- .../components/chat/AgentChatPane.tsx | 8 +- .../onboarding/ProjectSetupPage.tsx | 2 +- .../settings/ProvidersSection.test.tsx | 18 +- .../components/settings/ProvidersSection.tsx | 6 +- .../shared/UnifiedModelSelector.tsx | 2 +- apps/desktop/src/renderer/lib/modelOptions.ts | 3 + apps/desktop/src/test/setup.ts | 2 +- apps/mcp-server/src/test/setup.ts | 2 +- 19 files changed, 708 insertions(+), 268 deletions(-) diff --git a/.ade/ade.yaml b/.ade/ade.yaml index 8679ba36a..9f155cd7e 100644 --- a/.ade/ade.yaml +++ b/.ade/ade.yaml @@ -23,3 +23,9 @@ ai: autoTitleEnabled: true autoTitleModelId: openai/gpt-5.3-codex-spark autoTitleRefreshOnComplete: true + localProviders: + lmstudio: + enabled: true + endpoint: http://127.0.0.1:1234 + autoDetect: true + preferredModelId: null diff --git a/.ade/cto/identity.yaml b/.ade/cto/identity.yaml index ad9417718..e7a73393e 100644 --- a/.ade/cto/identity.yaml +++ b/.ade/cto/identity.yaml @@ -1,7 +1,14 @@ name: CTO -version: 3 -persona: Persistent project CTO with collaborative personality. -personality: casual +version: 1 +persona: >- + You are the CTO for this project inside ADE. + + You are the persistent technical lead who owns architecture, execution + quality, engineering continuity, and team direction. + + Use ADE's tools and project context to help the team move forward with clear, + concrete decisions. +personality: strategic modelPreferences: provider: claude model: sonnet @@ -21,8 +28,4 @@ openclawContextPolicy: - secret - token - system_prompt -onboardingState: - completedSteps: - - identity - completedAt: 2026-04-02T14:20:19.124Z -updatedAt: 2026-04-02T14:20:19.127Z +updatedAt: 1970-01-01T00:00:00.000Z diff --git a/apps/desktop/src/main/main.ts b/apps/desktop/src/main/main.ts index 784312c21..fd24d7e6a 100644 --- a/apps/desktop/src/main/main.ts +++ b/apps/desktop/src/main/main.ts @@ -1455,7 +1455,15 @@ app.whenReady().then(async () => { }, }); agentChatServiceRef = agentChatService; - agentChatService.cleanupStaleAttachments(); + setImmediate(() => { + void Promise.resolve() + .then(() => agentChatService.cleanupStaleAttachments()) + .catch((err) => { + logger.warn("agent_chat.cleanup_stale_attachments_failed", { + error: err instanceof Error ? err.message : String(err), + }); + }); + }); // Wire agentChatService into prService for integration resolution prService.setAgentChatService(agentChatService); diff --git a/apps/desktop/src/main/services/ai/aiIntegrationService.ts b/apps/desktop/src/main/services/ai/aiIntegrationService.ts index 7836c2188..f02fb9801 100644 --- a/apps/desktop/src/main/services/ai/aiIntegrationService.ts +++ b/apps/desktop/src/main/services/ai/aiIntegrationService.ts @@ -539,6 +539,21 @@ async function buildLocalRuntimeConnection(args: { }); } } + + return { + provider: args.provider, + label, + kind: "local", + configured: true, + authAvailable: false, + runtimeDetected: false, + runtimeAvailable: false, + health: "unreachable", + source: "config", + endpoint: configuredEndpoint, + blocker: `${label} is configured for ${configuredEndpoint}, but the runtime did not respond.`, + lastCheckedAt: args.checkedAt, + }; } return { diff --git a/apps/desktop/src/main/services/ai/localModelDiscovery.ts b/apps/desktop/src/main/services/ai/localModelDiscovery.ts index 68551b0e5..6f93ba5f7 100644 --- a/apps/desktop/src/main/services/ai/localModelDiscovery.ts +++ b/apps/desktop/src/main/services/ai/localModelDiscovery.ts @@ -27,14 +27,16 @@ export type LocalProviderInspection = { }; const CACHE_TTL_MS = 30_000; +let inspectionCacheGeneration = 0; let discoverCache: { key: string; + generation: number; cachedAt: number; models: DiscoveredLocalModel[]; } | null = null; -let inspectionCache = new Map(); +let inspectionCache = new Map(); function buildCacheKey(auth: DetectedAuth[]): string { return auth @@ -334,9 +336,10 @@ export async function inspectLocalProvider( timeoutMs = 2_000, ): Promise { const key = buildInspectionKey(provider, endpoint); + const generation = inspectionCacheGeneration; const cached = inspectionCache.get(key); const now = Date.now(); - if (cached && now - cached.cachedAt < CACHE_TTL_MS) { + if (cached && cached.generation === generation && now - cached.cachedAt < CACHE_TTL_MS) { return cached.inspection; } @@ -344,19 +347,28 @@ export async function inspectLocalProvider( ? await inspectLmStudioProvider(endpoint, timeoutMs) : await inspectOpenAiCompatibleProvider(provider, endpoint, timeoutMs); - inspectionCache.set(key, { cachedAt: now, inspection }); + if (generation === inspectionCacheGeneration) { + inspectionCache.set(key, { generation, cachedAt: now, inspection }); + } return inspection; } export function clearLocalProviderInspectionCache(): void { - inspectionCache = new Map(); + inspectionCacheGeneration += 1; + inspectionCache = new Map(); discoverCache = null; } export async function discoverLocalModels(auth: DetectedAuth[]): Promise { const key = buildCacheKey(auth); + const generation = inspectionCacheGeneration; const now = Date.now(); - if (discoverCache && discoverCache.key === key && now - discoverCache.cachedAt < CACHE_TTL_MS) { + if ( + discoverCache + && discoverCache.generation === generation + && discoverCache.key === key + && now - discoverCache.cachedAt < CACHE_TTL_MS + ) { return discoverCache.models; } @@ -380,6 +392,8 @@ export async function discoverLocalModels(auth: DetectedAuth[]): Promise { expect(result.answer).toBe("user answer"); }); + it("accepts structured askUser prompts and returns normalized answers", async () => { + const cwd = makeTmpDir("ade-tools-askuser-structured-"); + const onAskUser = vi.fn().mockResolvedValue({ + answer: "auth-refactor", + answers: { + roadmap: ["auth-refactor", "bug-fixes"], + }, + responseText: null, + decision: "accept", + }); + const tools = createUniversalToolSet(cwd, { + permissionMode: "plan", + onAskUser, + }); + + const result = await (tools.askUser as any).execute({ + title: "Mock plan question", + body: "Choose the most important sprint goals.", + questions: [ + { + id: "roadmap", + header: "Sprint scope", + question: "Which items should we prioritize?", + multiSelect: true, + options: [ + { label: "Auth refactor", value: "auth-refactor", recommended: true }, + { label: "Bug fixes", value: "bug-fixes" }, + ], + }, + ], + }); + + expect(onAskUser).toHaveBeenCalledWith({ + title: "Mock plan question", + body: "Choose the most important sprint goals.", + questions: [ + { + id: "roadmap", + header: "Sprint scope", + question: "Which items should we prioritize?", + multiSelect: true, + options: [ + { label: "Auth refactor", value: "auth-refactor", recommended: true }, + { label: "Bug fixes", value: "bug-fixes" }, + ], + }, + ], + }); + expect(result).toMatchObject({ + answer: "auth-refactor", + answers: { + roadmap: ["auth-refactor", "bug-fixes"], + }, + decision: "accept", + }); + }); + // ── exitPlanMode tool ─────────────────────────────────────────── it("does not expose exitPlanMode in non-plan permission modes", async () => { diff --git a/apps/desktop/src/main/services/ai/tools/universalTools.ts b/apps/desktop/src/main/services/ai/tools/universalTools.ts index cba2026fd..5700892db 100644 --- a/apps/desktop/src/main/services/ai/tools/universalTools.ts +++ b/apps/desktop/src/main/services/ai/tools/universalTools.ts @@ -20,6 +20,42 @@ const execFileAsync = promisify(execFile); export type PermissionMode = "plan" | "edit" | "full-auto"; +export type AskUserToolOption = { + label: string; + value?: string; + description?: string; + recommended?: boolean; + preview?: string; + previewFormat?: "markdown" | "html"; +}; + +export type AskUserToolQuestion = { + id?: string; + header?: string; + question: string; + options?: AskUserToolOption[]; + multiSelect?: boolean; + allowsFreeform?: boolean; + isSecret?: boolean; + defaultAssumption?: string | null; + impact?: string | null; +}; + +export type AskUserToolInput = { + question?: string; + title?: string; + body?: string; + questions?: AskUserToolQuestion[]; +}; + +export type AskUserToolResult = { + answer: string; + answers?: Record; + responseText?: string | null; + decision?: string; + error?: string; +}; + export interface UniversalToolSetOptions { permissionMode: PermissionMode; memoryService?: ReturnType; @@ -35,7 +71,7 @@ export interface UniversalToolSetOptions { updatedAt: string; }; /** Callback invoked when askUser tool is called; must return the user's response */ - onAskUser?: (question: string) => Promise; + onAskUser?: (input: AskUserToolInput) => Promise; /** Optional callback for ADE-managed tool approvals in interactive chat sessions. */ onApprovalRequest?: (request: ToolApprovalRequest) => Promise; /** Sandbox config for API-model workers. CLI models skip this check. */ @@ -907,21 +943,63 @@ function createGitLogTool(cwd: string) { }); } -function createAskUserTool(onAskUser?: (question: string) => Promise) { +const askUserToolOptionSchema = z.object({ + label: z.string().describe("User-facing option label"), + value: z.string().optional().describe("Optional stable value to return for this option"), + description: z.string().optional().describe("Optional short sentence explaining the option"), + recommended: z.boolean().optional().describe("Whether this option should be highlighted as recommended"), + preview: z.string().optional().describe("Optional preview content shown alongside the option"), + previewFormat: z.enum(["markdown", "html"]).optional().describe("Preview rendering format"), +}); + +const askUserToolQuestionSchema = z.object({ + id: z.string().optional().describe("Optional stable identifier for this question"), + header: z.string().optional().describe("Short question header shown in the UI"), + question: z.string().describe("The question to ask the user"), + options: z.array(askUserToolOptionSchema).optional().describe("Optional multiple-choice options"), + multiSelect: z.boolean().optional().describe("Allow selecting more than one option"), + allowsFreeform: z.boolean().optional().describe("Allow typing a custom answer"), + isSecret: z.boolean().optional().describe("Hide typed input while the user answers"), + defaultAssumption: z.string().nullable().optional().describe("Default assumption if the user skips this question"), + impact: z.string().nullable().optional().describe("Why this question matters"), +}); + +function createAskUserTool( + onAskUser?: (input: AskUserToolInput) => Promise, +) { return tool({ description: "Ask the user a clarifying question when you need more information to proceed. " + "Use sparingly — only when truly blocked.", inputSchema: z.object({ - question: z.string().describe("The question to ask the user"), - }), - execute: async ({ question }) => { + question: z.string().optional().describe("Simple text question to ask the user"), + title: z.string().optional().describe("Optional modal title for a richer prompt"), + body: z.string().optional().describe("Optional supporting context shown above the question list"), + questions: z.array(askUserToolQuestionSchema).optional().describe("Optional structured questions with choices"), + }).refine( + (value) => { + const question = typeof value.question === "string" ? value.question.trim() : ""; + const body = typeof value.body === "string" ? value.body.trim() : ""; + return question.length > 0 || body.length > 0 || Boolean(value.questions?.length); + }, + { message: "Provide question, body, or questions." }, + ), + execute: async (input) => { if (!onAskUser) { return { answer: "", error: "askUser callback not configured" }; } try { - const answer = await onAskUser(question); - return { answer }; + const response = await onAskUser(input); + if (typeof response === "string") { + return { answer: response }; + } + return { + answer: response.answer ?? "", + ...(response.answers ? { answers: response.answers } : {}), + ...(response.responseText !== undefined ? { responseText: response.responseText } : {}), + ...(response.decision !== undefined ? { decision: response.decision } : {}), + ...(response.error !== undefined ? { error: response.error } : {}), + }; } catch (err) { return { answer: "", diff --git a/apps/desktop/src/main/services/ai/tools/workflowTools.ts b/apps/desktop/src/main/services/ai/tools/workflowTools.ts index fffeb0ac2..7263c1844 100644 --- a/apps/desktop/src/main/services/ai/tools/workflowTools.ts +++ b/apps/desktop/src/main/services/ai/tools/workflowTools.ts @@ -172,12 +172,11 @@ export function createWorkflowTools( error: "Local computer-use fallback is disabled for this chat session.", }; } - const tmpDir = fs.mkdtempSync(path.join(require("node:os").tmpdir(), "ade-screenshot-")); - const tmpPath = path.join( - tmpDir, - `screenshot-${Date.now()}.png`, - ); + let tmpDir: string | null = null; try { + tmpDir = fs.mkdtempSync(path.join(require("node:os").tmpdir(), "ade-screenshot-")); + const tmpPath = path.join(tmpDir, `screenshot-${Date.now()}.png`); + // Use macOS screencapture to grab the screen await execFileAsync("screencapture", ["-x", tmpPath], { timeout: 15_000, @@ -218,7 +217,9 @@ export function createWorkflowTools( return formatToolError("Screenshot failed", err); } finally { try { - fs.rmSync(tmpDir, { recursive: true, force: true }); + if (tmpDir) { + fs.rmSync(tmpDir, { recursive: true, force: true }); + } } catch { // Best-effort cleanup only. } diff --git a/apps/desktop/src/main/services/chat/agentChatService.test.ts b/apps/desktop/src/main/services/chat/agentChatService.test.ts index 7f038978f..5cfc53aeb 100644 --- a/apps/desktop/src/main/services/chat/agentChatService.test.ts +++ b/apps/desktop/src/main/services/chat/agentChatService.test.ts @@ -206,6 +206,10 @@ vi.mock("../ai/authDetector", () => ({ detectAllAuth: vi.fn(async () => []), })); +vi.mock("../ai/localModelDiscovery", () => ({ + discoverLocalModels: vi.fn(async () => []), +})); + vi.mock("../git/git", () => ({ runGit: vi.fn(async () => ({ stdout: "", stderr: "", exitCode: 0 })), })); @@ -285,11 +289,14 @@ vi.mock("./cursorAcpPool", () => ({ // Import system under test (after mocks) // --------------------------------------------------------------------------- import { + buildUnifiedStreamMessages, buildComputerUseDirective, createAgentChatService, + shouldAutoContinueUnifiedLocalTurn, } from "./agentChatService"; import { spawn } from "node:child_process"; import { detectAllAuth } from "../ai/authDetector"; +import { discoverLocalModels } from "../ai/localModelDiscovery"; import * as providerResolver from "../ai/providerResolver"; import { createUniversalToolSet } from "../ai/tools/universalTools"; import { createWorkflowTools } from "../ai/tools/workflowTools"; @@ -626,6 +633,8 @@ beforeEach(() => { vi.mocked(generateText).mockReset(); vi.mocked(unstable_v2_createSession).mockReset(); vi.mocked(detectAllAuth).mockResolvedValue([]); + vi.mocked(discoverLocalModels).mockReset(); + vi.mocked(discoverLocalModels).mockResolvedValue([]); vi.mocked(providerResolver.resolveModel).mockResolvedValue({} as any); vi.mocked(parseAgentChatTranscript).mockReturnValue([]); }); @@ -4504,6 +4513,89 @@ describe("createAgentChatService", () => { expect(updated?.permissionMode).toBe("edit"); expect(updated?.unifiedPermissionMode).toBe("edit"); }); + + it("detects narrated next-step early stops for local unified turns", () => { + expect(shouldAutoContinueUnifiedLocalTurn({ + modelDescriptor: { authTypes: ["local"] }, + permissionMode: "edit", + assistantText: "I will explore the src directory to identify where pages and routing are defined in the application.", + toolCallCount: 1, + toolResultCount: 1, + continuationCount: 0, + })).toBe(true); + + expect(shouldAutoContinueUnifiedLocalTurn({ + modelDescriptor: { authTypes: ["local"] }, + permissionMode: "plan", + assistantText: "I will explore the src directory to identify where pages and routing are defined in the application.", + toolCallCount: 1, + toolResultCount: 1, + continuationCount: 0, + })).toBe(false); + + expect(shouldAutoContinueUnifiedLocalTurn({ + modelDescriptor: { authTypes: ["local"] }, + permissionMode: "edit", + assistantText: "I found the routing file and can add the blank about page next.", + toolCallCount: 1, + toolResultCount: 1, + continuationCount: 0, + })).toBe(false); + }); + + it("preserves original attachments across local auto-continuation retries", () => { + const resolvedPath = path.join(tmpRoot, "note.txt"); + fs.writeFileSync(resolvedPath, "remember this", "utf8"); + + const streamMessages = buildUnifiedStreamMessages({ + messages: [ + { + role: "user", + content: "Add an about me page.\n\nAttached context:\n- file: note.txt", + }, + { + role: "assistant", + content: "I will explore the src directory to identify where pages and routing are defined in the application.", + }, + { + role: "user", + content: "Continue from your last step.", + }, + ], + persistedTurnUserMessageIndex: 0, + resolvedAttachments: [{ + path: "note.txt", + type: "file", + _rootPath: tmpRoot, + _resolvedPath: resolvedPath, + }], + modelDescriptor: { + id: "lmstudio/qwen2.5-coder:32b", + displayName: "qwen2.5-coder:32b", + family: "lmstudio", + authTypes: ["local"], + contextWindow: 0, + maxOutputTokens: 0, + capabilities: { tools: true, vision: false, reasoning: false, streaming: true }, + color: "#64748B", + sdkProvider: "@ai-sdk/openai-compatible", + sdkModelId: "qwen2.5-coder:32b", + isCliWrapped: false, + harnessProfile: "verified", + } as any, + logger: createLogger() as any, + }); + + expect(streamMessages).toHaveLength(3); + expect(streamMessages[0]?.content).toEqual(expect.arrayContaining([ + expect.objectContaining({ type: "text" }), + expect.objectContaining({ type: "file", filename: "note.txt" }), + ])); + expect(streamMessages[2]).toEqual({ + role: "user", + content: "Continue from your last step.", + }); + }); }); it("emits immediate startup activity before unified stream output arrives", async () => { diff --git a/apps/desktop/src/main/services/chat/agentChatService.ts b/apps/desktop/src/main/services/chat/agentChatService.ts index d4e407ffc..0e514c948 100644 --- a/apps/desktop/src/main/services/chat/agentChatService.ts +++ b/apps/desktop/src/main/services/chat/agentChatService.ts @@ -107,6 +107,7 @@ import type { AgentChatUnifiedPermissionMode, PendingInputQuestion, PendingInputRequest, + PendingInputSource, AgentChatUpdateSessionArgs, ComputerUseBackendStatus, ComputerUsePolicy, @@ -135,7 +136,7 @@ import { canSwitchChatSessionModel } from "../../../shared/chatModelSwitching"; import { detectAllAuth } from "../ai/authDetector"; import * as providerResolver from "../ai/providerResolver"; import { buildCodexAppServerMcpConfigOverrides } from "../ai/codexAppServerConfig"; -import { createUniversalToolSet, type PermissionMode } from "../ai/tools/universalTools"; +import { createUniversalToolSet, type AskUserToolInput, type PermissionMode } from "../ai/tools/universalTools"; import { createWorkflowTools } from "../ai/tools/workflowTools"; import { createLinearTools } from "../ai/tools/linearTools"; import { createCtoOperatorTools, type CtoOperatorToolDeps } from "../ai/tools/ctoOperatorTools"; @@ -969,6 +970,56 @@ function normalizeUsagePayload( return { inputTokens, outputTokens }; } +function mergeUsagePayloads( + left: { inputTokens?: number | null; outputTokens?: number | null } | undefined, + right: { inputTokens?: number | null; outputTokens?: number | null } | undefined, +): { inputTokens?: number | null; outputTokens?: number | null } | undefined { + if (!left) return right; + if (!right) return left; + const inputTokens = (left.inputTokens ?? 0) + (right.inputTokens ?? 0); + const outputTokens = (left.outputTokens ?? 0) + (right.outputTokens ?? 0); + if (!inputTokens && !outputTokens) return undefined; + return { inputTokens, outputTokens }; +} + +const MAX_UNIFIED_LOCAL_AUTO_CONTINUATIONS = 1; +const UNIFIED_LOCAL_AUTO_CONTINUE_PROMPT = + "Continue from your last step. Do not restate that you will inspect, explore, or review the codebase. " + + "Take the next concrete action now, using tools if needed. Only stop when you have completed the request or are truly blocked."; + +export function shouldAutoContinueUnifiedLocalTurn(args: { + modelDescriptor: Pick; + permissionMode: PermissionMode; + assistantText: string; + toolCallCount: number; + toolResultCount: number; + continuationCount: number; + pendingApprovalCount?: number; +}): boolean { + if (!args.modelDescriptor.authTypes.includes("local")) return false; + if (args.permissionMode === "plan") return false; + if (args.continuationCount >= MAX_UNIFIED_LOCAL_AUTO_CONTINUATIONS) return false; + if ((args.pendingApprovalCount ?? 0) > 0) return false; + if (args.toolCallCount < 1 || args.toolResultCount < 1) return false; + + const normalized = args.assistantText.trim().replace(/\s+/g, " "); + if (!normalized.length || normalized.length > 280) return false; + + const lower = normalized.toLowerCase(); + if (lower.includes("?")) return false; + if ( + /\b(done|completed|complete|implemented|created|updated|finished|fixed|here('| i)?s|i found|i added|i changed|i created|i updated|i implemented|blocked)\b/i.test(lower) + ) { + return false; + } + + const narratedIntentPattern = + /^(?:ok[, ]+|alright[, ]+|next[, ]+|now[, ]+)?(?:i(?:'ll| will| am going to|'m going to)|let me)\b/i; + const nextActionPattern = + /\b(explore|inspect|check|look at|search|review|read|analy[sz]e|identify|open|scan|trace|browse)\b/i; + return narratedIntentPattern.test(normalized) && nextActionPattern.test(normalized); +} + const KNOWN_CODEX_EFFORTS = new Set(CODEX_REASONING_EFFORTS.map((e) => e.effort)); const EFFORT_ALIASES: Record> = { @@ -1407,6 +1458,35 @@ function buildStreamingUserContent( return parts; } +export function buildUnifiedStreamMessages(args: { + messages: Array<{ role: string; content: string }>; + persistedTurnUserMessageIndex: number; + resolvedAttachments: ResolvedAgentChatFileRef[]; + modelDescriptor: ModelDescriptor; + logger?: Logger; +}): ModelMessage[] { + return args.messages.map((message, index): ModelMessage => { + const isPersistedTurnUserMessage = index === args.persistedTurnUserMessageIndex && message.role === "user"; + if (!isPersistedTurnUserMessage) { + return { + role: message.role === "user" ? "user" : "assistant", + content: message.content, + }; + } + + return { + role: "user", + content: buildStreamingUserContent({ + baseText: message.content, + attachments: args.resolvedAttachments, + runtimeKind: "unified", + modelDescriptor: args.modelDescriptor, + logger: args.logger, + }), + }; + }); +} + function buildExecutionModeDirective( mode: AgentChatExecutionMode | null | undefined, provider: AgentChatProvider, @@ -2029,6 +2109,19 @@ function applyLocalHarnessPermissionMode(args: { }; } +function enforceManagedLocalHarnessPermissionMode( + managed: ManagedChatSession, + descriptor?: ModelDescriptor | null, +): void { + const harnessPermissions = applyLocalHarnessPermissionMode({ + descriptor: descriptor ?? resolveSessionModelDescriptor(managed.session) ?? undefined, + requestedPermissionMode: managed.session.permissionMode, + requestedUnifiedPermissionMode: managed.session.unifiedPermissionMode, + }); + managed.session.permissionMode = harnessPermissions.requestedPermissionMode ?? managed.session.permissionMode; + managed.session.unifiedPermissionMode = harnessPermissions.requestedUnifiedPermissionMode ?? managed.session.unifiedPermissionMode; +} + function resolveCursorSessionModeId( session: Pick, ): string | null { @@ -4310,8 +4403,17 @@ export function createAgentChatService(args: { auth: Awaited>, ): Promise => { if (auth.some((entry) => entry.type === "local")) { - const discovered = await discoverLocalModels(auth); - replaceDynamicLocalModelDescriptors(discovered.map(discoveredLocalModelToDescriptor)); + try { + const discovered = await discoverLocalModels(auth); + replaceDynamicLocalModelDescriptors(discovered.map(discoveredLocalModelToDescriptor)); + } catch (err) { + replaceDynamicLocalModelDescriptors([]); + logger.warn("agent_chat.local_model_discovery_failed", { + error: err instanceof Error ? err.message : String(err), + }); + } + } else { + replaceDynamicLocalModelDescriptors([]); } return getRegistryModels(auth).filter((descriptor) => !descriptor.deprecated); }; @@ -4329,8 +4431,19 @@ export function createAgentChatService(args: { } const localProvider = descriptor.family as LocalProviderFamily; - const discovered = (await discoverLocalModels(auth)).filter((model) => model.provider === localProvider); - replaceDynamicLocalModelDescriptors(discovered.map(discoveredLocalModelToDescriptor)); + let discovered: DiscoveredLocalModel[] = []; + try { + discovered = (await discoverLocalModels(auth)).filter((model) => model.provider === localProvider); + replaceDynamicLocalModelDescriptors(discovered.map(discoveredLocalModelToDescriptor)); + } catch (err) { + replaceDynamicLocalModelDescriptors([]); + logger.warn("agent_chat.local_model_resolution_failed", { + sessionId: managed.session.id, + modelId: descriptor.id, + error: err instanceof Error ? err.message : String(err), + }); + return descriptor; + } const preferred = auth.find( (entry): entry is Extract>[number], { type: "local" }> => @@ -4345,8 +4458,6 @@ export function createAgentChatService(args: { if (discovered.length === 1) { const onlyDescriptor = getModelById(`${localProvider}/${discovered[0]!.modelId}`) ?? discoveredLocalModelToDescriptor(discovered[0]!); - managed.session.modelId = onlyDescriptor.id; - managed.session.model = onlyDescriptor.id; return onlyDescriptor; } @@ -4377,13 +4488,7 @@ export function createAgentChatService(args: { const auth = await detectAuth(); descriptor = await resolveUnifiedLocalDescriptor(managed, descriptor, auth); - const harnessPermissions = applyLocalHarnessPermissionMode({ - descriptor, - requestedPermissionMode: managed.session.permissionMode, - requestedUnifiedPermissionMode: managed.session.unifiedPermissionMode, - }); - managed.session.permissionMode = harnessPermissions.requestedPermissionMode ?? managed.session.permissionMode; - managed.session.unifiedPermissionMode = harnessPermissions.requestedUnifiedPermissionMode ?? managed.session.unifiedPermissionMode; + enforceManagedLocalHarnessPermissionMode(managed, descriptor); const resolvedModel = await providerResolver.resolveModel(descriptor.id, auth, { cwd: managed.laneWorktreePath, }); @@ -4443,6 +4548,7 @@ export function createAgentChatService(args: { managed.session.provider = "unified"; managed.session.unifiedPermissionMode = permMode; managed.session.permissionMode = syncLegacyPermissionMode(managed.session) ?? managed.session.permissionMode; + enforceManagedLocalHarnessPermissionMode(managed, descriptor); managed.session.capabilityMode = mcpClient ? "full_mcp" : "fallback"; return "handled"; }; @@ -7154,6 +7260,8 @@ export function createAgentChatService(args: { let assistantText = ""; let usage: { inputTokens?: number | null; outputTokens?: number | null } | undefined; + let finalAssistantText = ""; + let autoContinuationCount = 0; let streamedStepCount = 0; const turnStartedAt = Date.now(); let firstStreamEventLogged = false; @@ -7204,30 +7312,11 @@ export function createAgentChatService(args: { applyReconstructionContextToStreamingRuntime(managed, runtime); runtime.messages.push({ role: "user", content: userContent }); + const persistedTurnUserMessageIndex = runtime.messages.length - 1; const abortController = new AbortController(); runtime.abortController = abortController; - const streamMessages = runtime.messages.map((message, index): ModelMessage => { - const isCurrentUserMessage = index === runtime.messages.length - 1 && message.role === "user"; - if (!isCurrentUserMessage) { - return { - role: message.role as "user" | "assistant", - content: message.content, - }; - } - - return { - role: "user", - content: buildStreamingUserContent({ - baseText: streamingBaseText, - attachments: resolvedAttachments, - runtimeKind: "unified", - modelDescriptor: runtime.modelDescriptor, - logger, - }), - }; - }); const lightweight = isLightweightSession(managed.session); const executionLaneId = resolveManagedExecutionLaneId(managed); const tools = lightweight @@ -7343,47 +7432,39 @@ export function createAgentChatService(args: { : "User denied the action.", }; }, - onAskUser: async (question) => { - const askItemId = randomUUID(); - const request: PendingInputRequest = { - requestId: askItemId, - itemId: askItemId, + onAskUser: async (input: AskUserToolInput) => { + const normalizedQuestion = typeof input.question === "string" ? input.question.trim() : ""; + const normalizedBody = typeof input.body === "string" ? input.body.trim() : ""; + const response = await requestChatInput({ + chatSessionId: managed.session.id, + title: typeof input.title === "string" && input.title.trim().length + ? input.title.trim() + : "Question from agent", + body: normalizedBody || normalizedQuestion, + questions: input.questions, source: "unified", - kind: "question", - description: question, - questions: [ - { - id: "response", - header: "Question", - question, - allowsFreeform: true, - }, - ], - allowsFreeform: true, - blocking: true, - canProceedWithoutAnswer: false, providerMetadata: { tool: "askUser", - inputType: "text", + inputType: input.questions?.length ? "structured" : "text", + }, + eventDescription: normalizedBody || normalizedQuestion, + eventDetail: { + tool: "askUser", + ...(normalizedQuestion ? { question: normalizedQuestion } : {}), + inputType: input.questions?.length ? "structured" : "text", }, - turnId, - }; - emitPendingInputRequest(managed, request, { - kind: "tool_call", - description: question, - detail: { tool: "askUser", question, inputType: "text" }, - }); - - const response = await new Promise<{ decision?: AgentChatApprovalDecision; responseText?: string | null; answers?: Record }>((resolve) => { - runtime.pendingApprovals.set(askItemId, { category: "askUser", request, resolve }); }); - runtime.pendingApprovals.delete(askItemId); - const normalizedAnswers = normalizePendingInputAnswers(request, response.answers, response.responseText); - const answer = normalizedAnswers.response?.[0] ?? ""; - if (answer.length) return answer; - if (response.decision === "accept") return "yes"; - if (response.decision === "decline") return "no"; - return String(response.decision); + const primaryQuestionId = input.questions?.[0]?.id ?? "response"; + const answer = response.answers[primaryQuestionId]?.[0] + ?? Object.values(response.answers)[0]?.[0] + ?? response.responseText + ?? ""; + return { + answer, + answers: response.answers, + responseText: response.responseText, + decision: response.decision, + }; }, }); @@ -7567,186 +7648,230 @@ export function createAgentChatService(args: { return baseHarnessPrompt; })(); - const stream = streamText({ - model: runtime.resolvedModel, - ...(harnessPrompt ? { system: harnessPrompt } : {}), - messages: streamMessages, - ...(Object.keys(tools).length ? { tools } : {}), - providerOptions: providerOptions as any, - ...(!lightweight ? { stopWhen: stepCountIs(20) } : {}), - abortSignal: abortController.signal, - onError({ error }) { - logger.warn("agent_chat.unified_stream_error", { - sessionId: managed.session.id, - error: error instanceof Error ? error.message : String(error), - }); - }, - }); + let shouldAutoContinue = false; + const baseTurnMessages = runtime.messages.slice(); + do { + assistantText = ""; + let iterationUsage: { inputTokens?: number | null; outputTokens?: number | null } | undefined; + let iterationToolCallCount = 0; + let iterationToolResultCount = 0; + shouldAutoContinue = false; + + const currentIterationMessages = baseTurnMessages; + const streamMessages = buildUnifiedStreamMessages({ + messages: currentIterationMessages, + persistedTurnUserMessageIndex, + resolvedAttachments, + modelDescriptor: runtime.modelDescriptor, + logger, + }); - // ── Stream processing loop ── - const streamSupportsReasoning = runtime.modelDescriptor.capabilities.reasoning; - for await (const part of stream.fullStream as AsyncIterable) { - if (runtime.interrupted) break; - if (!part || typeof part !== "object") continue; - markFirstStreamEvent(String(part.type ?? "unknown")); + const stream = streamText({ + model: runtime.resolvedModel, + ...(harnessPrompt ? { system: harnessPrompt } : {}), + messages: streamMessages, + ...(Object.keys(tools).length ? { tools } : {}), + providerOptions: providerOptions as any, + ...(!lightweight ? { stopWhen: stepCountIs(20) } : {}), + abortSignal: abortController.signal, + onError({ error }) { + logger.warn("agent_chat.unified_stream_error", { + sessionId: managed.session.id, + error: error instanceof Error ? error.message : String(error), + }); + }, + }); - if (part.type === "start-step") { - streamedStepCount += 1; - emitChatEvent(managed, { - type: "step_boundary", - stepNumber: typeof part.stepNumber === "number" ? part.stepNumber + 1 : streamedStepCount, - turnId, - }); - if (!streamSupportsReasoning && streamedStepCount === 1) { + // ── Stream processing loop ── + const streamSupportsReasoning = runtime.modelDescriptor.capabilities.reasoning; + for await (const part of stream.fullStream as AsyncIterable) { + if (runtime.interrupted) break; + if (!part || typeof part !== "object") continue; + markFirstStreamEvent(String(part.type ?? "unknown")); + + if (part.type === "start-step") { + streamedStepCount += 1; emitChatEvent(managed, { - type: "activity", - activity: "working", - detail: WORKING_ACTIVITY_DETAIL, + type: "step_boundary", + stepNumber: typeof part.stepNumber === "number" ? part.stepNumber + 1 : streamedStepCount, turnId, }); + if (!streamSupportsReasoning && streamedStepCount === 1) { + emitChatEvent(managed, { + type: "activity", + activity: "working", + detail: WORKING_ACTIVITY_DETAIL, + turnId, + }); + } + continue; } - continue; - } - if (part.type === "source") { - emitChatEvent(managed, { - type: "activity", - activity: "searching", - detail: - typeof part.title === "string" && part.title.trim().length - ? part.title - : typeof part.url === "string" && part.url.trim().length - ? part.url - : "Gathering sources", - turnId, - }); - continue; - } + if (part.type === "source") { + emitChatEvent(managed, { + type: "activity", + activity: "searching", + detail: + typeof part.title === "string" && part.title.trim().length + ? part.title + : typeof part.url === "string" && part.url.trim().length + ? part.url + : "Gathering sources", + turnId, + }); + continue; + } - if (part.type === "text-delta") { - const delta = String(part.text ?? part.textDelta ?? ""); - if (!delta.length) continue; - assistantText += delta; - emitChatEvent(managed, { - type: "text", - text: delta, - turnId, - itemId: typeof part.id === "string" ? part.id : undefined - }); - continue; - } + if (part.type === "text-delta") { + const delta = String(part.text ?? part.textDelta ?? ""); + if (!delta.length) continue; + assistantText += delta; + emitChatEvent(managed, { + type: "text", + text: delta, + turnId, + itemId: typeof part.id === "string" ? part.id : undefined + }); + continue; + } - if (part.type === "reasoning-start") { - emitChatEvent(managed, { - type: "activity", - activity: streamSupportsReasoning ? "thinking" : "working", - detail: streamSupportsReasoning ? REASONING_ACTIVITY_DETAIL : WORKING_ACTIVITY_DETAIL, - turnId, - }); - continue; - } + if (part.type === "reasoning-start") { + emitChatEvent(managed, { + type: "activity", + activity: streamSupportsReasoning ? "thinking" : "working", + detail: streamSupportsReasoning ? REASONING_ACTIVITY_DETAIL : WORKING_ACTIVITY_DETAIL, + turnId, + }); + continue; + } - if (part.type === "reasoning" || part.type === "reasoning-delta") { - const delta = String(part.text ?? part.textDelta ?? part.delta ?? ""); - if (!delta.length) continue; - if (!streamSupportsReasoning) { + if (part.type === "reasoning" || part.type === "reasoning-delta") { + const delta = String(part.text ?? part.textDelta ?? part.delta ?? ""); + if (!delta.length) continue; + if (!streamSupportsReasoning) { + emitChatEvent(managed, { + type: "activity", + activity: "working", + detail: WORKING_ACTIVITY_DETAIL, + turnId, + }); + continue; + } emitChatEvent(managed, { type: "activity", - activity: "working", - detail: WORKING_ACTIVITY_DETAIL, + activity: "thinking", + detail: REASONING_ACTIVITY_DETAIL, turnId, }); + emitChatEvent(managed, { + type: "reasoning", + text: delta, + turnId, + itemId: typeof part.id === "string" ? part.id : undefined + }); continue; } - emitChatEvent(managed, { - type: "activity", - activity: "thinking", - detail: REASONING_ACTIVITY_DETAIL, - turnId, - }); - emitChatEvent(managed, { - type: "reasoning", - text: delta, - turnId, - itemId: typeof part.id === "string" ? part.id : undefined - }); - continue; - } - if (part.type === "reasoning-end") { - flushBufferedReasoning(managed); - continue; - } + if (part.type === "reasoning-end") { + flushBufferedReasoning(managed); + continue; + } - if (part.type === "tool-call") { - const nextActivity = activityForToolName(String(part.toolName ?? "tool")); - const parentItemId = readProviderParentItemId((part as { providerMetadata?: unknown }).providerMetadata); - emitChatEvent(managed, { - type: "activity", - activity: nextActivity.activity, - detail: nextActivity.detail, - turnId, - }); - emitChatEvent(managed, { - type: "tool_call", - tool: String(part.toolName ?? "tool"), - args: part.input ?? part.args ?? part.arguments, - itemId: String(part.toolCallId ?? randomUUID()), - ...(parentItemId ? { parentItemId } : {}), - turnId - }); - continue; - } + if (part.type === "tool-call") { + iterationToolCallCount += 1; + const nextActivity = activityForToolName(String(part.toolName ?? "tool")); + const parentItemId = readProviderParentItemId((part as { providerMetadata?: unknown }).providerMetadata); + emitChatEvent(managed, { + type: "activity", + activity: nextActivity.activity, + detail: nextActivity.detail, + turnId, + }); + emitChatEvent(managed, { + type: "tool_call", + tool: String(part.toolName ?? "tool"), + args: part.input ?? part.args ?? part.arguments, + itemId: String(part.toolCallId ?? randomUUID()), + ...(parentItemId ? { parentItemId } : {}), + turnId + }); + continue; + } - if (part.type === "tool-result") { - const parentItemId = readProviderParentItemId((part as { providerMetadata?: unknown }).providerMetadata); - emitChatEvent(managed, { - type: "tool_result", - tool: String(part.toolName ?? "tool"), - result: part.output ?? part.result, - itemId: String(part.toolCallId ?? randomUUID()), - ...(parentItemId ? { parentItemId } : {}), - turnId, - status: part.preliminary ? "running" : "completed" - }); - continue; - } + if (part.type === "tool-result") { + iterationToolResultCount += 1; + const parentItemId = readProviderParentItemId((part as { providerMetadata?: unknown }).providerMetadata); + emitChatEvent(managed, { + type: "tool_result", + tool: String(part.toolName ?? "tool"), + result: part.output ?? part.result, + itemId: String(part.toolCallId ?? randomUUID()), + ...(parentItemId ? { parentItemId } : {}), + turnId, + status: part.preliminary ? "running" : "completed" + }); + continue; + } - if (part.type === "tool-error") { - emitChatEvent(managed, { - type: "error", - message: `Tool '${String(part.toolName ?? "tool")}' failed: ${String(part.error ?? "unknown error")}`, - turnId, - itemId: String(part.toolCallId ?? randomUUID()) - }); - continue; - } + if (part.type === "tool-error") { + emitChatEvent(managed, { + type: "error", + message: `Tool '${String(part.toolName ?? "tool")}' failed: ${String(part.error ?? "unknown error")}`, + turnId, + itemId: String(part.toolCallId ?? randomUUID()) + }); + continue; + } - if (part.type === "tool-approval-request") { - const toolName = String(part.toolCall?.toolName ?? "tool"); - emitChatEvent(managed, { - type: "error", - message: isPlanningApprovalGuarded(managed) - ? buildPlanningApprovalViolation(toolName) - : `Unexpected SDK approval request for '${toolName}'. This tool should use ADE-managed approvals instead.`, - turnId - }); - continue; - } + if (part.type === "tool-approval-request") { + const toolName = String(part.toolCall?.toolName ?? "tool"); + emitChatEvent(managed, { + type: "error", + message: isPlanningApprovalGuarded(managed) + ? buildPlanningApprovalViolation(toolName) + : `Unexpected SDK approval request for '${toolName}'. This tool should use ADE-managed approvals instead.`, + turnId + }); + continue; + } - if (part.type === "finish") { - usage = normalizeUsagePayload(part.totalUsage ?? part.usage); - continue; + if (part.type === "finish") { + iterationUsage = normalizeUsagePayload(part.totalUsage ?? part.usage); + continue; + } + + if (part.type === "error") { + emitChatEvent(managed, { + type: "error", + message: String(part.error ?? "Stream error."), + turnId + }); + } } - if (part.type === "error") { - emitChatEvent(managed, { - type: "error", - message: String(part.error ?? "Stream error."), - turnId + usage = mergeUsagePayloads(usage, iterationUsage); + finalAssistantText += assistantText; + shouldAutoContinue = shouldAutoContinueUnifiedLocalTurn({ + modelDescriptor: runtime.modelDescriptor, + permissionMode: runtime.permissionMode, + assistantText, + toolCallCount: iterationToolCallCount, + toolResultCount: iterationToolResultCount, + continuationCount: autoContinuationCount, + pendingApprovalCount: runtime.pendingApprovals.size, + }); + if (shouldAutoContinue && !runtime.interrupted) { + logger.info("agent_chat.unified_local_auto_continue", { + sessionId: managed.session.id, + turnId, + modelId: runtime.modelDescriptor.id, + continuationCount: autoContinuationCount + 1, }); + baseTurnMessages.push({ role: "assistant", content: assistantText }); + baseTurnMessages.push({ role: "user", content: UNIFIED_LOCAL_AUTO_CONTINUE_PROMPT }); + autoContinuationCount += 1; } - } + } while (shouldAutoContinue && !runtime.interrupted); // ── Shared turn completion ── persistDeliveredLaneDirectiveKey(managed, args.laneDirectiveKey); @@ -7765,8 +7890,8 @@ export function createAgentChatService(args: { }); persistChatState(managed); } else { - if (assistantText.trim().length) { - runtime.messages.push({ role: "assistant", content: assistantText }); + if (finalAssistantText.trim().length) { + runtime.messages.push({ role: "assistant", content: finalAssistantText }); } runtime.busy = false; @@ -7784,10 +7909,10 @@ export function createAgentChatService(args: { ...(usage ? { usage } : {}) }); - if (assistantText.trim().length > 0) { + if (finalAssistantText.trim().length > 0) { appendWorkerActivityToCto(managed, { activityType: "chat_turn", - summary: assistantText, + summary: finalAssistantText, }); } @@ -12224,6 +12349,7 @@ export function createAgentChatService(args: { await ensureCursorRuntime(managed); managed.session.unifiedPermissionMode = persisted?.unifiedPermissionMode ?? managed.session.unifiedPermissionMode; managed.session.permissionMode = syncLegacyPermissionMode(managed.session) ?? managed.session.permissionMode; + enforceManagedLocalHarnessPermissionMode(managed); sessionService.setResumeCommand(sessionId, `chat:cursor:${sessionId}`); } else if (managed.runtime?.kind === "unified" || (managed.session.modelId && !providerResolver.isModelCliWrapped(managed.session.modelId))) { // Unified runtime resume — re-resolve the model @@ -12236,6 +12362,7 @@ export function createAgentChatService(args: { } managed.session.unifiedPermissionMode = persisted?.unifiedPermissionMode ?? managed.session.unifiedPermissionMode; managed.session.permissionMode = syncLegacyPermissionMode(managed.session) ?? managed.session.permissionMode; + enforceManagedLocalHarnessPermissionMode(managed, managed.runtime.modelDescriptor); managed.runtime.permissionMode = resolveSessionUnifiedPermissionMode( managed.session, resolveChatConfig().unifiedPermissionMode, @@ -12430,6 +12557,7 @@ export function createAgentChatService(args: { managed.session.provider, ); applyLegacyPermissionModeToNativeControls(managed.session, managed.session.permissionMode); + enforceManagedLocalHarnessPermissionMode(managed); normalizeSessionNativePermissionControls(managed.session, resolveChatConfig()); managed.selectedExecutionLaneId = selectedExecutionLaneId ?? managed.selectedExecutionLaneId; refreshReconstructionContext(managed, { includeConversationTail: usesIdentityContinuity(managed) }); @@ -12701,6 +12829,11 @@ export function createAgentChatService(args: { managed.session.unifiedPermissionMode = "edit"; } managed.session.permissionMode = syncLegacyPermissionMode(managed.session) ?? managed.session.permissionMode; + enforceManagedLocalHarnessPermissionMode(managed, managed.runtime.modelDescriptor); + managed.runtime.permissionMode = resolveSessionUnifiedPermissionMode( + managed.session, + resolveChatConfig().unifiedPermissionMode, + ); managed.runtime.pendingApprovals.delete(itemId); pending.resolve({ decision: resolvedDecision, answers, responseText }); emitPendingInputResolved(managed, { @@ -12972,6 +13105,7 @@ export function createAgentChatService(args: { ); applyLegacyPermissionModeToNativeControls(managed.session, managed.session.permissionMode); } + enforceManagedLocalHarnessPermissionMode(managed, descriptor); normalizeSessionNativePermissionControls(managed.session, chatConfig); // Apply reasoningEffort BEFORE pre-warming so the V2 session is created @@ -13078,6 +13212,7 @@ export function createAgentChatService(args: { || cursorModeId !== undefined || cursorConfigValues !== undefined ) { + enforceManagedLocalHarnessPermissionMode(managed); normalizeSessionNativePermissionControls(managed.session, chatConfig); if (managed.runtime?.kind === "unified") { managed.runtime.permissionMode = resolveSessionUnifiedPermissionMode( @@ -13430,6 +13565,10 @@ export function createAgentChatService(args: { chatSessionId: string; title: string; body: string; + source?: PendingInputSource; + providerMetadata?: Record; + eventDescription?: string; + eventDetail?: Record; questions?: Array<{ id?: string; header?: string; @@ -13552,7 +13691,7 @@ export function createAgentChatService(args: { const request: PendingInputRequest = { requestId: itemId, itemId, - source: "ade", + source: args.source ?? "ade", kind: questions.some((q) => q.options?.length) ? "structured_question" : "question", title: args.title, description: questions[0]?.question ?? args.body, @@ -13560,6 +13699,7 @@ export function createAgentChatService(args: { allowsFreeform: true, blocking: true, canProceedWithoutAnswer: false, + ...(args.providerMetadata ? { providerMetadata: args.providerMetadata } : {}), turnId: managed.runtime?.activeTurnId ?? null, }; @@ -13571,7 +13711,8 @@ export function createAgentChatService(args: { managed.localPendingInputs.set(itemId, { request, resolve }); emitPendingInputRequest(managed, request, { kind: "tool_call", - description: request.description ?? args.body, + description: args.eventDescription ?? request.description ?? args.body, + ...(args.eventDetail !== undefined ? { detail: args.eventDetail } : {}), }); }); diff --git a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts index dd7b8457f..8b79f7659 100644 --- a/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts +++ b/apps/desktop/src/main/services/orchestrator/unifiedOrchestratorAdapter.ts @@ -321,7 +321,7 @@ function resolveManagedPermissionMode(args: { : undefined; if (args.descriptor?.authTypes.includes("local")) { if (args.descriptor.harnessProfile === "read_only") return "plan"; - if (args.descriptor.harnessProfile === "guarded" && normalizedCandidate == null) return "plan"; + if (args.descriptor.harnessProfile === "guarded") return "plan"; } return normalizedCandidate; } diff --git a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx index 86eb1d5cc..b16c0daaa 100644 --- a/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx +++ b/apps/desktop/src/renderer/components/chat/AgentChatPane.tsx @@ -781,7 +781,7 @@ export function AgentChatPane({ (entry): entry is { type: "local"; provider: LocalProviderFamily; endpoint: string } => entry.type === "local" && entry.provider === provider, ) ?? null; - const modelIds = runtimeConnection?.loadedModelIds?.length + const modelIds = runtimeConnection?.loadedModelIds !== undefined && runtimeConnection.loadedModelIds !== null ? runtimeConnection.loadedModelIds.filter((id): id is string => String(id ?? "").startsWith(`${provider}/`)) : availableModelIds.filter((id) => id.startsWith(`${provider}/`)); return { @@ -1056,7 +1056,11 @@ export function AgentChatPane({ } for (const model of unifiedModels) { const resolved = resolveRegistryModelId(model.id); - if (resolved) available.add(resolved); + if (resolved) { + available.add(resolved); + } else { + available.add(model.id); + } } const ordered = MODEL_REGISTRY diff --git a/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx b/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx index 2c7c607fe..7ba8ffea7 100644 --- a/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx +++ b/apps/desktop/src/renderer/components/onboarding/ProjectSetupPage.tsx @@ -56,7 +56,7 @@ const STEP_HEADERS: Record = { tools: { heading: "Developer Tools", sub: "ADE needs git for version control. GitHub CLI unlocks PR creation, review requests, and CI checks." }, ai: { heading: "Connect AI providers", - sub: "Link API keys, CLIs, and local runtimes (LM Studio, Ollama, vLLM) so ADE can power chat, codegen, and background automations. After the CLI is installed and signed in, Cursor models appear in work chat automatically.", + sub: "Link API keys, CLIs, and local runtimes (LM Studio, Ollama, vLLM) so ADE can power chat, codegen, and background helpers. After the CLI is installed and signed in, Cursor models appear in work chat automatically.", }, helpers: { heading: "Background Helpers", sub: "These lightweight AI automations run in the background while you work. All are optional and can be changed anytime in Settings." }, github: { heading: "GitHub Integration", sub: "A personal access token lets ADE create PRs, request reviews, and monitor CI on your behalf." }, diff --git a/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx b/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx index f3340bc7f..3788bb112 100644 --- a/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx +++ b/apps/desktop/src/renderer/components/settings/ProvidersSection.test.tsx @@ -30,6 +30,22 @@ function buildStatus(claudeRuntimeAvailable: boolean, localModels: string[] = [] ] : [], availableModelIds: localModels, + runtimeConnections: { + lmstudio: { + provider: "lmstudio", + label: "LM Studio", + kind: "local", + endpoint: "http://localhost:1234", + configured: true, + authAvailable: false, + runtimeDetected: localModels.length > 0, + runtimeAvailable: localModels.length > 0, + health: localModels.length > 0 ? "ready" : "unreachable", + blocker: localModels.length > 0 ? null : "No lmstudio runtime with loaded models was detected.", + loadedModelIds: localModels, + lastCheckedAt: "2026-03-17T19:00:00.000Z", + }, + }, providerConnections: { claude: { provider: "claude", @@ -172,7 +188,7 @@ describe("ProvidersSection", () => { expect((await screen.findAllByText("LM Studio")).length).toBeGreaterThan(0); expect(screen.getAllByText("Ready").length).toBeGreaterThan(0); - expect(screen.getAllByText("LM Studio is reachable at http://localhost:1234. ADE can use 2 loaded models from this runtime.").length).toBeGreaterThan(0); + expect(screen.getAllByText("LM Studio is reachable at http://localhost:1234. ADE can use 2 loaded models from this runtime (ready).").length).toBeGreaterThan(0); expect(screen.getAllByText("meta-llama-3.1-70b-instruct (LM Studio)").length).toBeGreaterThan(0); expect(screen.getAllByText("qwen2.5-coder:32b (LM Studio)").length).toBeGreaterThan(0); }); diff --git a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx index 7b934afec..d9c6efc8a 100644 --- a/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx +++ b/apps/desktop/src/renderer/components/settings/ProvidersSection.tsx @@ -284,7 +284,9 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh ]); setStatus(nextStatus); setProjectConfigSnapshot(nextProjectConfig); - setLocalProviderDrafts(buildLocalProviderDrafts(nextProjectConfig, nextStatus)); + if (editingLocalProvider == null && savingLocalProvider == null) { + setLocalProviderDrafts(buildLocalProviderDrafts(nextProjectConfig, nextStatus)); + } setStoredProviders(nextStoredProviders.map((entry) => entry.trim().toLowerCase()).filter(Boolean)); } catch (err) { setError(err instanceof Error ? err.message : String(err)); @@ -293,7 +295,7 @@ export function ProvidersSection({ forceRefreshOnMount = false }: { forceRefresh setLoading(false); } } - }, []); + }, [editingLocalProvider, savingLocalProvider]); useEffect(() => { void refreshStatus(forceRefreshOnMount ? { force: true } : undefined); diff --git a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx index 062ee112a..a2c65fd6a 100644 --- a/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx +++ b/apps/desktop/src/renderer/components/shared/UnifiedModelSelector.tsx @@ -135,7 +135,7 @@ function createUnknownModelPlaceholder(modelId: string): ModelDescriptor { isCliWrapped: false, discoverySource: localProvider === "lmstudio" ? "lmstudio-openai" : localProvider, harnessProfile: "guarded", - aliases: [brand], + aliases: brand ? [brand] : [], }; } return { diff --git a/apps/desktop/src/renderer/lib/modelOptions.ts b/apps/desktop/src/renderer/lib/modelOptions.ts index c6e5fc1f3..b059a6fdb 100644 --- a/apps/desktop/src/renderer/lib/modelOptions.ts +++ b/apps/desktop/src/renderer/lib/modelOptions.ts @@ -162,6 +162,9 @@ export function deriveConfiguredModelIds( if (!isLocalProviderFamily(normalizedProvider)) { continue; } + if (connection == null || connection.runtimeAvailable !== true) { + continue; + } if (!hasDynamicLocalModelIdsForProvider(normalizedProvider, status.availableModelIds, runtimeConnections)) { addKnownModelIds(ids, normalizedProvider, false); } diff --git a/apps/desktop/src/test/setup.ts b/apps/desktop/src/test/setup.ts index 06654363c..5f3b472be 100644 --- a/apps/desktop/src/test/setup.ts +++ b/apps/desktop/src/test/setup.ts @@ -33,10 +33,10 @@ function shouldTrackTempDir(dirPath: string): boolean { function cleanupTrackedTempDirs(): void { const state = getTestTempTrackerState(); const targets = [...state.trackedDirs].sort((left, right) => right.length - left.length); - state.trackedDirs.clear(); for (const target of targets) { try { fs.rmSync(target, { recursive: true, force: true }); + state.trackedDirs.delete(target); } catch { // Best-effort cleanup only for test temp roots. } diff --git a/apps/mcp-server/src/test/setup.ts b/apps/mcp-server/src/test/setup.ts index b27c9d28f..2df0e44b9 100644 --- a/apps/mcp-server/src/test/setup.ts +++ b/apps/mcp-server/src/test/setup.ts @@ -33,10 +33,10 @@ function shouldTrackTempDir(dirPath: string): boolean { function cleanupTrackedTempDirs(): void { const state = getTestTempTrackerState(); const targets = [...state.trackedDirs].sort((left, right) => right.length - left.length); - state.trackedDirs.clear(); for (const target of targets) { try { fs.rmSync(target, { recursive: true, force: true }); + state.trackedDirs.delete(target); } catch { // Best-effort cleanup only for test temp roots. }