From 7657c7d6c188437c2df81d94fc950877f0d86305 Mon Sep 17 00:00:00 2001 From: Daniele Date: Thu, 4 Jun 2026 09:46:51 +0200 Subject: [PATCH] feat: add incremental AST cache to eliminate redundant AstEncoder work Three-level (L1 in-memory mtime / L2 disk content-hash / L3 chunk-aware) incremental AST cache that skips the pure-Python json.dumps(AstEncoder) bottleneck across runs and on partial file changes. - src/pyspector/ast_cache.py: cache implementation (JSON+base64 persistence, no pickle/code-exec on load) - src/pyspector/_ast_encode.py: shared AST->JSON encoder (single source of truth, eliminates encoder drift between cli.py and the cache) - src/pyspector/cli.py: wire the cache into get_python_file_asts - tests/unit/ast_cache_test.py: unit tests Co-Authored-By: Claude Opus 4.8 --- .gitignore | 3 + src/pyspector/_ast_encode.py | 55 +++ src/pyspector/ast_cache.py | 473 +++++++++++++++++++++++++ src/pyspector/cli.py | 65 +--- tests/unit/ast_cache_test.py | 644 +++++++++++++++++++++++++++++++++++ 5 files changed, 1192 insertions(+), 48 deletions(-) create mode 100644 src/pyspector/_ast_encode.py create mode 100644 src/pyspector/ast_cache.py create mode 100644 tests/unit/ast_cache_test.py diff --git a/.gitignore b/.gitignore index 4ed39ab..b0ec2be 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,6 @@ venv.bak/ target/ .vscode/target/ Cargo.lock + +# PySpector incremental AST cache (generated, never version-controlled) +.pyspector_cache/ diff --git a/src/pyspector/_ast_encode.py b/src/pyspector/_ast_encode.py new file mode 100644 index 0000000..54a4df5 --- /dev/null +++ b/src/pyspector/_ast_encode.py @@ -0,0 +1,55 @@ +""" +Shared AST → JSON encoder for PySpector. + +Single source of truth for the JSON schema consumed by the Rust core. +Imported by both ast_cache.py and cli.py to eliminate encoder drift. +""" +from __future__ import annotations + +import ast +import json +from typing import Any, Dict + + +class AstEncoder(json.JSONEncoder): + """Serialize ast.AST nodes to the JSON schema expected by the Rust core.""" + + def default(self, node: Any) -> Any: + if isinstance(node, ast.AST): + out: Dict[str, Any] = { + "node_type": node.__class__.__name__, + "lineno": getattr(node, "lineno", -1), + "col_offset": getattr(node, "col_offset", -1), + } + child_nodes: Dict[str, Any] = {} + simple_fields: Dict[str, Any] = {} + for fname, value in ast.iter_fields(node): + if type(value) is list: + if value and all(isinstance(n, ast.AST) for n in value): + child_nodes[fname] = value + else: + simple_fields[fname] = str(value) if value else [] + elif isinstance(value, ast.AST): + child_nodes[fname] = [value] + else: + if isinstance(value, bytes): + simple_fields[fname] = value.decode("utf-8", errors="replace") + elif isinstance(value, int) and value.bit_length() > 14000: + simple_fields[fname] = 0 + elif isinstance(value, (int, float, str, bool)) or value is None: + simple_fields[fname] = value + else: + simple_fields[fname] = str(value) + out["children"] = child_nodes + out["fields"] = simple_fields + return out + if isinstance(node, bytes): + return node.decode("utf-8", errors="replace") + if hasattr(node, "__dict__"): + return str(node) + return super().default(node) + + +def encode_node(node: ast.AST) -> str: + """Serialize a single AST node to JSON.""" + return json.dumps(node, cls=AstEncoder) diff --git a/src/pyspector/ast_cache.py b/src/pyspector/ast_cache.py new file mode 100644 index 0000000..fc5a34c --- /dev/null +++ b/src/pyspector/ast_cache.py @@ -0,0 +1,473 @@ +""" +Incremental AST cache for PySpector. + +Three-level hierarchy +--------------------- +L1 in-memory mtime guard — zero work on hit within a process run +L2 disk content-hash guard — no parse/encode across runs +L3 chunk-aware per-function/class subtree reuse when a file partially changes + +Bottleneck eliminated: json.dumps(ast_tree, cls=AstEncoder) is pure-Python +O(N nodes). ast.parse() is C and negligible by comparison. + +Persistence format +------------------ +Entries are stored as JSON with zlib-compressed fields base64-encoded. +pickle is deliberately NOT used: it executes arbitrary code on load, making +it unsafe when cache files reside in a repository directory controlled by +an untrusted third party. +""" +from __future__ import annotations + +import ast +import base64 +import dataclasses +import hashlib +import json +import warnings +import zlib +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from ._ast_encode import AstEncoder, encode_node # noqa: F401 (re-exported for tests) + +# v1 used pickle (security risk); v2 uses JSON + base64 +CACHE_VERSION = 2 +_ZLIB_LEVEL = 3 # favour speed over ratio for ephemeral cache data +MAX_L1_ENTRIES: int = 512 + + +# ── Data structures ────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class AstChunk: + """Serialised AST for one top-level syntactic block.""" + chunk_id: str # "FunctionDef:my_func", "ClassDef:MyClass", "stmt:42" + start_line: int # 1-based, matches ast.lineno + end_line: int + content_hash: str # sha256 of this chunk's source text + ast_json_z: bytes # zlib-compressed JSON of the AstNode subtree + + +@dataclass(frozen=True) +class FileCacheEntry: + file_path: str + file_hash: str # sha256 of full file content + mtime: float + full_ast_json_z: bytes # zlib-compressed full AST JSON string + chunks: Dict[str, AstChunk] + version: int = CACHE_VERSION + + +# ── Chunking helpers ───────────────────────────────────────────────────────── + +_NAMED_TYPES = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) + + +def _make_chunk_id(node: ast.stmt, seen: Dict[str, int]) -> str: + """Produce a stable chunk ID for a top-level AST statement.""" + if isinstance(node, _NAMED_TYPES): + base = f"{node.__class__.__name__}:{node.name}" + idx = seen.get(base, 0) + seen[base] = idx + 1 + return base if idx == 0 else f"{base}:{idx}" + return f"stmt:{node.lineno}" + + +def _source_slice(lines: List[str], node: ast.stmt) -> str: + start = node.lineno - 1 # ast.lineno is 1-based + end = getattr(node, "end_lineno", node.lineno) + return "".join(lines[start:end]) + + +# ── Module JSON assembly ───────────────────────────────────────────────────── + + +def _assemble_module_json( + body_parts: List[str], + type_ignore_parts: List[str], +) -> str: + """ + Build the Module JSON wrapper around pre-serialized body/type_ignore fragments. + + Pre-conditions (caller must ensure): + Every string in body_parts / type_ignore_parts is valid JSON produced by + encode_node(). Values are embedded verbatim — not re-serialized or escaped. + + Mirrors AstEncoder's field/children split: + - non-empty AST-node list → placed under "children" + - empty list → placed under "fields" as [] + """ + ch_items: List[str] = [] + fi_items: List[str] = [] + + if body_parts: + ch_items.append('"body": [' + ",".join(body_parts) + "]") + else: + fi_items.append('"body": []') + + if type_ignore_parts: + ch_items.append('"type_ignores": [' + ",".join(type_ignore_parts) + "]") + else: + fi_items.append('"type_ignores": []') + + ch_json = "{" + ", ".join(ch_items) + "}" + fi_json = "{" + ", ".join(fi_items) + "}" + return ( + '{"node_type": "Module", "lineno": -1, "col_offset": -1, ' + '"children": ' + ch_json + ', ' + '"fields": ' + fi_json + "}" + ) + + +# ── Incremental JSON construction ──────────────────────────────────────────── + + +def _build_ast_json_and_chunks( + tree: ast.Module, + source: str, + old_chunks: Dict[str, AstChunk], +) -> Tuple[str, Dict[str, AstChunk]]: + """ + Serialise *tree* to AST JSON, reusing encoded subtrees from *old_chunks* + for any chunk whose content hash AND start_line are both unchanged. + + Skips encode_node() for every unchanged top-level function/class — + typically 80-100 % of body nodes when only a few lines change. + + Returns (full_ast_json, new_chunks_dict). + """ + lines = source.splitlines(keepends=True) + seen: Dict[str, int] = {} + new_chunks: Dict[str, AstChunk] = {} + body_parts: List[str] = [] + + for node in tree.body: + cid = _make_chunk_id(node, seen) + src = _source_slice(lines, node) + end = getattr(node, "end_lineno", node.lineno) + new_hash = hashlib.sha256(src.encode()).hexdigest() + + old = old_chunks.get(cid) + reuse = ( + old is not None + and old.content_hash == new_hash + and old.start_line == node.lineno + ) + + if reuse: + assert old is not None # type narrowing + node_json = zlib.decompress(old.ast_json_z).decode() + chunk_z = old.ast_json_z + else: + node_json = encode_node(node) + chunk_z = zlib.compress(node_json.encode(), _ZLIB_LEVEL) + + new_chunks[cid] = AstChunk( + chunk_id=cid, + start_line=node.lineno, + end_line=end, + content_hash=new_hash, + ast_json_z=chunk_z, + ) + body_parts.append(node_json) + + type_ignore_parts = [encode_node(ti) for ti in tree.type_ignores] + full_json = _assemble_module_json(body_parts, type_ignore_parts) + return full_json, new_chunks + + +# ── Disk serialization — JSON + base64, no executable deserialization ───────── + + +def _serialize_entry(entry: FileCacheEntry) -> str: + """Serialize a FileCacheEntry to a JSON string. No code-execution paths.""" + return json.dumps({ + "version": entry.version, + "file_path": entry.file_path, + "file_hash": entry.file_hash, + "mtime": entry.mtime, + "full_ast_json_z": base64.b64encode(entry.full_ast_json_z).decode(), + "chunks": { + k: { + "chunk_id": c.chunk_id, + "start_line": c.start_line, + "end_line": c.end_line, + "content_hash": c.content_hash, + "ast_json_z": base64.b64encode(c.ast_json_z).decode(), + } + for k, c in entry.chunks.items() + }, + }) + + +def _deserialize_entry(raw: str) -> FileCacheEntry: + """Deserialize a FileCacheEntry from JSON. Raises on malformed data.""" + d = json.loads(raw) + return FileCacheEntry( + file_path=d["file_path"], + file_hash=d["file_hash"], + mtime=float(d["mtime"]), + full_ast_json_z=base64.b64decode(d["full_ast_json_z"]), + chunks={ + k: AstChunk( + chunk_id=v["chunk_id"], + start_line=int(v["start_line"]), + end_line=int(v["end_line"]), + content_hash=v["content_hash"], + ast_json_z=base64.b64decode(v["ast_json_z"]), + ) + for k, v in d["chunks"].items() + }, + version=int(d["version"]), + ) + + +# ── Cache ───────────────────────────────────────────────────────────────────── + + +class IncrementalAstCache: + """ + Three-level incremental AST cache. + + Parameters + ---------- + cache_dir : Path, optional + Directory for the persistent (L2) disk cache. When *None* only the + in-memory (L1) cache is active. If the directory cannot be created, + a warning is issued and the cache operates in L1-only mode. + max_l1_entries : int + Maximum entries kept in the in-memory LRU cache. Oldest entries are + evicted when the limit is exceeded. Default: 512. + + Usage + ----- + :: + + cache = IncrementalAstCache(cache_dir=Path(".pyspector_cache/ast")) + ast_json = cache.get_ast_json(Path("src/foo.py"), content) + """ + + def __init__( + self, + cache_dir: Optional[Path] = None, + max_l1_entries: int = MAX_L1_ENTRIES, + ) -> None: + self._l1: OrderedDict[str, FileCacheEntry] = OrderedDict() + self._max_l1 = max_l1_entries + self._cache_dir: Optional[Path] = None + if cache_dir: + try: + cache_dir.mkdir(parents=True, exist_ok=True) + self._cache_dir = cache_dir + except OSError as e: + warnings.warn( + f"PySpector: cannot create cache directory {cache_dir!r}: {e}. " + "Disk cache disabled for this run.", + stacklevel=2, + ) + + # ── Public API ─────────────────────────────────────────────────────────── + + def get_ast_json(self, file_path: Path, content: str) -> str: + """ + Return the AST JSON string for *file_path*. + + Raises + ------ + SyntaxError + If the file cannot be parsed, so callers can emit user-facing + warnings while keeping cache logic out of the CLI layer. + """ + return zlib.decompress(self._get_entry(file_path, content).full_ast_json_z).decode() + + def invalidate(self, file_path: Path) -> None: + """Remove all cached data for a single file.""" + key = str(file_path.resolve()) + self._l1.pop(key, None) + p = self._disk_path(file_path) + if p and p.exists(): + p.unlink(missing_ok=True) + + def get_changed_chunks( + self, file_path: Path, old_content: str, new_content: str + ) -> List[str]: + """ + Return the IDs of top-level chunks that differ between two versions + of a file, without updating the cache. Useful for incremental + analysis drivers that want to know exactly what changed. + """ + def _chunk_hashes(source: str) -> Dict[str, str]: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=SyntaxWarning) + try: + tree = ast.parse(source, filename=str(file_path)) + except SyntaxError: + return {} + lines = source.splitlines(keepends=True) + seen: Dict[str, int] = {} + out: Dict[str, str] = {} + for node in tree.body: + cid = _make_chunk_id(node, seen) + out[cid] = hashlib.sha256(_source_slice(lines, node).encode()).hexdigest() + return out + + old_h = _chunk_hashes(old_content) + new_h = _chunk_hashes(new_content) + changed = [cid for cid, h in new_h.items() if old_h.get(cid) != h] + changed += [cid for cid in old_h if cid not in new_h] + return changed + + # ── Internal ───────────────────────────────────────────────────────────── + + def _l1_get(self, key: str) -> Optional[FileCacheEntry]: + entry = self._l1.get(key) + if entry is not None: + self._l1.move_to_end(key) + return entry + + def _l1_put(self, key: str, entry: FileCacheEntry) -> None: + self._l1[key] = entry + self._l1.move_to_end(key) + while len(self._l1) > self._max_l1: + self._l1.popitem(last=False) # evict least-recently-used + + def _get_entry(self, file_path: Path, content: str) -> FileCacheEntry: + # Resolve once: L1 key and L2 hash must both use the canonical path. + file_path = file_path.resolve() + key = str(file_path) + + try: + mtime = file_path.stat().st_mtime + except OSError: + mtime = 0.0 + + # L1 – mtime guard (cheapest check: dict lookup + float compare) + l1 = self._l1_get(key) + if l1 and l1.mtime == mtime and l1.version == CACHE_VERSION: + return l1 + + file_hash = hashlib.sha256(content.encode()).hexdigest() + + # L1 – hash guard (file touched externally but content unchanged) + if l1 and l1.file_hash == file_hash and l1.version == CACHE_VERSION: + updated = dataclasses.replace(l1, mtime=mtime) + self._l1_put(key, updated) + return updated + + # L2 – disk (survive across process restarts) + l2 = self._disk_load(file_path, file_hash) + if l2: + updated_l2 = dataclasses.replace(l2, mtime=mtime) + self._l1_put(key, updated_l2) + return updated_l2 + + # L3 – build with chunk-level subtree reuse + old_chunks: Dict[str, AstChunk] = ( + l1.chunks if (l1 and l1.version == CACHE_VERSION) else {} + ) + entry = self._build(file_path, content, file_hash, mtime, old_chunks) + self._l1_put(key, entry) + self._disk_save(entry) + return entry + + def _build( + self, + file_path: Path, + content: str, + file_hash: str, + mtime: float, + old_chunks: Dict[str, AstChunk], + ) -> FileCacheEntry: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=SyntaxWarning) + tree = ast.parse(content, filename=str(file_path)) # SyntaxError propagates + + full_json, chunks = _build_ast_json_and_chunks(tree, content, old_chunks) + return FileCacheEntry( + file_path=str(file_path), + file_hash=file_hash, + mtime=mtime, + full_ast_json_z=zlib.compress(full_json.encode(), _ZLIB_LEVEL), + chunks=chunks, + ) + + # ── Disk I/O ───────────────────────────────────────────────────────────── + + def _disk_path(self, file_path: Path) -> Optional[Path]: + if not self._cache_dir: + return None + key = hashlib.sha256(str(file_path.resolve()).encode()).hexdigest() + return self._cache_dir / f"{key}.json" + + def _disk_load(self, file_path: Path, file_hash: str) -> Optional[FileCacheEntry]: + p = self._disk_path(file_path) + if not p or not p.exists(): + return None + try: + entry = _deserialize_entry(p.read_text(encoding="utf-8")) + if entry.version == CACHE_VERSION and entry.file_hash == file_hash: + return entry + except Exception: + try: + p.unlink(missing_ok=True) + except OSError: + pass + return None + + def _disk_save(self, entry: FileCacheEntry) -> None: + p = self._disk_path(Path(entry.file_path)) + if not p: + return + tmp = p.with_suffix(".tmp") + try: + tmp.write_text(_serialize_entry(entry), encoding="utf-8") + tmp.replace(p) # atomic on POSIX; best-effort on Windows + except OSError as e: + warnings.warn( + f"PySpector: cache write failed for {entry.file_path!r}: {e}", + stacklevel=2, + ) + except Exception as e: + warnings.warn( + f"PySpector: unexpected cache error for {entry.file_path!r}: {e}", + stacklevel=2, + ) + finally: + # Remove temp file if replace() did not atomically rename it. + try: + tmp.unlink(missing_ok=True) + except OSError: + pass + + +# ── Process-level singleton ─────────────────────────────────────────────────── + +_instance: Optional[IncrementalAstCache] = None + + +def get_cache(scan_path: Optional[Path] = None) -> IncrementalAstCache: + """ + Return the process-level cache instance. + + The disk cache is rooted at */.pyspector_cache/ast* when + *scan_path* is supplied on the first call. Subsequent calls return the + same instance regardless of *scan_path*. + """ + global _instance + if _instance is None: + cache_dir: Optional[Path] = None + if scan_path: + base = scan_path if scan_path.is_dir() else scan_path.parent + cache_dir = base / ".pyspector_cache" / "ast" + _instance = IncrementalAstCache(cache_dir=cache_dir) + return _instance + + +def _reset_cache_singleton() -> None: + """Reset the process-level singleton. Use only in tests.""" + global _instance + _instance = None diff --git a/src/pyspector/cli.py b/src/pyspector/cli.py index b22f387..a755a10 100644 --- a/src/pyspector/cli.py +++ b/src/pyspector/cli.py @@ -13,6 +13,8 @@ from pathlib import Path from typing import Optional, Dict, Any, List, cast +from .ast_cache import IncrementalAstCache, get_cache +from ._ast_encode import AstEncoder from .config import load_config, get_default_rules from .reporting import Reporter from .triage import run_triage_tui @@ -48,11 +50,6 @@ def get_startup_note(): pass return random.choice(fallbacks) -_list = list -_tuple = tuple -_ast_AST = ast.AST - - def _dbg(debug: bool, msg: str = "", **style_kwargs) -> None: """Emit *msg* via click.echo only when --debug is enabled. @@ -129,47 +126,6 @@ def _print_banner() -> None: click.echo(click.style(f"{note}\n", fg="bright_black", italic=True)) -_ast_iter_fields = ast.iter_fields - -# --- Helper function for AST serialization --- -class AstEncoder(json.JSONEncoder): - def default(self, node): - if isinstance(node, _ast_AST): - fields = { - "node_type": node.__class__.__name__, - "lineno": getattr(node, 'lineno', -1), - "col_offset": getattr(node, 'col_offset', -1), - } - child_nodes = {} - simple_fields = {} - for field, value in _ast_iter_fields(node): - if type(value).__name__ == 'list': - if value and all(isinstance(n, _ast_AST) for n in value): - child_nodes[field] = value - else: - simple_fields[field] = str(value) if value else [] - elif isinstance(value, _ast_AST): - child_nodes[field] = [value] - else: - if isinstance(value, bytes): - simple_fields[field] = value.decode('utf-8', errors='replace') - elif isinstance(value, int) and value.bit_length() > 14000: - simple_fields[field] = 0 - elif isinstance(value, (int, float, str, bool)) or value is None: - simple_fields[field] = value - else: - simple_fields[field] = str(value) - - fields["children"] = child_nodes - fields["fields"] = simple_fields - return fields - elif isinstance(node, bytes): - return node.decode('utf-8', errors='replace') - elif hasattr(node, '__dict__'): - return str(node) - return super().default(node) - - def should_skip_file(file_path: Path) -> bool: """Determine if a file should be skipped during AST parsing.""" path_str = str(file_path) @@ -220,6 +176,7 @@ def get_python_file_asts( _stats_meta: Optional[Dict[str, int]] = None, debug: bool = False, exclude: Optional[List[str]] = None, + cache: Optional[IncrementalAstCache] = None, ) -> List[Dict[str, Any]]: """ Recursively finds Python files and returns their content and AST. @@ -232,6 +189,11 @@ def get_python_file_asts( ``{'skipped': N, 'errors': N}`` for use by StatsCollector. Defaults to None (no tracking). Backward-compatible: callers that do not pass this argument are unaffected. + cache: Optional incremental AST cache. When supplied (and syntax + warnings are not being promoted to errors), the cached AST JSON + is reused instead of re-running ast.parse + json.dumps. The cache + suppresses SyntaxWarning internally, so it is bypassed whenever + ``enable_syntax_warnings`` is True to preserve that diagnostic. """ if _stats_meta is not None: _stats_meta['skipped'] = 0 @@ -272,8 +234,11 @@ def get_python_file_asts( try: content = py_file.read_text(encoding="utf-8") - parsed_ast = ast.parse(content, filename=str(py_file)) - ast_json = json.dumps(parsed_ast, cls=AstEncoder) + if cache is not None and not enable_syntax_warnings: + ast_json = cache.get_ast_json(py_file, content) + else: + parsed_ast = ast.parse(content, filename=str(py_file)) + ast_json = json.dumps(parsed_ast, cls=AstEncoder) results.append( { "file_path": str(py_file.resolve()), @@ -722,6 +687,9 @@ def _execute_scan( _dbg(debug, f"[*] Starting PySpector scan on '{scan_path}'...") + # ── AST Cache ───────────────────────────────────────────────────────── + cache = get_cache(scan_path) + # ── Load Baseline ───────────────────────────────────────────────────── baseline_path = ( scan_path / ".pyspector_baseline.json" @@ -758,6 +726,7 @@ def _execute_scan( _stats_meta=ast_stats_meta, debug=debug, exclude=list(config.get("exclude", [])), + cache=cache, ) _dbg(debug, f"[*] Successfully parsed {len(python_files_data)} Python files in {time.time()-t_parse:.2f}s") diff --git a/tests/unit/ast_cache_test.py b/tests/unit/ast_cache_test.py new file mode 100644 index 0000000..224bf53 --- /dev/null +++ b/tests/unit/ast_cache_test.py @@ -0,0 +1,644 @@ +import ast +import hashlib +import json +import os +import tempfile +import unittest +import warnings +import zlib +from pathlib import Path +from unittest.mock import patch + +from pyspector._ast_encode import AstEncoder, encode_node +from pyspector.ast_cache import ( + CACHE_VERSION, + AstChunk, + FileCacheEntry, + IncrementalAstCache, + MAX_L1_ENTRIES, + _assemble_module_json, + _build_ast_json_and_chunks, + _deserialize_entry, + _make_chunk_id, + _reset_cache_singleton, + _serialize_entry, + get_cache, +) + + +def _parse_json(ast_json: str) -> dict: + return json.loads(ast_json) + + +def _make_cache(tmp: Path, max_l1: int = MAX_L1_ENTRIES) -> IncrementalAstCache: + return IncrementalAstCache(cache_dir=tmp / "cache", max_l1_entries=max_l1) + + +# ── TestChunkIds ────────────────────────────────────────────────────────────── + + +class TestChunkIds(unittest.TestCase): + def _ids(self, source: str) -> list: + tree = ast.parse(source) + seen: dict = {} + return [_make_chunk_id(n, seen) for n in tree.body] + + def test_function(self): + self.assertEqual(self._ids("def foo(): pass"), ["FunctionDef:foo"]) + + def test_async_function(self): + self.assertEqual(self._ids("async def bar(): pass"), ["AsyncFunctionDef:bar"]) + + def test_class(self): + self.assertEqual(self._ids("class MyClass: pass"), ["ClassDef:MyClass"]) + + def test_bare_statement(self): + self.assertEqual(self._ids("x = 1"), ["stmt:1"]) + + def test_duplicate_names_get_suffix(self): + ids = self._ids("def foo(): pass\ndef foo(): pass") + self.assertEqual(ids, ["FunctionDef:foo", "FunctionDef:foo:1"]) + + def test_mixed(self): + ids = self._ids("x = 1\ndef foo(): pass\nclass Bar: pass") + self.assertEqual(ids, ["stmt:1", "FunctionDef:foo", "ClassDef:Bar"]) + + +# ── TestBuildAstJson ───────────────────────────────────────────────────────── + + +class TestBuildAstJson(unittest.TestCase): + def _build(self, source: str, old: dict | None = None) -> tuple: + tree = ast.parse(source) + return _build_ast_json_and_chunks(tree, source, old or {}) + + def test_empty_module(self): + json_str, chunks = self._build("") + parsed = _parse_json(json_str) + self.assertEqual(parsed["node_type"], "Module") + self.assertEqual(parsed["fields"]["body"], []) + self.assertEqual(chunks, {}) + + def test_single_function_structure(self): + src = "def foo(x):\n return x + 1\n" + json_str, chunks = self._build(src) + parsed = _parse_json(json_str) + body = parsed["children"]["body"] + self.assertEqual(len(body), 1) + self.assertEqual(body[0]["node_type"], "FunctionDef") + self.assertIn("FunctionDef:foo", chunks) + + def test_json_matches_direct_encoder(self): + src = "x = 1\ndef foo(): pass\nclass Bar: pass\n" + tree = ast.parse(src) + direct = json.dumps(tree, cls=AstEncoder) + incremental, _ = self._build(src) + self.assertEqual(_parse_json(direct), _parse_json(incremental)) + + def test_chunk_reuse_skips_encoding(self): + src = "def foo(): pass\ndef bar(): pass\n" + _, old_chunks = self._build(src) + + new_src = "def foo(): pass\ndef bar(): return 42\n" + new_tree = ast.parse(new_src) + _, new_chunks = _build_ast_json_and_chunks(new_tree, new_src, old_chunks) + + # Unchanged chunk: identical compressed bytes reused + self.assertEqual(old_chunks["FunctionDef:foo"].ast_json_z, new_chunks["FunctionDef:foo"].ast_json_z) + # Changed chunk: different bytes + self.assertNotEqual( + old_chunks["FunctionDef:bar"].ast_json_z, + new_chunks["FunctionDef:bar"].ast_json_z, + ) + + def test_moved_chunk_not_reused(self): + src = "def foo(): pass\ndef bar(): pass\n" + _, old_chunks = self._build(src) + + # Insert a line at top → foo shifts to line 2 + new_src = "x = 1\ndef foo(): pass\ndef bar(): pass\n" + new_tree = ast.parse(new_src) + _, new_chunks = _build_ast_json_and_chunks(new_tree, new_src, old_chunks) + + # foo moved from line 1 → 2: must NOT reuse + self.assertNotEqual( + old_chunks["FunctionDef:foo"].ast_json_z, + new_chunks["FunctionDef:foo"].ast_json_z, + ) + + +# ── TestAssembleModuleJson ──────────────────────────────────────────────────── + + +class TestAssembleModuleJson(unittest.TestCase): + def test_non_empty_body_goes_to_children(self): + body = ['{"node_type": "Assign"}'] + result = _parse_json(_assemble_module_json(body, [])) + self.assertIn("body", result["children"]) + self.assertNotIn("body", result["fields"]) + + def test_empty_body_goes_to_fields(self): + result = _parse_json(_assemble_module_json([], [])) + self.assertIn("body", result["fields"]) + self.assertEqual(result["fields"]["body"], []) + + def test_non_empty_type_ignores_goes_to_children(self): + ti = ['{"node_type": "TypeIgnore"}'] + result = _parse_json(_assemble_module_json([], ti)) + self.assertIn("type_ignores", result["children"]) + + def test_empty_type_ignores_goes_to_fields(self): + result = _parse_json(_assemble_module_json([], [])) + self.assertIn("type_ignores", result["fields"]) + self.assertEqual(result["fields"]["type_ignores"], []) + + def test_module_metadata(self): + result = _parse_json(_assemble_module_json([], [])) + self.assertEqual(result["node_type"], "Module") + self.assertEqual(result["lineno"], -1) + self.assertEqual(result["col_offset"], -1) + + def test_output_is_valid_json(self): + body = ['{"node_type": "Expr", "lineno": 1}'] + json.loads(_assemble_module_json(body, [])) # must not raise + + +# ── TestSerializeDeserialize ────────────────────────────────────────────────── + + +class TestSerializeDeserialize(unittest.TestCase): + def _make_entry(self) -> FileCacheEntry: + src = "def foo(): pass\n" + tree = ast.parse(src) + full_json, chunks = _build_ast_json_and_chunks(tree, src, {}) + return FileCacheEntry( + file_path="/tmp/test_file.py", + file_hash=hashlib.sha256(src.encode()).hexdigest(), + mtime=1234567890.0, + full_ast_json_z=zlib.compress(full_json.encode()), + chunks=chunks, + ) + + def test_roundtrip(self): + entry = self._make_entry() + restored = _deserialize_entry(_serialize_entry(entry)) + self.assertEqual(restored.file_path, entry.file_path) + self.assertEqual(restored.file_hash, entry.file_hash) + self.assertEqual(restored.mtime, entry.mtime) + self.assertEqual(restored.full_ast_json_z, entry.full_ast_json_z) + self.assertEqual(restored.version, entry.version) + self.assertEqual(set(restored.chunks.keys()), set(entry.chunks.keys())) + + def test_serialized_is_json_not_pickle(self): + entry = self._make_entry() + raw = _serialize_entry(entry) + # Must be valid JSON + d = json.loads(raw) + self.assertIn("version", d) + self.assertIn("file_hash", d) + # Must NOT be a pickle stream (pickle starts with 0x80 or b'\x80') + self.assertFalse(raw.encode()[0:1] == b'\x80') + # Must start with '{' (JSON object) + self.assertEqual(raw[0], '{') + + def test_deserialize_raises_on_garbage(self): + with self.assertRaises(Exception): + _deserialize_entry("not json at all }{") + + +# ── TestIncrementalAstCache ─────────────────────────────────────────────────── + + +class TestIncrementalAstCache(unittest.TestCase): + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + self.tmp = Path(self._tmpdir.name) + + def tearDown(self): + self._tmpdir.cleanup() + + def _write(self, name: str, content: str) -> Path: + p = self.tmp / name + p.write_text(content, encoding="utf-8") + return p + + def _l1_key(self, p: Path) -> str: + """Return the L1 dict key for a path (always the resolved form).""" + return str(p.resolve()) + + # ── L1 mtime hit ────────────────────────────────────────────────────── + + def test_l1_mtime_hit_skips_hash(self): + cache = _make_cache(self.tmp) + src = "def foo(): pass\n" + p = self._write("a.py", src) + + cache.get_ast_json(p, src) # populate L1 + + with patch("pyspector.ast_cache.hashlib") as mock_hash: + cache.get_ast_json(p, src) # same mtime → must not hash + mock_hash.sha256.assert_not_called() + + # ── L1 hash hit ─────────────────────────────────────────────────────── + + def test_l1_hash_hit_updates_mtime(self): + cache = _make_cache(self.tmp) + src = "x = 1\n" + p = self._write("b.py", src) + + cache.get_ast_json(p, src) + entry_before = cache._l1[self._l1_key(p)] + old_mtime = entry_before.mtime + + # Touch the file (change mtime without changing content) + os.utime(p, (old_mtime + 1, old_mtime + 1)) + cache.get_ast_json(p, src) + + entry_after = cache._l1[self._l1_key(p)] + self.assertNotEqual(entry_after.mtime, old_mtime) + # Same bytes object (shallow copy via dataclasses.replace — no rebuild) + self.assertIs(entry_before.full_ast_json_z, entry_after.full_ast_json_z) + + # ── L2 disk hit ─────────────────────────────────────────────────────── + + def test_l2_disk_survives_l1_eviction(self): + cache = _make_cache(self.tmp) + src = "def saved(): pass\n" + p = self._write("c.py", src) + + cache.get_ast_json(p, src) # write to disk + cache._l1.clear() # evict L1 + + ast_json = cache.get_ast_json(p, src) # must load from disk + self.assertEqual(_parse_json(ast_json)["node_type"], "Module") + + def test_l2_stale_on_content_change(self): + cache = _make_cache(self.tmp) + src_v1 = "def foo(): pass\n" + src_v2 = "def foo(): return 1\n" + p = self._write("d.py", src_v1) + + cache.get_ast_json(p, src_v1) + cache._l1.clear() + + p.write_text(src_v2, encoding="utf-8") + ast_json = cache.get_ast_json(p, src_v2) + func = _parse_json(ast_json)["children"]["body"][0] + self.assertEqual(func["node_type"], "FunctionDef") + + # ── Cache version invalidation ──────────────────────────────────────── + + def test_stale_version_triggers_rebuild(self): + cache = _make_cache(self.tmp) + src = "x = 42\n" + p = self._write("e.py", src) + cache.get_ast_json(p, src) + + disk_p = cache._disk_path(p) + assert disk_p is not None and disk_p.exists() + + # Tamper with version in the JSON cache file + data = json.loads(disk_p.read_text(encoding="utf-8")) + data["version"] = 0 + disk_p.write_text(json.dumps(data), encoding="utf-8") + + cache._l1.clear() + ast_json = cache.get_ast_json(p, src) + self.assertIn("Module", ast_json) + + # ── SyntaxError propagation ─────────────────────────────────────────── + + def test_syntax_error_propagates(self): + cache = _make_cache(self.tmp) + p = self._write("bad.py", "def (: pass\n") + with self.assertRaises(SyntaxError): + cache.get_ast_json(p, "def (: pass\n") + + def test_syntax_error_not_cached(self): + cache = _make_cache(self.tmp) + p = self._write("bad2.py", "def (: pass\n") + try: + cache.get_ast_json(p, "def (: pass\n") + except SyntaxError: + pass + self.assertNotIn(self._l1_key(p), cache._l1) + + # ── invalidate() ────────────────────────────────────────────────────── + + def test_invalidate_clears_l1_and_disk(self): + cache = _make_cache(self.tmp) + src = "y = 7\n" + p = self._write("f.py", src) + cache.get_ast_json(p, src) + + disk_p = cache._disk_path(p) + assert disk_p is not None and disk_p.exists() + + cache.invalidate(p) + self.assertNotIn(self._l1_key(p), cache._l1) + self.assertFalse(disk_p.exists()) + + # ── get_changed_chunks() ────────────────────────────────────────────── + + def test_get_changed_chunks_detects_modification(self): + cache = _make_cache(self.tmp) + p = self.tmp / "g.py" + old = "def foo(): pass\ndef bar(): pass\n" + new = "def foo(): return 1\ndef bar(): pass\n" + changed = cache.get_changed_chunks(p, old, new) + self.assertIn("FunctionDef:foo", changed) + self.assertNotIn("FunctionDef:bar", changed) + + def test_get_changed_chunks_detects_addition(self): + cache = _make_cache(self.tmp) + p = self.tmp / "h.py" + old = "def foo(): pass\n" + new = "def foo(): pass\ndef baz(): pass\n" + changed = cache.get_changed_chunks(p, old, new) + self.assertIn("FunctionDef:baz", changed) + + def test_get_changed_chunks_detects_deletion(self): + cache = _make_cache(self.tmp) + p = self.tmp / "i.py" + old = "def foo(): pass\ndef bar(): pass\n" + new = "def foo(): pass\n" + changed = cache.get_changed_chunks(p, old, new) + self.assertIn("FunctionDef:bar", changed) + + # ── No-disk-cache mode ──────────────────────────────────────────────── + + def test_works_without_cache_dir(self): + cache = IncrementalAstCache(cache_dir=None) + src = "z = 99\n" + p = self._write("j.py", src) + self.assertIn("Module", cache.get_ast_json(p, src)) + + # ── Output format ───────────────────────────────────────────────────── + + def test_output_is_valid_json(self): + cache = _make_cache(self.tmp) + src = "import os\n\ndef greet(name: str) -> str:\n return f'hello {name}'\n" + p = self._write("k.py", src) + parsed = _parse_json(cache.get_ast_json(p, src)) + self.assertEqual(parsed["node_type"], "Module") + + def test_output_matches_direct_encode(self): + """Cache output must be semantically identical to direct AstEncoder output.""" + cache = _make_cache(self.tmp) + src = "x = 1\n\nclass Foo:\n def method(self): pass\n" + p = self._write("l.py", src) + + cached = _parse_json(cache.get_ast_json(p, src)) + direct = _parse_json(json.dumps(ast.parse(src), cls=AstEncoder)) + self.assertEqual(cached, direct) + + # ── Security: no pickle in disk cache ───────────────────────────────── + + def test_disk_cache_uses_json_not_pickle(self): + """Disk cache must store JSON, not pickle (no arbitrary code execution).""" + cache = _make_cache(self.tmp) + src = "def secure(): pass\n" + p = self._write("sec.py", src) + cache.get_ast_json(p, src) + + disk_p = cache._disk_path(p) + assert disk_p is not None and disk_p.exists() + + raw = disk_p.read_bytes() + # JSON object starts with '{' + self.assertEqual(raw[0:1], b"{") + # Must be parseable as JSON + data = json.loads(raw.decode("utf-8")) + self.assertIn("version", data) + self.assertIn("file_hash", data) + self.assertIn("chunks", data) + # Must NOT be a pickle stream (pickle magic bytes 0x80) + self.assertNotEqual(raw[0:1], b"\x80") + + def test_disk_cache_file_extension_is_json(self): + cache = _make_cache(self.tmp) + p = self._write("ext.py", "x = 1\n") + disk_p = cache._disk_path(p) + assert disk_p is not None + self.assertEqual(disk_p.suffix, ".json") + + def test_corrupted_cache_recovers_gracefully(self): + """A corrupted JSON cache file must be discarded and rebuilt without error.""" + cache = _make_cache(self.tmp) + src = "x = 1\n" + p = self._write("corrupt.py", src) + cache.get_ast_json(p, src) + + disk_p = cache._disk_path(p) + assert disk_p is not None + disk_p.write_text("}{not valid json", encoding="utf-8") + + cache._l1.clear() + ast_json = cache.get_ast_json(p, src) + self.assertIn("Module", ast_json) + # File must be rebuilt after recovery + self.assertTrue(disk_p.exists()) + self.assertEqual(json.loads(disk_p.read_text(encoding="utf-8"))["version"], CACHE_VERSION) + + # ── Path canonicalization ───────────────────────────────────────────── + + def test_resolved_path_used_as_l1_key(self): + """The L1 key must always be the resolved (canonical) path.""" + cache = _make_cache(self.tmp) + src = "def foo(): pass\n" + p = self._write("canon.py", src) + + cache.get_ast_json(p, src) + + # Key in L1 must be the resolved form + self.assertIn(str(p.resolve()), cache._l1) + + def test_same_file_via_resolve_hits_same_entry(self): + """Calling get_ast_json with an already-resolved path must hit L1.""" + cache = _make_cache(self.tmp) + src = "x = 1\n" + p = self._write("res.py", src) + + cache.get_ast_json(p, src) + initial_len = len(cache._l1) + + # Call again with the resolved path — must not create a second entry + cache.get_ast_json(p.resolve(), src) + self.assertEqual(len(cache._l1), initial_len) + + # ── L1 LRU eviction ─────────────────────────────────────────────────── + + def test_l1_lru_eviction(self): + """L1 must evict LRU entries when max_l1_entries is exceeded.""" + cache = _make_cache(self.tmp, max_l1=2) + + files = [] + for i in range(3): + src = f"def f{i}(): pass\n" + p = self._write(f"lru_{i}.py", src) + files.append(p) + cache.get_ast_json(p, src) + + self.assertEqual(len(cache._l1), 2) + # Most recently used entries should be present + self.assertIn(str(files[2].resolve()), cache._l1) + self.assertIn(str(files[1].resolve()), cache._l1) + # Oldest entry should have been evicted + self.assertNotIn(str(files[0].resolve()), cache._l1) + + def test_l1_lru_access_updates_recency(self): + """Accessing an entry should protect it from eviction.""" + cache = _make_cache(self.tmp, max_l1=2) + + files = [] + for i in range(2): + src = f"def f{i}(): pass\n" + p = self._write(f"lru_rec_{i}.py", src) + files.append(p) + cache.get_ast_json(p, src) + + # Access the first file to make it the most-recently-used + cache.get_ast_json(files[0], f"def f0(): pass\n") + + # Add a third file, which should evict files[1] (LRU), not files[0] + src2 = "def f2(): pass\n" + p2 = self._write("lru_rec_2.py", src2) + cache.get_ast_json(p2, src2) + + self.assertIn(str(files[0].resolve()), cache._l1) + self.assertNotIn(str(files[1].resolve()), cache._l1) + + # ── mkdir failure → graceful degradation ────────────────────────────── + + def test_mkdir_failure_degrades_to_no_disk(self): + """If the cache directory cannot be created, the cache runs L1-only.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + with patch("pathlib.Path.mkdir", side_effect=OSError("Permission denied")): + cache = IncrementalAstCache(cache_dir=Path("/fake/no/permission")) + + self.assertIsNone(cache._cache_dir) + self.assertTrue(any("cache directory" in str(w.message) for w in caught)) + + # L1-only mode must still work + src = "x = 1\n" + p = self._write("fallback.py", src) + self.assertIn("Module", cache.get_ast_json(p, src)) + + # ── Disk write failure is non-blocking ──────────────────────────────── + + def test_disk_write_failure_does_not_crash(self): + """A disk write failure must issue a warning but not abort the scan.""" + cache = _make_cache(self.tmp) + src = "def resilient(): pass\n" + p = self._write("write_fail.py", src) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + with patch("pathlib.Path.write_text", side_effect=OSError("disk full")): + ast_json = cache.get_ast_json(p, src) + + self.assertIn("Module", ast_json) + # A warning must have been issued + self.assertTrue(any("cache" in str(w.message).lower() for w in caught)) + + # ── Encoder parity ──────────────────────────────────────────────────── + + def test_cache_output_matches_ast_encoder(self): + """The cache path and the direct AstEncoder path must be identical.""" + cache = _make_cache(self.tmp) + src = ( + "import os\n\n" + "CONST = 42\n\n" + "class Processor:\n" + " def run(self, data: list) -> dict:\n" + " return {str(i): v for i, v in enumerate(data)}\n" + ) + p = self._write("parity.py", src) + + cached = _parse_json(cache.get_ast_json(p, src)) + direct = _parse_json(json.dumps(ast.parse(src), cls=AstEncoder)) + self.assertEqual(cached, direct) + + def test_encode_node_matches_ast_encoder_for_single_node(self): + """encode_node() must produce the same output as json.dumps(..., cls=AstEncoder).""" + src = "def foo(x: int) -> str: return str(x)\n" + tree = ast.parse(src) + node = tree.body[0] + via_encode_node = json.loads(encode_node(node)) + via_encoder = json.loads(json.dumps(node, cls=AstEncoder)) + self.assertEqual(via_encode_node, via_encoder) + + # ── Large file smoke test ───────────────────────────────────────────── + + def test_large_file_smoke(self): + """Cache must handle files with many top-level functions without error.""" + src = "\n".join(f"def func_{i}(x): return x + {i}" for i in range(200)) + p = self._write("large.py", src) + cache = _make_cache(self.tmp) + + ast_json = cache.get_ast_json(p, src) + parsed = _parse_json(ast_json) + self.assertEqual(parsed["node_type"], "Module") + self.assertEqual(len(parsed["children"]["body"]), 200) + + # ── Singleton ───────────────────────────────────────────────────────── + + def test_singleton_same_instance(self): + _reset_cache_singleton() + c1 = get_cache() + c2 = get_cache() + self.assertIs(c1, c2) + _reset_cache_singleton() + + def test_singleton_reset_yields_new_instance(self): + _reset_cache_singleton() + c1 = get_cache() + _reset_cache_singleton() + c2 = get_cache() + self.assertIsNot(c1, c2) + _reset_cache_singleton() + + # ── Frozen dataclass safety ─────────────────────────────────────────── + + def test_file_cache_entry_is_immutable(self): + """FileCacheEntry must be frozen so callers cannot mutate shared state.""" + import dataclasses + self.assertTrue(dataclasses.fields(FileCacheEntry)) + entry = FileCacheEntry( + file_path="/x", + file_hash="abc", + mtime=1.0, + full_ast_json_z=b"", + chunks={}, + ) + with self.assertRaises((dataclasses.FrozenInstanceError, AttributeError)): + entry.mtime = 2.0 # type: ignore[misc] + + def test_ast_chunk_is_immutable(self): + import dataclasses + chunk = AstChunk( + chunk_id="FunctionDef:foo", + start_line=1, + end_line=3, + content_hash="abc", + ast_json_z=b"", + ) + with self.assertRaises((dataclasses.FrozenInstanceError, AttributeError)): + chunk.start_line = 99 # type: ignore[misc] + + # ── CACHE_VERSION stored in disk file ───────────────────────────────── + + def test_cache_version_stored_in_disk_file(self): + cache = _make_cache(self.tmp) + p = self._write("ver.py", "x = 1\n") + cache.get_ast_json(p, "x = 1\n") + + disk_p = cache._disk_path(p) + assert disk_p is not None + data = json.loads(disk_p.read_text(encoding="utf-8")) + self.assertEqual(data["version"], CACHE_VERSION) + + +if __name__ == "__main__": + unittest.main()