diff --git a/README.md b/README.md index f54c56f5c..45c960cc8 100644 --- a/README.md +++ b/README.md @@ -398,7 +398,7 @@ codegraph cycles --functions # Function-level cycles ### Semantic Search -Local embeddings for every function, method, and class — search by natural language. Everything runs locally using [@huggingface/transformers](https://huggingface.co/docs/transformers.js) — no API keys needed. +Local embeddings for every function, method, and class — search by natural language. Everything runs locally using [@huggingface/transformers](https://huggingface.co/docs/transformers.js) — no API keys needed. Prefer a remote or self-hosted model instead? Set `embeddings.provider: "openai"` and `llm.baseUrl` in your config to call any OpenAI-compatible `/embeddings` endpoint — see [configuration.md](docs/guides/configuration.md#embeddings-embeddings). ```bash codegraph embed # Build embeddings (default: nomic) diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index f3d34c17f..02e08c5d6 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -173,18 +173,40 @@ Defaults applied to graph queries when the CLI flag is omitted. ## Embeddings (`embeddings`) -Controls the local embedding model used by `codegraph embed` and `codegraph search`. +Controls the embedding backend used by `codegraph embed` and `codegraph search`. | Key | Type | Default | Purpose | |-----|------|---------|---------| -| `model` | `string \| null` | `null` | Model registry key (see `src/domain/search/models.ts`). When `null`, `codegraph embed` reuses the model already stored in the database, or falls back to the built-in default (`"nomic"`) for fresh graphs. Common options: `"nomic"`, `"nomic-v1.5"`, `"bge-large"`. | +| `model` | `string \| null` | `null` | When `provider` is `null` (local, default): a model registry key (see `src/domain/search/models.ts`). When `null`, `codegraph embed` reuses the model already stored in the database, or falls back to the built-in default (`"nomic"`) for fresh graphs. Common options: `"nomic"`, `"nomic-v1.5"`, `"bge-large"`. When `provider` is `"openai"`: the model identifier your endpoint expects (e.g. `"text-embedding-3-small"`, or whatever name your self-hosted server registers) — required in that case. | | `llmProvider` | `string \| null` | `null` | Optional LLM provider for query expansion. `null` disables it. | +| `provider` | `string \| null` | `null` | Embedding backend. `null` (default) uses the local bundled model via `@huggingface/transformers`. `"openai"` calls a remote OpenAI-compatible `/embeddings` endpoint configured via `llm.baseUrl` — this covers self-hosted servers (text-embeddings-inference, Ollama, LM Studio, vLLM, etc.), not just OpenAI itself. | + +### Remote embedding provider + +Point `codegraph embed` at a self-hosted or third-party embedding endpoint instead of downloading a local model: + +```json +{ + "embeddings": { + "provider": "openai", + "model": "my-embedding-model" + }, + "llm": { + "baseUrl": "http://my-tailnet-host:8080/v1", + "apiKeyCommand": "op read op://vault/embeddings/api-key" + } +} +``` + +The endpoint must accept `POST /embeddings` with `{ "model": "...", "input": ["text", ...] }` and return `{ "data": [{ "embedding": [...], "index": 0 }, ...] }` — the same shape OpenAI's API uses. `llm.apiKey`/`llm.apiKeyCommand` are optional; omit them for endpoints that don't require auth. Vector dimensionality is read from the response, so there's no model registry to keep in sync. + +`codegraph search` (semantic and hybrid modes) and the `semantic_search` MCP tool automatically embed the query through the same remote endpoint when the stored embeddings were built with `embeddings.provider: "openai"` — no extra configuration needed. --- ## LLM credentials (`llm`) -Used by features that call out to a chat-completion API (e.g. query expansion). Codegraph never hardcodes a provider — you pick one. +Used by features that call out to a chat-completion API (e.g. query expansion), and reused by the [remote embedding provider](#embeddings-embeddings) (`baseUrl`, `apiKey`, `apiKeyCommand`) so credentials aren't duplicated across features. Codegraph never hardcodes a provider — you pick one. | Key | Type | Default | Purpose | |-----|------|---------|---------| @@ -193,6 +215,7 @@ Used by features that call out to a chat-completion API (e.g. query expansion). | `baseUrl` | `string \| null` | `null` | Override the provider's base URL (for compatible proxies, local servers, etc.). | | `apiKey` | `string \| null` | `null` | Plaintext API key. Prefer `apiKeyCommand` or env vars over this. | | `apiKeyCommand` | `string \| null` | `null` | Shell-out command that prints the key to stdout. Split on whitespace and run via `execFileSync` (no shell — `$(...)`, pipes, globs, and variable expansion are not supported). 10s timeout, 64 KB max output. | +| `requestTimeoutMs` | `number` | `120000` | Per-request timeout for remote HTTP calls made against `baseUrl` (currently the [remote embedding provider](#embeddings-embeddings)). Aborts and throws if a self-hosted server hangs mid-request instead of blocking indefinitely. | Resolution order (first non-empty wins): `apiKeyCommand` output → `CODEGRAPH_LLM_API_KEY` env var → `apiKey` field. @@ -215,6 +238,7 @@ These env vars override the corresponding `llm.*` fields when set: - `CODEGRAPH_LLM_PROVIDER` → `llm.provider` - `CODEGRAPH_LLM_MODEL` → `llm.model` - `CODEGRAPH_LLM_API_KEY` → `llm.apiKey` +- `CODEGRAPH_LLM_BASE_URL` → `llm.baseUrl` --- diff --git a/src/cli/commands/embed.ts b/src/cli/commands/embed.ts index 547ef8b11..0dd94c101 100644 --- a/src/cli/commands/embed.ts +++ b/src/cli/commands/embed.ts @@ -6,6 +6,7 @@ import { DEFAULT_MODEL, EMBEDDING_STRATEGIES, MODELS, + resolveRemoteEmbeddingOptions, } from '../../domain/search/index.js'; import { info, warn } from '../../infrastructure/logger.js'; import type { CommandDefinition } from '../types.js'; @@ -48,15 +49,29 @@ export const command: CommandDefinition = { ], ['-d, --db ', 'Path to graph.db'], ], - validate([_dir], opts) { + validate([_dir], opts, ctx) { if (!(EMBEDDING_STRATEGIES as readonly string[]).includes(opts.strategy)) { return `Unknown strategy: ${opts.strategy}. Available: ${EMBEDDING_STRATEGIES.join(', ')}`; } + const provider = ctx.config.embeddings?.provider ?? null; + if (provider && provider !== 'openai') { + return ( + `Unsupported embeddings.provider "${provider}". Currently supported: "openai" ` + + '(any OpenAI-compatible /embeddings endpoint, including self-hosted servers).' + ); + } + if (provider && !opts.model && !ctx.config.embeddings?.model) { + return ( + `embeddings.provider is set to "${provider}" but no model is configured. ` + + 'Set embeddings.model to the model identifier your endpoint expects, or pass --model.' + ); + } }, async execute([dir], opts, ctx) { const root = path.resolve(dir || '.'); const dbPath = opts.db as string | undefined; const embeddingsConfig = ctx.config.embeddings; + const provider = embeddingsConfig?.provider ?? null; const flagModel = opts.model as string | undefined; const configModel = (embeddingsConfig?.model as string | null | undefined) ?? null; @@ -65,6 +80,10 @@ export const command: CommandDefinition = { model = flagModel; } else if (configModel) { model = configModel; + } else if (provider) { + // Unreachable in practice — validate() rejects a provider with no model + // before execute() runs — but keeps this branch type-safe. + model = DEFAULT_MODEL; } else { const sticky = resolveStickyModel(dbPath); if (sticky) { @@ -77,6 +96,8 @@ export const command: CommandDefinition = { } } - await buildEmbeddings(root, model, dbPath, { strategy: opts.strategy }); + const remote = + provider === 'openai' ? resolveRemoteEmbeddingOptions(ctx.config, model) : undefined; + await buildEmbeddings(root, model, dbPath, { strategy: opts.strategy, remote }); }, }; diff --git a/src/cli/commands/models.ts b/src/cli/commands/models.ts index e575e36ee..2bf2854b8 100644 --- a/src/cli/commands/models.ts +++ b/src/cli/commands/models.ts @@ -7,6 +7,14 @@ export const command: CommandDefinition = { execute(_args, _opts, ctx) { const embeddingsConfig = ctx.config.embeddings; const defaultModel = (embeddingsConfig?.model as string) || DEFAULT_MODEL; + + if (embeddingsConfig?.provider) { + const remoteModel = embeddingsConfig.model || '(not configured — set embeddings.model)'; + console.log( + `\nembeddings.provider is set to "${embeddingsConfig.provider}" — codegraph embed will call ` + + `model "${remoteModel}" at llm.baseUrl instead of a local model below.`, + ); + } console.log('\nAvailable embedding models:\n'); interface ModelEntry { diff --git a/src/domain/search/generator.ts b/src/domain/search/generator.ts index 02e43f1ca..d9b47cf0a 100644 --- a/src/domain/search/generator.ts +++ b/src/domain/search/generator.ts @@ -5,6 +5,11 @@ import { warn } from '../../infrastructure/logger.js'; import { DbError } from '../../shared/errors.js'; import type { BetterSqlite3Database, NodeRow } from '../../types.js'; import { embed, getModelConfig } from './models.js'; +import { + DEFAULT_REMOTE_CONTEXT_WINDOW, + embedRemote, + type RemoteEmbeddingOptions, +} from './providers/remote.js'; import { buildSourceText } from './strategies/source.js'; import { buildStructuredText } from './strategies/structured.js'; @@ -167,6 +172,7 @@ function persistEmbeddings( dim: number, modelName: string, strategy: EmbeddingStrategy, + provider: string | null, ): void { const { nodeIds, nodeNames, previews, texts, overflowCount } = prepared; const insert = db.prepare( @@ -189,12 +195,25 @@ function persistEmbeddings( if (overflowCount > 0) { insertMeta.run('truncated_count', String(overflowCount)); } + // Record which backend produced these vectors so search-time routing + // (`embedQuery` in `search/semantic.ts`) can key off embed-time truth + // instead of the live config, which may have drifted since `embed` ran. + if (provider) { + insertMeta.run('provider', provider); + } }); insertAll(); } export interface BuildEmbeddingsOptions { strategy?: EmbeddingStrategy; + /** + * When set, embeddings are generated via a remote OpenAI-compatible + * endpoint instead of the local bundled model. `modelKey` is then treated + * as an opaque model identifier passed straight to the endpoint, not a + * local registry key. + */ + remote?: RemoteEmbeddingOptions; } /** @@ -225,12 +244,21 @@ export async function buildEmbeddings( const nodeCount = [...byFile.values()].reduce((acc, list) => acc + list.length, 0); console.log(`Building embeddings for ${nodeCount} symbols (strategy: ${strategy})...`); - const config = getModelConfig(modelKey); - const prepared = prepareEmbeddingTexts(byFile, db, resolvedRoot, strategy, config.contextWindow); + let contextWindow: number; + let displayName: string; + if (options.remote) { + contextWindow = DEFAULT_REMOTE_CONTEXT_WINDOW; + displayName = options.remote.model; + } else { + const modelConfig = getModelConfig(modelKey); + contextWindow = modelConfig.contextWindow; + displayName = modelConfig.name; + } + const prepared = prepareEmbeddingTexts(byFile, db, resolvedRoot, strategy, contextWindow); if (prepared.overflowCount > 0) { warn( - `${prepared.overflowCount} symbol(s) exceeded model context window (${config.contextWindow} tokens) and were truncated`, + `${prepared.overflowCount} symbol(s) exceeded model context window (${contextWindow} tokens) and were truncated`, ); } @@ -247,13 +275,22 @@ export async function buildEmbeddings( ); } - console.log(`Embedding ${prepared.texts.length} symbols...`); - const { vectors, dim } = await embed(prepared.texts, modelKey); + console.log( + `Embedding ${prepared.texts.length} symbols${options.remote ? ` via remote provider (${displayName})` : ''}...`, + ); + const { vectors, dim } = options.remote + ? await embedRemote(prepared.texts, options.remote) + : await embed(prepared.texts, modelKey); - persistEmbeddings(db, prepared, vectors as Float32Array[], dim, config.name, strategy); + // Only "openai" (OpenAI-compatible /embeddings) is currently supported as a + // remote provider — `options.remote` being set implies it. Recorded so + // search-time routing doesn't have to trust the live config (see + // `embedQuery` in `search/semantic.ts`). + const provider = options.remote ? 'openai' : null; + persistEmbeddings(db, prepared, vectors as Float32Array[], dim, displayName, strategy, provider); console.log( - `\nStored ${vectors.length} embeddings (${dim}d, ${config.name}, strategy: ${strategy}) in graph.db`, + `\nStored ${vectors.length} embeddings (${dim}d, ${displayName}, strategy: ${strategy}) in graph.db`, ); closeDb(db); } diff --git a/src/domain/search/index.ts b/src/domain/search/index.ts index dc3ba85c1..6b7c355c6 100644 --- a/src/domain/search/index.ts +++ b/src/domain/search/index.ts @@ -8,6 +8,8 @@ export type { BuildEmbeddingsOptions } from './generator.js'; export { buildEmbeddings, estimateTokens } from './generator.js'; export type { ModelConfig } from './models.js'; export { DEFAULT_MODEL, disposeModel, EMBEDDING_STRATEGIES, embed, MODELS } from './models.js'; +export type { RemoteEmbeddingOptions } from './providers/remote.js'; +export { embedRemote, resolveRemoteEmbeddingOptions } from './providers/remote.js'; export { search } from './search/cli-formatter.js'; export { hybridSearchData } from './search/hybrid.js'; export { ftsSearchData } from './search/keyword.js'; diff --git a/src/domain/search/providers/remote.ts b/src/domain/search/providers/remote.ts new file mode 100644 index 000000000..545a1b1d9 --- /dev/null +++ b/src/domain/search/providers/remote.ts @@ -0,0 +1,165 @@ +import { ConfigError, EngineError } from '../../../shared/errors.js'; +import type { CodegraphConfig } from '../../../types.js'; + +/** Batch size for remote `/embeddings` requests. Conservative default — most + * OpenAI-compatible servers accept much larger batches, but this keeps + * individual request bodies and timeouts predictable across unknown hosts. */ +const REMOTE_BATCH_SIZE = 32; + +/** + * Context window assumed for remote models when truncating oversized symbols. + * Remote model context limits aren't known ahead of time (unlike the local + * registry in `models.ts`), so this is a conservative default matching most + * modern embedding models rather than a per-model lookup. + */ +export const DEFAULT_REMOTE_CONTEXT_WINDOW = 8192; + +export interface RemoteEmbeddingOptions { + baseUrl: string; + model: string; + apiKey?: string | null; + /** Per-request timeout in ms. Defaults to `DEFAULT_REQUEST_TIMEOUT_MS` when omitted. */ + timeoutMs?: number; +} + +/** + * Fallback per-request timeout when `RemoteEmbeddingOptions.timeoutMs` isn't + * supplied (e.g. direct `embedRemote` calls that bypass config resolution). + * Mirrors `DEFAULTS.llm.requestTimeoutMs` in `infrastructure/config.ts`. + */ +const DEFAULT_REQUEST_TIMEOUT_MS = 120_000; + +interface OpenAIEmbeddingItem { + embedding: number[]; + index: number; +} + +interface OpenAIEmbeddingResponse { + data: OpenAIEmbeddingItem[]; +} + +function embeddingsEndpoint(baseUrl: string): string { + const trimmed = baseUrl.replace(/\/+$/, ''); + return trimmed.endsWith('/embeddings') ? trimmed : `${trimmed}/embeddings`; +} + +/** + * Resolve the remote embedding endpoint config from `llm.*`, given the + * already-resolved model identifier (from `--model` / `embeddings.model`). + * Throws a ConfigError if `llm.baseUrl` isn't set — there's no sensible + * default host for a self-hosted endpoint. + */ +export function resolveRemoteEmbeddingOptions( + config: Pick, + model: string, +): RemoteEmbeddingOptions { + const baseUrl = config.llm.baseUrl; + if (!baseUrl) { + throw new ConfigError( + 'embeddings.provider is "openai" but llm.baseUrl is not set. ' + + 'Point it at your embeddings endpoint, e.g. "http://localhost:8080/v1" ' + + '(config key "llm.baseUrl" or env var CODEGRAPH_LLM_BASE_URL).', + ); + } + return { + baseUrl, + model, + apiKey: config.llm.apiKey, + timeoutMs: config.llm.requestTimeoutMs, + }; +} + +/** + * Generate embeddings via a remote OpenAI-compatible `/embeddings` endpoint. + * Works with OpenAI itself and any self-hosted server implementing the same + * request/response shape (text-embeddings-inference, Ollama, LM Studio, vLLM). + */ +export async function embedRemote( + texts: string[], + options: RemoteEmbeddingOptions, +): Promise<{ vectors: Float32Array[]; dim: number }> { + if (texts.length === 0) return { vectors: [], dim: 0 }; + + const url = embeddingsEndpoint(options.baseUrl); + const headers: Record = { 'Content-Type': 'application/json' }; + if (options.apiKey) headers.Authorization = `Bearer ${options.apiKey}`; + + const results: Float32Array[] = []; + const timeoutMs = options.timeoutMs ?? DEFAULT_REQUEST_TIMEOUT_MS; + let dim = 0; + + for (let i = 0; i < texts.length; i += REMOTE_BATCH_SIZE) { + const batch = texts.slice(i, i + REMOTE_BATCH_SIZE); + + const controller = new AbortController(); + const timeoutHandle = setTimeout(() => controller.abort(), timeoutMs); + + let response: Response; + try { + response = await fetch(url, { + method: 'POST', + headers, + body: JSON.stringify({ model: options.model, input: batch }), + signal: controller.signal, + }); + } catch (err: unknown) { + if (err instanceof Error && err.name === 'AbortError') { + throw new EngineError( + `Remote embedding endpoint ${url} did not respond within ${timeoutMs}ms ` + + `(batch ${Math.floor(i / REMOTE_BATCH_SIZE) + 1})`, + ); + } + throw new EngineError( + `Failed to reach remote embedding endpoint at ${url}: ${err instanceof Error ? err.message : String(err)}`, + { cause: err instanceof Error ? err : undefined }, + ); + } finally { + clearTimeout(timeoutHandle); + } + + if (!response.ok) { + const body = await response.text().catch(() => ''); + throw new EngineError( + `Remote embedding endpoint ${url} returned ${response.status} ${response.statusText}` + + (body ? `: ${body.slice(0, 500)}` : ''), + ); + } + + const json = (await response.json()) as OpenAIEmbeddingResponse; + if (!Array.isArray(json.data) || json.data.length !== batch.length) { + throw new EngineError( + `Remote embedding endpoint ${url} returned an unexpected response shape ` + + `(expected ${batch.length} embeddings, got ${json.data?.length ?? 0})`, + ); + } + + // OpenAI-compatible servers aren't guaranteed to preserve input order — sort by index. + const sorted = [...json.data].sort((a, b) => a.index - b.index); + for (const item of sorted) { + if (!Array.isArray(item.embedding)) { + throw new EngineError( + `Remote embedding endpoint ${url} returned an item with a missing or non-array ` + + `"embedding" field (index ${item.index})`, + ); + } + const vec = Float32Array.from(item.embedding); + if (dim === 0) { + dim = vec.length; + } else if (vec.length !== dim) { + throw new EngineError( + `Remote embedding endpoint ${url} returned inconsistent vector dimensions ` + + `(expected ${dim}, got ${vec.length} for response item index ${item.index})`, + ); + } + results.push(vec); + } + + if (texts.length > REMOTE_BATCH_SIZE) { + process.stderr.write( + ` Embedded ${Math.min(i + REMOTE_BATCH_SIZE, texts.length)}/${texts.length}\r`, + ); + } + } + + return { vectors: results, dim }; +} diff --git a/src/domain/search/search/prepare.ts b/src/domain/search/search/prepare.ts index 3907aa5b6..a28330ced 100644 --- a/src/domain/search/search/prepare.ts +++ b/src/domain/search/search/prepare.ts @@ -20,6 +20,15 @@ export interface PreparedSearch { }>; modelKey: string | null; storedDim: number | null; + /** Raw model identifier recorded at embed time — set even when it isn't a + * local registry key (e.g. a remote provider's model name). */ + storedModel: string | null; + /** + * Embedding backend recorded at embed time (e.g. `"openai"`), or `null` for + * the local bundled model. Search-time routing must key off this rather + * than the live config — the config may have changed since `embed` ran. + */ + storedProvider: string | null; } export interface PrepareSearchOpts { @@ -44,6 +53,7 @@ export function prepareSearch( } const storedModel = getEmbeddingMeta(db, 'model') || null; + const storedProvider = getEmbeddingMeta(db, 'provider') || null; const dimStr = getEmbeddingMeta(db, 'dim'); const storedDim = dimStr ? parseInt(dimStr, 10) : null; @@ -87,7 +97,7 @@ export function prepareSearch( let rows = db.prepare(sql).all(...params) as PreparedSearch['rows']; rows = applyFilters(rows, opts); - return { db, rows, modelKey, storedDim }; + return { db, rows, modelKey, storedDim, storedModel, storedProvider }; } catch (err) { db.close(); throw err; diff --git a/src/domain/search/search/semantic.ts b/src/domain/search/search/semantic.ts index 2c0b82616..07f8f4bb2 100644 --- a/src/domain/search/search/semantic.ts +++ b/src/domain/search/search/semantic.ts @@ -2,10 +2,42 @@ import { loadConfig } from '../../../infrastructure/config.js'; import { warn } from '../../../infrastructure/logger.js'; import type { BetterSqlite3Database, CodegraphConfig } from '../../../types.js'; import { normalizeSymbol } from '../../queries.js'; -import { embed } from '../models.js'; +import { embed, MODELS } from '../models.js'; +import { embedRemote, resolveRemoteEmbeddingOptions } from '../providers/remote.js'; import { cosineSim } from '../stores/sqlite-blob.js'; import { type PreparedSearch, prepareSearch } from './prepare.js'; +/** + * Embed query text with whichever backend produced the stored embeddings. + * `modelKey` is a resolved local registry key (from `--model` or matched + * against `MODELS`), or an arbitrary identifier (an explicit `--model` + * override, or unmatched) when embeddings were built via a remote provider. + * + * Routing is decided from `storedProvider` — the provider recorded in + * `embedding_meta` at embed time — rather than the live config. If the + * config drifted after `embed` ran (e.g. `embeddings.provider` unset later + * on a different machine), trusting live config here would silently route + * the query through the wrong backend instead of the one that actually + * produced the stored vectors, which can produce misleading similarity + * scores rather than an obvious error. + */ +async function embedQuery( + texts: string[], + config: CodegraphConfig, + modelKey: string | null, + storedModel: string | null, + storedProvider: string | null, +): Promise<{ vectors: Float32Array[]; dim: number }> { + const isKnownLocalModel = modelKey != null && modelKey in MODELS; + if (!isKnownLocalModel && storedProvider === 'openai') { + const remoteModel = modelKey || storedModel; + if (remoteModel) { + return embedRemote(texts, resolveRemoteEmbeddingOptions(config, remoteModel)); + } + } + return embed(texts, modelKey ?? undefined); +} + export interface SemanticSearchOpts { config?: CodegraphConfig; limit?: number; @@ -61,13 +93,13 @@ export async function searchData( const prepared = prepareSearch(customDbPath, opts); if (!prepared) return null; - const { db, rows, modelKey, storedDim } = prepared; + const { db, rows, modelKey, storedDim, storedModel, storedProvider } = prepared; try { const { vectors: [queryVec], dim, - } = await embed([query], modelKey ?? undefined); + } = await embedQuery([query], config, modelKey, storedModel, storedProvider); if (checkDimensionMismatch(storedDim, dim)) return null; @@ -192,10 +224,16 @@ export async function multiSearchData( const prepared = prepareSearch(customDbPath, opts); if (!prepared) return null; - const { db, rows, modelKey, storedDim } = prepared; + const { db, rows, modelKey, storedDim, storedModel, storedProvider } = prepared; try { - const { vectors: queryVecs, dim } = await embed(queries, modelKey ?? undefined); + const { vectors: queryVecs, dim } = await embedQuery( + queries, + config, + modelKey, + storedModel, + storedProvider, + ); warnOnSimilarQueries(queries, queryVecs as Float32Array[], similarityWarnThreshold); diff --git a/src/infrastructure/config.ts b/src/infrastructure/config.ts index 1e5492ee9..c8a146873 100644 --- a/src/infrastructure/config.ts +++ b/src/infrastructure/config.ts @@ -38,13 +38,18 @@ export const DEFAULTS = { defaultLimit: 20, excludeTests: false, }, - embeddings: { model: null as string | null, llmProvider: null as string | null }, + embeddings: { + model: null as string | null, + llmProvider: null as string | null, + provider: null as string | null, + }, llm: { provider: null as string | null, model: null as string | null, baseUrl: null as string | null, apiKey: null as string | null, apiKeyCommand: null as string | null, + requestTimeoutMs: 120_000, }, search: { defaultMinScore: 0.2, rrfK: 60, topK: 15, similarityWarnThreshold: 0.85 }, ci: { failOnCycles: false, impactThreshold: null as number | null }, @@ -676,7 +681,12 @@ export function loadConfigWithProvenance( } // Layer 3+: env overrides (LLM keys) - const ENV_LLM_KEYS = ['CODEGRAPH_LLM_PROVIDER', 'CODEGRAPH_LLM_API_KEY', 'CODEGRAPH_LLM_MODEL']; + const ENV_LLM_KEYS = [ + 'CODEGRAPH_LLM_PROVIDER', + 'CODEGRAPH_LLM_API_KEY', + 'CODEGRAPH_LLM_MODEL', + 'CODEGRAPH_LLM_BASE_URL', + ]; if (ENV_LLM_KEYS.some((k) => process.env[k] !== undefined)) { provenance.llm = 'env'; } @@ -688,6 +698,7 @@ const ENV_LLM_MAP: Record = { CODEGRAPH_LLM_PROVIDER: 'provider', CODEGRAPH_LLM_API_KEY: 'apiKey', CODEGRAPH_LLM_MODEL: 'model', + CODEGRAPH_LLM_BASE_URL: 'baseUrl', }; export function applyEnvOverrides(config: CodegraphConfig): CodegraphConfig { diff --git a/src/types.ts b/src/types.ts index 40bcb1b4e..9f85a9f33 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1408,6 +1408,16 @@ export interface CodegraphConfig { embeddings: { model: string | null; llmProvider: string | null; + /** + * Embedding backend for `codegraph embed`. `null` (default) uses the + * local bundled model via `@huggingface/transformers`. `"openai"` calls + * a remote OpenAI-compatible `/embeddings` endpoint configured via + * `llm.baseUrl` — this covers self-hosted servers (text-embeddings-inference, + * Ollama, LM Studio, vLLM, etc.) that implement the same request/response + * shape, not just OpenAI itself. When set, `embeddings.model` must be the + * model identifier the endpoint expects. + */ + provider: string | null; }; llm: { @@ -1423,6 +1433,12 @@ export interface CodegraphConfig { * values are rejected with a `ConfigError` at load time. */ apiKeyCommand: string | null; + /** + * Per-request timeout (ms) for remote HTTP calls made against `llm.baseUrl` + * (currently the remote embedding provider). Prevents an unresponsive + * self-hosted server from hanging the process indefinitely. Default: 120000. + */ + requestTimeoutMs: number; }; search: { diff --git a/tests/search/embedding-provider-metadata.test.ts b/tests/search/embedding-provider-metadata.test.ts new file mode 100644 index 000000000..6fcc0696c --- /dev/null +++ b/tests/search/embedding-provider-metadata.test.ts @@ -0,0 +1,91 @@ +import fs from 'node:fs'; +import os from 'node:os'; +import path from 'node:path'; +import Database from 'better-sqlite3'; +import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, test, vi } from 'vitest'; +import { initSchema } from '../../src/db/index.js'; + +// Local pipeline mock — needed because this suite switches back to the local +// model after a remote run, unlike the other embedding-remote-*.test.ts files +// which only ever exercise the remote path and mock transformers to throw. +vi.mock('@huggingface/transformers', () => ({ + pipeline: async () => async (batch) => { + const dim = 4; + const data = new Float32Array(dim * batch.length); + for (let t = 0; t < batch.length; t++) { + data[t * dim] = 0.5; + } + return { data }; + }, + cos_sim: () => 0, +})); + +import { buildEmbeddings } from '../../src/domain/search/index.js'; + +function insertNode(db, name, kind, file, line, endLine) { + return db + .prepare('INSERT INTO nodes (name, kind, file, line, end_line) VALUES (?, ?, ?, ?, ?)') + .run(name, kind, file, line, endLine).lastInsertRowid; +} + +function getProviderMeta(dbPath: string): string | undefined { + const db = new Database(dbPath, { readonly: true }); + const row = db.prepare("SELECT value FROM embedding_meta WHERE key = 'provider'").get() as + | { value: string } + | undefined; + db.close(); + return row?.value; +} + +describe('embedding_meta provider bookkeeping across provider switches', () => { + let tmpDir: string, dbPath: string; + const fetchMock = vi.fn(); + + beforeAll(() => { + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), 'codegraph-provider-meta-')); + fs.writeFileSync(path.join(tmpDir, 'math.js'), 'export function add(a, b) { return a + b; }\n'); + + const dbDir = path.join(tmpDir, '.codegraph'); + fs.mkdirSync(dbDir, { recursive: true }); + dbPath = path.join(dbDir, 'graph.db'); + + const db = new Database(dbPath); + db.pragma('journal_mode = WAL'); + initSchema(db); + insertNode(db, 'add', 'function', 'math.js', 1, 1); + db.close(); + }); + + afterAll(() => { + if (tmpDir) fs.rmSync(tmpDir, { recursive: true, force: true }); + }); + + beforeEach(() => { + vi.stubGlobal('fetch', fetchMock); + fetchMock.mockResolvedValue( + new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3, 0.4], index: 0 }] }), { + status: 200, + }), + ); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + fetchMock.mockReset(); + }); + + test('a full rebuild with the local model does not carry over a prior remote provider value', async () => { + // `buildEmbeddings` always deletes every embedding_meta row up front + // (loadNodesByFile) before persistEmbeddings writes fresh ones, so a + // later local-model build can never inherit a stale 'openai' marker from + // an earlier remote build — this test locks in that invariant. + await buildEmbeddings(tmpDir, 'my-remote-model', dbPath, { + remote: { baseUrl: 'http://localhost:9999/v1', model: 'my-remote-model', apiKey: 'sk-x' }, + }); + expect(getProviderMeta(dbPath)).toBe('openai'); + + await buildEmbeddings(tmpDir, 'minilm', dbPath, {}); + + expect(getProviderMeta(dbPath)).not.toBe('openai'); + }); +}); diff --git a/tests/search/embedding-remote-generator.test.ts b/tests/search/embedding-remote-generator.test.ts new file mode 100644 index 000000000..28862be9f --- /dev/null +++ b/tests/search/embedding-remote-generator.test.ts @@ -0,0 +1,77 @@ +import fs from 'node:fs'; +import os from 'node:os'; +import path from 'node:path'; +import Database from 'better-sqlite3'; +import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, test, vi } from 'vitest'; +import { initSchema } from '../../src/db/index.js'; +import { buildEmbeddings } from '../../src/domain/search/index.js'; + +// buildEmbeddings must never touch @huggingface/transformers on the remote +// path — mocking it to throw proves the remote branch doesn't fall through +// to the local loader. +vi.mock('@huggingface/transformers', () => { + throw new Error('local transformers pipeline should not be loaded on the remote path'); +}); + +function insertNode(db, name, kind, file, line, endLine) { + return db + .prepare('INSERT INTO nodes (name, kind, file, line, end_line) VALUES (?, ?, ?, ?, ?)') + .run(name, kind, file, line, endLine).lastInsertRowid; +} + +describe('buildEmbeddings with a remote provider', () => { + let tmpDir: string, dbPath: string; + const fetchMock = vi.fn(); + + beforeAll(() => { + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), 'codegraph-remote-embed-')); + fs.writeFileSync(path.join(tmpDir, 'math.js'), 'export function add(a, b) { return a + b; }\n'); + + const dbDir = path.join(tmpDir, '.codegraph'); + fs.mkdirSync(dbDir, { recursive: true }); + dbPath = path.join(dbDir, 'graph.db'); + + const db = new Database(dbPath); + db.pragma('journal_mode = WAL'); + initSchema(db); + insertNode(db, 'add', 'function', 'math.js', 1, 1); + db.close(); + }); + + afterAll(() => { + if (tmpDir) fs.rmSync(tmpDir, { recursive: true, force: true }); + }); + + beforeEach(() => { + vi.stubGlobal('fetch', fetchMock); + fetchMock.mockResolvedValue( + new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3, 0.4], index: 0 }] }), { + status: 200, + }), + ); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + fetchMock.mockReset(); + }); + + test('dispatches to the remote endpoint and persists its response', async () => { + await buildEmbeddings(tmpDir, 'my-remote-model', dbPath, { + remote: { baseUrl: 'http://localhost:9999/v1', model: 'my-remote-model', apiKey: 'sk-x' }, + }); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(fetchMock.mock.calls[0][0]).toBe('http://localhost:9999/v1/embeddings'); + + const db = new Database(dbPath, { readonly: true }); + const count = db.prepare('SELECT COUNT(*) as c FROM embeddings').get().c; + const modelMeta = db.prepare("SELECT value FROM embedding_meta WHERE key = 'model'").get(); + const dimMeta = db.prepare("SELECT value FROM embedding_meta WHERE key = 'dim'").get(); + db.close(); + + expect(count).toBe(1); + expect(modelMeta.value).toBe('my-remote-model'); + expect(dimMeta.value).toBe('4'); + }); +}); diff --git a/tests/search/embedding-remote-provider.test.ts b/tests/search/embedding-remote-provider.test.ts new file mode 100644 index 000000000..43591cad6 --- /dev/null +++ b/tests/search/embedding-remote-provider.test.ts @@ -0,0 +1,214 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { + embedRemote, + resolveRemoteEmbeddingOptions, +} from '../../src/domain/search/providers/remote.js'; +import { ConfigError, EngineError } from '../../src/shared/errors.js'; + +describe('resolveRemoteEmbeddingOptions', () => { + it('builds options from llm config', () => { + const options = resolveRemoteEmbeddingOptions( + { + llm: { + provider: 'openai', + model: null, + baseUrl: 'http://localhost:8080/v1', + apiKey: 'sk-test', + apiKeyCommand: null, + requestTimeoutMs: 120_000, + }, + }, + 'my-embed-model', + ); + expect(options).toEqual({ + baseUrl: 'http://localhost:8080/v1', + model: 'my-embed-model', + apiKey: 'sk-test', + timeoutMs: 120_000, + }); + }); + + it('throws ConfigError when llm.baseUrl is not set', () => { + expect(() => + resolveRemoteEmbeddingOptions( + { + llm: { + provider: 'openai', + model: null, + baseUrl: null, + apiKey: null, + apiKeyCommand: null, + requestTimeoutMs: 120_000, + }, + }, + 'my-embed-model', + ), + ).toThrow(ConfigError); + }); +}); + +describe('embedRemote', () => { + const fetchMock = vi.fn(); + + beforeEach(() => { + vi.stubGlobal('fetch', fetchMock); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + fetchMock.mockReset(); + }); + + it('returns an empty result without a network call for empty input', async () => { + const result = await embedRemote([], { baseUrl: 'http://localhost:8080/v1', model: 'm' }); + expect(result).toEqual({ vectors: [], dim: 0 }); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it('posts to /embeddings and parses an OpenAI-shaped response', async () => { + fetchMock.mockResolvedValueOnce( + new Response( + JSON.stringify({ + data: [ + { embedding: [0.1, 0.2, 0.3], index: 0 }, + { embedding: [0.4, 0.5, 0.6], index: 1 }, + ], + }), + { status: 200, headers: { 'Content-Type': 'application/json' } }, + ), + ); + + const result = await embedRemote(['a', 'b'], { + baseUrl: 'http://localhost:8080/v1', + model: 'my-model', + apiKey: 'sk-test', + }); + + expect(fetchMock).toHaveBeenCalledTimes(1); + const [url, init] = fetchMock.mock.calls[0]; + expect(url).toBe('http://localhost:8080/v1/embeddings'); + expect(init.method).toBe('POST'); + expect(init.headers.Authorization).toBe('Bearer sk-test'); + expect(JSON.parse(init.body)).toEqual({ model: 'my-model', input: ['a', 'b'] }); + + expect(result.dim).toBe(3); + expect(result.vectors).toHaveLength(2); + // Compare against Float32-rounded expectations — embedRemote stores vectors + // as Float32Array, which loses precision relative to the JSON doubles. + expect(Array.from(result.vectors[0])).toEqual(Array.from(Float32Array.from([0.1, 0.2, 0.3]))); + expect(Array.from(result.vectors[1])).toEqual(Array.from(Float32Array.from([0.4, 0.5, 0.6]))); + }); + + it('does not double up when baseUrl already ends with /embeddings', async () => { + fetchMock.mockResolvedValueOnce( + new Response(JSON.stringify({ data: [{ embedding: [1], index: 0 }] }), { status: 200 }), + ); + await embedRemote(['x'], { baseUrl: 'http://localhost:8080/v1/embeddings', model: 'm' }); + expect(fetchMock.mock.calls[0][0]).toBe('http://localhost:8080/v1/embeddings'); + }); + + it('sorts response items by index to restore input order', async () => { + fetchMock.mockResolvedValueOnce( + new Response( + JSON.stringify({ + data: [ + { embedding: [2], index: 1 }, + { embedding: [1], index: 0 }, + ], + }), + { status: 200 }, + ), + ); + const result = await embedRemote(['a', 'b'], { baseUrl: 'http://x', model: 'm' }); + expect(Array.from(result.vectors[0])).toEqual([1]); + expect(Array.from(result.vectors[1])).toEqual([2]); + }); + + it('omits the Authorization header when no apiKey is configured', async () => { + fetchMock.mockResolvedValueOnce( + new Response(JSON.stringify({ data: [{ embedding: [1], index: 0 }] }), { status: 200 }), + ); + await embedRemote(['a'], { baseUrl: 'http://x', model: 'm' }); + const [, init] = fetchMock.mock.calls[0]; + expect(init.headers.Authorization).toBeUndefined(); + }); + + it('batches requests larger than the batch size', async () => { + const texts = Array.from({ length: 40 }, (_, i) => `text-${i}`); + fetchMock.mockImplementation(async (_url, init) => { + const body = JSON.parse(init.body); + const data = body.input.map((_text: string, i: number) => ({ embedding: [1], index: i })); + return new Response(JSON.stringify({ data }), { status: 200 }); + }); + const result = await embedRemote(texts, { baseUrl: 'http://x', model: 'm' }); + expect(fetchMock).toHaveBeenCalledTimes(2); // 32 + 8 + expect(result.vectors).toHaveLength(40); + }); + + it('throws EngineError on a non-2xx response', async () => { + fetchMock.mockResolvedValueOnce( + new Response('bad request', { status: 400, statusText: 'Bad Request' }), + ); + await expect(embedRemote(['a'], { baseUrl: 'http://x', model: 'm' })).rejects.toThrow( + EngineError, + ); + }); + + it('throws EngineError when the response shape does not match the input length', async () => { + fetchMock.mockResolvedValueOnce( + new Response(JSON.stringify({ data: [{ embedding: [1], index: 0 }] }), { status: 200 }), + ); + await expect(embedRemote(['a', 'b'], { baseUrl: 'http://x', model: 'm' })).rejects.toThrow( + EngineError, + ); + }); + + it('throws EngineError when the network request itself fails', async () => { + fetchMock.mockRejectedValueOnce(new Error('ECONNREFUSED')); + await expect(embedRemote(['a'], { baseUrl: 'http://x', model: 'm' })).rejects.toThrow( + EngineError, + ); + }); + + it('aborts and throws EngineError when a request exceeds timeoutMs', async () => { + fetchMock.mockImplementation((_url, init: { signal: AbortSignal }) => { + return new Promise((_resolve, reject) => { + init.signal.addEventListener('abort', () => { + const err = new Error('This operation was aborted'); + err.name = 'AbortError'; + reject(err); + }); + }); + }); + await expect( + embedRemote(['a'], { baseUrl: 'http://x', model: 'm', timeoutMs: 10 }), + ).rejects.toThrow(/did not respond within 10ms/); + }); + + it('throws EngineError when a later item has a different vector dimension than earlier items', async () => { + fetchMock.mockResolvedValueOnce( + new Response( + JSON.stringify({ + data: [ + { embedding: [1, 2, 3], index: 0 }, + { embedding: [1, 2], index: 1 }, + ], + }), + { status: 200 }, + ), + ); + await expect(embedRemote(['a', 'b'], { baseUrl: 'http://x', model: 'm' })).rejects.toThrow( + /inconsistent vector dimensions/, + ); + }); + + it('throws EngineError instead of a raw TypeError when an item is missing the embedding field', async () => { + fetchMock.mockResolvedValueOnce( + new Response(JSON.stringify({ data: [{ index: 0 }] }), { status: 200 }), + ); + await expect(embedRemote(['a'], { baseUrl: 'http://x', model: 'm' })).rejects.toMatchObject({ + name: 'EngineError', + message: expect.stringContaining('missing or non-array "embedding" field'), + }); + }); +}); diff --git a/tests/search/embedding-remote-search.test.ts b/tests/search/embedding-remote-search.test.ts new file mode 100644 index 000000000..5c1c7850e --- /dev/null +++ b/tests/search/embedding-remote-search.test.ts @@ -0,0 +1,114 @@ +import fs from 'node:fs'; +import os from 'node:os'; +import path from 'node:path'; +import Database from 'better-sqlite3'; +import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, test, vi } from 'vitest'; +import { initSchema } from '../../src/db/index.js'; +import { buildEmbeddings, searchData } from '../../src/domain/search/index.js'; + +// buildEmbeddings/searchData must never touch @huggingface/transformers when a +// remote provider is configured for both the index and query embedding steps. +vi.mock('@huggingface/transformers', () => { + throw new Error('local transformers pipeline should not be loaded on the remote path'); +}); + +function insertNode(db, name, kind, file, line, endLine) { + return db + .prepare('INSERT INTO nodes (name, kind, file, line, end_line) VALUES (?, ?, ?, ?, ?)') + .run(name, kind, file, line, endLine).lastInsertRowid; +} + +describe('semantic search against remotely-built embeddings', () => { + let tmpDir: string, dbPath: string; + const fetchMock = vi.fn(); + const config = { + embeddings: { model: 'my-remote-model', llmProvider: null, provider: 'openai' }, + llm: { + provider: null, + model: null, + baseUrl: 'http://localhost:9999/v1', + apiKey: 'sk-x', + apiKeyCommand: null, + }, + search: { defaultMinScore: 0, rrfK: 60, topK: 15, similarityWarnThreshold: 0.85 }, + } as never; + + beforeAll(() => { + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), 'codegraph-remote-search-')); + fs.writeFileSync(path.join(tmpDir, 'math.js'), 'export function add(a, b) { return a + b; }\n'); + + const dbDir = path.join(tmpDir, '.codegraph'); + fs.mkdirSync(dbDir, { recursive: true }); + dbPath = path.join(dbDir, 'graph.db'); + + const db = new Database(dbPath); + db.pragma('journal_mode = WAL'); + initSchema(db); + insertNode(db, 'add', 'function', 'math.js', 1, 1); + db.close(); + }); + + afterAll(() => { + if (tmpDir) fs.rmSync(tmpDir, { recursive: true, force: true }); + }); + + beforeEach(() => { + vi.stubGlobal('fetch', fetchMock); + // Every call (index or query) gets the same fixed vector, so the indexed + // symbol always scores a perfect match against any query. + fetchMock.mockImplementation(async (_url, init) => { + const body = JSON.parse(init.body); + const data = body.input.map((_text: string, i: number) => ({ + embedding: [1, 0, 0, 0], + index: i, + })); + return new Response(JSON.stringify({ data }), { status: 200 }); + }); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + fetchMock.mockReset(); + }); + + test('query embedding is routed to the remote provider, not the local model', async () => { + await buildEmbeddings(tmpDir, 'my-remote-model', dbPath, { + remote: { baseUrl: 'http://localhost:9999/v1', model: 'my-remote-model', apiKey: 'sk-x' }, + }); + + const result = await searchData('addition helper', dbPath, { config }); + + expect(result).not.toBeNull(); + expect(result!.results.map((r) => r.name)).toContain('add'); + // One call to build the index embedding, one to embed the query. + expect(fetchMock).toHaveBeenCalledTimes(2); + for (const call of fetchMock.mock.calls) { + expect(call[0]).toBe('http://localhost:9999/v1/embeddings'); + } + }); + + test('query embedding still routes remotely when embeddings.provider config drifts after embed', async () => { + await buildEmbeddings(tmpDir, 'my-remote-model', dbPath, { + remote: { baseUrl: 'http://localhost:9999/v1', model: 'my-remote-model', apiKey: 'sk-x' }, + }); + + // Simulate config drift: whoever/whatever runs `search` no longer has + // embeddings.provider set to "openai" (e.g. cleared on a CI machine, or a + // different .codegraphrc.json applies). Routing must still honor the + // provider recorded in embedding_meta at embed time, not this live value + // — otherwise the query would silently fall back to the local model. + const driftedConfig = { + ...config, + embeddings: { model: 'my-remote-model', llmProvider: null, provider: null }, + } as never; + + const result = await searchData('addition helper', dbPath, { config: driftedConfig }); + + expect(result).not.toBeNull(); + expect(result!.results.map((r) => r.name)).toContain('add'); + expect(fetchMock).toHaveBeenCalledTimes(2); + for (const call of fetchMock.mock.calls) { + expect(call[0]).toBe('http://localhost:9999/v1/embeddings'); + } + }); +}); diff --git a/tests/unit/config.test.ts b/tests/unit/config.test.ts index b5e2211a5..7979a6313 100644 --- a/tests/unit/config.test.ts +++ b/tests/unit/config.test.ts @@ -58,7 +58,7 @@ describe('DEFAULTS', () => { }); it('has embeddings defaults', () => { - expect(DEFAULTS.embeddings).toEqual({ model: null, llmProvider: null }); + expect(DEFAULTS.embeddings).toEqual({ model: null, llmProvider: null, provider: null }); }); it('has llm defaults', () => { @@ -68,6 +68,7 @@ describe('DEFAULTS', () => { baseUrl: null, apiKey: null, apiKeyCommand: null, + requestTimeoutMs: 120_000, }); }); @@ -328,6 +329,7 @@ describe('applyEnvOverrides', () => { 'CODEGRAPH_LLM_PROVIDER', 'CODEGRAPH_LLM_API_KEY', 'CODEGRAPH_LLM_MODEL', + 'CODEGRAPH_LLM_BASE_URL', 'CODEGRAPH_ENGINE', 'CODEGRAPH_FAST_SKIP_DIAG', ]; @@ -362,6 +364,14 @@ describe('applyEnvOverrides', () => { expect(config.llm.model).toBe('gpt-4'); }); + it('overrides llm.baseUrl from env', () => { + process.env.CODEGRAPH_LLM_BASE_URL = 'http://localhost:8080/v1'; + const config = applyEnvOverrides({ + llm: { provider: null, model: null, baseUrl: null, apiKey: null }, + }); + expect(config.llm.baseUrl).toBe('http://localhost:8080/v1'); + }); + it('env vars take priority over file config', () => { process.env.CODEGRAPH_LLM_PROVIDER = 'anthropic'; const dir = fs.mkdtempSync(path.join(tmpDir, 'env-priority-')); diff --git a/tests/unit/embed-command.test.ts b/tests/unit/embed-command.test.ts new file mode 100644 index 000000000..8c85595aa --- /dev/null +++ b/tests/unit/embed-command.test.ts @@ -0,0 +1,118 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('../../src/domain/search/index.js', async (importOriginal) => { + const actual = await importOriginal(); + return { ...actual, buildEmbeddings: vi.fn() }; +}); +vi.mock('../../src/db/index.js', () => ({ + openReadonlyOrFail: vi.fn(() => { + throw new Error('no db in this test'); + }), +})); +vi.mock('../../src/db/repository/embeddings.js', () => ({ getEmbeddingMeta: vi.fn() })); + +const { command } = await import('../../src/cli/commands/embed.js'); +const { buildEmbeddings } = await import('../../src/domain/search/index.js'); + +function fakeCtx(embeddings: Record, llm: Record = {}) { + return { + config: { + embeddings: { model: null, llmProvider: null, provider: null, ...embeddings }, + llm: { + provider: null, + model: null, + baseUrl: null, + apiKey: null, + apiKeyCommand: null, + ...llm, + }, + }, + } as never; +} + +describe('embed command validate()', () => { + it('rejects an unknown strategy', () => { + const err = command.validate!([undefined], { strategy: 'bogus' } as never, fakeCtx({})); + expect(err).toMatch(/Unknown strategy/); + }); + + it('rejects an unsupported embeddings.provider', () => { + const err = command.validate!( + [undefined], + { strategy: 'structured' } as never, + fakeCtx({ provider: 'anthropic' }), + ); + expect(err).toMatch(/Unsupported embeddings.provider/); + }); + + it('rejects provider "openai" with no model configured', () => { + const err = command.validate!( + [undefined], + { strategy: 'structured' } as never, + fakeCtx({ provider: 'openai' }), + ); + expect(err).toMatch(/no model is configured/); + }); + + it('accepts provider "openai" with a config model', () => { + const err = command.validate!( + [undefined], + { strategy: 'structured' } as never, + fakeCtx({ provider: 'openai', model: 'text-embedding-3-small' }), + ); + expect(err).toBeUndefined(); + }); + + it('accepts provider "openai" with a --model flag', () => { + const err = command.validate!( + [undefined], + { strategy: 'structured', model: 'text-embedding-3-small' } as never, + fakeCtx({ provider: 'openai' }), + ); + expect(err).toBeUndefined(); + }); + + it('accepts no provider at all', () => { + const err = command.validate!([undefined], { strategy: 'structured' } as never, fakeCtx({})); + expect(err).toBeUndefined(); + }); +}); + +describe('embed command execute()', () => { + beforeEach(() => { + vi.mocked(buildEmbeddings).mockClear(); + }); + + afterEach(() => { + vi.mocked(buildEmbeddings).mockReset(); + }); + + it('passes a resolved remote config through to buildEmbeddings when provider is "openai"', async () => { + const ctx = fakeCtx( + { provider: 'openai', model: 'text-embedding-3-small' }, + { baseUrl: 'http://localhost:8080/v1', apiKey: 'sk-test', requestTimeoutMs: 5000 }, + ); + + await command.execute!([undefined], { strategy: 'structured' } as never, ctx); + + expect(buildEmbeddings).toHaveBeenCalledTimes(1); + const [, model, , options] = vi.mocked(buildEmbeddings).mock.calls[0]!; + expect(model).toBe('text-embedding-3-small'); + expect(options.remote).toEqual({ + baseUrl: 'http://localhost:8080/v1', + model: 'text-embedding-3-small', + apiKey: 'sk-test', + timeoutMs: 5000, + }); + }); + + it('does not build a remote config when no provider is set', async () => { + const ctx = fakeCtx({ model: 'minilm' }); + + await command.execute!([undefined], { strategy: 'structured' } as never, ctx); + + expect(buildEmbeddings).toHaveBeenCalledTimes(1); + const [, , , options] = vi.mocked(buildEmbeddings).mock.calls[0]!; + expect(options.remote).toBeUndefined(); + }); +}); diff --git a/tests/unit/models-command.test.ts b/tests/unit/models-command.test.ts new file mode 100644 index 000000000..9f642bbb7 --- /dev/null +++ b/tests/unit/models-command.test.ts @@ -0,0 +1,42 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { command } from '../../src/cli/commands/models.js'; + +function fakeCtx(embeddings: Record) { + return { + config: { + embeddings: { model: null, llmProvider: null, provider: null, ...embeddings }, + }, + } as never; +} + +describe('models command', () => { + let logSpy: ReturnType; + + beforeEach(() => { + logSpy = vi.spyOn(console, 'log').mockImplementation(() => undefined); + }); + + afterEach(() => { + logSpy.mockRestore(); + }); + + it('does not print the literal string "null" when a remote provider has no model configured', () => { + command.execute!([], {} as never, fakeCtx({ provider: 'openai' })); + + const banner = logSpy.mock.calls.map((call) => call[0]).find((line) => /openai/.test(line)); + expect(banner).toBeDefined(); + expect(banner).not.toMatch(/model "null"/); + expect(banner).toMatch(/not configured/); + }); + + it('prints the configured model name when a remote provider has a model set', () => { + command.execute!( + [], + {} as never, + fakeCtx({ provider: 'openai', model: 'text-embedding-3-small' }), + ); + + const banner = logSpy.mock.calls.map((call) => call[0]).find((line) => /openai/.test(line)); + expect(banner).toMatch(/model "text-embedding-3-small"/); + }); +});