diff --git a/.gitignore b/.gitignore index 07c8e14..89523e2 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ vmlinux.h *.o *.skel.h sigsegv_monitor +sample_segfault +__pycache__ diff --git a/Makefile b/Makefile index ff8c666..461452f 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,19 @@ CLANG ?= clang BPFTOOL ?= bpftool +# Git version info (build fails if not available) +GIT_REV := $(shell git rev-parse --short HEAD) +GIT_DATE := $(shell git log -1 --format=%cI) +ifeq ($(GIT_REV),) + $(error GIT_REV is not set - not in a git repository?) +endif +ifeq ($(GIT_DATE),) + $(error GIT_DATE is not set - not in a git repository?) +endif + # Output executable name APP = sigsegv_monitor +SAMPLE = sample_segfault # Source files BPF_SRC = sigsegv-monitor.bpf.c @@ -17,15 +28,20 @@ VMLINUX = vmlinux.h # Compiler flags # -g: Debug info (required for BTF) # -O2: Optimization (required for BPF) -CFLAGS := -g -O2 -Wall +CFLAGS := -g -O2 -Wall -DGIT_REV=\"$(GIT_REV)\" -DGIT_DATE=\"$(GIT_DATE)\" BPF_CFLAGS := -g -O2 -target bpf -D__TARGET_ARCH_x86 # Libs to link LIBS := -lbpf -lelf -lz -.PHONY: all clean +.PHONY: all clean sample test + +all: $(APP) $(SAMPLE) + +sample: $(SAMPLE) -all: $(APP) +test: $(SAMPLE) + sudo python3 verify_monitor.py .DELETE_ON_ERROR: @@ -45,6 +61,10 @@ $(APP): $(USER_SRC) $(SKEL_OBJ) @echo " CC $@" $(CLANG) $(CFLAGS) $(USER_SRC) $(LIBS) -o $@ +$(SAMPLE): sample_segfault.c + @echo " CC $@" + $(CLANG) -g -O0 -Wall $< -o $@ + clean: @echo " CLEAN" - rm -f $(APP) $(BPF_OBJ) $(SKEL_OBJ) $(VMLINUX) + rm -f $(APP) $(BPF_OBJ) $(SKEL_OBJ) $(VMLINUX) $(SAMPLE) diff --git a/sample_segfault.c b/sample_segfault.c new file mode 100644 index 0000000..0ae8128 --- /dev/null +++ b/sample_segfault.c @@ -0,0 +1,79 @@ +/* + * Sample program for sigsegv-monitor testing. + * + * This program: + * 1. Allocates several memory pages using mmap() + * 2. Touches each page to trigger page faults + * 3. Dereferences a null pointer to trigger SIGSEGV + * + * The sigsegv-monitor should capture: + * - Multiple page fault events (recorded in page_faults array) + * - One SIGSEGV event with cr2 = 0 (null pointer) + */ + +#include +#include +#include +#include +#include +#include + +#define NUM_PAGES 4 + +/* Prevent compiler from optimizing away memory accesses */ +volatile int sink; + +int main(int argc, char *argv[]) { + void *pages[NUM_PAGES]; + long page_size = sysconf(_SC_PAGESIZE); + + fprintf(stderr, "[sample_segfault] PID: %d\n", getpid()); + fprintf(stderr, "[sample_segfault] System page size: %ld\n", page_size); + fprintf(stderr, "[sample_segfault] Allocating %d pages...\n", NUM_PAGES); + + /* Allocate pages - these won't cause page faults yet (no physical memory assigned) */ + for (int i = 0; i < NUM_PAGES; i++) { + pages[i] = mmap(NULL, page_size, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (pages[i] == MAP_FAILED) { + perror("mmap"); + return 1; + } + fprintf(stderr, "[sample_segfault] Page %d allocated at %p\n", i, pages[i]); + } + + fprintf(stderr, "[sample_segfault] Touching pages to trigger page faults...\n"); + + /* Touch each page to trigger page faults (first access causes PF) */ + for (int i = 0; i < NUM_PAGES; i++) { + volatile char *ptr = (volatile char *)pages[i]; + /* Write to trigger page fault */ + *ptr = (char)(i + 1); + /* Read to ensure the write actually happened */ + sink = *ptr; + fprintf(stderr, "[sample_segfault] Page %d touched at %p (wrote %d)\n", + i, pages[i], i + 1); + } + + /* Output JSON on stdout with the expected page addresses for verification */ + printf("{\"page_size\":%ld,\"pages\":[", page_size); + for (int i = 0; i < NUM_PAGES; i++) { + if (i > 0) printf(","); + printf("\"0x%lx\"", (unsigned long)pages[i]); + } + printf("],\"segfault_addr\":\"0x0\"}\n"); + fflush(stdout); + + /* Small delay to ensure page faults are recorded */ + usleep(10000); + + fprintf(stderr, "[sample_segfault] Triggering SIGSEGV via null pointer dereference...\n"); + + /* Trigger SIGSEGV by dereferencing null pointer */ + volatile int *null_ptr = NULL; + sink = *null_ptr; /* This will cause SIGSEGV with cr2 = 0 */ + + /* Should never reach here */ + fprintf(stderr, "[sample_segfault] ERROR: Should have crashed!\n"); + return 1; +} diff --git a/sigsegv-monitor.bpf.c b/sigsegv-monitor.bpf.c index c0b6e36..39910a6 100644 --- a/sigsegv-monitor.bpf.c +++ b/sigsegv-monitor.bpf.c @@ -57,19 +57,14 @@ inline void cr2stats_push(struct cr2_stats* stats, struct cr2_stat* value) { // The `index` parameter here is not an index in the array, but an index in the ring buffer, // i.e. passing an index 0 would return the oldest element in the ring buffer. -inline struct cr2_stat* cr2stats_get(struct cr2_stats* stats, u32 index) { +inline struct cr2_stat* cr2stats_get(struct cr2_stats* stats, u64 index) { if (stats->count == MAX_USER_PF_ENTRIES) { - index += stats->head; - if (index >= MAX_USER_PF_ENTRIES) { - index -= MAX_USER_PF_ENTRIES; - } - } - - if (index < MAX_USER_PF_ENTRIES) { - return stats->stat + index; + index += stats->head; // this makes index unbounded to the verifier } - return NULL; + // establish bound for index; also helps if above index += ... needs to wrap around + index %= MAX_USER_PF_ENTRIES; + return stats->stat + index; } #endif @@ -166,12 +161,24 @@ int trace_sigsegv(struct trace_event_raw_signal_generate *ctx) { } event->pf_count = 0; - #ifdef TRACE_PF_CR2 - u32 pid = task->pid; +#ifdef TRACE_PF_CR2 + u32 const pid = task->pid; struct cr2_stats *cr2stats = bpf_map_lookup_elem(&pid_cr2, &pid); if (cr2stats) { - for (u32 i = 0; i < cr2stats->count && i < MAX_USER_PF_ENTRIES; i++) { + /* If we use a u32 for i, the verifier loses track of its value and rejects the program: + * 151: (bf) r4 = r5 ; R4_w=scalar(id=4) R5_w=scalar(id=4) + * ... + * 156: (67) r4 <<= 32 ; R4_w=scalar(smax=9223372032559808512,umax=18446744069414584320,var_off=(0x0; 0xffffffff00000000),s32_min=0,s32_max=0,u32_max=0) + * 157: (77) r4 >>= 32 ; R4_w=scalar(umax=4294967295,var_off=(0x0; 0xffffffff)) + * 158: (27) r4 *= 24 ; R4_w=scalar(umax=103079215080,var_off=(0x0; 0x1ffffffff8),s32_max=2147483640,u32_max=-8) + * 159: (bf) r5 = r0 ; R0=map_value(off=0,ks=4,vs=400,imm=0) R5_w=map_value(off=0,ks=4,vs=400,imm=0) + * 160: (0f) r5 += r4 ; R4_w=scalar(umax=103079215080,var_off=(0x0; 0x1ffffffff8),s32_max=2147483640,u32_max=-8) R5_w=map_value(off=0,ks=4,vs=400,umax=103079215080,var_off=(0x0; 0x1ffffffff8),s32_max=2147483640,u32_max=-8) + * ; event->pf[i].cr2 = stat->cr2; + * 161: (79) r4 = *(u64 *)(r5 +0) + * R5 unbounded memory access, make sure to bounds check any such access + */ + for (u64 i = 0; i < cr2stats->count && i < MAX_USER_PF_ENTRIES; i++) { struct cr2_stat* stat = cr2stats_get(cr2stats, i); if (stat) { event->pf[i].cr2 = stat->cr2; @@ -184,7 +191,7 @@ int trace_sigsegv(struct trace_event_raw_signal_generate *ctx) { bpf_map_delete_elem(&pid_cr2, &pid); } - #endif +#endif // TODO: when is this snapshot taken? or does the CPU not do LBR in the kernel? long ret = bpf_get_branch_snapshot(&event->lbr, sizeof(event->lbr), 0); @@ -203,13 +210,12 @@ int trace_sigsegv(struct trace_event_raw_signal_generate *ctx) { #ifdef TRACE_PF_CR2 SEC("tracepoint/exceptions/page_fault_user") int trace_page_fault(struct trace_event_raw_page_fault_user *ctx) { - struct cr2_stat stat; - u32 pid; - - stat.cr2 = ctx->address; - stat.err = ctx->error_code; - stat.tai = bpf_ktime_get_tai_ns(); - pid = (u32)bpf_get_current_pid_tgid(); + struct cr2_stat stat = { + .cr2 = ctx->address, + .err = ctx->error_code, + .tai = bpf_ktime_get_tai_ns() + }; + u32 const pid = (u32)bpf_get_current_pid_tgid(); struct cr2_stats *cr2stats = bpf_map_lookup_elem(&pid_cr2, &pid); if (cr2stats) { diff --git a/sigsegv-monitor.c b/sigsegv-monitor.c index 1cfdc9c..6786a76 100644 --- a/sigsegv-monitor.c +++ b/sigsegv-monitor.c @@ -1,6 +1,7 @@ #define _GNU_SOURCE #include #include +#include #include #include #include @@ -140,7 +141,18 @@ void clean() { free(cpus_fd); } -int main() { +void print_version(char const* prefix, FILE* out) { + fprintf(out, "%scommit %s committed %s\n", prefix, GIT_REV, GIT_DATE); +} + +int main(int argc, char *argv[]) { + if (argc > 1 && (strcmp(argv[1], "-v") == 0 || strcmp(argv[1], "--version") == 0)) { + print_version("", stdout); + return 0; + } else { + print_version("[*] version ", stderr); + } + struct sigsegv_monitor_bpf *skel; struct perf_buffer *pb = NULL; diff --git a/verify_monitor.py b/verify_monitor.py new file mode 100755 index 0000000..a1bb627 --- /dev/null +++ b/verify_monitor.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +""" +Verification test for sigsegv-monitor. + +Launches sigsegv_monitor and sample_segfault, then uses unittest assertions +to verify that the monitor correctly recorded: + - A SIGSEGV event for the sample program + - Page faults whose cr2 addresses fall within the pages allocated by the sample + - The SIGSEGV fault address (cr2 == 0 for null pointer dereference) + +Usage: + sudo python3 verify_monitor.py # normal run + sudo python3 verify_monitor.py -v # verbose +""" + +import json +import os +import signal +import subprocess +import time +import unittest +from typing import Any, Dict, List, NamedTuple, Optional, Set + +# --------------------------------------------------------------------------- +# Configuration – tailored to sample_segfault +# --------------------------------------------------------------------------- +MONITOR_BIN: str = "./sigsegv_monitor" +SAMPLE_BIN: str = "./sample_segfault" + +# The kernel's comm field is 16 bytes including NUL, so at most 15 chars. +_COMM_MAX: int = 15 +PROCESS_NAME: str = os.path.basename(SAMPLE_BIN)[:_COMM_MAX] + +# sample_segfault allocates and touches 4 pages and then loads from the null page. +MIN_PAGE_FAULTS: int = 4+1 + +# Page size is determined at runtime. +PAGE_SIZE: int = os.sysconf("SC_PAGESIZE") + + +# --------------------------------------------------------------------------- +# Data classes – sample_segfault output +# --------------------------------------------------------------------------- + +class SampleOutput(NamedTuple): + """JSON output emitted by sample_segfault on stdout.""" + page_size: int + pages: List[int] + segfault_addr: int + + @classmethod + def from_json(cls, raw: Dict[str, Any]) -> "SampleOutput": + return SampleOutput( + page_size=raw["page_size"], + pages=[int(p, 16) for p in raw["pages"]], + segfault_addr=int(raw["segfault_addr"], 16), + ) + + +# --------------------------------------------------------------------------- +# Data classes – sigsegv_monitor output +# --------------------------------------------------------------------------- + +class ProcessInfo(NamedTuple): + rootns_pid: int + ns_pid: int + comm: str + + @classmethod + def from_json(cls, raw: Dict[str, Any]) -> "ProcessInfo": + return ProcessInfo( + rootns_pid=raw["rootns_pid"], + ns_pid=raw["ns_pid"], + comm=raw["comm"], + ) + + +class ThreadInfo(NamedTuple): + rootns_tid: int + ns_tid: int + comm: str + + @classmethod + def from_json(cls, raw: Dict[str, Any]) -> "ThreadInfo": + return ThreadInfo( + rootns_tid=raw["rootns_tid"], + ns_tid=raw["ns_tid"], + comm=raw["comm"], + ) + + +class Registers(NamedTuple): + rip: int + rsp: int + rax: int + rbx: int + rcx: int + rdx: int + rsi: int + rdi: int + rbp: int + r8: int + r9: int + r10: int + r11: int + r12: int + r13: int + r14: int + r15: int + flags: int + trapno: int + err: int + cr2: int + + @classmethod + def from_json(cls, raw: Dict[str, Any]) -> "Registers": + return Registers(**{name: int(raw[name], 16) for name in cls._fields}) + + +class PageFault(NamedTuple): + cr2: int + err: int + tai: int + + @classmethod + def from_json(cls, raw: Dict[str, Any]) -> "PageFault": + return PageFault( + cr2=int(raw["cr2"], 16), + err=int(raw["err"], 16), + tai=raw["tai"], + ) + + +class LbrEntry(NamedTuple): + from_addr: int + to_addr: int + + @classmethod + def from_json(cls, raw: Dict[str, Any]) -> "LbrEntry": + return LbrEntry( + from_addr=int(raw["from"], 16), + to_addr=int(raw["to"], 16), + ) + + +class MonitorEvent(NamedTuple): + """A single JSON record emitted by sigsegv_monitor.""" + cpu: int + tai: int + process: ProcessInfo + thread: ThreadInfo + si_code: int + registers: Registers + page_faults: List[PageFault] + lbr: List[Optional[LbrEntry]] + + @classmethod + def from_json(cls, raw: Dict[str, Any]) -> "MonitorEvent": + return MonitorEvent( + cpu=raw["cpu"], + tai=raw["tai"], + process=ProcessInfo.from_json(raw["process"]), + thread=ThreadInfo.from_json(raw["thread"]), + si_code=raw["si_code"], + registers=Registers.from_json(raw["registers"]), + page_faults=[PageFault.from_json(pf) for pf in raw["page_faults"]], + lbr=[ + LbrEntry.from_json(e) if e is not None else None + for e in raw["lbr"] + ], + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def parse_monitor_events(text: str) -> List[MonitorEvent]: + """Parse newline-delimited JSON into MonitorEvent objects.""" + return [ + MonitorEvent.from_json(json.loads(line)) + for line in text.splitlines() + if line.strip() + ] + + +def parse_sample_output(text: str) -> SampleOutput: + """Parse the single JSON line from sample_segfault.""" + for line in text.splitlines(): + line = line.strip() + if not line: + continue + return SampleOutput.from_json(json.loads(line)) + raise ValueError("No JSON found in sample output") + + +def addr_in_page(addr: int, page_base: int) -> bool: + """Return True if *addr* falls within [page_base, page_base + PAGE_SIZE).""" + return page_base <= addr < page_base + PAGE_SIZE + + +# --------------------------------------------------------------------------- +# Test fixture +# --------------------------------------------------------------------------- + +class TestSigsegvMonitor(unittest.TestCase): + """End-to-end test: launch monitor + sample, verify output.""" + + sample: Optional[SampleOutput] = None + event: Optional[MonitorEvent] = None + + @classmethod + def setUpClass(cls) -> None: + """Start monitor, run sample, stop monitor, collect outputs.""" + # 1. Start monitor + monitor_proc = subprocess.Popen( + [MONITOR_BIN], + stdout=subprocess.PIPE, + ) + # Give the BPF program a moment to attach. + time.sleep(1) + + # Verify the monitor is still running. If it exited already the + # BPF program was never loaded (common causes: missing root + # privileges, BPF verification failure, missing binary). + if monitor_proc.poll() is not None: + stdout_bytes, _ = monitor_proc.communicate() + rc = monitor_proc.returncode + raise RuntimeError( + f"{MONITOR_BIN} exited prematurely (rc={rc}) before the " + f"sample was started. The BPF program was never loaded.\n" + f"Common causes: not running as root, BPF verifier " + f"rejection, or the binary is missing.") + + # 2. Run sample + sample_proc = subprocess.Popen( + [SAMPLE_BIN], + stdout=subprocess.PIPE, + ) + stdout, _ = sample_proc.communicate(timeout=10) + cls.sample = parse_sample_output(stdout.decode()) + + # Small grace period so the monitor can pick up the event from the + # kernel queue. Note that it handles the SIGINT below, so + # processing / writing out the JSON should not be an issue. + time.sleep(0.5) + + # 3. Stop monitor & collect output + monitor_proc.send_signal(signal.SIGINT) + try: + stdout, _ = monitor_proc.communicate(timeout=5) + except subprocess.TimeoutExpired: + monitor_proc.kill() + stdout, _ = monitor_proc.communicate() + + events = parse_monitor_events(stdout.decode()) + + # 4. Filter events by the known child PID; expect exactly one. + pid = sample_proc.pid + matching = [e for e in events if e.process.rootns_pid == pid] + if len(matching) != 1: + all_pids = [e.process.rootns_pid for e in events] + raise AssertionError( + f"Expected exactly 1 event for PID {pid}, " + f"got {len(matching)} (all pids: {all_pids})") + cls.event = matching[0] + + # -- tests -- + + def test_sigsegv_fault_address(self) -> None: + """The SIGSEGV cr2 must equal the expected fault address (0 for NULL).""" + cr2 = self.event.registers.cr2 + expected = self.sample.segfault_addr + self.assertEqual(cr2, expected) + + def test_sigsegv_process_info(self) -> None: + """Process and thread info must be populated.""" + self.assertGreater(self.event.process.rootns_pid, 0) + self.assertGreater(self.event.thread.rootns_tid, 0) + self.assertEqual(PROCESS_NAME, self.event.process.comm) + + def test_minimum_page_faults(self) -> None: + """The monitor must record at least MIN_PAGE_FAULTS page faults.""" + self.assertGreaterEqual( + len(self.event.page_faults), MIN_PAGE_FAULTS, + f"Expected >= {MIN_PAGE_FAULTS} page faults, " + f"got {len(self.event.page_faults)}", + ) + + def test_page_fault_addresses_match_expected_pages(self) -> None: + """Every page allocated by the sample must appear in the PF list.""" + expected_pages = self.sample.pages + + matched: Set[int] = set() + for pf in self.event.page_faults: + for pg_idx, pg_base in enumerate(expected_pages): + if addr_in_page(pf.cr2, pg_base): + matched.add(pg_idx) + break + + for pg_idx, pg_base in enumerate(expected_pages): + self.assertIn( + pg_idx, matched, + f"Expected page {pg_idx} at 0x{pg_base:x} " + f"(range [0x{pg_base:x}, 0x{pg_base + PAGE_SIZE:x})) " + f"was not matched by any recorded page fault", + ) + + +if __name__ == "__main__": + unittest.main()