From 6987904dbbc1984fdaf205fca33778760c7bf8fc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 17:16:07 +0000 Subject: [PATCH 1/8] Initial plan From 11275c8746e3258310eb102909dcc64ed43df0b7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 17:29:22 +0000 Subject: [PATCH 2/8] Add continuous graph updates via webhook and poll watcher - Add api/git_utils/incremental_update.py with incremental_update(), fetch_remote(), get_remote_head(), and repo_local_path() helpers - Export new functions from api/git_utils/__init__.py - Add POST /api/webhook endpoint with HMAC-SHA256 validation, branch filtering, and repo URL matching - Add background poll watcher via FastAPI lifespan (_poll_loop, _poll_all_repos, _poll_repo) - Add WEBHOOK_SECRET, TRACKED_BRANCH, POLL_INTERVAL env vars - Document new env vars in .env.template - Add tests/test_webhook.py with unit tests" Co-authored-by: gkorland <753206+gkorland@users.noreply.github.com> --- .env.template | 18 ++ api/git_utils/__init__.py | 6 + api/git_utils/incremental_update.py | 193 +++++++++++++++++++ api/index.py | 227 +++++++++++++++++++++- tests/test_webhook.py | 280 ++++++++++++++++++++++++++++ 5 files changed, 721 insertions(+), 3 deletions(-) create mode 100644 api/git_utils/incremental_update.py create mode 100644 tests/test_webhook.py diff --git a/.env.template b/.env.template index 76067778..2c4beb8a 100644 --- a/.env.template +++ b/.env.template @@ -14,3 +14,21 @@ FLASK_RUN_PORT=5000 # Set to 1 to enable public access for analyze_repo/switch_commit endpoints CODE_GRAPH_PUBLIC=0 + +# --------------------------------------------------------------------------- +# Continuous graph updates (webhook / poll-watcher) +# --------------------------------------------------------------------------- + +# HMAC-SHA256 secret shared with GitHub/GitLab for webhook signature +# validation. Leave empty to disable signature checking (not recommended +# for production deployments). +WEBHOOK_SECRET= + +# Name of the branch to track for automatic incremental updates. +# Only push events targeting this branch trigger a graph update. +TRACKED_BRANCH=main + +# Seconds between automatic poll-watcher checks (0 = disable poll-watcher). +# The poll-watcher runs as a background task and checks every tracked +# repository for new commits on TRACKED_BRANCH. +POLL_INTERVAL=60 diff --git a/api/git_utils/__init__.py b/api/git_utils/__init__.py index 4fd3af98..56cfb8dd 100644 --- a/api/git_utils/__init__.py +++ b/api/git_utils/__init__.py @@ -1 +1,7 @@ from .git_utils import * +from .incremental_update import ( + fetch_remote as fetch_remote, + get_remote_head as get_remote_head, + incremental_update as incremental_update, + repo_local_path as repo_local_path, +) diff --git a/api/git_utils/incremental_update.py b/api/git_utils/incremental_update.py new file mode 100644 index 00000000..71e56b4b --- /dev/null +++ b/api/git_utils/incremental_update.py @@ -0,0 +1,193 @@ +"""Incremental graph update engine. + +Given a before/after commit SHA pair, computes the file-level diff, +applies additions/deletions/modifications to the FalkorDB code graph, +and bookmarks the new commit SHA in Redis so the system can resume +correctly after restarts or failures. +""" + +import logging +import os +import subprocess +from pathlib import Path +from typing import Optional + +from pygit2.enums import CheckoutStrategy +from pygit2.repository import Repository + +from ..analyzers.source_analyzer import SourceAnalyzer +from ..graph import Graph +from ..info import set_repo_commit +from .git_utils import classify_changes + +logger = logging.getLogger(__name__) + + +def repo_local_path(repo_name: str) -> Path: + """Return the local filesystem path for a cloned repository. + + Respects the ``REPOSITORIES_DIR`` environment variable; falls back to + ``/repositories/`` which matches the convention used by + :func:`api.project._clone_source`. + """ + base = os.getenv("REPOSITORIES_DIR", str(Path.cwd() / "repositories")) + return Path(base) / repo_name + + +def fetch_remote(repo_path: Path) -> None: + """Fetch latest changes from the remote *origin*. + + Args: + repo_path: Absolute path to the local git clone. + + Raises: + subprocess.CalledProcessError: If the git fetch command fails. + """ + logger.info("Fetching remote changes for %s", repo_path) + subprocess.run( + ["git", "fetch", "origin"], + cwd=str(repo_path), + check=True, + capture_output=True, + text=True, + ) + + +def get_remote_head(repo_path: Path, branch: str) -> Optional[str]: + """Return the full SHA of the remote tracking branch HEAD. + + Args: + repo_path: Absolute path to the local git clone. + branch: Branch name (e.g. ``"main"``). + + Returns: + The 40-character commit SHA, or ``None`` if the branch does not exist + on the remote or the command fails. + """ + try: + result = subprocess.run( + ["git", "rev-parse", f"origin/{branch}"], + cwd=str(repo_path), + capture_output=True, + text=True, + check=True, + ) + return result.stdout.strip() or None + except subprocess.CalledProcessError: + logger.warning("Could not resolve origin/%s in %s", branch, repo_path) + return None + + +def incremental_update( + repo_name: str, + from_sha: str, + to_sha: str, + ignore: Optional[list[str]] = None, +) -> dict: + """Incrementally update the code graph from ``from_sha`` to ``to_sha``. + + Deleted files are removed from the graph. Modified files are removed + and then re-analysed. Added files are analysed and inserted. The + commit bookmark stored in Redis is updated to the short ID of ``to_sha`` + on success, matching the convention used by the rest of the system. + + This function is idempotent: if ``from_sha == to_sha`` it returns + immediately without touching the graph or the bookmark. + + Args: + repo_name: Graph name in FalkorDB (and repository directory name). + from_sha: Commit SHA the graph is currently at (old state). + Accepts both abbreviated and full 40-char SHAs. + to_sha: Target commit SHA to advance the graph to (new state). + Accepts both abbreviated and full 40-char SHAs. + ignore: Optional list of path prefixes to skip during analysis. + + Returns: + A :class:`dict` with keys: + + * ``files_added`` – number of newly added source files processed. + * ``files_modified`` – number of modified source files re-processed. + * ``files_deleted`` – number of deleted source files removed. + * ``commit`` – the short SHA bookmark now stored in Redis. + + Raises: + ValueError: If the local repository clone cannot be found, or if + either SHA cannot be resolved. + """ + if ignore is None: + ignore = [] + + if from_sha == to_sha: + logger.info( + "incremental_update: from_sha == to_sha (%s); nothing to do", from_sha + ) + return { + "files_added": 0, + "files_modified": 0, + "files_deleted": 0, + "commit": to_sha, + } + + repo_path = repo_local_path(repo_name) + if not repo_path.exists(): + raise ValueError(f"Local repository not found at {repo_path}") + + logger.info( + "Incremental update for '%s': %s -> %s", repo_name, from_sha, to_sha + ) + + repo = Repository(str(repo_path)) + + # Resolve commits – accepts both abbreviated and full SHAs + try: + from_commit = repo.revparse_single(from_sha) + except Exception as exc: + raise ValueError(f"Cannot resolve from_sha '{from_sha}': {exc}") from exc + try: + to_commit = repo.revparse_single(to_sha) + except Exception as exc: + raise ValueError(f"Cannot resolve to_sha '{to_sha}': {exc}") from exc + + # Compute the file-level diff between the two commits + analyzer = SourceAnalyzer() + supported_types = analyzer.supported_types() + diff = repo.diff(from_commit, to_commit) + added, deleted, modified = classify_changes(diff, repo, supported_types, ignore) + + logger.info( + "Diff for '%s': %d added, %d modified, %d deleted", + repo_name, + len(added), + len(modified), + len(deleted), + ) + + # Checkout target commit so files on disk reflect to_sha + repo.checkout_tree(to_commit.tree, strategy=CheckoutStrategy.FORCE) + repo.set_head_detached(to_commit.id) + + # Apply graph changes + g = Graph(repo_name) + + files_to_remove = deleted + modified + if files_to_remove: + logger.info("Removing %d file(s) from graph", len(files_to_remove)) + g.delete_files(files_to_remove) + + files_to_add = added + modified + if files_to_add: + logger.info("Inserting/updating %d file(s) in graph", len(files_to_add)) + analyzer.analyze_files(files_to_add, repo_path, g) + + # Persist the new commit bookmark using the short ID for consistency + # with the rest of the system (build_commit_graph, analyze_sources …) + new_commit_short = to_commit.short_id + set_repo_commit(repo_name, new_commit_short) + logger.info("Graph for '%s' updated to commit %s", repo_name, new_commit_short) + + return { + "files_added": len(added), + "files_modified": len(modified), + "files_deleted": len(deleted), + "commit": new_commit_short, + } diff --git a/api/index.py b/api/index.py index 38dfb61d..dfb03a4e 100644 --- a/api/index.py +++ b/api/index.py @@ -1,19 +1,28 @@ """ Main API module for CodeGraph. """ +import hashlib +import hmac import os import asyncio +import contextlib import logging from pathlib import Path from dotenv import load_dotenv -from fastapi import Depends, FastAPI, Header, HTTPException, Query +from fastapi import Depends, FastAPI, Header, HTTPException, Query, Request from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel from api.analyzers.source_analyzer import SourceAnalyzer from api.git_utils import git_utils from api.git_utils.git_graph import AsyncGitGraph +from api.git_utils.incremental_update import ( + fetch_remote, + get_remote_head, + incremental_update, + repo_local_path, +) from api.graph import Graph, AsyncGraphQuery, async_get_repos -from api.info import async_get_repo_info +from api.info import async_get_repo_info, get_repo_commit from api.llm import ask from api.project import Project @@ -98,7 +107,140 @@ class SwitchCommitRequest(BaseModel): str(Path(__file__).resolve().parent.parent)) ).resolve() -app = FastAPI() +# --------------------------------------------------------------------------- +# Webhook / poll-watcher configuration +# --------------------------------------------------------------------------- + +# HMAC-SHA256 secret shared with GitHub/GitLab. Leave unset to skip +# signature validation (not recommended for production). +WEBHOOK_SECRET: str = os.getenv("WEBHOOK_SECRET", "") + +# Branch whose pushes trigger incremental graph updates. +TRACKED_BRANCH: str = os.getenv("TRACKED_BRANCH", "main") + +# Seconds between automatic poll checks (0 = disabled). +POLL_INTERVAL: int = int(os.getenv("POLL_INTERVAL", "60")) + +# --------------------------------------------------------------------------- +# Webhook helpers +# --------------------------------------------------------------------------- + +def _urls_match(stored_url: str, incoming_url: str) -> bool: + """Return True when two repository URLs refer to the same repo. + + Normalises both URLs by stripping a trailing ``.git`` suffix and + converting to lower-case so that, for example, + ``https://github.com/Org/Repo`` and + ``https://github.com/org/repo.git`` are treated as identical. + """ + def _normalise(u: str) -> str: + return u.rstrip("/").removesuffix(".git").lower() + + return _normalise(stored_url) == _normalise(incoming_url) + + +async def _find_repo_by_url(url: str) -> str | None: + """Return the graph name for a repository that matches *url*, or ``None``.""" + repos = await async_get_repos() + for repo_name in repos: + info = await async_get_repo_info(repo_name) + if info and _urls_match(info.get("repo_url", ""), url): + return repo_name + return None + +# --------------------------------------------------------------------------- +# Background poll-watcher helpers (synchronous, run in thread-pool executor) +# --------------------------------------------------------------------------- + +def _poll_repo(repo_name: str) -> None: + """Fetch remote and apply incremental updates for *repo_name* if behind. + + This function is intentionally synchronous so it can be safely offloaded + to ``asyncio``'s default ``ThreadPoolExecutor``. + """ + path = repo_local_path(repo_name) + if not path.exists(): + logger.debug("Poll: local clone not found for '%s', skipping", repo_name) + return + + try: + fetch_remote(path) + except Exception as exc: + logger.warning("Poll: git fetch failed for '%s': %s", repo_name, exc) + return + + remote_head = get_remote_head(path, TRACKED_BRANCH) + if not remote_head: + return + + current_sha = get_repo_commit(repo_name) + if not current_sha: + logger.debug("Poll: no stored commit for '%s', skipping", repo_name) + return + + # Handle comparison between short (7-char) and full (40-char) SHAs: a short + # stored SHA is a valid prefix of a full remote SHA for the same commit. + # We only apply prefix matching when the stored SHA is shorter. + if len(current_sha) < len(remote_head): + up_to_date = remote_head.startswith(current_sha) + elif len(current_sha) > len(remote_head): + up_to_date = current_sha.startswith(remote_head) + else: + up_to_date = current_sha == remote_head + if up_to_date: + logger.debug("Poll: '%s' is up-to-date at %s", repo_name, current_sha) + return + + logger.info( + "Poll: new commits detected for '%s' (%s -> %s), updating …", + repo_name, current_sha, remote_head, + ) + try: + result = incremental_update(repo_name, current_sha, remote_head) + logger.info("Poll: '%s' updated — %s", repo_name, result) + except Exception as exc: + logger.exception( + "Poll: incremental update failed for '%s': %s", repo_name, exc + ) + + +async def _poll_all_repos() -> None: + """Check every indexed repository for new commits on the tracked branch.""" + repos = await async_get_repos() + loop = asyncio.get_running_loop() + for repo_name in repos: + await loop.run_in_executor(None, _poll_repo, repo_name) + + +async def _poll_loop() -> None: + """Continuously poll all repositories at the configured interval.""" + logger.info( + "Poll-watcher started (interval=%ds, branch='%s')", + POLL_INTERVAL, TRACKED_BRANCH, + ) + while True: + try: + await _poll_all_repos() + except Exception as exc: + logger.exception("Poll loop error: %s", exc) + await asyncio.sleep(POLL_INTERVAL) + +# --------------------------------------------------------------------------- +# Application lifespan (starts/stops the background poll task) +# --------------------------------------------------------------------------- + +@contextlib.asynccontextmanager +async def _lifespan(application: FastAPI): + poll_task = None + if POLL_INTERVAL > 0: + poll_task = asyncio.create_task(_poll_loop()) + yield + if poll_task is not None: + poll_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await poll_task + +app = FastAPI(lifespan=_lifespan) # --------------------------------------------------------------------------- # API routes @@ -290,6 +432,85 @@ async def list_commits(data: RepoRequest, _=Depends(public_or_auth)): await git_graph.close() return {"status": "success", "commits": commits} + +@app.post('/api/webhook') +async def webhook(request: Request): + """Receive a GitHub/GitLab push event and trigger an incremental graph update. + + When ``WEBHOOK_SECRET`` is set the endpoint validates the + ``X-Hub-Signature-256`` header using HMAC-SHA256; requests with a missing + or invalid signature are rejected with **401 Unauthorized**. + + Only pushes to the branch configured via ``TRACKED_BRANCH`` (default + ``main``) trigger an update; pushes to other branches are acknowledged + with a ``200 ignored`` response so that GitHub does not retry them. + + The repository is identified by matching the ``repository.clone_url`` + field in the payload against the URLs stored for already-indexed + repositories. + """ + body = await request.body() + + # Validate HMAC-SHA256 signature when a secret is configured + if WEBHOOK_SECRET: + sig_header = request.headers.get("X-Hub-Signature-256", "") + mac = hmac.new(WEBHOOK_SECRET.encode(), body, hashlib.sha256) + expected_sig = "sha256=" + mac.hexdigest() + if not hmac.compare_digest(sig_header, expected_sig): + raise HTTPException(status_code=401, detail="Invalid webhook signature") + + try: + payload = await request.json() + except Exception: + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + ref = payload.get("ref", "") + before = payload.get("before", "") + after = payload.get("after", "") + repo_url = payload.get("repository", {}).get("clone_url", "") + + # Only process pushes to the configured tracked branch + expected_ref = f"refs/heads/{TRACKED_BRANCH}" + if ref != expected_ref: + logger.debug("Webhook: ignoring push to '%s' (tracking '%s')", ref, expected_ref) + return {"status": "ignored", "reason": f"Branch not tracked: {ref}"} + + if not before or not after or not repo_url: + raise HTTPException( + status_code=400, + detail="Payload missing required fields: ref, before, after, repository.clone_url", + ) + + # Resolve the repository name from the stored index + repo_name = await _find_repo_by_url(repo_url) + if repo_name is None: + logger.warning("Webhook: received push for unknown repo '%s'", repo_url) + return JSONResponse( + {"status": "error", "detail": "Repository not indexed"}, + status_code=404, + ) + + logger.info( + "Webhook: updating '%s' from %s to %s", repo_name, before[:8], after[:8] + ) + + def _update() -> dict: + path = repo_local_path(repo_name) + if path.exists(): + fetch_remote(path) + return incremental_update(repo_name, before, after) + + loop = asyncio.get_running_loop() + try: + result = await loop.run_in_executor(None, _update) + except Exception as exc: + logger.exception( + "Webhook: incremental update failed for '%s': %s", repo_name, exc + ) + return JSONResponse({"status": "error", "detail": str(exc)}, status_code=500) + + return {"status": "success", **result} + # --------------------------------------------------------------------------- # SPA static file serving (must come after API routes) # --------------------------------------------------------------------------- diff --git a/tests/test_webhook.py b/tests/test_webhook.py new file mode 100644 index 00000000..5cca2d31 --- /dev/null +++ b/tests/test_webhook.py @@ -0,0 +1,280 @@ +"""Unit tests for the webhook endpoint and incremental update helpers. + +These tests use ``monkeypatch`` to mock out external collaborators (FalkorDB, +Redis, git) so they run without a live database or network connection. +""" + +import hashlib +import hmac +import json + +import pytest +from starlette.testclient import TestClient + +import api.index + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# A full Git SHA-1 hash is 40 hexadecimal characters. +_FULL_SHA_BEFORE = "aaaa1111" * 5 # 40-char SHA simulating the "before" commit +_FULL_SHA_AFTER = "bbbb2222" * 5 # 40-char SHA simulating the "after" commit + + +class _FakePath: + """Minimal Path-like object for use with monkeypatch.""" + + def __init__(self, *, exists: bool): + self._exists = exists + + def exists(self) -> bool: + return self._exists + + +def _make_push_payload( + ref: str = "refs/heads/main", + before: str = _FULL_SHA_BEFORE, + after: str = _FULL_SHA_AFTER, + clone_url: str = "https://github.com/example/myrepo.git", +) -> dict: + return { + "ref": ref, + "before": before, + "after": after, + "repository": {"clone_url": clone_url}, + } + + +def _sign(body: bytes, secret: str) -> str: + mac = hmac.new(secret.encode(), body, hashlib.sha256) + return "sha256=" + mac.hexdigest() + + +# --------------------------------------------------------------------------- +# _urls_match +# --------------------------------------------------------------------------- + +def test_urls_match_identical(): + assert api.index._urls_match( + "https://github.com/org/repo.git", + "https://github.com/org/repo.git", + ) + + +def test_urls_match_git_suffix(): + assert api.index._urls_match( + "https://github.com/org/repo", + "https://github.com/org/repo.git", + ) + + +def test_urls_match_case_insensitive(): + assert api.index._urls_match( + "https://github.com/Org/Repo.git", + "https://github.com/org/repo.git", + ) + + +def test_urls_match_trailing_slash(): + assert api.index._urls_match( + "https://github.com/org/repo/", + "https://github.com/org/repo.git", + ) + + +def test_urls_no_match_different_repo(): + assert not api.index._urls_match( + "https://github.com/org/repo-a.git", + "https://github.com/org/repo-b.git", + ) + + +# --------------------------------------------------------------------------- +# Webhook endpoint – no secret configured (open mode) +# --------------------------------------------------------------------------- + +@pytest.fixture() +def client_open(monkeypatch): + """Test client with no webhook secret and no poll-watcher.""" + monkeypatch.setattr(api.index, "WEBHOOK_SECRET", "") + monkeypatch.setattr(api.index, "POLL_INTERVAL", 0) + return TestClient(api.index.app, raise_server_exceptions=False) + + +def test_webhook_ignored_wrong_branch(client_open, monkeypatch): + """Pushes to non-tracked branches return 200 with status='ignored'.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload(ref="refs/heads/feature/x") + resp = client_open.post("/api/webhook", json=payload) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ignored" + + +def test_webhook_unknown_repo(client_open, monkeypatch): + """Webhook for a repo URL that is not indexed returns 404.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + + # No repos indexed → _find_repo_by_url returns None + async def _fake_get_repos(): + return [] + + monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) + + payload = _make_push_payload() + resp = client_open.post("/api/webhook", json=payload) + assert resp.status_code == 404 + + +def test_webhook_success(client_open, monkeypatch): + """Valid push to tracked branch triggers incremental_update and returns stats.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + + async def _fake_get_repos(): + return ["myrepo"] + + async def _fake_get_repo_info(repo_name): + return {"repo_url": "https://github.com/example/myrepo.git"} + + update_calls = [] + + def _fake_update(repo_name, from_sha, to_sha, ignore=None): + update_calls.append((repo_name, from_sha, to_sha)) + return { + "files_added": 1, + "files_modified": 0, + "files_deleted": 0, + "commit": to_sha[:7], + } + + monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) + monkeypatch.setattr(api.index, "async_get_repo_info", _fake_get_repo_info) + monkeypatch.setattr(api.index, "incremental_update", _fake_update) + # Skip git fetch (no real clone) + monkeypatch.setattr(api.index, "fetch_remote", lambda path: None) + monkeypatch.setattr(api.index, "repo_local_path", lambda name: _FakePath(exists=False)) + + payload = _make_push_payload() + resp = client_open.post("/api/webhook", json=payload) + + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "success" + assert data["files_added"] == 1 + assert len(update_calls) == 1 + assert update_calls[0] == ("myrepo", _FULL_SHA_BEFORE, _FULL_SHA_AFTER) + + +# --------------------------------------------------------------------------- +# Webhook endpoint – HMAC-SHA256 signature validation +# --------------------------------------------------------------------------- + +@pytest.fixture() +def client_secured(monkeypatch): + """Test client with WEBHOOK_SECRET='mysecret' and poll disabled.""" + monkeypatch.setattr(api.index, "WEBHOOK_SECRET", "mysecret") + monkeypatch.setattr(api.index, "POLL_INTERVAL", 0) + return TestClient(api.index.app, raise_server_exceptions=False) + + +def test_webhook_missing_signature_rejected(client_secured, monkeypatch): + """Requests without X-Hub-Signature-256 header are rejected with 401.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + resp = client_secured.post("/api/webhook", json=payload) + assert resp.status_code == 401 + + +def test_webhook_wrong_signature_rejected(client_secured, monkeypatch): + """Requests with an incorrect signature are rejected with 401.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + body = json.dumps(payload).encode() + bad_sig = _sign(body, "wrongsecret") + resp = client_secured.post( + "/api/webhook", + content=body, + headers={"Content-Type": "application/json", "X-Hub-Signature-256": bad_sig}, + ) + assert resp.status_code == 401 + + +def test_webhook_valid_signature_accepted(client_secured, monkeypatch): + """Requests with a correct HMAC-SHA256 signature are accepted.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + monkeypatch.setattr(api.index, "WEBHOOK_SECRET", "mysecret") + + async def _fake_get_repos(): + return ["myrepo"] + + async def _fake_get_repo_info(repo_name): + return {"repo_url": "https://github.com/example/myrepo.git"} + + monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) + monkeypatch.setattr(api.index, "async_get_repo_info", _fake_get_repo_info) + monkeypatch.setattr(api.index, "incremental_update", lambda *a, **kw: { + "files_added": 0, "files_modified": 0, "files_deleted": 0, "commit": "abc1234", + }) + monkeypatch.setattr(api.index, "fetch_remote", lambda path: None) + monkeypatch.setattr(api.index, "repo_local_path", lambda name: _FakePath(exists=False)) + + payload = _make_push_payload() + body = json.dumps(payload).encode() + sig = _sign(body, "mysecret") + + resp = client_secured.post( + "/api/webhook", + content=body, + headers={"Content-Type": "application/json", "X-Hub-Signature-256": sig}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + +def test_webhook_invalid_json(client_open, monkeypatch): + """Non-JSON bodies are rejected with 400.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + resp = client_open.post( + "/api/webhook", + content=b"not-json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# --------------------------------------------------------------------------- +# incremental_update – unit tests (no live DB/git) +# --------------------------------------------------------------------------- + +def test_incremental_update_idempotent(monkeypatch, tmp_path): + """Calling incremental_update with the same SHA twice is a no-op.""" + from api.git_utils.incremental_update import incremental_update as _iu + + # Patch set_repo_commit to detect unexpected writes + writes = [] + monkeypatch.setattr("api.git_utils.incremental_update.set_repo_commit", + lambda *a: writes.append(a)) + + sha = "abc1234" + result = _iu("some-repo", sha, sha) + + assert result["files_added"] == 0 + assert result["files_modified"] == 0 + assert result["files_deleted"] == 0 + assert result["commit"] == sha + assert writes == [], "set_repo_commit must not be called for no-op update" + + +def test_incremental_update_missing_repo(monkeypatch, tmp_path): + """incremental_update raises ValueError when local clone does not exist.""" + from api.git_utils.incremental_update import incremental_update as _iu + + monkeypatch.setattr( + "api.git_utils.incremental_update.repo_local_path", + lambda name: tmp_path / "nonexistent", + ) + + with pytest.raises(ValueError, match="Local repository not found"): + _iu("some-repo", "aaa1111", "bbb2222") From f281513fa4191ac47eddf9c5240f84de0d1ea3d9 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sat, 14 Mar 2026 23:24:35 +0200 Subject: [PATCH 3/8] Verify and fix webhook/incremental update findings Reprocess dependent files during incremental updates, add repo-scoped update locking, harden webhook auth and provider handling, and fall back to full reindex when the stored bookmark no longer matches incoming history. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .env.template | 6 +- README.md | 6 +- api/analyzers/analyzer.py | 75 ++++++++- api/analyzers/csharp/analyzer.py | 45 +++--- api/analyzers/java/analyzer.py | 48 +++--- api/analyzers/python/analyzer.py | 47 +++--- api/analyzers/source_analyzer.py | 3 +- api/git_utils/incremental_update.py | 156 +++++++++++++++--- api/graph.py | 61 ++++++- api/index.py | 237 ++++++++++++++++++++++++---- tests/test_incremental_update.py | 151 ++++++++++++++++++ tests/test_webhook.py | 196 +++++++++++++++++++++-- 12 files changed, 882 insertions(+), 149 deletions(-) create mode 100644 tests/test_incremental_update.py diff --git a/.env.template b/.env.template index 3d23d9a4..d738a910 100644 --- a/.env.template +++ b/.env.template @@ -32,9 +32,9 @@ PORT=5000 # Continuous graph updates (webhook / poll-watcher) # --------------------------------------------------------------------------- -# HMAC-SHA256 secret shared with GitHub/GitLab for webhook signature -# validation. Leave empty to disable signature checking (not recommended -# for production deployments). +# Shared secret used for GitHub HMAC verification or GitLab's +# X-Gitlab-Token verification. Leave empty to require +# Authorization: Bearer on /api/webhook instead. WEBHOOK_SECRET= # Name of the branch to track for automatic incremental updates. diff --git a/README.md b/README.md index b83218d3..48dbcaa6 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ cp .env.template .env | `MODEL_NAME` | LiteLLM model used by `/api/chat` | No | `gemini/gemini-flash-lite-latest` | | `HOST` | Optional Uvicorn bind host for `start.sh`/`make run-*` | No | `0.0.0.0` or `127.0.0.1` depending on command | | `PORT` | Optional Uvicorn bind port for `start.sh`/`make run-*` | No | `5000` | -| `WEBHOOK_SECRET` | Optional HMAC secret for `/api/webhook` signature validation | No | empty | +| `WEBHOOK_SECRET` | Shared secret for GitHub HMAC or GitLab `X-Gitlab-Token` verification on `/api/webhook` | No | empty | | `TRACKED_BRANCH` | Branch watched by the webhook and poll-watcher | No | `main` | | `POLL_INTERVAL` | Seconds between background poll checks (`0` disables polling) | No | `60` | @@ -100,7 +100,7 @@ The chat endpoint also needs the provider credential expected by your chosen `MO - If `SECRET_TOKEN` is unset, the current implementation accepts requests without an `Authorization` header. - Setting `CODE_GRAPH_PUBLIC=1` makes the read-only endpoints public even when `SECRET_TOKEN` is configured. -Continuous graph updates can be triggered either by posting a GitHub/GitLab push payload to `/api/webhook` or by enabling the background poll-watcher with `POLL_INTERVAL > 0`. +Continuous graph updates can be triggered either by posting a GitHub/GitLab push payload to `/api/webhook` or by enabling the background poll-watcher with `POLL_INTERVAL > 0`. When `WEBHOOK_SECRET` is unset, `/api/webhook` falls back to the same bearer-token auth used by the other mutating endpoints. ### 3. Install dependencies @@ -246,7 +246,7 @@ A C analyzer exists in the source tree, but it is commented out and is not curre | POST | `/api/analyze_folder` | Analyze a local source folder | | POST | `/api/analyze_repo` | Clone and analyze a git repository | | POST | `/api/switch_commit` | Switch the indexed repository to a specific commit | -| POST | `/api/webhook` | Receive a push event and apply an incremental graph update | +| POST | `/api/webhook` | Receive a GitHub/GitLab push event and apply an incremental graph update | ## License diff --git a/api/analyzers/analyzer.py b/api/analyzers/analyzer.py index 64d49004..a02ff151 100644 --- a/api/analyzers/analyzer.py +++ b/api/analyzers/analyzer.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -7,6 +8,14 @@ from abc import ABC, abstractmethod from multilspy import SyncLanguageServer +from ..graph import Graph + + +@dataclass(frozen=True) +class ResolvedEntityRef: + id: int + + class AbstractAnalyzer(ABC): def __init__(self, language: Language) -> None: self.language = language @@ -56,8 +65,69 @@ def resolve(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: P try: locations = lsp.request_definition(str(file_path), node.start_point.row, node.start_point.column) return [(files[Path(self.resolve_path(location['absolutePath'], path))], files[Path(self.resolve_path(location['absolutePath'], path))].tree.root_node.descendant_for_point_range(Point(location['range']['start']['line'], location['range']['start']['character']), Point(location['range']['end']['line'], location['range']['end']['character']))) for location in locations if location and Path(self.resolve_path(location['absolutePath'], path)) in files] - except Exception as e: + except Exception: return [] + + def resolve_entities( + self, + files: dict[Path, File], + lsp: SyncLanguageServer, + file_path: Path, + path: Path, + node: Node, + graph: Graph, + parent_types: list[str], + graph_labels: list[str], + reject_parent_types: Optional[set[str]] = None, + ) -> list[Entity | ResolvedEntityRef]: + try: + locations = lsp.request_definition( + str(file_path), node.start_point.row, node.start_point.column + ) + except Exception: + return [] + + resolved_entities: list[Entity | ResolvedEntityRef] = [] + for location in locations: + if not location or 'absolutePath' not in location: + continue + + resolved_path = Path(self.resolve_path(location['absolutePath'], path)) + if resolved_path in files: + file = files[resolved_path] + resolved_node = file.tree.root_node.descendant_for_point_range( + Point( + location['range']['start']['line'], + location['range']['start']['character'], + ), + Point( + location['range']['end']['line'], + location['range']['end']['character'], + ), + ) + entity_node = self.find_parent(resolved_node, parent_types) + if entity_node is None: + continue + if reject_parent_types and entity_node.type in reject_parent_types: + continue + + entity = file.entities.get(entity_node) + if entity is not None: + resolved_entities.append(entity) + continue + + if graph is None: + continue + + graph_entity = graph.get_entity_at_position( + str(resolved_path), + location['range']['start']['line'], + graph_labels, + ) + if graph_entity is not None: + resolved_entities.append(ResolvedEntityRef(graph_entity.id)) + + return resolved_entities @abstractmethod def add_dependencies(self, path: Path, files: list[Path]): @@ -133,7 +203,7 @@ def add_symbols(self, entity: Entity) -> None: pass @abstractmethod - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, key: str, symbol: Node) -> list[Entity | ResolvedEntityRef]: """ Resolve a symbol to an entity. @@ -148,4 +218,3 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ """ pass - diff --git a/api/analyzers/csharp/analyzer.py b/api/analyzers/csharp/analyzer.py index 74c3906e..61e8c8a6 100644 --- a/api/analyzers/csharp/analyzer.py +++ b/api/analyzers/csharp/analyzer.py @@ -105,34 +105,41 @@ def is_dependency(self, file_path: str) -> bool: def resolve_path(self, file_path: str, path: Path) -> str: return file_path - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) - return res + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration'], + ['Class', 'Interface', 'Enum', 'Struct'], + ) - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: if node.type == 'invocation_expression': func_node = node.child_by_field_name('function') if func_node and func_node.type == 'member_access_expression': func_node = func_node.child_by_field_name('name') if func_node: node = func_node - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - method_dec = self.find_parent(resolved_node, ['method_declaration', 'constructor_declaration', 'class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration']) - if method_dec and method_dec.type in ['class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration']: - continue - if method_dec in file.entities: - res.append(file.entities[method_dec]) - return res + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['method_declaration', 'constructor_declaration', 'class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration'], + ['Method', 'Constructor'], + {'class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration'}, + ) - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, key: str, symbol: Node) -> list[Entity]: if key in ["implement_interface", "base_class", "extend_interface", "parameters", "return_type"]: - return self.resolve_type(files, lsp, file_path, path, symbol) + return self.resolve_type(files, lsp, file_path, path, graph, symbol) elif key in ["call"]: - return self.resolve_method(files, lsp, file_path, path, symbol) + return self.resolve_method(files, lsp, file_path, path, graph, symbol) else: raise ValueError(f"Unknown key {key}") diff --git a/api/analyzers/java/analyzer.py b/api/analyzers/java/analyzer.py index 5269d698..1ce80f80 100644 --- a/api/analyzers/java/analyzer.py +++ b/api/analyzers/java/analyzer.py @@ -1,7 +1,8 @@ import os from pathlib import Path import subprocess -from ...entities import * +from ...entities.entity import Entity +from ...entities.file import File from typing import Optional from ..analyzer import AbstractAnalyzer @@ -102,28 +103,35 @@ def resolve_path(self, file_path: str, path: Path) -> str: return f"{path}/temp_deps/{args[1]}/{targs}/{args[-1]}" return file_path - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_declaration', 'interface_declaration', 'enum_declaration']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) - return res + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['class_declaration', 'interface_declaration', 'enum_declaration'], + ['Class', 'Interface', 'Enum'], + ) - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] - for file, resolved_node in self.resolve(files, lsp, file_path, path, node.child_by_field_name('name')): - method_dec = self.find_parent(resolved_node, ['method_declaration', 'constructor_declaration', 'class_declaration', 'interface_declaration', 'enum_declaration']) - if method_dec and method_dec.type in ['class_declaration', 'interface_declaration', 'enum_declaration']: - continue - if method_dec in file.entities: - res.append(file.entities[method_dec]) - return res + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + return self.resolve_entities( + files, + lsp, + file_path, + path, + node.child_by_field_name('name'), + graph, + ['method_declaration', 'constructor_declaration', 'class_declaration', 'interface_declaration', 'enum_declaration'], + ['Method', 'Constructor'], + {'class_declaration', 'interface_declaration', 'enum_declaration'}, + ) - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, key: str, symbol: Node) -> list[Entity]: if key in ["implement_interface", "base_class", "extend_interface", "parameters", "return_type"]: - return self.resolve_type(files, lsp, file_path, path, symbol) + return self.resolve_type(files, lsp, file_path, path, graph, symbol) elif key in ["call"]: - return self.resolve_method(files, lsp, file_path, path, symbol) + return self.resolve_method(files, lsp, file_path, path, graph, symbol) else: raise ValueError(f"Unknown key {key}") diff --git a/api/analyzers/python/analyzer.py b/api/analyzers/python/analyzer.py index 7a991202..a63d0b4b 100644 --- a/api/analyzers/python/analyzer.py +++ b/api/analyzers/python/analyzer.py @@ -4,7 +4,8 @@ from pathlib import Path import tomllib -from ...entities import * +from ...entities.entity import Entity +from ...entities.file import File from typing import Optional from ..analyzer import AbstractAnalyzer @@ -91,34 +92,40 @@ def is_dependency(self, file_path: str) -> bool: def resolve_path(self, file_path: str, path: Path) -> str: return file_path - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path, node: Node) -> list[Entity]: - res = [] + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path, graph, node: Node) -> list[Entity]: if node.type == 'attribute': node = node.child_by_field_name('attribute') - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_definition']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) - return res + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['class_definition'], + ['Class'], + ) - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: if node.type == 'call': node = node.child_by_field_name('function') if node.type == 'attribute': node = node.child_by_field_name('attribute') - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - method_dec = self.find_parent(resolved_node, ['function_definition', 'class_definition']) - if not method_dec: - continue - if method_dec in file.entities: - res.append(file.entities[method_dec]) - return res + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['function_definition', 'class_definition'], + ['Function', 'Class'], + ) - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, key: str, symbol: Node) -> list[Entity]: if key in ["base_class", "parameters", "return_type"]: - return self.resolve_type(files, lsp, file_path, path, symbol) + return self.resolve_type(files, lsp, file_path, path, graph, symbol) elif key in ["call"]: - return self.resolve_method(files, lsp, file_path, path, symbol) + return self.resolve_method(files, lsp, file_path, path, graph, symbol) else: raise ValueError(f"Unknown key {key}") diff --git a/api/analyzers/source_analyzer.py b/api/analyzers/source_analyzer.py index 4186f358..73e2cc1c 100644 --- a/api/analyzers/source_analyzer.py +++ b/api/analyzers/source_analyzer.py @@ -149,7 +149,7 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: file = self.files[file_path] logging.info(f'Processing file ({i + 1}/{files_len}): {file_path}') for _, entity in file.entities.items(): - entity.resolved_symbol(lambda key, symbol, fp=file_path: analyzers[fp.suffix].resolve_symbol(self.files, lsps[fp.suffix], fp, path, key, symbol)) + entity.resolved_symbol(lambda key, symbol, fp=file_path: analyzers[fp.suffix].resolve_symbol(self.files, lsps[fp.suffix], fp, path, graph, key, symbol)) for key, symbols in entity.symbols.items(): for symbol in symbols: if len(symbol.resolved_symbol) == 0: @@ -220,4 +220,3 @@ def analyze_local_repository(self, path: str, ignore: Optional[list[str]] = None graph.set_graph_commit(current_commit.short_id) return graph - diff --git a/api/git_utils/incremental_update.py b/api/git_utils/incremental_update.py index 71e56b4b..e6514037 100644 --- a/api/git_utils/incremental_update.py +++ b/api/git_utils/incremental_update.py @@ -6,6 +6,7 @@ correctly after restarts or failures. """ +from contextlib import contextmanager import logging import os import subprocess @@ -17,10 +18,12 @@ from ..analyzers.source_analyzer import SourceAnalyzer from ..graph import Graph -from ..info import set_repo_commit +from ..info import get_redis_connection, set_repo_commit from .git_utils import classify_changes logger = logging.getLogger(__name__) +REPO_UPDATE_LOCK_TIMEOUT = int(os.getenv("REPO_UPDATE_LOCK_TIMEOUT", "300")) +REPO_UPDATE_LOCK_WAIT = int(os.getenv("REPO_UPDATE_LOCK_WAIT", "30")) def repo_local_path(repo_name: str) -> Path: @@ -78,6 +81,94 @@ def get_remote_head(repo_path: Path, branch: str) -> Optional[str]: return None +@contextmanager +def repo_update_lock(repo_name: str): + """Acquire a repo-scoped distributed lock for graph mutations.""" + redis_connection = get_redis_connection() + lock = redis_connection.lock( + f"code-graph:repo-update:{repo_name}", + timeout=REPO_UPDATE_LOCK_TIMEOUT, + blocking_timeout=REPO_UPDATE_LOCK_WAIT, + thread_local=False, + ) + + logger.debug("Acquiring repo update lock for '%s'", repo_name) + if not lock.acquire(blocking=True): + raise TimeoutError(f"Timed out waiting for update lock for '{repo_name}'") + + try: + yield + finally: + if lock.owned(): + lock.release() + logger.debug("Released repo update lock for '%s'", repo_name) + + +def _resolve_commit(repo: Repository, sha: str): + return repo.revparse_single(sha) + + +def _is_ancestor(repo: Repository, ancestor_sha: str, descendant_sha: str) -> bool: + ancestor = _resolve_commit(repo, ancestor_sha) + descendant = _resolve_commit(repo, descendant_sha) + return ancestor.id == descendant.id or repo.merge_base(ancestor.id, descendant.id) == ancestor.id + + +def can_incrementally_update( + repo_path: Path, + from_sha: str, + to_sha: str, + before_sha: Optional[str] = None, +) -> bool: + """Return True when the stored bookmark can be safely advanced incrementally.""" + try: + repo = Repository(str(repo_path)) + if before_sha is not None and not _is_ancestor(repo, from_sha, before_sha): + return False + + anchor_sha = before_sha or from_sha + return _is_ancestor(repo, anchor_sha, to_sha) + except Exception as exc: + logger.warning( + "Cannot validate incremental update range for '%s' -> '%s' (before=%s): %s", + from_sha, + to_sha, + before_sha, + exc, + ) + return False + + +def _dedupe_paths(paths: list[Path]) -> list[Path]: + seen: set[Path] = set() + deduped: list[Path] = [] + for path in paths: + if path in seen: + continue + seen.add(path) + deduped.append(path) + return deduped + + +def _collect_transitive_dependents(g: Graph, changed_files: list[Path]) -> list[Path]: + seen = set(changed_files) + dependents: list[Path] = [] + frontier = _dedupe_paths(changed_files) + + while frontier: + direct_dependents = g.get_direct_dependent_files(frontier) + next_frontier: list[Path] = [] + for dependent in direct_dependents: + if dependent in seen: + continue + seen.add(dependent) + dependents.append(dependent) + next_frontier.append(dependent) + frontier = next_frontier + + return dependents + + def incremental_update( repo_name: str, from_sha: str, @@ -162,28 +253,47 @@ def incremental_update( len(deleted), ) - # Checkout target commit so files on disk reflect to_sha - repo.checkout_tree(to_commit.tree, strategy=CheckoutStrategy.FORCE) - repo.set_head_detached(to_commit.id) - - # Apply graph changes - g = Graph(repo_name) - - files_to_remove = deleted + modified - if files_to_remove: - logger.info("Removing %d file(s) from graph", len(files_to_remove)) - g.delete_files(files_to_remove) - - files_to_add = added + modified - if files_to_add: - logger.info("Inserting/updating %d file(s) in graph", len(files_to_add)) - analyzer.analyze_files(files_to_add, repo_path, g) - - # Persist the new commit bookmark using the short ID for consistency - # with the rest of the system (build_commit_graph, analyze_sources …) - new_commit_short = to_commit.short_id - set_repo_commit(repo_name, new_commit_short) - logger.info("Graph for '%s' updated to commit %s", repo_name, new_commit_short) + files_to_remove = _dedupe_paths(deleted + modified) + + with repo_update_lock(repo_name): + try: + # Checkout target commit so files on disk reflect to_sha + repo.checkout_tree(to_commit.tree, strategy=CheckoutStrategy.FORCE) + repo.set_head_detached(to_commit.id) + + # Apply graph changes + g = Graph(repo_name) + dependent_files = _collect_transitive_dependents(g, files_to_remove) + + if dependent_files: + logger.info( + "Reprocessing %d dependent file(s) for '%s'", + len(dependent_files), + repo_name, + ) + + if files_to_remove: + logger.info("Removing %d file(s) from graph", len(files_to_remove)) + g.delete_files(files_to_remove) + + deleted_files = set(deleted) + files_to_add = [ + file_path + for file_path in _dedupe_paths(added + modified + dependent_files) + if file_path not in deleted_files + ] + if files_to_add: + logger.info("Inserting/updating %d file(s) in graph", len(files_to_add)) + analyzer.analyze_files(files_to_add, repo_path, g) + + # Persist the new commit bookmark using the short ID for consistency + # with the rest of the system (build_commit_graph, analyze_sources …) + new_commit_short = to_commit.short_id + set_repo_commit(repo_name, new_commit_short) + logger.info("Graph for '%s' updated to commit %s", repo_name, new_commit_short) + except Exception: + logger.exception("Incremental update failed for '%s'", repo_name) + raise return { "files_added": len(added), diff --git a/api/graph.py b/api/graph.py index 085dfde1..83f959d0 100644 --- a/api/graph.py +++ b/api/graph.py @@ -1,6 +1,6 @@ import os import time -from .entities import * +from .entities import File, encode_edge, encode_node from typing import Optional from falkordb import FalkorDB, Path, Node, QueryResult from falkordb.asyncio import FalkorDB as AsyncFalkorDB @@ -32,6 +32,20 @@ def get_repos() -> list[str]: graphs = [g for g in graphs if not (g.endswith('_git') or g.endswith('_schema'))] return graphs + +def delete_graph_if_exists(name: str) -> bool: + """Delete *name* when it already exists in FalkorDB.""" + db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), + port=os.getenv('FALKORDB_PORT', 6379), + username=os.getenv('FALKORDB_USERNAME', None), + password=os.getenv('FALKORDB_PASSWORD', None)) + + if name not in db.list_graphs(): + return False + + db.select_graph(name).delete() + return True + class Graph(): """ Represents a connection to a graph database using FalkorDB. @@ -171,7 +185,7 @@ def _query(self, q: str, params: Optional[dict] = None) -> QueryResult: return result_set - def get_sub_graph(self, l: int) -> dict: + def get_sub_graph(self, limit: int) -> dict: q = """MATCH (src) OPTIONAL MATCH (src)-[e]->(dest) @@ -180,7 +194,7 @@ def get_sub_graph(self, l: int) -> dict: sub_graph = {'nodes': [], 'edges': [] } - result_set = self._query(q, {'limit': l}).result_set + result_set = self._query(q, {'limit': limit}).result_set for row in result_set: src = row[0] e = row[1] @@ -466,6 +480,44 @@ def get_file(self, path: str, name: str, ext: str) -> Optional[File]: return file + def get_entity_at_position(self, path: str, line: int, labels: Optional[list[str]] = None) -> Optional[Node]: + """Return the smallest entity spanning *line* within *path*.""" + label_filter = ":" + ":".join(labels) if labels else "" + q = f"""MATCH (e{label_filter}) + WHERE e.path = $path + AND e.src_start <= $line + AND e.src_end >= $line + RETURN e + ORDER BY (e.src_end - e.src_start) ASC + LIMIT 1""" + + res = self._query(q, {'path': path, 'line': line}).result_set + if len(res) == 0: + return None + + return res[0][0] + + def get_direct_dependent_files(self, files: list[Path]) -> list[Path]: + """Return files that directly depend on entities defined in *files*.""" + if len(files) == 0: + return [] + + q = """UNWIND $files AS file + MATCH (changed_file:File {path: file['path'], name: file['name'], ext: file['ext']}) + MATCH (changed_file)-[:DEFINES*]->(changed_entity) + MATCH (dependent_entity)-[:CALLS|EXTENDS|IMPLEMENTS|RETURNS|PARAMETERS]->(changed_entity) + MATCH (dependent_file:File)-[:DEFINES*]->(dependent_entity) + RETURN DISTINCT dependent_file.path, dependent_file.name, dependent_file.ext""" + + params = { + 'files': [ + {'path': str(file_path), 'name': file_path.name, 'ext': file_path.suffix} + for file_path in files + ] + } + result_set = self._query(q, params).result_set + return [Path(row[0]) for row in result_set] + # set file code coverage # if file coverage is 100% set every defined function coverage to 100% aswell def set_file_coverage(self, path: str, name: str, ext: str, coverage: float) -> None: @@ -478,7 +530,7 @@ def set_file_coverage(self, path: str, name: str, ext: str, coverage: float) -> params = {'path': path, 'name': name, 'ext': ext, 'coverage': coverage} - res = self._query(q, params) + self._query(q, params) def connect_entities(self, relation: str, src_id: int, dest_id: int, properties: dict = {}) -> None: """ @@ -768,4 +820,3 @@ async def stats(self) -> dict: async def close(self) -> None: await self.db.aclose() - diff --git a/api/index.py b/api/index.py index dfb03a4e..c566f6d3 100644 --- a/api/index.py +++ b/api/index.py @@ -16,12 +16,14 @@ from api.git_utils import git_utils from api.git_utils.git_graph import AsyncGitGraph from api.git_utils.incremental_update import ( + can_incrementally_update, fetch_remote, get_remote_head, incremental_update, repo_local_path, + repo_update_lock, ) -from api.graph import Graph, AsyncGraphQuery, async_get_repos +from api.graph import Graph, AsyncGraphQuery, async_get_repos, delete_graph_if_exists from api.info import async_get_repo_info, get_repo_commit from api.llm import ask from api.project import Project @@ -148,6 +150,170 @@ async def _find_repo_by_url(url: str) -> str | None: return repo_name return None + +def _webhook_auth_mode() -> str: + if WEBHOOK_SECRET: + return "shared-secret" + if SECRET_TOKEN: + return "token" + return "disabled" + + +def _log_webhook_auth_mode() -> None: + mode = _webhook_auth_mode() + if mode == "shared-secret": + logger.info( + "Webhook auth mode: shared secret (GitHub HMAC or GitLab X-Gitlab-Token)" + ) + elif mode == "token": + logger.info("Webhook auth mode: Authorization bearer token fallback") + else: + logger.warning( + "Webhook auth is not configured; /api/webhook will reject requests until " + "WEBHOOK_SECRET or SECRET_TOKEN is set" + ) + + +def _authenticate_webhook_request(request: Request, body: bytes) -> None: + """Authenticate a webhook request using the configured webhook auth mode.""" + if WEBHOOK_SECRET: + github_signature = request.headers.get("X-Hub-Signature-256") + gitlab_token = request.headers.get("X-Gitlab-Token") + gitlab_event = request.headers.get("X-Gitlab-Event") + gitlab_signature = request.headers.get("X-Gitlab-Signature") + + if github_signature: + mac = hmac.new(WEBHOOK_SECRET.encode(), body, hashlib.sha256) + expected_signature = "sha256=" + mac.hexdigest() + if not hmac.compare_digest(github_signature, expected_signature): + raise HTTPException(status_code=401, detail="Invalid GitHub webhook signature") + return + + if gitlab_token or gitlab_event or gitlab_signature: + if not gitlab_token: + raise HTTPException( + status_code=401, + detail="GitLab webhooks must include X-Gitlab-Token", + ) + if not hmac.compare_digest(gitlab_token, WEBHOOK_SECRET): + raise HTTPException(status_code=401, detail="Invalid GitLab webhook token") + return + + raise HTTPException( + status_code=401, + detail="Missing supported webhook authentication header", + ) + + if not SECRET_TOKEN: + logger.error( + "Webhook auth misconfigured: set WEBHOOK_SECRET or SECRET_TOKEN before " + "accepting webhook updates" + ) + raise HTTPException( + status_code=503, + detail="Webhook authentication is not configured", + ) + + token_required(request.headers.get("Authorization")) + + +def _extract_repo_url(payload: dict) -> str: + repository = payload.get("repository", {}) + project = payload.get("project", {}) + return ( + repository.get("clone_url") + or repository.get("git_http_url") + or project.get("git_http_url") + or "" + ) + + +def _full_reindex_repository( + repo_name: str, + repo_path: Path, + repo_url: str = "", + ignore: list[str] | None = None, + reason: str = "", +) -> dict: + if ignore is None: + ignore = [] + + logger.warning( + "Falling back to a full reindex for '%s'%s", + repo_name, + f": {reason}" if reason else "", + ) + + with repo_update_lock(repo_name): + delete_graph_if_exists(repo_name) + delete_graph_if_exists(git_utils.GitRepoName(repo_name)) + + if repo_path.exists(): + proj = Project.from_local_repository(repo_path) + elif repo_url: + proj = Project.from_git_repository(repo_url) + else: + raise ValueError( + f"Cannot reindex '{repo_name}': local clone is missing and no repo URL is available" + ) + + proj.analyze_sources(ignore) + proj.process_git_history(ignore) + + return { + "mode": "full_reindex", + "files_added": 0, + "files_modified": 0, + "files_deleted": 0, + "commit": get_repo_commit(repo_name), + } + + +def _sync_repo_graph( + repo_name: str, + repo_path: Path, + target_sha: str, + *, + before_sha: str | None = None, + repo_url: str = "", + ignore: list[str] | None = None, +) -> dict: + if ignore is None: + ignore = [] + + if not repo_path.exists(): + return _full_reindex_repository( + repo_name, + repo_path, + repo_url, + ignore, + "local clone missing", + ) + + stored_sha = get_repo_commit(repo_name) + if not stored_sha: + return _full_reindex_repository( + repo_name, + repo_path, + repo_url, + ignore, + "missing stored commit bookmark", + ) + + if not can_incrementally_update(repo_path, stored_sha, target_sha, before_sha): + return _full_reindex_repository( + repo_name, + repo_path, + repo_url, + ignore, + ( + f"stored bookmark {stored_sha} does not align with " + f"before={before_sha or ''} and target={target_sha}" + ), + ) + + return incremental_update(repo_name, stored_sha, target_sha, ignore) + # --------------------------------------------------------------------------- # Background poll-watcher helpers (synchronous, run in thread-pool executor) # --------------------------------------------------------------------------- @@ -174,29 +340,28 @@ def _poll_repo(repo_name: str) -> None: return current_sha = get_repo_commit(repo_name) - if not current_sha: - logger.debug("Poll: no stored commit for '%s', skipping", repo_name) - return - - # Handle comparison between short (7-char) and full (40-char) SHAs: a short - # stored SHA is a valid prefix of a full remote SHA for the same commit. - # We only apply prefix matching when the stored SHA is shorter. - if len(current_sha) < len(remote_head): - up_to_date = remote_head.startswith(current_sha) - elif len(current_sha) > len(remote_head): - up_to_date = current_sha.startswith(remote_head) + if current_sha: + # Handle comparison between short (7-char) and full (40-char) SHAs: a short + # stored SHA is a valid prefix of a full remote SHA for the same commit. + # We only apply prefix matching when the stored SHA is shorter. + if len(current_sha) < len(remote_head): + up_to_date = remote_head.startswith(current_sha) + elif len(current_sha) > len(remote_head): + up_to_date = current_sha.startswith(remote_head) + else: + up_to_date = current_sha == remote_head + if up_to_date: + logger.debug("Poll: '%s' is up-to-date at %s", repo_name, current_sha) + return else: - up_to_date = current_sha == remote_head - if up_to_date: - logger.debug("Poll: '%s' is up-to-date at %s", repo_name, current_sha) - return + logger.warning("Poll: '%s' has no stored bookmark; forcing a full reindex", repo_name) logger.info( "Poll: new commits detected for '%s' (%s -> %s), updating …", repo_name, current_sha, remote_head, ) try: - result = incremental_update(repo_name, current_sha, remote_head) + result = _sync_repo_graph(repo_name, path, remote_head) logger.info("Poll: '%s' updated — %s", repo_name, result) except Exception as exc: logger.exception( @@ -231,6 +396,7 @@ async def _poll_loop() -> None: @contextlib.asynccontextmanager async def _lifespan(application: FastAPI): + _log_webhook_auth_mode() poll_task = None if POLL_INTERVAL > 0: poll_task = asyncio.create_task(_poll_loop()) @@ -437,27 +603,21 @@ async def list_commits(data: RepoRequest, _=Depends(public_or_auth)): async def webhook(request: Request): """Receive a GitHub/GitLab push event and trigger an incremental graph update. - When ``WEBHOOK_SECRET`` is set the endpoint validates the - ``X-Hub-Signature-256`` header using HMAC-SHA256; requests with a missing - or invalid signature are rejected with **401 Unauthorized**. + When ``WEBHOOK_SECRET`` is set the endpoint validates GitHub's + ``X-Hub-Signature-256`` HMAC signature or GitLab's ``X-Gitlab-Token``. + Without ``WEBHOOK_SECRET`` the endpoint falls back to the standard bearer + token auth used by the other mutating routes. Only pushes to the branch configured via ``TRACKED_BRANCH`` (default ``main``) trigger an update; pushes to other branches are acknowledged with a ``200 ignored`` response so that GitHub does not retry them. - The repository is identified by matching the ``repository.clone_url`` - field in the payload against the URLs stored for already-indexed - repositories. + The repository is identified by matching the payload's repository URL + (`repository.clone_url`, `repository.git_http_url`, or `project.git_http_url`) + against the URLs stored for already-indexed repositories. """ body = await request.body() - - # Validate HMAC-SHA256 signature when a secret is configured - if WEBHOOK_SECRET: - sig_header = request.headers.get("X-Hub-Signature-256", "") - mac = hmac.new(WEBHOOK_SECRET.encode(), body, hashlib.sha256) - expected_sig = "sha256=" + mac.hexdigest() - if not hmac.compare_digest(sig_header, expected_sig): - raise HTTPException(status_code=401, detail="Invalid webhook signature") + _authenticate_webhook_request(request, body) try: payload = await request.json() @@ -467,7 +627,7 @@ async def webhook(request: Request): ref = payload.get("ref", "") before = payload.get("before", "") after = payload.get("after", "") - repo_url = payload.get("repository", {}).get("clone_url", "") + repo_url = _extract_repo_url(payload) # Only process pushes to the configured tracked branch expected_ref = f"refs/heads/{TRACKED_BRANCH}" @@ -478,7 +638,10 @@ async def webhook(request: Request): if not before or not after or not repo_url: raise HTTPException( status_code=400, - detail="Payload missing required fields: ref, before, after, repository.clone_url", + detail=( + "Payload missing required fields: ref, before, after, and a repository URL " + "(repository.clone_url, repository.git_http_url, or project.git_http_url)" + ), ) # Resolve the repository name from the stored index @@ -498,7 +661,13 @@ def _update() -> dict: path = repo_local_path(repo_name) if path.exists(): fetch_remote(path) - return incremental_update(repo_name, before, after) + return _sync_repo_graph( + repo_name, + path, + after, + before_sha=before, + repo_url=repo_url, + ) loop = asyncio.get_running_loop() try: diff --git a/tests/test_incremental_update.py b/tests/test_incremental_update.py new file mode 100644 index 00000000..c0ffe0d9 --- /dev/null +++ b/tests/test_incremental_update.py @@ -0,0 +1,151 @@ +from contextlib import contextmanager +import importlib + +from api.analyzers.python.analyzer import PythonAnalyzer + + +class _DummyLSP: + def __init__(self, locations): + self._locations = locations + + def request_definition(self, *_args, **_kwargs): + return self._locations + + +class _DummyGraphNode: + def __init__(self, node_id: int): + self.id = node_id + + +class _DummyGraphLookup: + def __init__(self, node_id: int): + self.node_id = node_id + self.calls = [] + + def get_entity_at_position(self, path, line, labels): + self.calls.append((path, line, labels)) + return _DummyGraphNode(self.node_id) + + +def test_python_resolve_symbol_uses_graph_fallback(tmp_path): + """Cross-file resolution falls back to graph lookups for unchanged files.""" + analyzer = PythonAnalyzer() + caller = tmp_path / "caller.py" + target = tmp_path / "target.py" + caller.write_text("foo()\n") + target.write_text("def foo():\n pass\n") + + tree = analyzer.parser.parse(caller.read_bytes()) + call_node = analyzer._captures("(call) @call", tree.root_node)["call"][0] + graph = _DummyGraphLookup(42) + lsp = _DummyLSP( + [ + { + "absolutePath": str(target), + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 1, "character": 0}, + }, + } + ] + ) + + resolved = analyzer.resolve_symbol({}, lsp, caller, tmp_path, graph, "call", call_node) + + assert [entity.id for entity in resolved] == [42] + assert graph.calls == [(str(target), 0, ["Function", "Class"])] + + +def test_incremental_update_reprocesses_dependents_under_repo_lock(monkeypatch, tmp_path): + """Incremental updates expand transitive dependents and hold the repo lock.""" + incremental_update_module = importlib.import_module("api.git_utils.incremental_update") + repo_path = tmp_path / "repo" + repo_path.mkdir() + operations = [] + + class _FakeCommit: + def __init__(self, sha): + self.id = sha + self.short_id = sha[:7] + self.tree = object() + + class _FakeRepo: + def revparse_single(self, sha): + return _FakeCommit(sha) + + def diff(self, _from_commit, _to_commit): + return object() + + def checkout_tree(self, _tree, strategy=None): + operations.append(("checkout", strategy)) + + def set_head_detached(self, commit_id): + operations.append(("detach", commit_id)) + + class _FakeAnalyzer: + def supported_types(self): + return [".py"] + + def analyze_files(self, files, path, graph): + operations.append(("analyze", [file.name for file in files], path, graph)) + + class _FakeGraph: + def __init__(self, name): + self.name = name + + def get_direct_dependent_files(self, files): + names = tuple(file.name for file in files) + operations.append(("dependents", names)) + if names == ("deleted.py", "modified.py"): + return [repo_path / "caller.py"] + if names == ("caller.py",): + return [repo_path / "transitive.py"] + return [] + + def delete_files(self, files): + operations.append(("delete", [file.name for file in files])) + + @contextmanager + def _fake_repo_lock(repo_name): + operations.append(("lock-enter", repo_name)) + try: + yield + finally: + operations.append(("lock-exit", repo_name)) + + monkeypatch.setattr(incremental_update_module, "repo_local_path", lambda _name: repo_path) + monkeypatch.setattr(incremental_update_module, "Repository", lambda _path: _FakeRepo()) + monkeypatch.setattr(incremental_update_module, "SourceAnalyzer", _FakeAnalyzer) + monkeypatch.setattr(incremental_update_module, "Graph", _FakeGraph) + monkeypatch.setattr( + incremental_update_module, + "classify_changes", + lambda _diff, _repo, _supported, _ignore: ( + [repo_path / "added.py"], + [repo_path / "deleted.py"], + [repo_path / "modified.py"], + ), + ) + monkeypatch.setattr( + incremental_update_module, + "set_repo_commit", + lambda repo_name, commit: operations.append(("bookmark", repo_name, commit)), + ) + monkeypatch.setattr(incremental_update_module, "repo_update_lock", _fake_repo_lock) + + result = incremental_update_module.incremental_update("repo", "aaaa111", "bbbb222") + + assert result == { + "files_added": 1, + "files_modified": 1, + "files_deleted": 1, + "commit": "bbbb222", + } + assert ("delete", ["deleted.py", "modified.py"]) in operations + analyze_call = next(op for op in operations if op[0] == "analyze") + assert analyze_call[1] == ["added.py", "modified.py", "caller.py", "transitive.py"] + assert analyze_call[2] == repo_path + assert operations[0] == ("lock-enter", "repo") + assert operations[-1] == ("lock-exit", "repo") + assert operations.index(("lock-enter", "repo")) < operations.index(("checkout", incremental_update_module.CheckoutStrategy.FORCE)) + assert operations.index(("bookmark", "repo", "bbbb222")) < operations.index(("lock-exit", "repo")) diff --git a/tests/test_webhook.py b/tests/test_webhook.py index c5fdc851..7ab0092e 100644 --- a/tests/test_webhook.py +++ b/tests/test_webhook.py @@ -93,28 +93,42 @@ def test_urls_no_match_different_repo(): # --------------------------------------------------------------------------- -# Webhook endpoint – no secret configured (open mode) +# Webhook endpoint – bearer token fallback mode # --------------------------------------------------------------------------- @pytest.fixture() -def client_open(monkeypatch): - """Test client with no webhook secret and no poll-watcher.""" +def client_token_auth(monkeypatch): + """Test client with bearer-token webhook auth and no poll-watcher.""" monkeypatch.setattr(api.index, "WEBHOOK_SECRET", "") + monkeypatch.setattr(api.index, "SECRET_TOKEN", "apitoken") monkeypatch.setattr(api.index, "POLL_INTERVAL", 0) return TestClient(api.index.app, raise_server_exceptions=False) -def test_webhook_ignored_wrong_branch(client_open, monkeypatch): +@pytest.fixture() +def client_misconfigured(monkeypatch): + """Test client with webhook auth disabled entirely.""" + monkeypatch.setattr(api.index, "WEBHOOK_SECRET", "") + monkeypatch.setattr(api.index, "SECRET_TOKEN", None) + monkeypatch.setattr(api.index, "POLL_INTERVAL", 0) + return TestClient(api.index.app, raise_server_exceptions=False) + + +def test_webhook_ignored_wrong_branch(client_token_auth, monkeypatch): """Pushes to non-tracked branches return 200 with status='ignored'.""" monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") payload = _make_push_payload(ref="refs/heads/feature/x") - resp = client_open.post("/api/webhook", json=payload) + resp = client_token_auth.post( + "/api/webhook", + json=payload, + headers={"Authorization": "Bearer apitoken"}, + ) assert resp.status_code == 200 data = resp.json() assert data["status"] == "ignored" -def test_webhook_unknown_repo(client_open, monkeypatch): +def test_webhook_unknown_repo(client_token_auth, monkeypatch): """Webhook for a repo URL that is not indexed returns 404.""" monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") @@ -125,11 +139,15 @@ async def _fake_get_repos(): monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) payload = _make_push_payload() - resp = client_open.post("/api/webhook", json=payload) + resp = client_token_auth.post( + "/api/webhook", + json=payload, + headers={"Authorization": "Bearer apitoken"}, + ) assert resp.status_code == 404 -def test_webhook_success(client_open, monkeypatch): +def test_webhook_success(client_token_auth, monkeypatch): """Valid push to tracked branch triggers incremental_update and returns stats.""" monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") @@ -141,8 +159,8 @@ async def _fake_get_repo_info(repo_name): update_calls = [] - def _fake_update(repo_name, from_sha, to_sha, ignore=None): - update_calls.append((repo_name, from_sha, to_sha)) + def _fake_sync(repo_name, path, to_sha, before_sha=None, repo_url="", ignore=None): + update_calls.append((repo_name, before_sha, to_sha, repo_url)) return { "files_added": 1, "files_modified": 0, @@ -152,20 +170,45 @@ def _fake_update(repo_name, from_sha, to_sha, ignore=None): monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) monkeypatch.setattr(api.index, "async_get_repo_info", _fake_get_repo_info) - monkeypatch.setattr(api.index, "incremental_update", _fake_update) + monkeypatch.setattr(api.index, "_sync_repo_graph", _fake_sync) # Skip git fetch (no real clone) monkeypatch.setattr(api.index, "fetch_remote", lambda path: None) monkeypatch.setattr(api.index, "repo_local_path", lambda name: _FakePath(exists=False)) payload = _make_push_payload() - resp = client_open.post("/api/webhook", json=payload) + resp = client_token_auth.post( + "/api/webhook", + json=payload, + headers={"Authorization": "Bearer apitoken"}, + ) assert resp.status_code == 200 data = resp.json() assert data["status"] == "success" assert data["files_added"] == 1 assert len(update_calls) == 1 - assert update_calls[0] == ("myrepo", _FULL_SHA_BEFORE, _FULL_SHA_AFTER) + assert update_calls[0] == ( + "myrepo", + _FULL_SHA_BEFORE, + _FULL_SHA_AFTER, + "https://github.com/example/myrepo.git", + ) + + +def test_webhook_requires_bearer_token_when_secret_missing(client_token_auth, monkeypatch): + """Bearer token auth protects the webhook when WEBHOOK_SECRET is unset.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + resp = client_token_auth.post("/api/webhook", json=payload) + assert resp.status_code == 401 + + +def test_webhook_rejected_when_no_auth_is_configured(client_misconfigured, monkeypatch): + """The webhook returns 503 when neither WEBHOOK_SECRET nor SECRET_TOKEN is set.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + resp = client_misconfigured.post("/api/webhook", json=payload) + assert resp.status_code == 503 # --------------------------------------------------------------------------- @@ -215,7 +258,7 @@ async def _fake_get_repo_info(repo_name): monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) monkeypatch.setattr(api.index, "async_get_repo_info", _fake_get_repo_info) - monkeypatch.setattr(api.index, "incremental_update", lambda *a, **kw: { + monkeypatch.setattr(api.index, "_sync_repo_graph", lambda *a, **kw: { "files_added": 0, "files_modified": 0, "files_deleted": 0, "commit": "abc1234", }) monkeypatch.setattr(api.index, "fetch_remote", lambda path: None) @@ -234,17 +277,136 @@ async def _fake_get_repo_info(repo_name): assert resp.json()["status"] == "success" -def test_webhook_invalid_json(client_open, monkeypatch): +def test_gitlab_webhook_token_accepted(client_secured, monkeypatch): + """GitLab webhooks authenticate via X-Gitlab-Token and git_http_url payloads.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + + async def _fake_get_repos(): + return ["myrepo"] + + async def _fake_get_repo_info(repo_name): + return {"repo_url": "https://gitlab.com/example/myrepo.git"} + + monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) + monkeypatch.setattr(api.index, "async_get_repo_info", _fake_get_repo_info) + monkeypatch.setattr(api.index, "_sync_repo_graph", lambda *a, **kw: { + "files_added": 0, "files_modified": 0, "files_deleted": 0, "commit": "abc1234", + }) + monkeypatch.setattr(api.index, "fetch_remote", lambda path: None) + monkeypatch.setattr(api.index, "repo_local_path", lambda name: _FakePath(exists=False)) + + payload = { + "ref": "refs/heads/main", + "before": _FULL_SHA_BEFORE, + "after": _FULL_SHA_AFTER, + "repository": {"git_http_url": "https://gitlab.com/example/myrepo.git"}, + } + resp = client_secured.post( + "/api/webhook", + json=payload, + headers={"X-Gitlab-Token": "mysecret", "X-Gitlab-Event": "Push Hook"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + +def test_gitlab_webhook_missing_token_rejected(client_secured, monkeypatch): + """GitLab requests without X-Gitlab-Token are rejected.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + resp = client_secured.post( + "/api/webhook", + json=payload, + headers={"X-Gitlab-Event": "Push Hook"}, + ) + assert resp.status_code == 401 + + +def test_webhook_invalid_json(client_token_auth, monkeypatch): """Non-JSON bodies are rejected with 400.""" monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") - resp = client_open.post( + resp = client_token_auth.post( "/api/webhook", content=b"not-json", - headers={"Content-Type": "application/json"}, + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer apitoken", + }, ) assert resp.status_code == 400 +def test_sync_repo_graph_uses_stored_bookmark(monkeypatch, tmp_path): + """Incremental sync uses the stored bookmark instead of payload.before.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + calls = [] + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: "stored123") + monkeypatch.setattr(api.index, "can_incrementally_update", lambda *args, **kwargs: True) + monkeypatch.setattr( + api.index, + "incremental_update", + lambda repo_name, from_sha, to_sha, ignore=None: calls.append( + (repo_name, from_sha, to_sha, ignore) + ) or { + "files_added": 0, + "files_modified": 0, + "files_deleted": 0, + "commit": to_sha[:7], + }, + ) + + api.index._sync_repo_graph( + "myrepo", + repo_path, + _FULL_SHA_AFTER, + before_sha=_FULL_SHA_BEFORE, + ) + + assert calls == [("myrepo", "stored123", _FULL_SHA_AFTER, [])] + + +def test_sync_repo_graph_full_reindexes_without_bookmark(monkeypatch, tmp_path): + """Missing bookmarks fall back to a full reindex instead of partial diffing.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: None) + monkeypatch.setattr( + api.index, + "_full_reindex_repository", + lambda *args, **kwargs: {"mode": "full_reindex", "commit": "abc1234"}, + ) + + result = api.index._sync_repo_graph("myrepo", repo_path, _FULL_SHA_AFTER) + + assert result["mode"] == "full_reindex" + + +def test_sync_repo_graph_full_reindexes_on_history_gap(monkeypatch, tmp_path): + """History gaps or force-pushes fall back to a full reindex.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: "stored123") + monkeypatch.setattr(api.index, "can_incrementally_update", lambda *args, **kwargs: False) + monkeypatch.setattr( + api.index, + "_full_reindex_repository", + lambda *args, **kwargs: {"mode": "full_reindex", "commit": "abc1234"}, + ) + + result = api.index._sync_repo_graph( + "myrepo", + repo_path, + _FULL_SHA_AFTER, + before_sha=_FULL_SHA_BEFORE, + ) + + assert result["mode"] == "full_reindex" + + # --------------------------------------------------------------------------- # incremental_update – unit tests (no live DB/git) # --------------------------------------------------------------------------- From 29c2b4d1cc1c5b9bbad31773cc1e5ca6f51ea3dc Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Wed, 18 Mar 2026 09:44:35 +0200 Subject: [PATCH 4/8] fix(api): address PR review comments for webhook/incremental update - Sanitize error response in webhook endpoint to avoid exposing internal exception details (information exposure via str(exc)) - Normalize SHA format in incremental_update no-op response to use consistent short SHA (to_sha[:7]) matching the non-noop path - Add comprehensive poll-watcher unit tests covering: missing clone, fetch failure, up-to-date skip, behind trigger, no remote head, and missing bookmark scenarios Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- api/git_utils/incremental_update.py | 2 +- api/index.py | 5 +- tests/test_webhook.py | 137 +++++++++++++++++++++++++++- 3 files changed, 141 insertions(+), 3 deletions(-) diff --git a/api/git_utils/incremental_update.py b/api/git_utils/incremental_update.py index e6514037..87aeacce 100644 --- a/api/git_utils/incremental_update.py +++ b/api/git_utils/incremental_update.py @@ -216,7 +216,7 @@ def incremental_update( "files_added": 0, "files_modified": 0, "files_deleted": 0, - "commit": to_sha, + "commit": to_sha[:7], } repo_path = repo_local_path(repo_name) diff --git a/api/index.py b/api/index.py index c566f6d3..a171ebfd 100644 --- a/api/index.py +++ b/api/index.py @@ -676,7 +676,10 @@ def _update() -> dict: logger.exception( "Webhook: incremental update failed for '%s': %s", repo_name, exc ) - return JSONResponse({"status": "error", "detail": str(exc)}, status_code=500) + return JSONResponse( + {"status": "error", "detail": "Incremental update failed"}, + status_code=500, + ) return {"status": "success", **result} diff --git a/tests/test_webhook.py b/tests/test_webhook.py index 7ab0092e..004aa8a1 100644 --- a/tests/test_webhook.py +++ b/tests/test_webhook.py @@ -430,7 +430,7 @@ def test_incremental_update_idempotent(monkeypatch, tmp_path): assert result["files_added"] == 0 assert result["files_modified"] == 0 assert result["files_deleted"] == 0 - assert result["commit"] == sha + assert result["commit"] == sha[:7] assert writes == [], "set_repo_commit must not be called for no-op update" @@ -447,3 +447,138 @@ def test_incremental_update_missing_repo(monkeypatch, tmp_path): with pytest.raises(ValueError, match="Local repository not found"): _iu("some-repo", "aaa1111", "bbb2222") + + +# --------------------------------------------------------------------------- +# Poll-watcher – unit tests +# --------------------------------------------------------------------------- + +def test_poll_repo_skips_missing_clone(monkeypatch, tmp_path): + """_poll_repo returns early when the local clone does not exist.""" + monkeypatch.setattr( + api.index, "repo_local_path", lambda name: tmp_path / "nonexistent" + ) + fetch_calls = [] + monkeypatch.setattr(api.index, "fetch_remote", lambda p: fetch_calls.append(p)) + + api.index._poll_repo("myrepo") + + assert fetch_calls == [], "fetch_remote should not be called for missing clones" + + +def test_poll_repo_handles_fetch_failure(monkeypatch, tmp_path): + """_poll_repo logs a warning and returns when git fetch fails.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + monkeypatch.setattr(api.index, "repo_local_path", lambda name: repo_path) + monkeypatch.setattr( + api.index, + "fetch_remote", + lambda p: (_ for _ in ()).throw(RuntimeError("network error")), + ) + sync_calls = [] + monkeypatch.setattr( + api.index, + "_sync_repo_graph", + lambda *a, **kw: sync_calls.append(1), + ) + + api.index._poll_repo("myrepo") + + assert sync_calls == [], "sync should not be called when fetch fails" + + +def test_poll_repo_skips_when_up_to_date(monkeypatch, tmp_path): + """_poll_repo does nothing when stored bookmark matches remote HEAD.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + monkeypatch.setattr(api.index, "repo_local_path", lambda name: repo_path) + monkeypatch.setattr(api.index, "fetch_remote", lambda p: None) + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + monkeypatch.setattr( + api.index, "get_remote_head", lambda p, b: "abcdef1234567890" * 2 + "abcdef12" + ) + # Stored bookmark is a short SHA prefix of the remote HEAD + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: "abcdef1") + + sync_calls = [] + monkeypatch.setattr( + api.index, + "_sync_repo_graph", + lambda *a, **kw: sync_calls.append(1), + ) + + api.index._poll_repo("myrepo") + + assert sync_calls == [], "sync should not be called when repo is up-to-date" + + +def test_poll_repo_triggers_sync_when_behind(monkeypatch, tmp_path): + """_poll_repo calls _sync_repo_graph when remote HEAD has advanced.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + monkeypatch.setattr(api.index, "repo_local_path", lambda name: repo_path) + monkeypatch.setattr(api.index, "fetch_remote", lambda p: None) + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + remote_sha = "bbbb2222" * 5 + monkeypatch.setattr(api.index, "get_remote_head", lambda p, b: remote_sha) + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: "aaaa111") + + sync_calls = [] + + def _fake_sync(repo_name, path, target_sha, **kwargs): + sync_calls.append((repo_name, target_sha)) + return {"files_added": 1, "files_modified": 0, "files_deleted": 0, "commit": "bbbb222"} + + monkeypatch.setattr(api.index, "_sync_repo_graph", _fake_sync) + + api.index._poll_repo("myrepo") + + assert len(sync_calls) == 1 + assert sync_calls[0] == ("myrepo", remote_sha) + + +def test_poll_repo_handles_no_remote_head(monkeypatch, tmp_path): + """_poll_repo returns early when the remote branch HEAD cannot be resolved.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + monkeypatch.setattr(api.index, "repo_local_path", lambda name: repo_path) + monkeypatch.setattr(api.index, "fetch_remote", lambda p: None) + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + monkeypatch.setattr(api.index, "get_remote_head", lambda p, b: None) + + sync_calls = [] + monkeypatch.setattr( + api.index, + "_sync_repo_graph", + lambda *a, **kw: sync_calls.append(1), + ) + + api.index._poll_repo("myrepo") + + assert sync_calls == [], "sync should not be called when remote HEAD is unknown" + + +def test_poll_repo_forces_reindex_without_bookmark(monkeypatch, tmp_path): + """_poll_repo triggers sync even without a stored bookmark (full reindex path).""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + monkeypatch.setattr(api.index, "repo_local_path", lambda name: repo_path) + monkeypatch.setattr(api.index, "fetch_remote", lambda p: None) + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + remote_sha = "cccc3333" * 5 + monkeypatch.setattr(api.index, "get_remote_head", lambda p, b: remote_sha) + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: None) + + sync_calls = [] + + def _fake_sync(repo_name, path, target_sha, **kwargs): + sync_calls.append((repo_name, target_sha)) + return {"mode": "full_reindex", "commit": "cccc333"} + + monkeypatch.setattr(api.index, "_sync_repo_graph", _fake_sync) + + api.index._poll_repo("myrepo") + + assert len(sync_calls) == 1 + assert sync_calls[0] == ("myrepo", remote_sha) From 4f3a97add5ebb05c7f114a6069734c4af06b735b Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Wed, 18 Mar 2026 09:55:51 +0200 Subject: [PATCH 5/8] fix(api): checkout target commit before full reindex to avoid stale analysis When _full_reindex_repository is triggered (e.g., after a force-push or missing bookmark), the working tree may still be at the old commit. The reindex would analyze stale files and set the bookmark to the old commit, creating an infinite retry loop. Now accepts target_sha and checks out the target commit before analysis when a local clone exists. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- api/index.py | 11 +++++++++++ tests/test_webhook.py | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/api/index.py b/api/index.py index a171ebfd..f6cbcc92 100644 --- a/api/index.py +++ b/api/index.py @@ -234,6 +234,7 @@ def _full_reindex_repository( repo_url: str = "", ignore: list[str] | None = None, reason: str = "", + target_sha: str | None = None, ) -> dict: if ignore is None: ignore = [] @@ -249,6 +250,14 @@ def _full_reindex_repository( delete_graph_if_exists(git_utils.GitRepoName(repo_name)) if repo_path.exists(): + if target_sha: + from pygit2.enums import CheckoutStrategy + from pygit2.repository import Repository + repo = Repository(str(repo_path)) + target_commit = repo.revparse_single(target_sha) + repo.checkout_tree(target_commit.tree, strategy=CheckoutStrategy.FORCE) + repo.set_head_detached(target_commit.id) + logger.info("Checked out target commit %s before full reindex", target_sha[:8]) proj = Project.from_local_repository(repo_path) elif repo_url: proj = Project.from_git_repository(repo_url) @@ -298,6 +307,7 @@ def _sync_repo_graph( repo_url, ignore, "missing stored commit bookmark", + target_sha=target_sha, ) if not can_incrementally_update(repo_path, stored_sha, target_sha, before_sha): @@ -310,6 +320,7 @@ def _sync_repo_graph( f"stored bookmark {stored_sha} does not align with " f"before={before_sha or ''} and target={target_sha}" ), + target_sha=target_sha, ) return incremental_update(repo_name, stored_sha, target_sha, ignore) diff --git a/tests/test_webhook.py b/tests/test_webhook.py index 004aa8a1..b862e906 100644 --- a/tests/test_webhook.py +++ b/tests/test_webhook.py @@ -372,16 +372,21 @@ def test_sync_repo_graph_full_reindexes_without_bookmark(monkeypatch, tmp_path): repo_path = tmp_path / "repo" repo_path.mkdir() + reindex_calls = [] + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: None) monkeypatch.setattr( api.index, "_full_reindex_repository", - lambda *args, **kwargs: {"mode": "full_reindex", "commit": "abc1234"}, + lambda *args, **kwargs: reindex_calls.append(kwargs) or { + "mode": "full_reindex", "commit": "abc1234", + }, ) result = api.index._sync_repo_graph("myrepo", repo_path, _FULL_SHA_AFTER) assert result["mode"] == "full_reindex" + assert reindex_calls[0].get("target_sha") == _FULL_SHA_AFTER def test_sync_repo_graph_full_reindexes_on_history_gap(monkeypatch, tmp_path): @@ -389,12 +394,16 @@ def test_sync_repo_graph_full_reindexes_on_history_gap(monkeypatch, tmp_path): repo_path = tmp_path / "repo" repo_path.mkdir() + reindex_calls = [] + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: "stored123") monkeypatch.setattr(api.index, "can_incrementally_update", lambda *args, **kwargs: False) monkeypatch.setattr( api.index, "_full_reindex_repository", - lambda *args, **kwargs: {"mode": "full_reindex", "commit": "abc1234"}, + lambda *args, **kwargs: reindex_calls.append(kwargs) or { + "mode": "full_reindex", "commit": "abc1234", + }, ) result = api.index._sync_repo_graph( @@ -405,6 +414,7 @@ def test_sync_repo_graph_full_reindexes_on_history_gap(monkeypatch, tmp_path): ) assert result["mode"] == "full_reindex" + assert reindex_calls[0].get("target_sha") == _FULL_SHA_AFTER # --------------------------------------------------------------------------- From 81dc67f687c03c3d9fe84e26d5948e9756e722f8 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Wed, 18 Mar 2026 10:41:31 +0200 Subject: [PATCH 6/8] docs(readme): add webhook setup instructions for GitHub and GitLab Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- README.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/README.md b/README.md index 48dbcaa6..c2bf2ad0 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,29 @@ The chat endpoint also needs the provider credential expected by your chosen `MO Continuous graph updates can be triggered either by posting a GitHub/GitLab push payload to `/api/webhook` or by enabling the background poll-watcher with `POLL_INTERVAL > 0`. When `WEBHOOK_SECRET` is unset, `/api/webhook` falls back to the same bearer-token auth used by the other mutating endpoints. +#### Setting up a webhook + +After indexing a repository with `/api/analyze_repo`, you can register a webhook so the graph stays in sync automatically. + +**GitHub:** + +1. Go to your repository → **Settings** → **Webhooks** → **Add webhook**. +2. Set **Payload URL** to `https:///api/webhook`. +3. Set **Content type** to `application/json`. +4. Set **Secret** to the same value as your `WEBHOOK_SECRET` environment variable. +5. Under **Which events?**, select **Just the push event**. +6. Click **Add webhook**. + +**GitLab:** + +1. Go to your project → **Settings** → **Webhooks** → **Add new webhook**. +2. Set **URL** to `https:///api/webhook`. +3. Set **Secret token** to the same value as your `WEBHOOK_SECRET` environment variable. +4. Check **Push events** as the trigger. +5. Click **Add webhook**. + +> **Tip:** If you cannot configure a webhook (e.g. you don't have admin access), enable the background poll-watcher instead by setting `POLL_INTERVAL` to a non-zero value (in seconds). It will periodically check the remote for new commits on `TRACKED_BRANCH`. + ### 3. Install dependencies ```bash From b99ff2eed20b9e832ca6dcad4f3c447c64f484f4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:13:00 +0000 Subject: [PATCH 7/8] fix(review): address PR review feedback - api/graph.py: Add _VALID_ENTITY_LABELS allowlist to get_entity_at_position to prevent Cypher injection via f-string label interpolation - api/graph.py: Extract _make_falkordb_connection() helper; delete_graph_if_exists now accepts optional db param so callers can share a single connection - api/index.py: Move pygit2 imports (CheckoutStrategy, Repository) to module-level; add subprocess and re to module-level imports - api/index.py: Replace fragile prefix-based SHA comparison with git rev-parse to resolve stored short SHA to full 40-char form for unambiguous comparison; validate SHA hex format before passing to git - api/index.py: Use raise ... from exc for webhook JSON parse error - api/analyzers/{python,java,csharp}/analyzer.py: Fix graph param type to Graph and return types to list[Entity | ResolvedEntityRef] in resolve_type, resolve_method, resolve_symbol; add ResolvedEntityRef and Graph imports - tests/test_webhook.py: Update poll-watcher tests to mock subprocess.run; add command assertions in mock callbacks Co-authored-by: gkorland <753206+gkorland@users.noreply.github.com> --- api/analyzers/csharp/analyzer.py | 9 +++--- api/analyzers/java/analyzer.py | 11 +++---- api/analyzers/python/analyzer.py | 11 +++---- api/graph.py | 50 ++++++++++++++++++++++---------- api/index.py | 49 +++++++++++++++++++------------ tests/test_webhook.py | 27 ++++++++++++++--- 6 files changed, 105 insertions(+), 52 deletions(-) diff --git a/api/analyzers/csharp/analyzer.py b/api/analyzers/csharp/analyzer.py index 61e8c8a6..317888b5 100644 --- a/api/analyzers/csharp/analyzer.py +++ b/api/analyzers/csharp/analyzer.py @@ -5,7 +5,8 @@ from ...entities.entity import Entity from ...entities.file import File from typing import Optional -from ..analyzer import AbstractAnalyzer +from ..analyzer import AbstractAnalyzer, ResolvedEntityRef +from ...graph import Graph import tree_sitter_c_sharp as tscsharp from tree_sitter import Language, Node @@ -105,7 +106,7 @@ def is_dependency(self, file_path: str) -> bool: def resolve_path(self, file_path: str, path: Path) -> str: return file_path - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, node: Node) -> list[Entity | ResolvedEntityRef]: return self.resolve_entities( files, lsp, @@ -117,7 +118,7 @@ def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_pa ['Class', 'Interface', 'Enum', 'Struct'], ) - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, node: Node) -> list[Entity | ResolvedEntityRef]: if node.type == 'invocation_expression': func_node = node.child_by_field_name('function') if func_node and func_node.type == 'member_access_expression': @@ -136,7 +137,7 @@ def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ {'class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration'}, ) - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, key: str, symbol: Node) -> list[Entity | ResolvedEntityRef]: if key in ["implement_interface", "base_class", "extend_interface", "parameters", "return_type"]: return self.resolve_type(files, lsp, file_path, path, graph, symbol) elif key in ["call"]: diff --git a/api/analyzers/java/analyzer.py b/api/analyzers/java/analyzer.py index 1ce80f80..55dc38e2 100644 --- a/api/analyzers/java/analyzer.py +++ b/api/analyzers/java/analyzer.py @@ -4,7 +4,8 @@ from ...entities.entity import Entity from ...entities.file import File from typing import Optional -from ..analyzer import AbstractAnalyzer +from ..analyzer import AbstractAnalyzer, ResolvedEntityRef +from ...graph import Graph from multilspy import SyncLanguageServer @@ -103,7 +104,7 @@ def resolve_path(self, file_path: str, path: Path) -> str: return f"{path}/temp_deps/{args[1]}/{targs}/{args[-1]}" return file_path - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, node: Node) -> list[Entity | ResolvedEntityRef]: return self.resolve_entities( files, lsp, @@ -115,7 +116,7 @@ def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_pa ['Class', 'Interface', 'Enum'], ) - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, node: Node) -> list[Entity | ResolvedEntityRef]: return self.resolve_entities( files, lsp, @@ -127,8 +128,8 @@ def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ ['Method', 'Constructor'], {'class_declaration', 'interface_declaration', 'enum_declaration'}, ) - - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, key: str, symbol: Node) -> list[Entity]: + + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, key: str, symbol: Node) -> list[Entity | ResolvedEntityRef]: if key in ["implement_interface", "base_class", "extend_interface", "parameters", "return_type"]: return self.resolve_type(files, lsp, file_path, path, graph, symbol) elif key in ["call"]: diff --git a/api/analyzers/python/analyzer.py b/api/analyzers/python/analyzer.py index a63d0b4b..200221df 100644 --- a/api/analyzers/python/analyzer.py +++ b/api/analyzers/python/analyzer.py @@ -7,7 +7,8 @@ from ...entities.entity import Entity from ...entities.file import File from typing import Optional -from ..analyzer import AbstractAnalyzer +from ..analyzer import AbstractAnalyzer, ResolvedEntityRef +from ...graph import Graph import tree_sitter_python as tspython from tree_sitter import Language, Node @@ -92,7 +93,7 @@ def is_dependency(self, file_path: str) -> bool: def resolve_path(self, file_path: str, path: Path) -> str: return file_path - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path, graph, node: Node) -> list[Entity]: + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, node: Node) -> list[Entity | ResolvedEntityRef]: if node.type == 'attribute': node = node.child_by_field_name('attribute') return self.resolve_entities( @@ -106,7 +107,7 @@ def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_pa ['Class'], ) - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, node: Node) -> list[Entity | ResolvedEntityRef]: if node.type == 'call': node = node.child_by_field_name('function') if node.type == 'attribute': @@ -121,8 +122,8 @@ def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ ['function_definition', 'class_definition'], ['Function', 'Class'], ) - - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, key: str, symbol: Node) -> list[Entity]: + + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, key: str, symbol: Node) -> list[Entity | ResolvedEntityRef]: if key in ["base_class", "parameters", "return_type"]: return self.resolve_type(files, lsp, file_path, path, graph, symbol) elif key in ["call"]: diff --git a/api/graph.py b/api/graph.py index 83f959d0..e25e1401 100644 --- a/api/graph.py +++ b/api/graph.py @@ -10,12 +10,18 @@ logging.basicConfig(level=logging.DEBUG, format='%(filename)s - %(asctime)s - %(levelname)s - %(message)s') -def graph_exists(name: str): - db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), - port=os.getenv('FALKORDB_PORT', 6379), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None)) +def _make_falkordb_connection() -> FalkorDB: + """Create a FalkorDB connection using the standard environment variables.""" + return FalkorDB( + host=os.getenv('FALKORDB_HOST', 'localhost'), + port=os.getenv('FALKORDB_PORT', 6379), + username=os.getenv('FALKORDB_USERNAME', None), + password=os.getenv('FALKORDB_PASSWORD', None), + ) + +def graph_exists(name: str): + db = _make_falkordb_connection() return name in db.list_graphs() def get_repos() -> list[str]: @@ -23,22 +29,22 @@ def get_repos() -> list[str]: List processed repositories """ - db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), - port=os.getenv('FALKORDB_PORT', 6379), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None)) - + db = _make_falkordb_connection() graphs = db.list_graphs() graphs = [g for g in graphs if not (g.endswith('_git') or g.endswith('_schema'))] return graphs -def delete_graph_if_exists(name: str) -> bool: - """Delete *name* when it already exists in FalkorDB.""" - db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), - port=os.getenv('FALKORDB_PORT', 6379), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None)) +def delete_graph_if_exists(name: str, db: Optional[FalkorDB] = None) -> bool: + """Delete *name* when it already exists in FalkorDB. + + Args: + name: The graph name to delete. + db: Optional existing FalkorDB connection to reuse. When omitted a + new connection is created from environment variables. + """ + if db is None: + db = _make_falkordb_connection() if name not in db.list_graphs(): return False @@ -480,8 +486,20 @@ def get_file(self, path: str, name: str, ext: str) -> Optional[File]: return file + # Allowlist of graph node labels that may be passed to get_entity_at_position. + # Only labels produced by the analyzers are permitted; any other value raises + # ValueError to prevent Cypher injection via f-string interpolation. + _VALID_ENTITY_LABELS: frozenset[str] = frozenset({ + "File", "Class", "Function", "Method", "Interface", + "Enum", "Struct", "Constructor", + }) + def get_entity_at_position(self, path: str, line: int, labels: Optional[list[str]] = None) -> Optional[Node]: """Return the smallest entity spanning *line* within *path*.""" + if labels: + invalid = set(labels) - self._VALID_ENTITY_LABELS + if invalid: + raise ValueError(f"Invalid graph labels: {invalid}") label_filter = ":" + ":".join(labels) if labels else "" q = f"""MATCH (e{label_filter}) WHERE e.path = $path diff --git a/api/index.py b/api/index.py index f6cbcc92..651836cf 100644 --- a/api/index.py +++ b/api/index.py @@ -5,6 +5,8 @@ import asyncio import contextlib import logging +import re +import subprocess from pathlib import Path from dotenv import load_dotenv @@ -23,10 +25,12 @@ repo_local_path, repo_update_lock, ) -from api.graph import Graph, AsyncGraphQuery, async_get_repos, delete_graph_if_exists +from api.graph import Graph, AsyncGraphQuery, async_get_repos, delete_graph_if_exists, _make_falkordb_connection from api.info import async_get_repo_info, get_repo_commit from api.llm import ask from api.project import Project +from pygit2.enums import CheckoutStrategy +from pygit2.repository import Repository as GitRepository # Load environment variables from .env file @@ -246,14 +250,13 @@ def _full_reindex_repository( ) with repo_update_lock(repo_name): - delete_graph_if_exists(repo_name) - delete_graph_if_exists(git_utils.GitRepoName(repo_name)) + db = _make_falkordb_connection() + delete_graph_if_exists(repo_name, db) + delete_graph_if_exists(git_utils.GitRepoName(repo_name), db) if repo_path.exists(): if target_sha: - from pygit2.enums import CheckoutStrategy - from pygit2.repository import Repository - repo = Repository(str(repo_path)) + repo = GitRepository(str(repo_path)) target_commit = repo.revparse_single(target_sha) repo.checkout_tree(target_commit.tree, strategy=CheckoutStrategy.FORCE) repo.set_head_detached(target_commit.id) @@ -352,16 +355,26 @@ def _poll_repo(repo_name: str) -> None: current_sha = get_repo_commit(repo_name) if current_sha: - # Handle comparison between short (7-char) and full (40-char) SHAs: a short - # stored SHA is a valid prefix of a full remote SHA for the same commit. - # We only apply prefix matching when the stored SHA is shorter. - if len(current_sha) < len(remote_head): - up_to_date = remote_head.startswith(current_sha) - elif len(current_sha) > len(remote_head): - up_to_date = current_sha.startswith(remote_head) - else: - up_to_date = current_sha == remote_head - if up_to_date: + # Validate SHA format before passing to git: must be 7-40 hex characters. + if not re.fullmatch(r"[0-9a-f]{7,40}", current_sha): + logger.warning( + "Poll: '%s' stored SHA '%s' is not valid hex; skipping", + repo_name, current_sha, + ) + return + + # Resolve the stored (potentially short) SHA to its full 40-char form so + # the comparison is unambiguous — short SHAs can be ambiguous in large repos. + try: + result = subprocess.run( + ["git", "rev-parse", current_sha], + cwd=str(path), capture_output=True, text=True, check=True, + ) + full_current = result.stdout.strip() + except subprocess.CalledProcessError: + full_current = None + + if full_current and full_current == remote_head: logger.debug("Poll: '%s' is up-to-date at %s", repo_name, current_sha) return else: @@ -632,8 +645,8 @@ async def webhook(request: Request): try: payload = await request.json() - except Exception: - raise HTTPException(status_code=400, detail="Invalid JSON payload") + except Exception as exc: + raise HTTPException(status_code=400, detail="Invalid JSON payload") from exc ref = payload.get("ref", "") before = payload.get("before", "") diff --git a/tests/test_webhook.py b/tests/test_webhook.py index b862e906..32c2d413 100644 --- a/tests/test_webhook.py +++ b/tests/test_webhook.py @@ -505,12 +505,20 @@ def test_poll_repo_skips_when_up_to_date(monkeypatch, tmp_path): monkeypatch.setattr(api.index, "repo_local_path", lambda name: repo_path) monkeypatch.setattr(api.index, "fetch_remote", lambda p: None) monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") - monkeypatch.setattr( - api.index, "get_remote_head", lambda p, b: "abcdef1234567890" * 2 + "abcdef12" - ) - # Stored bookmark is a short SHA prefix of the remote HEAD + full_sha = "abcdef12" * 5 # 40-char remote HEAD + monkeypatch.setattr(api.index, "get_remote_head", lambda p, b: full_sha) + # Stored bookmark is the short (7-char) form of the same commit monkeypatch.setattr(api.index, "get_repo_commit", lambda name: "abcdef1") + # Mock subprocess.run so git rev-parse resolves short SHA to the full SHA + rev_parse_calls = [] + def _fake_run(cmd, **kwargs): + rev_parse_calls.append(cmd) + class _Result: + stdout = full_sha + "\n" + return _Result() + monkeypatch.setattr(api.index.subprocess, "run", _fake_run) + sync_calls = [] monkeypatch.setattr( api.index, @@ -521,6 +529,7 @@ def test_poll_repo_skips_when_up_to_date(monkeypatch, tmp_path): api.index._poll_repo("myrepo") assert sync_calls == [], "sync should not be called when repo is up-to-date" + assert rev_parse_calls == [["git", "rev-parse", "abcdef1"]] def test_poll_repo_triggers_sync_when_behind(monkeypatch, tmp_path): @@ -534,6 +543,15 @@ def test_poll_repo_triggers_sync_when_behind(monkeypatch, tmp_path): monkeypatch.setattr(api.index, "get_remote_head", lambda p, b: remote_sha) monkeypatch.setattr(api.index, "get_repo_commit", lambda name: "aaaa111") + # Mock subprocess.run so git rev-parse returns a different full SHA + rev_parse_calls = [] + def _fake_run(cmd, **kwargs): + rev_parse_calls.append(cmd) + class _Result: + stdout = "aaaa1111" * 5 + "\n" # different from remote_sha + return _Result() + monkeypatch.setattr(api.index.subprocess, "run", _fake_run) + sync_calls = [] def _fake_sync(repo_name, path, target_sha, **kwargs): @@ -546,6 +564,7 @@ def _fake_sync(repo_name, path, target_sha, **kwargs): assert len(sync_calls) == 1 assert sync_calls[0] == ("myrepo", remote_sha) + assert rev_parse_calls == [["git", "rev-parse", "aaaa111"]] def test_poll_repo_handles_no_remote_head(monkeypatch, tmp_path): From 369df38309b0481327c92952983499a1f8b1085c Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 20 Mar 2026 22:35:25 +0200 Subject: [PATCH 8/8] fix(poll): use per-repo tracked branch instead of global TRACKED_BRANCH Store the branch name in Redis during initial analysis (detected from HEAD). The poll watcher and webhook handler now read the per-repo branch first, falling back to the TRACKED_BRANCH env var only when no branch is stored. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .env.template | 5 +++-- api/index.py | 17 +++++++++-------- api/info.py | 23 +++++++++++++++++++++++ api/project.py | 8 ++++++-- tests/test_webhook.py | 11 +++++++++++ 5 files changed, 52 insertions(+), 12 deletions(-) diff --git a/.env.template b/.env.template index d738a910..2dc60549 100644 --- a/.env.template +++ b/.env.template @@ -37,8 +37,9 @@ PORT=5000 # Authorization: Bearer on /api/webhook instead. WEBHOOK_SECRET= -# Name of the branch to track for automatic incremental updates. -# Only push events targeting this branch trigger a graph update. +# Default branch to track for automatic incremental updates. +# Each repo remembers its own branch (detected at initial analysis). +# This value is only used as a fallback when no per-repo branch is stored. TRACKED_BRANCH=main # Seconds between automatic poll-watcher checks (0 = disable poll-watcher). diff --git a/api/index.py b/api/index.py index 651836cf..981cf99e 100644 --- a/api/index.py +++ b/api/index.py @@ -26,7 +26,7 @@ repo_update_lock, ) from api.graph import Graph, AsyncGraphQuery, async_get_repos, delete_graph_if_exists, _make_falkordb_connection -from api.info import async_get_repo_info, get_repo_commit +from api.info import async_get_repo_info, get_repo_branch, get_repo_commit from api.llm import ask from api.project import Project from pygit2.enums import CheckoutStrategy @@ -349,7 +349,7 @@ def _poll_repo(repo_name: str) -> None: logger.warning("Poll: git fetch failed for '%s': %s", repo_name, exc) return - remote_head = get_remote_head(path, TRACKED_BRANCH) + remote_head = get_remote_head(path, get_repo_branch(repo_name) or TRACKED_BRANCH) if not remote_head: return @@ -653,12 +653,6 @@ async def webhook(request: Request): after = payload.get("after", "") repo_url = _extract_repo_url(payload) - # Only process pushes to the configured tracked branch - expected_ref = f"refs/heads/{TRACKED_BRANCH}" - if ref != expected_ref: - logger.debug("Webhook: ignoring push to '%s' (tracking '%s')", ref, expected_ref) - return {"status": "ignored", "reason": f"Branch not tracked: {ref}"} - if not before or not after or not repo_url: raise HTTPException( status_code=400, @@ -677,6 +671,13 @@ async def webhook(request: Request): status_code=404, ) + # Only process pushes to the repo's tracked branch + tracked = get_repo_branch(repo_name) or TRACKED_BRANCH + expected_ref = f"refs/heads/{tracked}" + if ref != expected_ref: + logger.debug("Webhook: ignoring push to '%s' (tracking '%s')", ref, expected_ref) + return {"status": "ignored", "reason": f"Branch not tracked: {ref}"} + logger.info( "Webhook: updating '%s' from %s to %s", repo_name, before[:8], after[:8] ) diff --git a/api/info.py b/api/info.py index b1d9ea7e..f2603276 100644 --- a/api/info.py +++ b/api/info.py @@ -67,6 +67,29 @@ def get_repo_commit(repo_name: str) -> str: raise +def set_repo_branch(repo_name: str, branch: str) -> None: + """Save the tracked branch name for *repo_name*.""" + try: + r = get_redis_connection() + key = _repo_info_key(repo_name) + r.hset(key, 'branch', branch) + logging.info(f"Repository '{repo_name}' tracked branch set to: {branch}") + except Exception as e: + logging.error(f"Error saving branch for '{repo_name}': {e}") + raise + + +def get_repo_branch(repo_name: str) -> Optional[str]: + """Return the tracked branch for *repo_name*, or ``None`` if not stored.""" + try: + r = get_redis_connection() + key = _repo_info_key(repo_name) + return r.hget(key, "branch") + except Exception as e: + logging.error(f"Error retrieving branch for '{repo_name}': {e}") + raise + + def save_repo_info(repo_name: str, repo_url: str) -> None: """ Saves repository information (URL) to Redis under a hash named {repo_name}_info. diff --git a/api/project.py b/api/project.py index aed5a9e7..9b77426b 100644 --- a/api/project.py +++ b/api/project.py @@ -83,12 +83,16 @@ def analyze_sources(self, ignore: Optional[List[str]] = None) -> Graph: self.analyzer.analyze_local_folder(self.path, self.graph, ignore) try: - # Save processed commit hash to the DB + # Save processed commit hash and branch to the DB repo = Repository(self.path) current_commit = repo.walk(repo.head.target).__next__() set_repo_commit(self.name, current_commit.short_id) + + if not repo.head_is_detached: + branch_name = repo.head.shorthand + set_repo_branch(self.name, branch_name) except Exception: - # Probably not .git folder is missing + # Probably .git folder is missing pass return self.graph diff --git a/tests/test_webhook.py b/tests/test_webhook.py index 32c2d413..e314e05c 100644 --- a/tests/test_webhook.py +++ b/tests/test_webhook.py @@ -117,6 +117,17 @@ def client_misconfigured(monkeypatch): def test_webhook_ignored_wrong_branch(client_token_auth, monkeypatch): """Pushes to non-tracked branches return 200 with status='ignored'.""" monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + + async def _fake_get_repos(): + return ["myrepo"] + + async def _fake_get_repo_info(repo_name): + return {"repo_url": "https://github.com/example/myrepo.git"} + + monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) + monkeypatch.setattr(api.index, "async_get_repo_info", _fake_get_repo_info) + monkeypatch.setattr(api.index, "get_repo_branch", lambda name: None) + payload = _make_push_payload(ref="refs/heads/feature/x") resp = client_token_auth.post( "/api/webhook",