From 67c60bb223e684bf24b8902c090b657f38c1119b Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Wed, 1 Apr 2026 00:45:35 -0700 Subject: [PATCH 01/24] feat: add coverage snapshot server for code coverage POC When TUSK_COVERAGE_PORT env var is set, the SDK starts a tiny HTTP server that manages coverage.py. On each /snapshot request: - Stop coverage, get data, erase (reset), restart - Returns per-file line counts with clean per-test data - No diffing needed (coverage.py supports stop/erase/start cycle) Works with Flask, FastAPI, Django, gunicorn, uvicorn - any framework, because the SDK runs inside the app process. Requires: pip install coverage (or pip install tusk-drift[coverage]) --- drift/core/coverage_server.py | 135 ++++++++++++++++++++++++++++++++++ drift/core/drift_sdk.py | 4 + 2 files changed, 139 insertions(+) create mode 100644 drift/core/coverage_server.py diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py new file mode 100644 index 0000000..3d95daf --- /dev/null +++ b/drift/core/coverage_server.py @@ -0,0 +1,135 @@ +"""Coverage snapshot HTTP server for Python SDK. + +When TUSK_COVERAGE_PORT is set, starts a tiny HTTP server that manages +coverage.py. On each /snapshot request: +1. Stop coverage collection +2. Get coverage data (which lines were executed since last snapshot) +3. Erase coverage data (reset for next test) +4. Restart coverage collection +5. Return per-file line counts as JSON + +This gives clean per-test coverage data - no diffing needed. +""" + +from __future__ import annotations + +import json +import logging +import os +import threading +from http.server import HTTPServer, BaseHTTPRequestHandler + +logger = logging.getLogger("TuskDrift") + + +class CoverageSnapshotHandler(BaseHTTPRequestHandler): + """HTTP handler for coverage snapshot requests.""" + + # Shared state set by start_coverage_server + cov_instance = None + source_root = None + + def do_GET(self): + if self.path == "/snapshot": + self._handle_snapshot() + else: + self.send_response(404) + self.end_headers() + + def _handle_snapshot(self): + try: + cov = self.__class__.cov_instance + source_root = self.__class__.source_root + + if cov is None: + self.send_response(500) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"ok": False, "error": "coverage not initialized"}).encode()) + return + + # Stop coverage, get data, erase (reset), restart + cov.stop() + data = cov.get_data() + + # Extract per-file line counts + coverage = {} + for filename in data.measured_files(): + # Filter to user source files + if "site-packages" in filename or "lib/python" in filename: + continue + if source_root and not filename.startswith(source_root): + continue + + lines = data.lines(filename) + if lines: + # Convert to { "lineNumber": 1 } format (1 = covered) + coverage[filename] = {str(line): 1 for line in lines} + + # Erase data and restart for next test + cov.erase() + cov.start() + + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"ok": True, "coverage": coverage}).encode()) + + except Exception as e: + self.send_response(500) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"ok": False, "error": str(e)}).encode()) + + def log_message(self, format, *args): + """Suppress default HTTP server logging.""" + pass + + +def start_coverage_server(port: int | None = None) -> bool: + """Start the coverage snapshot server if TUSK_COVERAGE_PORT is set. + + Returns True if the server was started, False otherwise. + """ + port_str = os.environ.get("TUSK_COVERAGE_PORT") + if not port_str and port is None: + return False + + actual_port = port or int(port_str) + + # Try to import coverage + try: + import coverage as coverage_module + except ImportError: + logger.warning( + "TUSK_COVERAGE_PORT is set but 'coverage' package is not installed. " + "Install it with: pip install coverage" + ) + return False + + source_root = os.getcwd() + + # Start coverage collection + cov = coverage_module.Coverage( + source=[source_root], + omit=[ + "*/site-packages/*", + "*/venv/*", + "*/.venv/*", + "*/test*", + "*/__pycache__/*", + ], + ) + cov.start() + + # Set shared state on the handler class + CoverageSnapshotHandler.cov_instance = cov + CoverageSnapshotHandler.source_root = source_root + + # Start HTTP server in a daemon thread + http_server = HTTPServer(("127.0.0.1", actual_port), CoverageSnapshotHandler) + thread = threading.Thread(target=http_server.serve_forever, daemon=True) + thread.start() + + logger.info(f"Coverage snapshot server listening on port {actual_port}") + return True diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 67927f4..ffacf0b 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -160,6 +160,10 @@ def initialize( configure_logger(log_level=log_level, prefix="TuskDrift") + # Start coverage server early (before any SDK mode checks that might return early) + from .coverage_server import start_coverage_server + start_coverage_server() + instance._init_params = { "api_key": api_key, "env": env, From 3261a400396f4f7bdca1d5eb2e968c5099294b80 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Wed, 1 Apr 2026 01:16:34 -0700 Subject: [PATCH 02/24] feat: add ?baseline=true parameter using coverage.py analysis2 When /snapshot?baseline=true is called, uses coverage.analysis2() to get ALL coverable statements (including uncovered) for the denominator. Regular /snapshot calls only return executed lines (for per-test data). --- drift/core/coverage_server.py | 57 ++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 3d95daf..4e3351b 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -30,13 +30,17 @@ class CoverageSnapshotHandler(BaseHTTPRequestHandler): source_root = None def do_GET(self): - if self.path == "/snapshot": - self._handle_snapshot() + from urllib.parse import urlparse, parse_qs + parsed = urlparse(self.path) + if parsed.path == "/snapshot": + params = parse_qs(parsed.query) + is_baseline = params.get("baseline", ["false"])[0] == "true" + self._handle_snapshot(is_baseline) else: self.send_response(404) self.end_headers() - def _handle_snapshot(self): + def _handle_snapshot(self, is_baseline: bool = False): try: cov = self.__class__.cov_instance source_root = self.__class__.source_root @@ -48,23 +52,42 @@ def _handle_snapshot(self): self.wfile.write(json.dumps({"ok": False, "error": "coverage not initialized"}).encode()) return - # Stop coverage, get data, erase (reset), restart + # Stop coverage to read data cov.stop() - data = cov.get_data() - # Extract per-file line counts coverage = {} - for filename in data.measured_files(): - # Filter to user source files - if "site-packages" in filename or "lib/python" in filename: - continue - if source_root and not filename.startswith(source_root): - continue - - lines = data.lines(filename) - if lines: - # Convert to { "lineNumber": 1 } format (1 = covered) - coverage[filename] = {str(line): 1 for line in lines} + + if is_baseline: + # Baseline: return ALL coverable lines (including uncovered at count=0) + # This provides the denominator for coverage percentage. + # analysis2() returns (filename, statements, excluded, missing, formatted) + data = cov.get_data() + for filename in data.measured_files(): + if "site-packages" in filename or "lib/python" in filename: + continue + if source_root and not filename.startswith(source_root): + continue + try: + _, statements, _, missing, _ = cov.analysis2(filename) + missing_set = set(missing) + lines_map = {} + for line in statements: + lines_map[str(line)] = 0 if line in missing_set else 1 + if lines_map: + coverage[filename] = lines_map + except Exception: + continue + else: + # Regular snapshot: only executed lines since last reset + data = cov.get_data() + for filename in data.measured_files(): + if "site-packages" in filename or "lib/python" in filename: + continue + if source_root and not filename.startswith(source_root): + continue + lines = data.lines(filename) + if lines: + coverage[filename] = {str(line): 1 for line in lines} # Erase data and restart for next test cov.erase() From 2a546647b7c295d46c82435914d8c0b5a246c07d Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Wed, 1 Apr 2026 01:22:39 -0700 Subject: [PATCH 03/24] chore: add thread safety lock and clean shutdown for coverage server - Threading lock protects stop/get_data/erase/start sequence - stop_coverage_server() for clean shutdown, integrated into SDK shutdown() - Module-level server reference for proper cleanup --- drift/core/coverage_server.py | 116 +++++++++++++++++++--------------- drift/core/drift_sdk.py | 4 ++ 2 files changed, 70 insertions(+), 50 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 4e3351b..df4119d 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -28,6 +28,7 @@ class CoverageSnapshotHandler(BaseHTTPRequestHandler): # Shared state set by start_coverage_server cov_instance = None source_root = None + _lock = threading.Lock() def do_GET(self): from urllib.parse import urlparse, parse_qs @@ -42,56 +43,57 @@ def do_GET(self): def _handle_snapshot(self, is_baseline: bool = False): try: - cov = self.__class__.cov_instance - source_root = self.__class__.source_root - - if cov is None: - self.send_response(500) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"ok": False, "error": "coverage not initialized"}).encode()) - return - - # Stop coverage to read data - cov.stop() - - coverage = {} - - if is_baseline: - # Baseline: return ALL coverable lines (including uncovered at count=0) - # This provides the denominator for coverage percentage. - # analysis2() returns (filename, statements, excluded, missing, formatted) - data = cov.get_data() - for filename in data.measured_files(): - if "site-packages" in filename or "lib/python" in filename: - continue - if source_root and not filename.startswith(source_root): - continue - try: - _, statements, _, missing, _ = cov.analysis2(filename) - missing_set = set(missing) - lines_map = {} - for line in statements: - lines_map[str(line)] = 0 if line in missing_set else 1 - if lines_map: - coverage[filename] = lines_map - except Exception: - continue - else: - # Regular snapshot: only executed lines since last reset - data = cov.get_data() - for filename in data.measured_files(): - if "site-packages" in filename or "lib/python" in filename: - continue - if source_root and not filename.startswith(source_root): - continue - lines = data.lines(filename) - if lines: - coverage[filename] = {str(line): 1 for line in lines} - - # Erase data and restart for next test - cov.erase() - cov.start() + with self.__class__._lock: + cov = self.__class__.cov_instance + source_root = self.__class__.source_root + + if cov is None: + self.send_response(500) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"ok": False, "error": "coverage not initialized"}).encode()) + return + + # Stop coverage to read data + cov.stop() + + coverage = {} + + if is_baseline: + # Baseline: return ALL coverable lines (including uncovered at count=0) + # This provides the denominator for coverage percentage. + # analysis2() returns (filename, statements, excluded, missing, formatted) + data = cov.get_data() + for filename in data.measured_files(): + if "site-packages" in filename or "lib/python" in filename: + continue + if source_root and not filename.startswith(source_root): + continue + try: + _, statements, _, missing, _ = cov.analysis2(filename) + missing_set = set(missing) + lines_map = {} + for line in statements: + lines_map[str(line)] = 0 if line in missing_set else 1 + if lines_map: + coverage[filename] = lines_map + except Exception: + continue + else: + # Regular snapshot: only executed lines since last reset + data = cov.get_data() + for filename in data.measured_files(): + if "site-packages" in filename or "lib/python" in filename: + continue + if source_root and not filename.startswith(source_root): + continue + lines = data.lines(filename) + if lines: + coverage[filename] = {str(line): 1 for line in lines} + + # Erase data and restart for next test + cov.erase() + cov.start() self.send_response(200) self.send_header("Content-Type", "application/json") @@ -109,11 +111,16 @@ def log_message(self, format, *args): pass +_coverage_server: HTTPServer | None = None + + def start_coverage_server(port: int | None = None) -> bool: """Start the coverage snapshot server if TUSK_COVERAGE_PORT is set. Returns True if the server was started, False otherwise. """ + global _coverage_server + port_str = os.environ.get("TUSK_COVERAGE_PORT") if not port_str and port is None: return False @@ -151,8 +158,17 @@ def start_coverage_server(port: int | None = None) -> bool: # Start HTTP server in a daemon thread http_server = HTTPServer(("127.0.0.1", actual_port), CoverageSnapshotHandler) + _coverage_server = http_server thread = threading.Thread(target=http_server.serve_forever, daemon=True) thread.start() logger.info(f"Coverage snapshot server listening on port {actual_port}") return True + + +def stop_coverage_server() -> None: + """Shut down the coverage snapshot server if running.""" + global _coverage_server + if _coverage_server is not None: + _coverage_server.shutdown() + _coverage_server = None diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index ffacf0b..f8ad573 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -831,6 +831,8 @@ def shutdown(self) -> None: """Shutdown the SDK.""" import asyncio + from .coverage_server import stop_coverage_server + # Shutdown OpenTelemetry tracer provider if self._td_span_processor is not None: self._td_span_processor.shutdown() @@ -851,3 +853,5 @@ def shutdown(self) -> None: TraceBlockingManager.get_instance().shutdown() except Exception as e: logger.error(f"Error shutting down trace blocking manager: {e}") + + stop_coverage_server() From da89503f9788ddbd050ffa4eed8e2c8310debe12 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Wed, 1 Apr 2026 17:33:41 -0700 Subject: [PATCH 04/24] feat: add branch coverage tracking - Enable branch=True in coverage.py initialization - Extract branch data via cov._analyze(filename) API: - numbers.n_branches, n_missing_branches for totals - missing_branch_arcs() for per-line branch detail - Return branch data in /snapshot response alongside line coverage - Python shows accurate branch coverage (93.3% for demo app) --- drift/core/coverage_server.py | 71 +++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 4 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index df4119d..709c14d 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -61,8 +61,7 @@ def _handle_snapshot(self, is_baseline: bool = False): if is_baseline: # Baseline: return ALL coverable lines (including uncovered at count=0) - # This provides the denominator for coverage percentage. - # analysis2() returns (filename, statements, excluded, missing, formatted) + # plus branch coverage data. data = cov.get_data() for filename in data.measured_files(): if "site-packages" in filename or "lib/python" in filename: @@ -75,8 +74,15 @@ def _handle_snapshot(self, is_baseline: bool = False): lines_map = {} for line in statements: lines_map[str(line)] = 0 if line in missing_set else 1 + + # Branch data from coverage.py + branch_data = _get_branch_data(cov, data, filename) + if lines_map: - coverage[filename] = lines_map + coverage[filename] = { + "lines": lines_map, + **branch_data, + } except Exception: continue else: @@ -89,7 +95,11 @@ def _handle_snapshot(self, is_baseline: bool = False): continue lines = data.lines(filename) if lines: - coverage[filename] = {str(line): 1 for line in lines} + branch_data = _get_branch_data(cov, data, filename) + coverage[filename] = { + "lines": {str(line): 1 for line in lines}, + **branch_data, + } # Erase data and restart for next test cov.erase() @@ -111,6 +121,58 @@ def log_message(self, format, *args): pass +def _get_branch_data(cov, data, filename: str) -> dict: + """Extract branch coverage data for a file. + + Returns dict with totalBranches, coveredBranches, and per-line branch detail. + Uses coverage.py's analysis API which tracks branches as arcs (from_line, to_line). + """ + try: + if not data.has_arcs(): + return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + + # Use internal _analyze for full branch analysis + analysis = cov._analyze(filename) + numbers = analysis.numbers + + total_branches = numbers.n_branches + covered_branches = total_branches - numbers.n_missing_branches + + # Get per-line branch detail from missing_branch_arcs + missing_arcs = analysis.missing_branch_arcs() + executed_arcs = set(data.arcs(filename) or []) + + # Build per-line branch info + # Collect all branch source lines from both executed and missing + branch_lines: dict[int, dict] = {} # from_line -> {total, covered} + + # Count executed arcs by source line + for from_line, to_line in executed_arcs: + if from_line < 0: # negative = entry/exit arcs, skip + continue + if from_line not in branch_lines: + branch_lines[from_line] = {"total": 0, "covered": 0} + branch_lines[from_line]["total"] += 1 + branch_lines[from_line]["covered"] += 1 + + # Count missing arcs by source line + for from_line, to_lines in missing_arcs.items(): + if from_line not in branch_lines: + branch_lines[from_line] = {"total": 0, "covered": 0} + branch_lines[from_line]["total"] += len(to_lines) + + # Convert to string keys + branches = {str(line): info for line, info in branch_lines.items()} + + return { + "totalBranches": total_branches, + "coveredBranches": covered_branches, + "branches": branches, + } + except Exception: + return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + + _coverage_server: HTTPServer | None = None @@ -142,6 +204,7 @@ def start_coverage_server(port: int | None = None) -> bool: # Start coverage collection cov = coverage_module.Coverage( source=[source_root], + branch=True, # Enable branch coverage tracking omit=[ "*/site-packages/*", "*/venv/*", From 020e2afe04839e424f85bf99c55ff7ecc2db9cfb Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Wed, 1 Apr 2026 19:19:28 -0700 Subject: [PATCH 05/24] wip: migrate coverage to protobuf channel (Python handler timing out - needs debug) --- drift/core/communication/communicator.py | 58 +++++++++++++ drift/core/communication/types.py | 3 + drift/core/coverage_server.py | 101 +++++++++++++++++++---- drift/core/drift_sdk.py | 7 +- 4 files changed, 152 insertions(+), 17 deletions(-) diff --git a/drift/core/communication/communicator.py b/drift/core/communication/communicator.py index 6ae4361..e6afebc 100644 --- a/drift/core/communication/communicator.py +++ b/drift/core/communication/communicator.py @@ -16,8 +16,11 @@ from ..span_serialization import clean_span_to_proto from ..types import CleanSpanData, calling_library_context from .types import ( + BranchInfo, CliMessage, ConnectRequest, + CoverageSnapshotResponse, + FileCoverageData, GetMockRequest, InstrumentationVersionMismatchAlert, MessageType, @@ -750,6 +753,10 @@ def _background_read_loop(self) -> None: self._handle_set_time_travel_sync(cli_message) continue + if cli_message.type == MessageType.COVERAGE_SNAPSHOT: + self._handle_coverage_snapshot_sync(cli_message) + continue + # Route responses to waiting callers by request_id request_id = cli_message.request_id if request_id: @@ -809,6 +816,57 @@ def _handle_set_time_travel_sync(self, cli_message: CliMessage) -> None: except Exception as e: logger.error(f"Failed to send SetTimeTravel response: {e}") + def _handle_coverage_snapshot_sync(self, cli_message: CliMessage) -> None: + """Handle CoverageSnapshot request from CLI and send response.""" + request = cli_message.coverage_snapshot_request + if not request: + return + + logger.debug(f"Received CoverageSnapshot request: baseline={request.baseline}") + + try: + from ..coverage_server import take_coverage_snapshot + + result = take_coverage_snapshot(request.baseline) + + # Convert to protobuf + coverage: dict[str, FileCoverageData] = {} + for file_path, file_data in result.items(): + branches: dict[str, BranchInfo] = {} + for line, branch_info in file_data.get("branches", {}).items(): + branches[line] = BranchInfo( + total=branch_info.get("total", 0), + covered=branch_info.get("covered", 0), + ) + + coverage[file_path] = FileCoverageData( + lines=file_data.get("lines", {}), + total_branches=file_data.get("totalBranches", 0), + covered_branches=file_data.get("coveredBranches", 0), + branches=branches, + ) + + response = CoverageSnapshotResponse( + success=True, + error="", + coverage=coverage, + ) + except Exception as e: + logger.error(f"Failed to take coverage snapshot: {e}") + response = CoverageSnapshotResponse(success=False, error=str(e)) + + sdk_message = SdkMessage( + type=MessageType.COVERAGE_SNAPSHOT, + request_id=cli_message.request_id, + coverage_snapshot_response=response, + ) + + try: + self._send_message_sync(sdk_message) + logger.debug(f"Sent CoverageSnapshot response: success={response.success}") + except Exception as e: + logger.error(f"Failed to send CoverageSnapshot response: {e}") + def _send_message_sync(self, message: SdkMessage) -> None: """Send a message synchronously on the main socket.""" if not self._socket: diff --git a/drift/core/communication/types.py b/drift/core/communication/types.py index 0ecd7aa..298a203 100644 --- a/drift/core/communication/types.py +++ b/drift/core/communication/types.py @@ -43,7 +43,10 @@ from typing import Any from tusk.drift.core.v1 import ( + BranchInfo, CliMessage, + CoverageSnapshotResponse, + FileCoverageData, InstrumentationVersionMismatchAlert, MessageType, Runtime, diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 709c14d..1a1fd5b 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -173,38 +173,90 @@ def _get_branch_data(cov, data, filename: str) -> dict: return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} +def take_coverage_snapshot(baseline: bool = False) -> dict: + """Take a coverage snapshot (callable from both HTTP handler and protobuf handler). + + Returns dict of { filePath: { "lines": {...}, "totalBranches": N, ... } } + """ + cov = CoverageSnapshotHandler.cov_instance + source_root = CoverageSnapshotHandler.source_root + + if cov is None: + raise RuntimeError("Coverage not initialized") + + with CoverageSnapshotHandler._lock: + cov.stop() + coverage = {} + + if baseline: + data = cov.get_data() + for filename in data.measured_files(): + if "site-packages" in filename or "lib/python" in filename: + continue + if source_root and not filename.startswith(source_root): + continue + try: + _, statements, _, missing, _ = cov.analysis2(filename) + missing_set = set(missing) + lines_map = {} + for line in statements: + lines_map[str(line)] = 0 if line in missing_set else 1 + branch_data = _get_branch_data(cov, data, filename) + if lines_map: + coverage[filename] = {"lines": lines_map, **branch_data} + except Exception: + continue + else: + data = cov.get_data() + for filename in data.measured_files(): + if "site-packages" in filename or "lib/python" in filename: + continue + if source_root and not filename.startswith(source_root): + continue + lines = data.lines(filename) + if lines: + branch_data = _get_branch_data(cov, data, filename) + coverage[filename] = { + "lines": {str(line): 1 for line in lines}, + **branch_data, + } + + cov.erase() + cov.start() + + return coverage + + _coverage_server: HTTPServer | None = None -def start_coverage_server(port: int | None = None) -> bool: - """Start the coverage snapshot server if TUSK_COVERAGE_PORT is set. +def start_coverage_collection() -> bool: + """Initialize coverage.py collection if NODE_V8_COVERAGE is set. - Returns True if the server was started, False otherwise. - """ - global _coverage_server + Coverage data is accessed via take_coverage_snapshot() which can be called + from the protobuf handler or HTTP server. - port_str = os.environ.get("TUSK_COVERAGE_PORT") - if not port_str and port is None: + Returns True if coverage was started, False otherwise. + """ + # NODE_V8_COVERAGE is set by the CLI when coverage is enabled. + # Python doesn't use V8 but we use the same env var as the signal. + if not os.environ.get("NODE_V8_COVERAGE"): return False - actual_port = port or int(port_str) - - # Try to import coverage try: import coverage as coverage_module except ImportError: logger.warning( - "TUSK_COVERAGE_PORT is set but 'coverage' package is not installed. " + "Coverage requested but 'coverage' package is not installed. " "Install it with: pip install coverage" ) return False source_root = os.getcwd() - # Start coverage collection cov = coverage_module.Coverage( source=[source_root], - branch=True, # Enable branch coverage tracking + branch=True, omit=[ "*/site-packages/*", "*/venv/*", @@ -215,10 +267,31 @@ def start_coverage_server(port: int | None = None) -> bool: ) cov.start() - # Set shared state on the handler class CoverageSnapshotHandler.cov_instance = cov CoverageSnapshotHandler.source_root = source_root + logger.info("Coverage collection started") + return True + + +def start_coverage_server(port: int | None = None) -> bool: + """Start the coverage HTTP snapshot server (legacy, for non-protobuf mode). + + Returns True if the server was started, False otherwise. + """ + global _coverage_server + + port_str = os.environ.get("TUSK_COVERAGE_PORT") + if not port_str and port is None: + return False + + actual_port = port or int(port_str) + + # Ensure coverage is initialized + if CoverageSnapshotHandler.cov_instance is None: + if not start_coverage_collection(): + return False + # Start HTTP server in a daemon thread http_server = HTTPServer(("127.0.0.1", actual_port), CoverageSnapshotHandler) _coverage_server = http_server diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index f8ad573..16787b7 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -160,9 +160,10 @@ def initialize( configure_logger(log_level=log_level, prefix="TuskDrift") - # Start coverage server early (before any SDK mode checks that might return early) - from .coverage_server import start_coverage_server - start_coverage_server() + # Start coverage collection early (before any SDK mode checks that might return early). + # Coverage data is accessed via protobuf channel (communicator handles requests). + from .coverage_server import start_coverage_collection + start_coverage_collection() instance._init_params = { "api_key": api_key, From b145c08fdb4ae777797323ea2016ae4da3601e05 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Wed, 1 Apr 2026 19:30:30 -0700 Subject: [PATCH 06/24] fix: Python protobuf coverage handler - use 'is None' not truthiness betterproto treats messages with all default values as falsy. CoverageSnapshotRequest(baseline=False) was falsy, causing per-test snapshots to be skipped. Changed 'if not request' to 'if request is None'. Also separated coverage initialization from HTTP server so coverage.py starts via start_coverage_collection() for the protobuf channel. Extracted take_coverage_snapshot() as reusable function. --- drift/core/communication/communicator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/drift/core/communication/communicator.py b/drift/core/communication/communicator.py index e6afebc..b7247c7 100644 --- a/drift/core/communication/communicator.py +++ b/drift/core/communication/communicator.py @@ -819,7 +819,7 @@ def _handle_set_time_travel_sync(self, cli_message: CliMessage) -> None: def _handle_coverage_snapshot_sync(self, cli_message: CliMessage) -> None: """Handle CoverageSnapshot request from CLI and send response.""" request = cli_message.coverage_snapshot_request - if not request: + if request is None: return logger.debug(f"Received CoverageSnapshot request: baseline={request.baseline}") @@ -865,7 +865,7 @@ def _handle_coverage_snapshot_sync(self, cli_message: CliMessage) -> None: self._send_message_sync(sdk_message) logger.debug(f"Sent CoverageSnapshot response: success={response.success}") except Exception as e: - logger.error(f"Failed to send CoverageSnapshot response: {e}") + logger.error(f"[coverage] Failed to send response: {e}") def _send_message_sync(self, message: SdkMessage) -> None: """Send a message synchronously on the main socket.""" From e0e69459401132849806182a0ac6e90bc9f00520 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Wed, 1 Apr 2026 19:41:00 -0700 Subject: [PATCH 07/24] refactor: remove HTTP server, clean up coverage module - Remove HTTP server code (CoverageSnapshotHandler, start_coverage_server, _coverage_server global, HTTPServer import) - Replace with clean module-level state (_cov_instance, _source_root, _lock) - Extract _is_user_file() helper - stop_coverage_server() -> stop_coverage_collection() - Update module docstring to reflect protobuf-only architecture --- drift/core/coverage_server.py | 355 +++++++++++----------------------- drift/core/drift_sdk.py | 4 +- 2 files changed, 118 insertions(+), 241 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 1a1fd5b..8ce361b 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -1,310 +1,187 @@ -"""Coverage snapshot HTTP server for Python SDK. - -When TUSK_COVERAGE_PORT is set, starts a tiny HTTP server that manages -coverage.py. On each /snapshot request: -1. Stop coverage collection -2. Get coverage data (which lines were executed since last snapshot) -3. Erase coverage data (reset for next test) -4. Restart coverage collection -5. Return per-file line counts as JSON - -This gives clean per-test coverage data - no diffing needed. +"""Coverage collection for Python SDK. + +Manages coverage.py for collecting per-test code coverage data. +Coverage data is accessed via take_coverage_snapshot() which is called +from the protobuf communicator handler. + +Flow: +1. start_coverage_collection() initializes coverage.py at SDK startup +2. Between tests, CLI sends CoverageSnapshotRequest via protobuf +3. Communicator calls take_coverage_snapshot(baseline) +4. For baseline: returns ALL coverable lines (including uncovered at count=0) +5. For per-test: returns only executed lines since last snapshot, then resets """ from __future__ import annotations -import json import logging import os import threading -from http.server import HTTPServer, BaseHTTPRequestHandler logger = logging.getLogger("TuskDrift") +# Shared state for coverage collection +_cov_instance = None +_source_root: str | None = None +_lock = threading.Lock() -class CoverageSnapshotHandler(BaseHTTPRequestHandler): - """HTTP handler for coverage snapshot requests.""" - # Shared state set by start_coverage_server - cov_instance = None - source_root = None - _lock = threading.Lock() - - def do_GET(self): - from urllib.parse import urlparse, parse_qs - parsed = urlparse(self.path) - if parsed.path == "/snapshot": - params = parse_qs(parsed.query) - is_baseline = params.get("baseline", ["false"])[0] == "true" - self._handle_snapshot(is_baseline) - else: - self.send_response(404) - self.end_headers() +def start_coverage_collection() -> bool: + """Initialize coverage.py collection if NODE_V8_COVERAGE is set. - def _handle_snapshot(self, is_baseline: bool = False): - try: - with self.__class__._lock: - cov = self.__class__.cov_instance - source_root = self.__class__.source_root - - if cov is None: - self.send_response(500) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"ok": False, "error": "coverage not initialized"}).encode()) - return - - # Stop coverage to read data - cov.stop() - - coverage = {} - - if is_baseline: - # Baseline: return ALL coverable lines (including uncovered at count=0) - # plus branch coverage data. - data = cov.get_data() - for filename in data.measured_files(): - if "site-packages" in filename or "lib/python" in filename: - continue - if source_root and not filename.startswith(source_root): - continue - try: - _, statements, _, missing, _ = cov.analysis2(filename) - missing_set = set(missing) - lines_map = {} - for line in statements: - lines_map[str(line)] = 0 if line in missing_set else 1 - - # Branch data from coverage.py - branch_data = _get_branch_data(cov, data, filename) - - if lines_map: - coverage[filename] = { - "lines": lines_map, - **branch_data, - } - except Exception: - continue - else: - # Regular snapshot: only executed lines since last reset - data = cov.get_data() - for filename in data.measured_files(): - if "site-packages" in filename or "lib/python" in filename: - continue - if source_root and not filename.startswith(source_root): - continue - lines = data.lines(filename) - if lines: - branch_data = _get_branch_data(cov, data, filename) - coverage[filename] = { - "lines": {str(line): 1 for line in lines}, - **branch_data, - } - - # Erase data and restart for next test - cov.erase() - cov.start() - - self.send_response(200) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"ok": True, "coverage": coverage}).encode()) - - except Exception as e: - self.send_response(500) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps({"ok": False, "error": str(e)}).encode()) - - def log_message(self, format, *args): - """Suppress default HTTP server logging.""" - pass - - -def _get_branch_data(cov, data, filename: str) -> dict: - """Extract branch coverage data for a file. + NODE_V8_COVERAGE is set by the CLI when coverage is enabled. + Python doesn't use V8 but we use the same env var as the signal. - Returns dict with totalBranches, coveredBranches, and per-line branch detail. - Uses coverage.py's analysis API which tracks branches as arcs (from_line, to_line). + Returns True if coverage was started, False otherwise. """ - try: - if not data.has_arcs(): - return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + global _cov_instance, _source_root - # Use internal _analyze for full branch analysis - analysis = cov._analyze(filename) - numbers = analysis.numbers - - total_branches = numbers.n_branches - covered_branches = total_branches - numbers.n_missing_branches + if not os.environ.get("NODE_V8_COVERAGE"): + return False - # Get per-line branch detail from missing_branch_arcs - missing_arcs = analysis.missing_branch_arcs() - executed_arcs = set(data.arcs(filename) or []) + try: + import coverage as coverage_module + except ImportError: + logger.warning( + "Coverage requested but 'coverage' package is not installed. " + "Install it with: pip install coverage" + ) + return False - # Build per-line branch info - # Collect all branch source lines from both executed and missing - branch_lines: dict[int, dict] = {} # from_line -> {total, covered} + _source_root = os.getcwd() - # Count executed arcs by source line - for from_line, to_line in executed_arcs: - if from_line < 0: # negative = entry/exit arcs, skip - continue - if from_line not in branch_lines: - branch_lines[from_line] = {"total": 0, "covered": 0} - branch_lines[from_line]["total"] += 1 - branch_lines[from_line]["covered"] += 1 + _cov_instance = coverage_module.Coverage( + source=[_source_root], + branch=True, + omit=[ + "*/site-packages/*", + "*/venv/*", + "*/.venv/*", + "*/test*", + "*/__pycache__/*", + ], + ) + _cov_instance.start() - # Count missing arcs by source line - for from_line, to_lines in missing_arcs.items(): - if from_line not in branch_lines: - branch_lines[from_line] = {"total": 0, "covered": 0} - branch_lines[from_line]["total"] += len(to_lines) + logger.info("Coverage collection started") + return True - # Convert to string keys - branches = {str(line): info for line, info in branch_lines.items()} - return { - "totalBranches": total_branches, - "coveredBranches": covered_branches, - "branches": branches, - } - except Exception: - return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} +def stop_coverage_collection() -> None: + """Stop coverage collection and clean up.""" + global _cov_instance + if _cov_instance is not None: + try: + _cov_instance.stop() + except Exception: + pass + _cov_instance = None def take_coverage_snapshot(baseline: bool = False) -> dict: - """Take a coverage snapshot (callable from both HTTP handler and protobuf handler). + """Take a coverage snapshot. - Returns dict of { filePath: { "lines": {...}, "totalBranches": N, ... } } - """ - cov = CoverageSnapshotHandler.cov_instance - source_root = CoverageSnapshotHandler.source_root + Called from the protobuf communicator handler between tests. - if cov is None: + Args: + baseline: If True, returns ALL coverable lines (including uncovered at count=0) + for computing the total coverage denominator. + If False, returns only lines executed since the last snapshot. + + Returns: + dict of { filePath: { "lines": {...}, "totalBranches": N, "coveredBranches": N, "branches": {...} } } + """ + if _cov_instance is None: raise RuntimeError("Coverage not initialized") - with CoverageSnapshotHandler._lock: - cov.stop() + with _lock: + _cov_instance.stop() coverage = {} if baseline: - data = cov.get_data() + data = _cov_instance.get_data() for filename in data.measured_files(): - if "site-packages" in filename or "lib/python" in filename: - continue - if source_root and not filename.startswith(source_root): + if not _is_user_file(filename): continue try: - _, statements, _, missing, _ = cov.analysis2(filename) + _, statements, _, missing, _ = _cov_instance.analysis2(filename) missing_set = set(missing) lines_map = {} for line in statements: lines_map[str(line)] = 0 if line in missing_set else 1 - branch_data = _get_branch_data(cov, data, filename) + branch_data = _get_branch_data(data, filename) if lines_map: coverage[filename] = {"lines": lines_map, **branch_data} except Exception: continue else: - data = cov.get_data() + data = _cov_instance.get_data() for filename in data.measured_files(): - if "site-packages" in filename or "lib/python" in filename: - continue - if source_root and not filename.startswith(source_root): + if not _is_user_file(filename): continue lines = data.lines(filename) if lines: - branch_data = _get_branch_data(cov, data, filename) + branch_data = _get_branch_data(data, filename) coverage[filename] = { "lines": {str(line): 1 for line in lines}, **branch_data, } - cov.erase() - cov.start() + _cov_instance.erase() + _cov_instance.start() return coverage -_coverage_server: HTTPServer | None = None - - -def start_coverage_collection() -> bool: - """Initialize coverage.py collection if NODE_V8_COVERAGE is set. - - Coverage data is accessed via take_coverage_snapshot() which can be called - from the protobuf handler or HTTP server. - - Returns True if coverage was started, False otherwise. - """ - # NODE_V8_COVERAGE is set by the CLI when coverage is enabled. - # Python doesn't use V8 but we use the same env var as the signal. - if not os.environ.get("NODE_V8_COVERAGE"): +def _is_user_file(filename: str) -> bool: + """Check if a file is a user source file (not third-party).""" + if "site-packages" in filename or "lib/python" in filename: return False - - try: - import coverage as coverage_module - except ImportError: - logger.warning( - "Coverage requested but 'coverage' package is not installed. " - "Install it with: pip install coverage" - ) + if _source_root and not filename.startswith(_source_root): return False - - source_root = os.getcwd() - - cov = coverage_module.Coverage( - source=[source_root], - branch=True, - omit=[ - "*/site-packages/*", - "*/venv/*", - "*/.venv/*", - "*/test*", - "*/__pycache__/*", - ], - ) - cov.start() - - CoverageSnapshotHandler.cov_instance = cov - CoverageSnapshotHandler.source_root = source_root - - logger.info("Coverage collection started") return True -def start_coverage_server(port: int | None = None) -> bool: - """Start the coverage HTTP snapshot server (legacy, for non-protobuf mode). +def _get_branch_data(data, filename: str) -> dict: + """Extract branch coverage data for a file. - Returns True if the server was started, False otherwise. + Uses coverage.py's arc tracking (from_line, to_line) to compute + per-line branch coverage. """ - global _coverage_server + try: + if not data.has_arcs(): + return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} - port_str = os.environ.get("TUSK_COVERAGE_PORT") - if not port_str and port is None: - return False + analysis = _cov_instance._analyze(filename) + numbers = analysis.numbers - actual_port = port or int(port_str) + total_branches = numbers.n_branches + covered_branches = max(0, total_branches - numbers.n_missing_branches) - # Ensure coverage is initialized - if CoverageSnapshotHandler.cov_instance is None: - if not start_coverage_collection(): - return False + missing_arcs = analysis.missing_branch_arcs() + executed_arcs = set(data.arcs(filename) or []) - # Start HTTP server in a daemon thread - http_server = HTTPServer(("127.0.0.1", actual_port), CoverageSnapshotHandler) - _coverage_server = http_server - thread = threading.Thread(target=http_server.serve_forever, daemon=True) - thread.start() + branch_lines: dict[int, dict] = {} - logger.info(f"Coverage snapshot server listening on port {actual_port}") - return True + for from_line, to_line in executed_arcs: + if from_line < 0: + continue + if from_line not in branch_lines: + branch_lines[from_line] = {"total": 0, "covered": 0} + branch_lines[from_line]["total"] += 1 + branch_lines[from_line]["covered"] += 1 + + for from_line, to_lines in missing_arcs.items(): + if from_line not in branch_lines: + branch_lines[from_line] = {"total": 0, "covered": 0} + branch_lines[from_line]["total"] += len(to_lines) + branches = {str(line): info for line, info in branch_lines.items()} -def stop_coverage_server() -> None: - """Shut down the coverage snapshot server if running.""" - global _coverage_server - if _coverage_server is not None: - _coverage_server.shutdown() - _coverage_server = None + return { + "totalBranches": total_branches, + "coveredBranches": covered_branches, + "branches": branches, + } + except Exception: + return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 16787b7..a420547 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -832,7 +832,7 @@ def shutdown(self) -> None: """Shutdown the SDK.""" import asyncio - from .coverage_server import stop_coverage_server + from .coverage_server import stop_coverage_collection # Shutdown OpenTelemetry tracer provider if self._td_span_processor is not None: @@ -855,4 +855,4 @@ def shutdown(self) -> None: except Exception as e: logger.error(f"Error shutting down trace blocking manager: {e}") - stop_coverage_server() + stop_coverage_collection() From 5e4aa978df0401993a963c9bfd042be6aeb4b72c Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Wed, 1 Apr 2026 19:46:37 -0700 Subject: [PATCH 08/24] fix: prod readiness - thread-safe coverage shutdown Add _lock protection to stop_coverage_collection() to prevent race condition where shutdown sets _cov_instance=None while a snapshot is in progress on the background reader thread. --- drift/core/coverage_server.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 8ce361b..16f935b 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -68,14 +68,15 @@ def start_coverage_collection() -> bool: def stop_coverage_collection() -> None: - """Stop coverage collection and clean up.""" + """Stop coverage collection and clean up. Thread-safe.""" global _cov_instance - if _cov_instance is not None: - try: - _cov_instance.stop() - except Exception: - pass - _cov_instance = None + with _lock: + if _cov_instance is not None: + try: + _cov_instance.stop() + except Exception: + pass + _cov_instance = None def take_coverage_snapshot(baseline: bool = False) -> dict: From 18e8349edf5655a1cbc62d748a1c1622815901ed Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Wed, 1 Apr 2026 20:10:20 -0700 Subject: [PATCH 09/24] feat: use TUSK_COVERAGE instead of NODE_V8_COVERAGE for Python --- drift/core/coverage_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 16f935b..505c12a 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -36,7 +36,8 @@ def start_coverage_collection() -> bool: """ global _cov_instance, _source_root - if not os.environ.get("NODE_V8_COVERAGE"): + # TUSK_COVERAGE is the language-agnostic signal from the CLI + if not os.environ.get("TUSK_COVERAGE"): return False try: From 5d438df3b7e0293d436181b07fe7da31e0d07215 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Fri, 3 Apr 2026 14:36:05 -0700 Subject: [PATCH 10/24] docs: add code coverage documentation Add docs/coverage.md explaining coverage.py integration, branch coverage via arc tracking, thread safety, and limitations. Update environment-variables.md with coverage env vars section. --- docs/coverage.md | 101 ++++++++++++++++++++++++++++++++++ docs/environment-variables.md | 13 +++++ 2 files changed, 114 insertions(+) create mode 100644 docs/coverage.md diff --git a/docs/coverage.md b/docs/coverage.md new file mode 100644 index 0000000..5252657 --- /dev/null +++ b/docs/coverage.md @@ -0,0 +1,101 @@ +# Code Coverage (Python) + +The Python SDK collects per-test code coverage during Tusk Drift replay using `coverage.py`. Unlike Node.js (which uses V8's built-in coverage), Python requires the `coverage` package to be installed. + +## Requirements + +```bash +pip install coverage +# or +pip install tusk-drift[coverage] +``` + +If `coverage` is not installed and `--coverage` is used, the SDK logs a warning and coverage is skipped — tests still run normally. + +## How It Works + +### coverage.py Integration + +When coverage is enabled (via `--show-coverage`, `--coverage-output`, or `coverage.enabled: true` in config), the CLI sets `TUSK_COVERAGE=true`. The SDK detects this during initialization and starts coverage.py: + +```python +# What the SDK does internally: +import coverage +cov = coverage.Coverage( + source=[os.path.realpath(os.getcwd())], + branch=True, + omit=["*/site-packages/*", "*/venv/*", "*/.venv/*", "*/test*", "*/__pycache__/*"], +) +cov.start() +``` + +Key points: +- `branch=True` enables branch coverage (arc-based tracking) +- `source` is set to the real path of the working directory (symlinks resolved) +- Third-party code (site-packages, venv) is excluded by default + +### Snapshot Flow + +1. **Baseline**: CLI sends `CoverageSnapshotRequest(baseline=true)`. The SDK: + - Calls `cov.stop()` + - Uses `cov.analysis2(filename)` for each measured file to get ALL coverable lines (statements + missing) + - Returns lines with count=0 for uncovered, count=1 for covered + - Calls `cov.erase()` then `cov.start()` to reset counters + +2. **Per-test**: CLI sends `CoverageSnapshotRequest(baseline=false)`. The SDK: + - Calls `cov.stop()` + - Uses `cov.get_data().lines(filename)` to get only executed lines since last reset + - Returns only covered lines (count=1) + - Calls `cov.erase()` then `cov.start()` to reset + +3. **Communication**: Results are sent back to the CLI via the existing protobuf channel — same socket used for replay. No HTTP server or extra ports. + +### Branch Coverage + +Branch coverage uses coverage.py's arc tracking. The SDK extracts per-line branch data using: + +```python +analysis = cov._analyze(filename) # Private API +missing_arcs = analysis.missing_branch_arcs() +executed_arcs = set(data.arcs(filename) or []) +``` + +For each branch point (line with multiple execution paths), the SDK reports: +- `total`: number of branch paths from that line +- `covered`: number of paths that were actually taken + +**Note:** `_analyze()` is a private coverage.py API. It's the only way to get per-line branch arc data. The public API (`analysis2()`) only provides aggregate branch counts. This means branch coverage may break on major coverage.py version upgrades. + +### Path Handling + +The SDK uses `os.path.realpath()` for the source root to handle symlinked project directories. File paths reported by coverage.py are also resolved via `realpath` before comparison. This prevents the silent failure where all files get filtered out because symlink paths don't match. + +## Environment Variables + +Set automatically by the CLI. You should not set these manually. + +| Variable | Description | +|----------|-------------| +| `TUSK_COVERAGE` | Set to `true` by the CLI when coverage is enabled. The SDK checks this to decide whether to start coverage.py. | + +Note: `NODE_V8_COVERAGE` is also set by the CLI (for Node.js), but the Python SDK ignores it — it only checks `TUSK_COVERAGE`. + +## Thread Safety + +Coverage collection uses a module-level lock (`threading.Lock`) to ensure thread safety: + +- `start_coverage_collection()`: Acquires lock while initializing. Guards against double initialization — if called twice, stops the existing instance first. +- `take_coverage_snapshot()`: Acquires lock for the entire stop/read/erase/start cycle. +- `stop_coverage_collection()`: Acquires lock while stopping and cleaning up. + +This is important because the protobuf communicator runs coverage handlers in a background thread. + +## Limitations + +- **`coverage` package required**: Unlike Node.js (V8 coverage is built-in), Python needs `pip install coverage`. If not installed, coverage silently doesn't work (warning logged). +- **Performance overhead**: coverage.py uses `sys.settrace()` which adds 10-30% execution overhead. V8 coverage is near-zero. This overhead only applies during `--coverage` replay runs. +- **Multi-process servers**: gunicorn with `--workers > 1` forks worker processes. The SDK starts coverage.py in the main process; forked workers don't inherit it. Use `--workers 1` during coverage runs. +- **Private API for branches**: `_analyze()` is not part of coverage.py's public API. Branch coverage detail may break on future coverage.py versions. +- **Python 3.12+ recommended for async**: coverage.py's `sys.settrace` can miss some async lines on Python < 3.12. Python 3.12+ uses `sys.monitoring` for better async tracking. +- **Startup ordering**: coverage.py starts during SDK initialization. Code that executes before `TuskDrift.initialize()` (e.g., module-level code in `tusk_drift_init.py`) isn't tracked. This is why `tusk_drift_init.py` typically shows 0% coverage. +- **C extensions invisible**: coverage.py can't track C extensions (numpy, Cython modules). Not relevant for typical web API servers. diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 7e1e365..0425f1b 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -174,7 +174,20 @@ These variables configure how the SDK connects to the Tusk CLI during replay: These are typically set automatically by the Tusk CLI and do not need to be configured manually. +## Coverage Variables + +Set automatically by the CLI when `tusk drift run --coverage` is used. You should **not** set them manually. + +| Variable | Description | +|----------|-------------| +| `TUSK_COVERAGE` | Set to `true` when coverage is enabled. The SDK checks this to start coverage.py. | + +Note: `NODE_V8_COVERAGE` is also set by the CLI (for Node.js) but is ignored by the Python SDK. + +See [Coverage Guide](./coverage.md) for details on how coverage collection works. + ## Related Docs - [Initialization Guide](./initialization.md) - SDK initialization parameters and config file settings - [Quick Start Guide](./quickstart.md) - Record and replay your first trace +- [Coverage Guide](./coverage.md) - Code coverage during test replay From 2a383b0bab42930764c5bba30514fb9f502e0953 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Fri, 3 Apr 2026 14:40:19 -0700 Subject: [PATCH 11/24] fix: coverage code quality improvements - Use getattr() for betterproto oneof field access (prevents AttributeError) - Fix _is_user_file path prefix collision (/app matching /application) - Add os.path.realpath() for symlink-safe path comparison - Add thread lock to start_coverage_collection() - Add double-init guard (stop existing instance before creating new) - Narrow */test* omit pattern to */tests/* and */test_*.py - Log failed file analysis at debug level instead of silent swallow --- drift/core/communication/communicator.py | 6 +-- drift/core/coverage_server.py | 51 +++++++++++++++--------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/drift/core/communication/communicator.py b/drift/core/communication/communicator.py index b7247c7..91da868 100644 --- a/drift/core/communication/communicator.py +++ b/drift/core/communication/communicator.py @@ -781,8 +781,8 @@ def _background_read_loop(self) -> None: def _handle_set_time_travel_sync(self, cli_message: CliMessage) -> None: """Handle SetTimeTravel request from CLI and send response.""" - request = cli_message.set_time_travel_request - if not request: + request = getattr(cli_message, "set_time_travel_request", None) + if request is None: return logger.debug( @@ -818,7 +818,7 @@ def _handle_set_time_travel_sync(self, cli_message: CliMessage) -> None: def _handle_coverage_snapshot_sync(self, cli_message: CliMessage) -> None: """Handle CoverageSnapshot request from CLI and send response.""" - request = cli_message.coverage_snapshot_request + request = getattr(cli_message, "coverage_snapshot_request", None) if request is None: return diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 505c12a..02ddc4a 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -27,10 +27,10 @@ def start_coverage_collection() -> bool: - """Initialize coverage.py collection if NODE_V8_COVERAGE is set. + """Initialize coverage.py collection if TUSK_COVERAGE is set. - NODE_V8_COVERAGE is set by the CLI when coverage is enabled. - Python doesn't use V8 but we use the same env var as the signal. + TUSK_COVERAGE is set by the CLI when coverage is enabled. + This is the language-agnostic signal (Node uses NODE_V8_COVERAGE additionally). Returns True if coverage was started, False otherwise. """ @@ -49,20 +49,29 @@ def start_coverage_collection() -> bool: ) return False - _source_root = os.getcwd() - - _cov_instance = coverage_module.Coverage( - source=[_source_root], - branch=True, - omit=[ - "*/site-packages/*", - "*/venv/*", - "*/.venv/*", - "*/test*", - "*/__pycache__/*", - ], - ) - _cov_instance.start() + with _lock: + # Guard against double initialization — stop existing instance first + if _cov_instance is not None: + try: + _cov_instance.stop() + except Exception: + pass + + _source_root = os.path.realpath(os.getcwd()) + + _cov_instance = coverage_module.Coverage( + source=[_source_root], + branch=True, + omit=[ + "*/site-packages/*", + "*/venv/*", + "*/.venv/*", + "*/tests/*", + "*/test_*.py", + "*/__pycache__/*", + ], + ) + _cov_instance.start() logger.info("Coverage collection started") return True @@ -114,7 +123,8 @@ def take_coverage_snapshot(baseline: bool = False) -> dict: branch_data = _get_branch_data(data, filename) if lines_map: coverage[filename] = {"lines": lines_map, **branch_data} - except Exception: + except Exception as e: + logger.debug(f"Failed to analyze {filename}: {e}") continue else: data = _cov_instance.get_data() @@ -139,7 +149,10 @@ def _is_user_file(filename: str) -> bool: """Check if a file is a user source file (not third-party).""" if "site-packages" in filename or "lib/python" in filename: return False - if _source_root and not filename.startswith(_source_root): + # Resolve symlinks for consistent path comparison + resolved = os.path.realpath(filename) + # Use trailing separator to avoid prefix collisions (/app matching /application) + if _source_root and not (resolved.startswith(_source_root + os.sep) or resolved == _source_root): return False return True From 5b3354be752c1a67b86daa55161cbe51f37431cd Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Fri, 3 Apr 2026 15:02:45 -0700 Subject: [PATCH 12/24] docs: clean up AI writing patterns in coverage doc --- docs/coverage.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/coverage.md b/docs/coverage.md index 5252657..ac346ce 100644 --- a/docs/coverage.md +++ b/docs/coverage.md @@ -10,7 +10,7 @@ pip install coverage pip install tusk-drift[coverage] ``` -If `coverage` is not installed and `--coverage` is used, the SDK logs a warning and coverage is skipped — tests still run normally. +If `coverage` is not installed when coverage is enabled, the SDK logs a warning and coverage is skipped. Tests still run normally. ## How It Works @@ -24,7 +24,7 @@ import coverage cov = coverage.Coverage( source=[os.path.realpath(os.getcwd())], branch=True, - omit=["*/site-packages/*", "*/venv/*", "*/.venv/*", "*/test*", "*/__pycache__/*"], + omit=["*/site-packages/*", "*/venv/*", "*/.venv/*", "*/tests/*", "*/test_*.py", "*/__pycache__/*"], ) cov.start() ``` @@ -93,7 +93,7 @@ This is important because the protobuf communicator runs coverage handlers in a ## Limitations - **`coverage` package required**: Unlike Node.js (V8 coverage is built-in), Python needs `pip install coverage`. If not installed, coverage silently doesn't work (warning logged). -- **Performance overhead**: coverage.py uses `sys.settrace()` which adds 10-30% execution overhead. V8 coverage is near-zero. This overhead only applies during `--coverage` replay runs. +- **Performance overhead**: coverage.py uses `sys.settrace()` which adds 10-30% execution overhead. This only applies during coverage replay runs. - **Multi-process servers**: gunicorn with `--workers > 1` forks worker processes. The SDK starts coverage.py in the main process; forked workers don't inherit it. Use `--workers 1` during coverage runs. - **Private API for branches**: `_analyze()` is not part of coverage.py's public API. Branch coverage detail may break on future coverage.py versions. - **Python 3.12+ recommended for async**: coverage.py's `sys.settrace` can miss some async lines on Python < 3.12. Python 3.12+ uses `sys.monitoring` for better async tracking. From 4d156c4f3b5411d77175af74ba91559922ec63dc Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Fri, 3 Apr 2026 15:27:49 -0700 Subject: [PATCH 13/24] fix: address bugbot review feedback - Move _cov_instance None check inside lock (TOCTOU race fix) - Fix branch counting to only include actual branch points, not all arcs --- drift/core/coverage_server.py | 37 ++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 02ddc4a..53ea8aa 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -102,10 +102,9 @@ def take_coverage_snapshot(baseline: bool = False) -> dict: Returns: dict of { filePath: { "lines": {...}, "totalBranches": N, "coveredBranches": N, "branches": {...} } } """ - if _cov_instance is None: - raise RuntimeError("Coverage not initialized") - with _lock: + if _cov_instance is None: + raise RuntimeError("Coverage not initialized") _cov_instance.stop() coverage = {} @@ -176,20 +175,30 @@ def _get_branch_data(data, filename: str) -> dict: missing_arcs = analysis.missing_branch_arcs() executed_arcs = set(data.arcs(filename) or []) - branch_lines: dict[int, dict] = {} - + # Group executed arcs by from_line (skip negative entry arcs) + executed_by_line: dict[int, list[int]] = {} for from_line, to_line in executed_arcs: if from_line < 0: continue - if from_line not in branch_lines: - branch_lines[from_line] = {"total": 0, "covered": 0} - branch_lines[from_line]["total"] += 1 - branch_lines[from_line]["covered"] += 1 - - for from_line, to_lines in missing_arcs.items(): - if from_line not in branch_lines: - branch_lines[from_line] = {"total": 0, "covered": 0} - branch_lines[from_line]["total"] += len(to_lines) + executed_by_line.setdefault(from_line, []).append(to_line) + + # A line is a branch point if: + # - it appears in missing_arcs (at least one path wasn't taken), OR + # - it has multiple executed arcs (multiple paths from same line) + branch_point_lines = set(missing_arcs.keys()) + for from_line, to_lines in executed_by_line.items(): + if len(to_lines) > 1: + branch_point_lines.add(from_line) + + branch_lines: dict[int, dict] = {} + + for from_line in branch_point_lines: + executed_count = len(executed_by_line.get(from_line, [])) + missing_count = len(missing_arcs.get(from_line, [])) + branch_lines[from_line] = { + "total": executed_count + missing_count, + "covered": executed_count, + } branches = {str(line): info for line, info in branch_lines.items()} From a9eeb5eb51ff1dd726091826bde4e734f16c2c74 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Mon, 6 Apr 2026 17:08:05 -0700 Subject: [PATCH 14/24] chore: update tusk-drift-schemas to >=0.1.34 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ac1a6c1..8be1c5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "protobuf>=3.20.0", "PyYAML>=6.0", "requests>=2.28.0", - "tusk-drift-schemas>=0.1.24", + "tusk-drift-schemas>=0.1.34", "aiohttp>=3.9.0", "aiofiles>=23.0.0", "opentelemetry-api>=1.20.0", From f0723e8970290347115aa28f58320b461f49621e Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Mon, 6 Apr 2026 17:31:17 -0700 Subject: [PATCH 15/24] fix: address lint, type check, and coverage restart safety - Add None guard for _cov_instance before ._analyze() to fix type checker error - Wrap stop/erase/start cycle in try/finally so coverage always restarts on error - Run ruff format on coverage_server.py and drift_sdk.py - Update uv.lock to match pyproject.toml dependency changes --- drift/core/coverage_server.py | 71 ++++++++++++++++++----------------- drift/core/drift_sdk.py | 1 + uv.lock | 16 ++++---- 3 files changed, 47 insertions(+), 41 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 53ea8aa..fab95f1 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -44,8 +44,7 @@ def start_coverage_collection() -> bool: import coverage as coverage_module except ImportError: logger.warning( - "Coverage requested but 'coverage' package is not installed. " - "Install it with: pip install coverage" + "Coverage requested but 'coverage' package is not installed. Install it with: pip install coverage" ) return False @@ -108,38 +107,39 @@ def take_coverage_snapshot(baseline: bool = False) -> dict: _cov_instance.stop() coverage = {} - if baseline: - data = _cov_instance.get_data() - for filename in data.measured_files(): - if not _is_user_file(filename): - continue - try: - _, statements, _, missing, _ = _cov_instance.analysis2(filename) - missing_set = set(missing) - lines_map = {} - for line in statements: - lines_map[str(line)] = 0 if line in missing_set else 1 - branch_data = _get_branch_data(data, filename) - if lines_map: - coverage[filename] = {"lines": lines_map, **branch_data} - except Exception as e: - logger.debug(f"Failed to analyze {filename}: {e}") - continue - else: - data = _cov_instance.get_data() - for filename in data.measured_files(): - if not _is_user_file(filename): - continue - lines = data.lines(filename) - if lines: - branch_data = _get_branch_data(data, filename) - coverage[filename] = { - "lines": {str(line): 1 for line in lines}, - **branch_data, - } - - _cov_instance.erase() - _cov_instance.start() + try: + if baseline: + data = _cov_instance.get_data() + for filename in data.measured_files(): + if not _is_user_file(filename): + continue + try: + _, statements, _, missing, _ = _cov_instance.analysis2(filename) + missing_set = set(missing) + lines_map = {} + for line in statements: + lines_map[str(line)] = 0 if line in missing_set else 1 + branch_data = _get_branch_data(data, filename) + if lines_map: + coverage[filename] = {"lines": lines_map, **branch_data} + except Exception as e: + logger.debug(f"Failed to analyze {filename}: {e}") + continue + else: + data = _cov_instance.get_data() + for filename in data.measured_files(): + if not _is_user_file(filename): + continue + lines = data.lines(filename) + if lines: + branch_data = _get_branch_data(data, filename) + coverage[filename] = { + "lines": {str(line): 1 for line in lines}, + **branch_data, + } + finally: + _cov_instance.erase() + _cov_instance.start() return coverage @@ -166,6 +166,9 @@ def _get_branch_data(data, filename: str) -> dict: if not data.has_arcs(): return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + if _cov_instance is None: + return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + analysis = _cov_instance._analyze(filename) numbers = analysis.numbers diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index a420547..87c07ff 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -163,6 +163,7 @@ def initialize( # Start coverage collection early (before any SDK mode checks that might return early). # Coverage data is accessed via protobuf channel (communicator handles requests). from .coverage_server import start_coverage_collection + start_coverage_collection() instance._init_params = { diff --git a/uv.lock b/uv.lock index 85de3b6..f2a575a 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.12'", @@ -2353,7 +2354,7 @@ wheels = [ [[package]] name = "tusk-drift-python-sdk" -version = "0.1.22" +version = "0.1.23" source = { editable = "." } dependencies = [ { name = "aiofiles" }, @@ -2415,33 +2416,34 @@ requires-dist = [ { name = "flask", marker = "extra == 'flask'", specifier = ">=3.1.2" }, { name = "opentelemetry-api", specifier = ">=1.20.0" }, { name = "opentelemetry-sdk", specifier = ">=1.20.0" }, - { name = "protobuf", specifier = ">=6.0" }, + { name = "protobuf", specifier = ">=3.20.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0,<9.0.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=7.0.0" }, { name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.15.0" }, { name = "python-jsonpath", marker = "extra == 'dev'", specifier = ">=0.10" }, { name = "pyyaml", specifier = ">=6.0" }, - { name = "requests", specifier = ">=2.32.5" }, + { name = "requests", specifier = ">=2.28.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = "==0.14.13" }, { name = "starlette", marker = "extra == 'fastapi'", specifier = "<0.42.0" }, { name = "time-machine", specifier = ">=2.10.0" }, - { name = "tusk-drift-schemas", specifier = ">=0.1.24" }, + { name = "tusk-drift-schemas", specifier = ">=0.1.34" }, { name = "ty", marker = "extra == 'dev'", specifier = "==0.0.12" }, { name = "typing-extensions", specifier = ">=4.4.0" }, { name = "uvicorn", marker = "extra == 'dev'", specifier = ">=0.34.2" }, { name = "uvicorn", marker = "extra == 'fastapi'", specifier = ">=0.34.2" }, ] +provides-extras = ["flask", "fastapi", "django", "rust", "dev"] [[package]] name = "tusk-drift-schemas" -version = "0.1.30" +version = "0.1.34" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "betterproto", extra = ["compiler"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b8/13/0490df57cef00ddb6e123908ec43d506d791542a26bc39492faea137301a/tusk_drift_schemas-0.1.30.tar.gz", hash = "sha256:ba992f243bbf68eae27ecfac003d33861f9e16abc7cc007fd511cc8ca342450f", size = 14741 } +sdist = { url = "https://files.pythonhosted.org/packages/6b/cb/f83b416768ca099f6438b98970abd02fafce9f5135f0c4450cb8b8fd815c/tusk_drift_schemas-0.1.34.tar.gz", hash = "sha256:8d60c69d21e03f04facb19ad71aa91060941131a2bef9b79cdc55cdac651e89a", size = 17170 } wheels = [ - { url = "https://files.pythonhosted.org/packages/23/f0/33cb9a814e8ad7a6c32e48041428bc41d0829c657c0879e4ed49e7b5144b/tusk_drift_schemas-0.1.30-py3-none-any.whl", hash = "sha256:0f1193548e559122d355d19b432792001059f3fcb3f254f8aba6afe0cdea3e23", size = 14565 }, + { url = "https://files.pythonhosted.org/packages/ef/d4/8e88b76ba1fb60f5dc69ac7cb6c132392b35a99a907191fcfc43dfecb999/tusk_drift_schemas-0.1.34-py3-none-any.whl", hash = "sha256:f1af65fe49b911cddf7814cd5f5b94c420155602648e1d0b85b4fdf3e215c93a", size = 16257 }, ] [[package]] From 5cc5a0772608af3acb32209852e589adbb986ce1 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Mon, 6 Apr 2026 18:05:55 -0700 Subject: [PATCH 16/24] fix: remove unused imports and simplify _is_user_file return --- drift/core/communication/types.py | 3 --- drift/core/coverage_server.py | 4 +--- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/drift/core/communication/types.py b/drift/core/communication/types.py index 298a203..0ecd7aa 100644 --- a/drift/core/communication/types.py +++ b/drift/core/communication/types.py @@ -43,10 +43,7 @@ from typing import Any from tusk.drift.core.v1 import ( - BranchInfo, CliMessage, - CoverageSnapshotResponse, - FileCoverageData, InstrumentationVersionMismatchAlert, MessageType, Runtime, diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index fab95f1..970e1cb 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -151,9 +151,7 @@ def _is_user_file(filename: str) -> bool: # Resolve symlinks for consistent path comparison resolved = os.path.realpath(filename) # Use trailing separator to avoid prefix collisions (/app matching /application) - if _source_root and not (resolved.startswith(_source_root + os.sep) or resolved == _source_root): - return False - return True + return not _source_root or resolved.startswith(_source_root + os.sep) or resolved == _source_root def _get_branch_data(data, filename: str) -> dict: From 97d824efa414678bd307fcff4463685c3a81a9b1 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Mon, 6 Apr 2026 18:18:57 -0700 Subject: [PATCH 17/24] fix: restore re-exported imports removed by mistake (BranchInfo, CoverageSnapshotResponse, FileCoverageData) --- drift/core/communication/types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/drift/core/communication/types.py b/drift/core/communication/types.py index 0ecd7aa..c55cf1a 100644 --- a/drift/core/communication/types.py +++ b/drift/core/communication/types.py @@ -43,7 +43,10 @@ from typing import Any from tusk.drift.core.v1 import ( + BranchInfo, # noqa: F401 (re-exported) CliMessage, + CoverageSnapshotResponse, # noqa: F401 (re-exported) + FileCoverageData, # noqa: F401 (re-exported) InstrumentationVersionMismatchAlert, MessageType, Runtime, From 01a0b3a98854fe3123c0eb94192818fab60043ad Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Mon, 6 Apr 2026 18:29:29 -0700 Subject: [PATCH 18/24] ref: remove proto re-exports from types.py, import directly from tusk.drift.core.v1 types.py now only contains SDK wrapper dataclasses and utility functions. Proto types are imported directly where used. Removes SDKMessageType/CLIMessageType aliases. --- drift/core/communication/__init__.py | 13 --------- drift/core/communication/communicator.py | 24 +++++++++------- drift/core/communication/types.py | 35 ++---------------------- 3 files changed, 17 insertions(+), 55 deletions(-) diff --git a/drift/core/communication/__init__.py b/drift/core/communication/__init__.py index 7c6246d..bf96bca 100644 --- a/drift/core/communication/__init__.py +++ b/drift/core/communication/__init__.py @@ -6,28 +6,18 @@ from .communicator import CommunicatorConfig, ProtobufCommunicator from .types import ( - CliMessage, - CLIMessageType, ConnectRequest, ConnectResponse, GetMockRequest, GetMockResponse, - MessageType, MockRequestInput, MockResponseOutput, - # Protobuf types (re-exported) - SdkMessage, - SDKMessageType, dict_to_span, extract_response_data, span_to_proto, ) __all__ = [ - # Message types - "MessageType", - "SDKMessageType", - "CLIMessageType", # Request/Response types "ConnectRequest", "ConnectResponse", @@ -35,9 +25,6 @@ "GetMockResponse", "MockRequestInput", "MockResponseOutput", - # Protobuf types - "SdkMessage", - "CliMessage", # Utilities "span_to_proto", "dict_to_span", diff --git a/drift/core/communication/communicator.py b/drift/core/communication/communicator.py index 91da868..3b32860 100644 --- a/drift/core/communication/communicator.py +++ b/drift/core/communication/communicator.py @@ -10,27 +10,31 @@ from dataclasses import dataclass from typing import Any -from tusk.drift.core.v1 import GetMockRequest as ProtoGetMockRequest - -from ...version import MIN_CLI_VERSION, SDK_VERSION -from ..span_serialization import clean_span_to_proto -from ..types import CleanSpanData, calling_library_context -from .types import ( +from tusk.drift.core.v1 import ( BranchInfo, CliMessage, - ConnectRequest, CoverageSnapshotResponse, FileCoverageData, - GetMockRequest, InstrumentationVersionMismatchAlert, MessageType, - MockRequestInput, - MockResponseOutput, SdkMessage, SendAlertRequest, SendInboundSpanForReplayRequest, SetTimeTravelResponse, UnpatchedDependencyAlert, +) +from tusk.drift.core.v1 import ( + GetMockRequest as ProtoGetMockRequest, +) + +from ...version import MIN_CLI_VERSION, SDK_VERSION +from ..span_serialization import clean_span_to_proto +from ..types import CleanSpanData, calling_library_context +from .types import ( + ConnectRequest, + GetMockRequest, + MockRequestInput, + MockResponseOutput, span_to_proto, ) diff --git a/drift/core/communication/types.py b/drift/core/communication/types.py index c55cf1a..de49ab1 100644 --- a/drift/core/communication/types.py +++ b/drift/core/communication/types.py @@ -12,20 +12,6 @@ from __future__ import annotations __all__ = [ - # Re-exported protobuf types - "CliMessage", - "InstrumentationVersionMismatchAlert", - "MessageType", - "Runtime", - "SdkMessage", - "SendAlertRequest", - "SendInboundSpanForReplayRequest", - "SetTimeTravelRequest", - "SetTimeTravelResponse", - "UnpatchedDependencyAlert", - # Aliases - "SDKMessageType", - "CLIMessageType", # Dataclasses "ConnectRequest", "ConnectResponse", @@ -42,21 +28,6 @@ from dataclasses import dataclass, field from typing import Any -from tusk.drift.core.v1 import ( - BranchInfo, # noqa: F401 (re-exported) - CliMessage, - CoverageSnapshotResponse, # noqa: F401 (re-exported) - FileCoverageData, # noqa: F401 (re-exported) - InstrumentationVersionMismatchAlert, - MessageType, - Runtime, - SdkMessage, - SendAlertRequest, - SendInboundSpanForReplayRequest, - SetTimeTravelRequest, - SetTimeTravelResponse, - UnpatchedDependencyAlert, -) from tusk.drift.core.v1 import ( ConnectRequest as ProtoConnectRequest, ) @@ -69,6 +40,9 @@ from tusk.drift.core.v1 import ( GetMockResponse as ProtoGetMockResponse, ) +from tusk.drift.core.v1 import ( + Runtime, +) from tusk.drift.core.v1 import ( Span as ProtoSpan, ) @@ -82,9 +56,6 @@ StatusCode as ProtoStatusCode, ) -SDKMessageType = MessageType -CLIMessageType = MessageType - def _python_to_value(value: Any) -> Any: """Convert Python value to protobuf Value.""" From c71dd375c50408dad134190b1d960bbca12d91ee Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Mon, 6 Apr 2026 19:13:47 -0700 Subject: [PATCH 19/24] fix: guard coverage with REPLAY mode check, add coverage_server unit tests - Skip coverage collection when TUSK_DRIFT_MODE is set to non-REPLAY mode - Add 9 unit tests covering start/stop, mode gating, file filtering, error handling --- drift/core/coverage_server.py | 7 ++ tests/unit/test_coverage_server.py | 118 +++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 tests/unit/test_coverage_server.py diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 970e1cb..05e36c7 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -40,6 +40,13 @@ def start_coverage_collection() -> bool: if not os.environ.get("TUSK_COVERAGE"): return False + # Coverage collection only makes sense in REPLAY mode. + # If TUSK_DRIFT_MODE is not set we still proceed for backwards compatibility. + mode = os.environ.get("TUSK_DRIFT_MODE", "").upper() + if mode and mode != "REPLAY": + logger.debug("Coverage collection skipped: not in REPLAY mode (mode=%s)", mode) + return False + try: import coverage as coverage_module except ImportError: diff --git a/tests/unit/test_coverage_server.py b/tests/unit/test_coverage_server.py new file mode 100644 index 0000000..14870e7 --- /dev/null +++ b/tests/unit/test_coverage_server.py @@ -0,0 +1,118 @@ +"""Tests for coverage_server.py - Coverage collection management.""" + +from __future__ import annotations + +import os + +import pytest + +from drift.core import coverage_server +from drift.core.coverage_server import ( + _is_user_file, + start_coverage_collection, + stop_coverage_collection, + take_coverage_snapshot, +) + + +@pytest.fixture(autouse=True) +def _reset_coverage_state(): + """Reset module-level globals between tests.""" + yield + stop_coverage_collection() + # Also make sure _source_root is cleared + coverage_server._source_root = None + + +class TestStartCoverageCollection: + """Tests for start_coverage_collection function.""" + + def test_returns_false_when_tusk_coverage_not_set(self, monkeypatch): + """Should return False when TUSK_COVERAGE env var is not set.""" + monkeypatch.delenv("TUSK_COVERAGE", raising=False) + monkeypatch.delenv("TUSK_DRIFT_MODE", raising=False) + + result = start_coverage_collection() + + assert result is False + + def test_returns_false_when_mode_is_record(self, monkeypatch): + """Should return False when TUSK_DRIFT_MODE=RECORD even if TUSK_COVERAGE=true.""" + monkeypatch.setenv("TUSK_COVERAGE", "true") + monkeypatch.setenv("TUSK_DRIFT_MODE", "RECORD") + + result = start_coverage_collection() + + assert result is False + + def test_returns_true_when_mode_is_replay(self, monkeypatch, mocker): + """Should return True when TUSK_COVERAGE=true and mode is REPLAY.""" + monkeypatch.setenv("TUSK_COVERAGE", "true") + monkeypatch.setenv("TUSK_DRIFT_MODE", "REPLAY") + + mock_cov_instance = mocker.MagicMock() + mock_coverage_module = mocker.MagicMock() + mock_coverage_module.Coverage.return_value = mock_cov_instance + mocker.patch.dict("sys.modules", {"coverage": mock_coverage_module}) + + result = start_coverage_collection() + + assert result is True + mock_cov_instance.start.assert_called_once() + + def test_returns_true_when_mode_not_set(self, monkeypatch, mocker): + """Should return True when TUSK_COVERAGE=true and TUSK_DRIFT_MODE is not set (backwards compat).""" + monkeypatch.setenv("TUSK_COVERAGE", "true") + monkeypatch.delenv("TUSK_DRIFT_MODE", raising=False) + + mock_cov_instance = mocker.MagicMock() + mock_coverage_module = mocker.MagicMock() + mock_coverage_module.Coverage.return_value = mock_cov_instance + mocker.patch.dict("sys.modules", {"coverage": mock_coverage_module}) + + result = start_coverage_collection() + + assert result is True + mock_cov_instance.start.assert_called_once() + + +class TestIsUserFile: + """Tests for _is_user_file function.""" + + def test_returns_false_for_site_packages(self): + """Should return False for paths containing site-packages.""" + assert _is_user_file("/usr/lib/python3.11/site-packages/requests/api.py") is False + + def test_returns_false_for_venv_paths(self): + """Should return False for paths containing lib/python (venv pattern).""" + assert _is_user_file("/app/venv/lib/python3.11/somepkg/mod.py") is False + + def test_returns_true_for_source_root_file(self, monkeypatch): + """Should return True for files within the source root.""" + source_root = os.path.realpath("/tmp/myproject") + monkeypatch.setattr(coverage_server, "_source_root", source_root) + + assert _is_user_file(os.path.join(source_root, "app", "main.py")) is True + + +class TestTakeCoverageSnapshot: + """Tests for take_coverage_snapshot function.""" + + def test_raises_runtime_error_when_not_initialized(self): + """Should raise RuntimeError when coverage is not initialized.""" + with pytest.raises(RuntimeError, match="Coverage not initialized"): + take_coverage_snapshot() + + +class TestStopCoverageCollection: + """Tests for stop_coverage_collection function.""" + + def test_cleans_up_state(self, monkeypatch, mocker): + """Should stop coverage instance and set it to None.""" + mock_cov = mocker.MagicMock() + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + + stop_coverage_collection() + + mock_cov.stop.assert_called_once() + assert coverage_server._cov_instance is None From 06b6c270855f82a3a9f0bad58a36e93a71bdbcfb Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Mon, 6 Apr 2026 19:57:06 -0700 Subject: [PATCH 20/24] fix: add coverage optional extra, fix docs install instructions, reorder init - Add coverage>=7.0.0 as optional dependency [coverage] extra in pyproject.toml - Fix docs to reference correct package name tusk-drift-python-sdk[coverage] - Move start_coverage_collection() after _initialized guard to avoid wasteful re-invocation --- docs/coverage.md | 4 +--- drift/core/drift_sdk.py | 13 +++++++------ pyproject.toml | 1 + uv.lock | 7 ++++++- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/docs/coverage.md b/docs/coverage.md index ac346ce..56a52ee 100644 --- a/docs/coverage.md +++ b/docs/coverage.md @@ -5,9 +5,7 @@ The Python SDK collects per-test code coverage during Tusk Drift replay using `c ## Requirements ```bash -pip install coverage -# or -pip install tusk-drift[coverage] +pip install tusk-drift-python-sdk[coverage] ``` If `coverage` is not installed when coverage is enabled, the SDK logs a warning and coverage is skipped. Tests still run normally. diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 87c07ff..2dcbf93 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -160,12 +160,6 @@ def initialize( configure_logger(log_level=log_level, prefix="TuskDrift") - # Start coverage collection early (before any SDK mode checks that might return early). - # Coverage data is accessed via protobuf channel (communicator handles requests). - from .coverage_server import start_coverage_collection - - start_coverage_collection() - instance._init_params = { "api_key": api_key, "env": env, @@ -185,6 +179,13 @@ def initialize( ) env = env_from_var + # Start coverage collection early (before any SDK mode checks that might return early), + # but after the _initialized guard so we don't re-invoke on repeated initialize() calls. + # Coverage data is accessed via protobuf channel (communicator handles requests). + from .coverage_server import start_coverage_collection + + start_coverage_collection() + if cls._initialized: logger.debug("Already initialized, skipping...") return instance diff --git a/pyproject.toml b/pyproject.toml index 8be1c5f..f6f7151 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ flask = ["Flask>=3.1.2"] fastapi = ["fastapi>=0.115.6", "uvicorn>=0.34.2", "starlette<0.42.0"] django = ["Django>=4.2"] +coverage = ["coverage>=7.0.0"] rust = ["drift-core-python>=0.1.9"] dev = [ "Flask>=3.1.2", diff --git a/uv.lock b/uv.lock index f2a575a..7d59570 100644 --- a/uv.lock +++ b/uv.lock @@ -2372,6 +2372,10 @@ dependencies = [ ] [package.optional-dependencies] +coverage = [ + { name = "coverage", version = "7.10.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "coverage", version = "7.13.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, +] dev = [ { name = "fastapi", version = "0.128.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "fastapi", version = "0.129.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, @@ -2408,6 +2412,7 @@ rust = [ requires-dist = [ { name = "aiofiles", specifier = ">=23.0.0" }, { name = "aiohttp", specifier = ">=3.9.0" }, + { name = "coverage", marker = "extra == 'coverage'", specifier = ">=7.0.0" }, { name = "django", marker = "extra == 'django'", specifier = ">=4.2" }, { name = "drift-core-python", marker = "extra == 'rust'", specifier = ">=0.1.9" }, { name = "fastapi", marker = "extra == 'dev'", specifier = ">=0.115.6" }, @@ -2432,7 +2437,7 @@ requires-dist = [ { name = "uvicorn", marker = "extra == 'dev'", specifier = ">=0.34.2" }, { name = "uvicorn", marker = "extra == 'fastapi'", specifier = ">=0.34.2" }, ] -provides-extras = ["flask", "fastapi", "django", "rust", "dev"] +provides-extras = ["flask", "fastapi", "django", "coverage", "rust", "dev"] [[package]] name = "tusk-drift-schemas" From e1fa8c0cf4df7dfdd8b912faeafdad3ba4d9e4b7 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Mon, 6 Apr 2026 20:09:10 -0700 Subject: [PATCH 21/24] fix: cache branch structure from baseline for deterministic per-test branch counts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Branch detection via _analyze() depends on observed arcs, which vary with thread timing. Now the baseline snapshot caches branch point structure (totals per line), and per-test snapshots reuse that cache — only computing covered counts from the current test's arcs. This eliminates flaky branch totals (was 12/18 or 12/22 randomly, now consistently 4/10). --- drift/core/coverage_server.py | 60 +++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 05e36c7..f2306f4 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -25,6 +25,12 @@ _source_root: str | None = None _lock = threading.Lock() +# Cache branch structure from baseline to ensure deterministic branch counts. +# Branch detection via _analyze() depends on observed arcs, which vary with +# thread timing. By caching from the baseline (which has the fullest data), +# per-test snapshots report consistent totals. +_branch_cache: dict[str, dict] | None = None + def start_coverage_collection() -> bool: """Initialize coverage.py collection if TUSK_COVERAGE is set. @@ -85,7 +91,7 @@ def start_coverage_collection() -> bool: def stop_coverage_collection() -> None: """Stop coverage collection and clean up. Thread-safe.""" - global _cov_instance + global _cov_instance, _branch_cache with _lock: if _cov_instance is not None: try: @@ -93,6 +99,7 @@ def stop_coverage_collection() -> None: except Exception: pass _cov_instance = None + _branch_cache = None def take_coverage_snapshot(baseline: bool = False) -> dict: @@ -115,7 +122,10 @@ def take_coverage_snapshot(baseline: bool = False) -> dict: coverage = {} try: + global _branch_cache if baseline: + # Baseline: compute fresh branch data and cache it for per-test reuse + _branch_cache = {} data = _cov_instance.get_data() for filename in data.measured_files(): if not _is_user_file(filename): @@ -127,6 +137,7 @@ def take_coverage_snapshot(baseline: bool = False) -> dict: for line in statements: lines_map[str(line)] = 0 if line in missing_set else 1 branch_data = _get_branch_data(data, filename) + _branch_cache[filename] = branch_data if lines_map: coverage[filename] = {"lines": lines_map, **branch_data} except Exception as e: @@ -139,7 +150,12 @@ def take_coverage_snapshot(baseline: bool = False) -> dict: continue lines = data.lines(filename) if lines: - branch_data = _get_branch_data(data, filename) + # Use cached branch data from baseline for stable totals. + # Fall back to live _analyze() if no cache (e.g., no baseline taken). + if _branch_cache is not None and filename in _branch_cache: + branch_data = _get_per_test_branch_data(data, filename, _branch_cache[filename]) + else: + branch_data = _get_branch_data(data, filename) coverage[filename] = { "lines": {str(line): 1 for line in lines}, **branch_data, @@ -217,3 +233,43 @@ def _get_branch_data(data, filename: str) -> dict: } except Exception: return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + + +def _get_per_test_branch_data(data, filename: str, cached: dict) -> dict: + """Compute per-test branch coverage using cached branch structure from baseline. + + Uses the cached branch point set (from baseline) for stable totals, + but computes covered counts from the current test's executed arcs. + This avoids flaky branch totals caused by non-deterministic arc detection. + """ + try: + if not data.has_arcs(): + return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + + executed_arcs = set(data.arcs(filename) or []) + + # Group executed arcs by from_line (skip negative entry arcs) + executed_by_line: dict[int, list[int]] = {} + for from_line, to_line in executed_arcs: + if from_line < 0: + continue + executed_by_line.setdefault(from_line, []).append(to_line) + + # Use cached branch points — only compute covered from current arcs + cached_branches = cached.get("branches", {}) + branches: dict[str, dict] = {} + total_covered = 0 + + for line_str, info in cached_branches.items(): + total = info["total"] + covered = min(len(executed_by_line.get(int(line_str), [])), total) + branches[line_str] = {"total": total, "covered": covered} + total_covered += covered + + return { + "totalBranches": cached.get("totalBranches", 0), + "coveredBranches": total_covered, + "branches": branches, + } + except Exception: + return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} From 9b197f7977cc6fae206fbedee5e05307832d4035 Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Mon, 6 Apr 2026 22:12:28 -0700 Subject: [PATCH 22/24] test: add Tusk-generated tests for coverage server and communicator handler --- tests/unit/test_communicator.py | 207 +++++++++++++++ tests/unit/test_coverage_server.py | 400 +++++++++++++++++++++++++++++ 2 files changed, 607 insertions(+) diff --git a/tests/unit/test_communicator.py b/tests/unit/test_communicator.py index a813318..361bc85 100644 --- a/tests/unit/test_communicator.py +++ b/tests/unit/test_communicator.py @@ -598,3 +598,210 @@ def test_has_response_lock(self): assert hasattr(communicator, "_response_lock") assert isinstance(communicator._response_lock, type(threading.Lock())) + + +class TestProtobufCommunicatorHandleSetTimeTravelSync: + """Tests for _handle_set_time_travel_sync method.""" + + def test_returns_early_when_no_request(self, mocker): + """Should return early when set_time_travel_request is None.""" + communicator = ProtobufCommunicator() + mock_send = mocker.patch.object(communicator, "_send_message_sync") + + mock_cli_message = mocker.MagicMock() + mock_cli_message.set_time_travel_request = None + + communicator._handle_set_time_travel_sync(mock_cli_message) + + mock_send.assert_not_called() + + def test_handles_successful_time_travel(self, mocker): + """Should send success response when time travel starts successfully.""" + from tusk.drift.core.v1 import MessageType + + communicator = ProtobufCommunicator() + mock_send = mocker.patch.object(communicator, "_send_message_sync") + + mock_cli_message = mocker.MagicMock() + mock_request = mocker.MagicMock() + mock_request.timestamp_seconds = 1234567890 + mock_request.trace_id = "trace123" + mock_request.timestamp_source = "manual" + mock_cli_message.set_time_travel_request = mock_request + mock_cli_message.request_id = "req123" + + mock_start_time_travel = mocker.patch( + "drift.instrumentation.datetime.instrumentation.start_time_travel", return_value=True + ) + + communicator._handle_set_time_travel_sync(mock_cli_message) + + mock_start_time_travel.assert_called_once_with(1234567890, "trace123") + mock_send.assert_called_once() + sent_message = mock_send.call_args[0][0] + assert sent_message.type == MessageType.SET_TIME_TRAVEL + assert sent_message.request_id == "req123" + assert sent_message.set_time_travel_response.success is True + + def test_handles_time_travel_failure(self, mocker): + """Should send error response when time travel fails.""" + from tusk.drift.core.v1 import MessageType + + communicator = ProtobufCommunicator() + mock_send = mocker.patch.object(communicator, "_send_message_sync") + + mock_cli_message = mocker.MagicMock() + mock_request = mocker.MagicMock() + mock_request.timestamp_seconds = 1234567890 + mock_request.trace_id = "trace123" + mock_request.timestamp_source = "manual" + mock_cli_message.set_time_travel_request = mock_request + mock_cli_message.request_id = "req456" + + mocker.patch("drift.instrumentation.datetime.instrumentation.start_time_travel", return_value=False) + + communicator._handle_set_time_travel_sync(mock_cli_message) + + mock_send.assert_called_once() + sent_message = mock_send.call_args[0][0] + assert sent_message.type == MessageType.SET_TIME_TRAVEL + assert sent_message.set_time_travel_response.success is False + assert "not available" in sent_message.set_time_travel_response.error + + def test_handles_time_travel_exception(self, mocker): + """Should send error response when exception occurs during time travel.""" + + communicator = ProtobufCommunicator() + mock_send = mocker.patch.object(communicator, "_send_message_sync") + + mock_cli_message = mocker.MagicMock() + mock_request = mocker.MagicMock() + mock_request.timestamp_seconds = 1234567890 + mock_request.trace_id = "trace123" + mock_request.timestamp_source = "manual" + mock_cli_message.set_time_travel_request = mock_request + mock_cli_message.request_id = "req789" + + mocker.patch( + "drift.instrumentation.datetime.instrumentation.start_time_travel", + side_effect=Exception("Time travel error"), + ) + + communicator._handle_set_time_travel_sync(mock_cli_message) + + mock_send.assert_called_once() + sent_message = mock_send.call_args[0][0] + assert sent_message.set_time_travel_response.success is False + assert "Time travel error" in sent_message.set_time_travel_response.error + + +class TestProtobufCommunicatorHandleCoverageSnapshotSync: + """Tests for _handle_coverage_snapshot_sync method.""" + + def test_returns_early_when_no_request(self, mocker): + """Should return early when coverage_snapshot_request is None.""" + communicator = ProtobufCommunicator() + mock_send = mocker.patch.object(communicator, "_send_message_sync") + + mock_cli_message = mocker.MagicMock() + mock_cli_message.coverage_snapshot_request = None + + communicator._handle_coverage_snapshot_sync(mock_cli_message) + + mock_send.assert_not_called() + + def test_handles_successful_baseline_snapshot(self, mocker): + """Should send success response with coverage data for baseline snapshot.""" + from tusk.drift.core.v1 import MessageType + + communicator = ProtobufCommunicator() + mock_send = mocker.patch.object(communicator, "_send_message_sync") + + mock_cli_message = mocker.MagicMock() + mock_request = mocker.MagicMock() + mock_request.baseline = True + mock_cli_message.coverage_snapshot_request = mock_request + mock_cli_message.request_id = "cov123" + + mock_snapshot_result = { + "/app/main.py": { + "lines": {"1": 1, "2": 0, "3": 1}, + "totalBranches": 2, + "coveredBranches": 1, + "branches": {"5": {"total": 2, "covered": 1}}, + } + } + mock_take_snapshot = mocker.patch( + "drift.core.coverage_server.take_coverage_snapshot", + return_value=mock_snapshot_result, + ) + + communicator._handle_coverage_snapshot_sync(mock_cli_message) + + mock_take_snapshot.assert_called_once_with(True) + mock_send.assert_called_once() + sent_message = mock_send.call_args[0][0] + assert sent_message.type == MessageType.COVERAGE_SNAPSHOT + assert sent_message.request_id == "cov123" + assert sent_message.coverage_snapshot_response.success is True + assert "/app/main.py" in sent_message.coverage_snapshot_response.coverage + + def test_converts_coverage_data_to_protobuf_format(self, mocker): + """Should properly convert coverage data to protobuf FileCoverageData format.""" + + communicator = ProtobufCommunicator() + mock_send = mocker.patch.object(communicator, "_send_message_sync") + + mock_cli_message = mocker.MagicMock() + mock_request = mocker.MagicMock() + mock_request.baseline = False + mock_cli_message.coverage_snapshot_request = mock_request + mock_cli_message.request_id = "cov789" + + mock_snapshot_result = { + "/app/file1.py": { + "lines": {"1": 1, "2": 1}, + "totalBranches": 4, + "coveredBranches": 3, + "branches": {"5": {"total": 2, "covered": 1}, "10": {"total": 2, "covered": 2}}, + } + } + mocker.patch( + "drift.core.coverage_server.take_coverage_snapshot", + return_value=mock_snapshot_result, + ) + + communicator._handle_coverage_snapshot_sync(mock_cli_message) + + sent_message = mock_send.call_args[0][0] + file_data = sent_message.coverage_snapshot_response.coverage["/app/file1.py"] + assert file_data.lines == {"1": 1, "2": 1} + assert file_data.total_branches == 4 + assert file_data.covered_branches == 3 + assert "5" in file_data.branches + assert file_data.branches["5"].total == 2 + assert file_data.branches["5"].covered == 1 + + def test_handles_snapshot_exception(self, mocker): + """Should send error response when take_coverage_snapshot raises exception.""" + + communicator = ProtobufCommunicator() + mock_send = mocker.patch.object(communicator, "_send_message_sync") + + mock_cli_message = mocker.MagicMock() + mock_request = mocker.MagicMock() + mock_request.baseline = False + mock_cli_message.coverage_snapshot_request = mock_request + mock_cli_message.request_id = "cov999" + + mocker.patch( + "drift.core.coverage_server.take_coverage_snapshot", + side_effect=RuntimeError("Coverage not initialized"), + ) + + communicator._handle_coverage_snapshot_sync(mock_cli_message) + + mock_send.assert_called_once() + sent_message = mock_send.call_args[0][0] + assert sent_message.coverage_snapshot_response.success is False + assert "Coverage not initialized" in sent_message.coverage_snapshot_response.error diff --git a/tests/unit/test_coverage_server.py b/tests/unit/test_coverage_server.py index 14870e7..523bfbe 100644 --- a/tests/unit/test_coverage_server.py +++ b/tests/unit/test_coverage_server.py @@ -116,3 +116,403 @@ def test_cleans_up_state(self, monkeypatch, mocker): mock_cov.stop.assert_called_once() assert coverage_server._cov_instance is None + + +class TestTakeCoverageSnapshotBaseline: + """Tests for take_coverage_snapshot in baseline mode.""" + + def test_baseline_returns_all_coverable_lines_including_uncovered(self, monkeypatch, mocker): + """Should return all coverable lines including uncovered (count=0) in baseline mode.""" + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + # Mock measured files + mock_data.measured_files.return_value = ["/app/main.py"] + mock_data.has_arcs.return_value = False + + # analysis2 returns: (filename, statements, excluded, missing, missing_formatted) + # statements=[1,2,3,4], missing=[2,4] means lines 1,3 covered, lines 2,4 uncovered + mock_cov.analysis2.return_value = (None, [1, 2, 3, 4], [], [2, 4], None) + mock_cov.get_data.return_value = mock_data + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + monkeypatch.setattr(coverage_server, "_source_root", os.path.realpath("/app")) + + result = take_coverage_snapshot(baseline=True) + + assert "/app/main.py" in result + lines = result["/app/main.py"]["lines"] + assert lines["1"] == 1 # covered + assert lines["2"] == 0 # uncovered + assert lines["3"] == 1 # covered + assert lines["4"] == 0 # uncovered + mock_cov.stop.assert_called() + mock_cov.erase.assert_called() + mock_cov.start.assert_called() + + def test_baseline_caches_branch_data_for_per_test_use(self, monkeypatch, mocker): + """Should cache branch structure in baseline mode for deterministic per-test counts.""" + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + mock_data.measured_files.return_value = ["/app/branchy.py"] + mock_data.has_arcs.return_value = True + mock_data.arcs.return_value = [(1, 2), (1, 3)] + + mock_cov.analysis2.return_value = (None, [1, 2, 3], [], [], None) + mock_cov.get_data.return_value = mock_data + + # Mock _analyze for branch detection + mock_analysis = mocker.MagicMock() + mock_numbers = mocker.MagicMock() + mock_numbers.n_branches = 2 + mock_numbers.n_missing_branches = 0 + mock_analysis.numbers = mock_numbers + mock_analysis.missing_branch_arcs.return_value = {} + mock_cov._analyze.return_value = mock_analysis + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + monkeypatch.setattr(coverage_server, "_source_root", os.path.realpath("/app")) + + take_coverage_snapshot(baseline=True) + + # Verify branch cache was populated + assert coverage_server._branch_cache is not None + assert "/app/branchy.py" in coverage_server._branch_cache + + def test_baseline_handles_analysis_exceptions_gracefully(self, monkeypatch, mocker): + """Should continue processing other files if analysis2 raises exception.""" + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + mock_data.measured_files.return_value = ["/app/broken.py", "/app/ok.py"] + mock_data.has_arcs.return_value = False + + def analysis_side_effect(filename): + if "broken" in filename: + raise Exception("Analysis failed") + return (None, [1, 2], [], [], None) + + mock_cov.analysis2.side_effect = analysis_side_effect + mock_cov.get_data.return_value = mock_data + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + monkeypatch.setattr(coverage_server, "_source_root", os.path.realpath("/app")) + + result = take_coverage_snapshot(baseline=True) + + # Should skip broken.py but process ok.py + assert "/app/broken.py" not in result + assert "/app/ok.py" in result + + +class TestTakeCoverageSnapshotPerTest: + """Tests for take_coverage_snapshot in per-test mode.""" + + def test_per_test_returns_only_executed_lines(self, monkeypatch, mocker): + """Should return only executed lines since last snapshot in per-test mode.""" + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + mock_data.measured_files.return_value = ["/app/main.py"] + mock_data.has_arcs.return_value = False + mock_data.lines.return_value = [5, 6, 7] # Only these lines executed + + mock_cov.get_data.return_value = mock_data + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + monkeypatch.setattr(coverage_server, "_source_root", os.path.realpath("/app")) + + result = take_coverage_snapshot(baseline=False) + + assert "/app/main.py" in result + lines = result["/app/main.py"]["lines"] + assert lines == {"5": 1, "6": 1, "7": 1} + + def test_per_test_uses_cached_branch_data_when_available(self, monkeypatch, mocker): + """Should use cached branch structure from baseline for stable totals.""" + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + mock_data.measured_files.return_value = ["/app/branchy.py"] + mock_data.has_arcs.return_value = True + mock_data.lines.return_value = [1, 2] + mock_data.arcs.return_value = [(1, 2)] # Only one branch taken this test + + mock_cov.get_data.return_value = mock_data + + # Populate cache with baseline data + cached_branch_data = { + "totalBranches": 2, + "coveredBranches": 2, + "branches": {"1": {"total": 2, "covered": 2}}, + } + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + monkeypatch.setattr(coverage_server, "_source_root", os.path.realpath("/app")) + monkeypatch.setattr(coverage_server, "_branch_cache", {"/app/branchy.py": cached_branch_data}) + + result = take_coverage_snapshot(baseline=False) + + # Should use cached totalBranches=2, but compute covered from current arcs + assert result["/app/branchy.py"]["totalBranches"] == 2 + assert result["/app/branchy.py"]["coveredBranches"] == 1 # Only 1 arc executed this test + + def test_per_test_falls_back_to_live_analyze_when_no_cache(self, monkeypatch, mocker): + """Should fall back to _get_branch_data if no baseline cache exists.""" + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + mock_data.measured_files.return_value = ["/app/main.py"] + mock_data.has_arcs.return_value = True + mock_data.lines.return_value = [1, 2] + mock_data.arcs.return_value = [(1, 2)] + + mock_cov.get_data.return_value = mock_data + + mock_analysis = mocker.MagicMock() + mock_numbers = mocker.MagicMock() + mock_numbers.n_branches = 1 + mock_numbers.n_missing_branches = 0 + mock_analysis.numbers = mock_numbers + mock_analysis.missing_branch_arcs.return_value = {} + mock_cov._analyze.return_value = mock_analysis + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + monkeypatch.setattr(coverage_server, "_source_root", os.path.realpath("/app")) + monkeypatch.setattr(coverage_server, "_branch_cache", None) # No cache + + result = take_coverage_snapshot(baseline=False) + + # Should call _analyze since no cache + mock_cov._analyze.assert_called_with("/app/main.py") + assert "/app/main.py" in result + + def test_per_test_skips_files_with_no_executed_lines(self, monkeypatch, mocker): + """Should skip files with no executed lines in per-test mode.""" + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + mock_data.measured_files.return_value = ["/app/notrun.py", "/app/run.py"] + mock_data.has_arcs.return_value = False + + def lines_side_effect(filename): + if "notrun" in filename: + return [] # No lines executed + return [1, 2] + + mock_data.lines.side_effect = lines_side_effect + mock_cov.get_data.return_value = mock_data + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + monkeypatch.setattr(coverage_server, "_source_root", os.path.realpath("/app")) + + result = take_coverage_snapshot(baseline=False) + + assert "/app/notrun.py" not in result + assert "/app/run.py" in result + + +class TestGetBranchData: + """Tests for _get_branch_data function.""" + + def test_returns_empty_when_no_arcs(self, monkeypatch, mocker): + """Should return zero branches when data has no arcs (branch coverage disabled).""" + from drift.core.coverage_server import _get_branch_data + + mock_data = mocker.MagicMock() + mock_data.has_arcs.return_value = False + + result = _get_branch_data(mock_data, "/app/main.py") + + assert result == {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + + def test_returns_empty_when_cov_instance_is_none(self, monkeypatch, mocker): + """Should return empty branch data when _cov_instance is None.""" + from drift.core.coverage_server import _get_branch_data + + mock_data = mocker.MagicMock() + mock_data.has_arcs.return_value = True + + monkeypatch.setattr(coverage_server, "_cov_instance", None) + + result = _get_branch_data(mock_data, "/app/main.py") + + assert result == {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + + def test_computes_branch_coverage_from_arcs(self, monkeypatch, mocker): + """Should compute per-line branch coverage from arc data.""" + from drift.core.coverage_server import _get_branch_data + + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + mock_data.has_arcs.return_value = True + # Line 5 has two paths: 5->6 (executed) and 5->8 (missing) + mock_data.arcs.return_value = [(5, 6)] + + mock_analysis = mocker.MagicMock() + mock_numbers = mocker.MagicMock() + mock_numbers.n_branches = 2 + mock_numbers.n_missing_branches = 1 + mock_analysis.numbers = mock_numbers + mock_analysis.missing_branch_arcs.return_value = {5: [(5, 8)]} + mock_cov._analyze.return_value = mock_analysis + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + + result = _get_branch_data(mock_data, "/app/main.py") + + assert result["totalBranches"] == 2 + assert result["coveredBranches"] == 1 + assert "5" in result["branches"] + assert result["branches"]["5"]["total"] == 2 + assert result["branches"]["5"]["covered"] == 1 + + def test_skips_negative_entry_arcs(self, monkeypatch, mocker): + """Should skip negative entry arcs (function entry points) when grouping.""" + from drift.core.coverage_server import _get_branch_data + + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + mock_data.has_arcs.return_value = True + # Negative arcs are function entry points, should be ignored + mock_data.arcs.return_value = [(-1, 1), (1, 2), (1, 3)] + + mock_analysis = mocker.MagicMock() + mock_numbers = mocker.MagicMock() + mock_numbers.n_branches = 2 + mock_numbers.n_missing_branches = 0 + mock_analysis.numbers = mock_numbers + mock_analysis.missing_branch_arcs.return_value = {} + mock_cov._analyze.return_value = mock_analysis + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + + result = _get_branch_data(mock_data, "/app/main.py") + + # Line 1 should have 2 branches (to 2 and 3), -1 should be skipped + assert "1" in result["branches"] + assert result["branches"]["1"]["covered"] == 2 + assert "-1" not in result["branches"] + + def test_handles_exceptions_gracefully(self, monkeypatch, mocker): + """Should return empty branch data on exceptions.""" + from drift.core.coverage_server import _get_branch_data + + mock_cov = mocker.MagicMock() + mock_data = mocker.MagicMock() + + mock_data.has_arcs.return_value = True + mock_cov._analyze.side_effect = Exception("Analysis error") + + monkeypatch.setattr(coverage_server, "_cov_instance", mock_cov) + + result = _get_branch_data(mock_data, "/app/main.py") + + assert result == {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + + +class TestGetPerTestBranchData: + """Tests for _get_per_test_branch_data function.""" + + def test_returns_empty_when_no_arcs(self, mocker): + """Should return zero branches when data has no arcs.""" + from drift.core.coverage_server import _get_per_test_branch_data + + mock_data = mocker.MagicMock() + mock_data.has_arcs.return_value = False + + cached = {"totalBranches": 2, "branches": {}} + result = _get_per_test_branch_data(mock_data, "/app/main.py", cached) + + assert result == {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + + def test_uses_cached_totals_with_current_covered_counts(self, mocker): + """Should use cached totals but compute covered from current test's arcs.""" + from drift.core.coverage_server import _get_per_test_branch_data + + mock_data = mocker.MagicMock() + mock_data.has_arcs.return_value = True + # Current test only executed one path from line 5 + mock_data.arcs.return_value = [(5, 6)] + + cached = { + "totalBranches": 2, + "branches": { + "5": {"total": 2, "covered": 2}, # Baseline saw both branches + }, + } + + result = _get_per_test_branch_data(mock_data, "/app/main.py", cached) + + # Should use cached total but compute covered from current arcs + assert result["totalBranches"] == 2 + assert result["coveredBranches"] == 1 # Only 1 arc in current test + assert result["branches"]["5"]["total"] == 2 + assert result["branches"]["5"]["covered"] == 1 + + def test_caps_covered_at_cached_total(self, mocker): + """Should cap covered count at cached total (prevents count overflow).""" + from drift.core.coverage_server import _get_per_test_branch_data + + mock_data = mocker.MagicMock() + mock_data.has_arcs.return_value = True + # More arcs than cached total (shouldn't happen but defensive) + mock_data.arcs.return_value = [(5, 6), (5, 7), (5, 8)] + + cached = { + "totalBranches": 2, + "branches": { + "5": {"total": 2, "covered": 1}, + }, + } + + result = _get_per_test_branch_data(mock_data, "/app/main.py", cached) + + # Should cap at cached total of 2, not report 3 + assert result["branches"]["5"]["covered"] == 2 + + def test_handles_exceptions_gracefully(self, mocker): + """Should return empty branch data on exceptions.""" + from drift.core.coverage_server import _get_per_test_branch_data + + mock_data = mocker.MagicMock() + mock_data.has_arcs.side_effect = Exception("Arc error") + + cached = {"totalBranches": 2, "branches": {}} + result = _get_per_test_branch_data(mock_data, "/app/main.py", cached) + + assert result == {"totalBranches": 0, "coveredBranches": 0, "branches": {}} + + +class TestIsUserFileEdgeCases: + """Tests for _is_user_file edge cases.""" + + def test_handles_source_root_exactly(self, monkeypatch): + """Should return True for the source root itself (not just children).""" + source_root = os.path.realpath("/app") + monkeypatch.setattr(coverage_server, "_source_root", source_root) + + result = _is_user_file(source_root) + + assert result is True + + def test_avoids_prefix_collision_with_trailing_separator(self, monkeypatch): + """Should use trailing separator to prevent /app matching /application.""" + source_root = os.path.realpath("/app") + monkeypatch.setattr(coverage_server, "_source_root", source_root) + + # /application should NOT match /app due to trailing separator check + result = _is_user_file("/application/file.py") + + assert result is False + + def test_returns_true_when_source_root_is_none(self): + """Should return True when _source_root is None (before initialization).""" + # When _source_root is None, the function returns True + # (see line 177: "not _source_root or ...") + result = _is_user_file("/any/path/file.py") + + assert result is True From 1eb2e39af5b66ef7dfcda1eefc2151df120392dd Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Tue, 7 Apr 2026 13:56:35 -0700 Subject: [PATCH 23/24] fix: move start_coverage_collection after _initialized guard, reset _source_root on stop - start_coverage_collection() was before the _initialized check, causing repeated initialize() calls to stop/restart coverage and lose data - stop_coverage_collection() now resets _source_root for cleanup completeness --- drift/core/coverage_server.py | 3 ++- drift/core/drift_sdk.py | 13 ++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index f2306f4..4decc66 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -91,7 +91,7 @@ def start_coverage_collection() -> bool: def stop_coverage_collection() -> None: """Stop coverage collection and clean up. Thread-safe.""" - global _cov_instance, _branch_cache + global _cov_instance, _branch_cache, _source_root with _lock: if _cov_instance is not None: try: @@ -100,6 +100,7 @@ def stop_coverage_collection() -> None: pass _cov_instance = None _branch_cache = None + _source_root = None def take_coverage_snapshot(baseline: bool = False) -> dict: diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 2dcbf93..395f8b9 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -179,17 +179,16 @@ def initialize( ) env = env_from_var - # Start coverage collection early (before any SDK mode checks that might return early), - # but after the _initialized guard so we don't re-invoke on repeated initialize() calls. - # Coverage data is accessed via protobuf channel (communicator handles requests). - from .coverage_server import start_coverage_collection - - start_coverage_collection() - if cls._initialized: logger.debug("Already initialized, skipping...") return instance + # Start coverage collection after the _initialized guard so repeated + # initialize() calls don't stop/restart coverage and lose accumulated data. + from .coverage_server import start_coverage_collection + + start_coverage_collection() + file_config = instance.file_config if ( From c0fac85f6d6382563209ca8b3fcc0325139cbcea Mon Sep 17 00:00:00 2001 From: Sohil Kshirsagar Date: Tue, 7 Apr 2026 14:16:59 -0700 Subject: [PATCH 24/24] ref: extract _group_arcs_by_line helper to deduplicate arc grouping --- drift/core/coverage_server.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/drift/core/coverage_server.py b/drift/core/coverage_server.py index 4decc66..16fd344 100644 --- a/drift/core/coverage_server.py +++ b/drift/core/coverage_server.py @@ -178,6 +178,16 @@ def _is_user_file(filename: str) -> bool: return not _source_root or resolved.startswith(_source_root + os.sep) or resolved == _source_root +def _group_arcs_by_line(arcs: set) -> dict[int, list[int]]: + """Group executed arcs by from_line, skipping negative entry arcs.""" + by_line: dict[int, list[int]] = {} + for from_line, to_line in arcs: + if from_line < 0: + continue + by_line.setdefault(from_line, []).append(to_line) + return by_line + + def _get_branch_data(data, filename: str) -> dict: """Extract branch coverage data for a file. @@ -199,13 +209,7 @@ def _get_branch_data(data, filename: str) -> dict: missing_arcs = analysis.missing_branch_arcs() executed_arcs = set(data.arcs(filename) or []) - - # Group executed arcs by from_line (skip negative entry arcs) - executed_by_line: dict[int, list[int]] = {} - for from_line, to_line in executed_arcs: - if from_line < 0: - continue - executed_by_line.setdefault(from_line, []).append(to_line) + executed_by_line = _group_arcs_by_line(executed_arcs) # A line is a branch point if: # - it appears in missing_arcs (at least one path wasn't taken), OR @@ -248,13 +252,7 @@ def _get_per_test_branch_data(data, filename: str, cached: dict) -> dict: return {"totalBranches": 0, "coveredBranches": 0, "branches": {}} executed_arcs = set(data.arcs(filename) or []) - - # Group executed arcs by from_line (skip negative entry arcs) - executed_by_line: dict[int, list[int]] = {} - for from_line, to_line in executed_arcs: - if from_line < 0: - continue - executed_by_line.setdefault(from_line, []).append(to_line) + executed_by_line = _group_arcs_by_line(executed_arcs) # Use cached branch points — only compute covered from current arcs cached_branches = cached.get("branches", {})