diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..bba3094 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,139 @@ +"""Tests for the CLI entry-point shims (pilotprotocol/cli.py). + +The shims seed ``~/.pilot/bin/`` then exec the seeded binary. We replace +``ensure_runtime_seeded`` + ``runtime_binary`` + ``subprocess.call`` so no +real binaries are required. +""" + +from __future__ import annotations + +from pathlib import Path +from unittest import mock + +import pytest + +import pilotprotocol.cli as cli_mod + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fake_runtime(monkeypatch, tmp_path): + """Stub ensure_runtime_seeded + runtime_binary so cli shims run dry.""" + rt_bin = tmp_path / "rt-bin" + rt_bin.mkdir() + + seeded = {"called": 0} + + def fake_seed(): + seeded["called"] += 1 + return rt_bin + + def fake_binary(name: str) -> Path: + p = rt_bin / name + p.write_text("#!/bin/sh\nexit 0\n") + p.chmod(0o755) + return p + + monkeypatch.setattr(cli_mod, "ensure_runtime_seeded", fake_seed) + monkeypatch.setattr(cli_mod, "runtime_binary", fake_binary) + return {"rt": rt_bin, "seeded": seeded} + + +@pytest.fixture +def fake_call(monkeypatch): + """Capture subprocess.call invocations and return a controlled exit code.""" + calls = [] + + def _call(cmd): + calls.append(cmd) + return 0 + + monkeypatch.setattr(cli_mod.subprocess, "call", _call) + return calls + + +# --------------------------------------------------------------------------- +# Each entry point should seed, then exec the right binary +# --------------------------------------------------------------------------- + + +class TestRunPilotctl: + def test_seeds_and_exits_zero(self, fake_runtime, fake_call, monkeypatch): + monkeypatch.setattr(cli_mod.sys, "argv", ["pilotctl", "info"]) + with pytest.raises(SystemExit) as exc: + cli_mod.run_pilotctl() + assert exc.value.code == 0 + assert fake_runtime["seeded"]["called"] == 1 + assert len(fake_call) == 1 + # Command should be [/pilotctl, "info"] + cmd = fake_call[0] + assert cmd[0].endswith("pilotctl") + assert cmd[1:] == ["info"] + + def test_propagates_nonzero_exit(self, fake_runtime, monkeypatch): + monkeypatch.setattr(cli_mod.subprocess, "call", lambda cmd: 7) + monkeypatch.setattr(cli_mod.sys, "argv", ["pilotctl"]) + with pytest.raises(SystemExit) as exc: + cli_mod.run_pilotctl() + assert exc.value.code == 7 + + def test_argv_passthrough_with_flags(self, fake_runtime, fake_call, monkeypatch): + monkeypatch.setattr( + cli_mod.sys, "argv", + ["pilotctl", "send-message", "agent", "--data", "hi", "--wait"], + ) + with pytest.raises(SystemExit): + cli_mod.run_pilotctl() + assert fake_call[0][1:] == [ + "send-message", "agent", "--data", "hi", "--wait", + ] + + +class TestRunDaemon: + def test_invokes_pilot_daemon(self, fake_runtime, fake_call, monkeypatch): + monkeypatch.setattr(cli_mod.sys, "argv", ["pilot-daemon", "--email", "a@b"]) + with pytest.raises(SystemExit): + cli_mod.run_daemon() + assert fake_call[0][0].endswith("pilot-daemon") + assert fake_call[0][1:] == ["--email", "a@b"] + + +class TestRunGateway: + def test_invokes_pilot_gateway(self, fake_runtime, fake_call, monkeypatch): + monkeypatch.setattr(cli_mod.sys, "argv", ["pilot-gateway"]) + with pytest.raises(SystemExit): + cli_mod.run_gateway() + assert fake_call[0][0].endswith("pilot-gateway") + + +class TestRunUpdater: + def test_invokes_pilot_updater(self, fake_runtime, fake_call, monkeypatch): + monkeypatch.setattr(cli_mod.sys, "argv", ["pilot-updater", "--check"]) + with pytest.raises(SystemExit): + cli_mod.run_updater() + assert fake_call[0][0].endswith("pilot-updater") + assert fake_call[0][1:] == ["--check"] + + +# --------------------------------------------------------------------------- +# Imports / module surface +# --------------------------------------------------------------------------- + + +class TestModuleSurface: + def test_all_entry_points_exist_and_are_callable(self): + for name in ("run_pilotctl", "run_daemon", "run_gateway", "run_updater"): + fn = getattr(cli_mod, name) + assert callable(fn), f"{name} must be callable" + + def test_console_scripts_match_pyproject(self): + # Sanity: the wrappers point at the names declared in pyproject.toml. + # If the binary name list drifts, this test fails loudly. + from pilotprotocol._runtime import _BIN_NAMES + + expected = {"pilotctl", "pilot-daemon", "pilot-gateway", "pilot-updater"} + assert set(_BIN_NAMES) == expected diff --git a/tests/test_client_edges.py b/tests/test_client_edges.py new file mode 100644 index 0000000..b081729 --- /dev/null +++ b/tests/test_client_edges.py @@ -0,0 +1,166 @@ +"""Small additional edge-case tests for pilotprotocol.client helpers. + +Targets the leftover branches not covered by test_client.py: +- ``_find_library`` ~/.pilot/bin lookup path +- ``_void_ptr_to_bytes`` null + non-null branches +- ``_free`` null + non-null branches +- ``Conn.read`` size <= 0 and size > 16 MB cap +- ``__init__`` ``__version__`` import-failure fallback +""" + +from __future__ import annotations + +import ctypes +import platform +import types +from pathlib import Path +from unittest import mock + +import pytest + +import pilotprotocol.client as client_mod + + +# --------------------------------------------------------------------------- +# _find_library: ~/.pilot/bin/ branch +# --------------------------------------------------------------------------- + + +class TestFindLibraryPilotBin: + def test_returns_home_pilot_bin_path(self, tmp_path, monkeypatch): + # Build a fake home with ~/.pilot/bin/ + fake_home = tmp_path / "home" + pilot_bin = fake_home / ".pilot" / "bin" + pilot_bin.mkdir(parents=True) + lib_name = client_mod._LIB_NAMES[platform.system()] + lib_file = pilot_bin / lib_name + lib_file.write_bytes(b"\x7fELF\x00\x00\x00\x00") # whatever + monkeypatch.delenv("PILOT_LIB_PATH", raising=False) + monkeypatch.setattr(Path, "home", classmethod(lambda cls: fake_home)) + result = client_mod._find_library() + assert result == str(lib_file) + + +# --------------------------------------------------------------------------- +# _void_ptr_to_bytes +# --------------------------------------------------------------------------- + + +class TestVoidPtrToBytes: + def test_null_returns_none(self): + assert client_mod._void_ptr_to_bytes(None) is None + assert client_mod._void_ptr_to_bytes(0) is None + + def test_nonnull_reads_c_string(self): + buf = ctypes.create_string_buffer(b"hello\x00") + ptr = ctypes.cast(buf, ctypes.c_void_p).value + result = client_mod._void_ptr_to_bytes(ptr) + assert result == b"hello" + + +# --------------------------------------------------------------------------- +# _free +# --------------------------------------------------------------------------- + + +class TestFree: + def test_null_is_noop(self, monkeypatch): + # Should not even call FreeString. + sentinel = {"called": False} + + class FakeLib: + def FreeString(self, ptr): + sentinel["called"] = True + + monkeypatch.setattr(client_mod, "_get_lib", lambda: FakeLib()) + client_mod._free(None) + client_mod._free(0) + assert sentinel["called"] is False + + def test_nonnull_calls_free_string(self, monkeypatch): + sentinel = {"freed": []} + + class FakeLib: + def FreeString(self, ptr): + sentinel["freed"].append(ptr) + + monkeypatch.setattr(client_mod, "_get_lib", lambda: FakeLib()) + client_mod._free(123) + assert sentinel["freed"] == [123] + + +# --------------------------------------------------------------------------- +# Conn.read size bounds +# --------------------------------------------------------------------------- + + +class _FakeReadLib: + """Captures the size argument passed to PilotConnRead.""" + + def __init__(self): + self.last_size = None + + def FreeString(self, ptr): + pass + + def PilotConnRead(self, h, size): + self.last_size = size + return types.SimpleNamespace(n=0, data=None, err=None) + + +class TestConnReadSizeBounds: + def test_zero_or_negative_returns_empty_without_call(self, monkeypatch): + lib = _FakeReadLib() + monkeypatch.setattr(client_mod, "_get_lib", lambda: lib) + conn = client_mod.Conn(handle=10) + assert conn.read(0) == b"" + assert conn.read(-100) == b"" + # Library was never invoked. + assert lib.last_size is None + + def test_size_over_16mb_is_capped(self, monkeypatch): + lib = _FakeReadLib() + monkeypatch.setattr(client_mod, "_get_lib", lambda: lib) + conn = client_mod.Conn(handle=10) + conn.read(64 * 1024 * 1024) + assert lib.last_size == 16 * 1024 * 1024 + + +# --------------------------------------------------------------------------- +# __init__ version fallback +# --------------------------------------------------------------------------- + + +class TestInitVersionFallback: + def test_version_resolves_to_string(self): + # Just verify the module imports and exposes a string __version__. + # The "unknown" fallback path requires importlib.metadata.version to + # raise; we can simulate that by reloading the module with a stub. + import importlib + import importlib.metadata as md + import pilotprotocol as pp + + assert isinstance(pp.__version__, str) + assert pp.__version__ # non-empty + + def test_version_fallback_when_metadata_missing(self, monkeypatch): + # Reimport pilotprotocol with importlib.metadata.version raising. + import importlib + import importlib.metadata as md + import sys + + def boom(name): + raise md.PackageNotFoundError(name) + + monkeypatch.setattr(md, "version", boom) + # Force a reimport of the top-level package. + if "pilotprotocol" in sys.modules: + del sys.modules["pilotprotocol"] + try: + import pilotprotocol as pp_reimport + assert pp_reimport.__version__ == "unknown" + finally: + # Restore the original module so other tests see the real version. + if "pilotprotocol" in sys.modules: + del sys.modules["pilotprotocol"] + import pilotprotocol # noqa: F401 diff --git a/tests/test_runtime_edges.py b/tests/test_runtime_edges.py new file mode 100644 index 0000000..1b587c1 --- /dev/null +++ b/tests/test_runtime_edges.py @@ -0,0 +1,535 @@ +"""Edge-case tests for pilotprotocol._runtime. + +The main test_runtime.py file covers the happy paths and the documented +state machine; this file targets the small branches the main suite skips: + +- ``_bundled_version`` / ``_runtime_version`` failure paths +- ``_daemon_running`` config-load + socket-close errors +- ``_atomic_install`` rename-failure cleanup +- ``_ensure_dir_writable`` permission failure +- ``_ensure_default_config`` race +- ``run_seeder`` ETXTBSY skip + non-busy OSError reraise +- ``runtime_binary`` / ``runtime_library`` fallback to the wheel +- The unsupported-platform branch in ``_platform_lib_name`` +- ``ensure_runtime_seeded`` cached-true return path +""" + +from __future__ import annotations + +import errno +import json +import os +import platform as platform_mod +import socket +from pathlib import Path + +import pytest + +import pilotprotocol._runtime as rt + + +# --------------------------------------------------------------------------- +# Shared isolation +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _isolate(tmp_path, monkeypatch): + """Each test starts with a clean tmp PILOT_HOME and reset marker.""" + fake_home = tmp_path / "home" / ".pilot" + monkeypatch.setenv("PILOT_HOME", str(fake_home)) + monkeypatch.setattr(rt, "_daemon_running", lambda: False) + rt.reset_seeded_marker() + yield {"home": fake_home, "tmp": tmp_path, "monkeypatch": monkeypatch} + rt.reset_seeded_marker() + + +# --------------------------------------------------------------------------- +# _platform_lib_name +# --------------------------------------------------------------------------- + + +class TestPlatformLibName: + def test_unsupported_platform_raises(self, monkeypatch): + monkeypatch.setattr(platform_mod, "system", lambda: "Plan9") + with pytest.raises(OSError, match="unsupported platform"): + rt._platform_lib_name() + + +# --------------------------------------------------------------------------- +# _pkg_bin_dir — make sure the un-stubbed code path is exercised +# --------------------------------------------------------------------------- + + +class TestPkgBinDir: + def test_returns_real_bin_dir_next_to_module(self): + p = rt._pkg_bin_dir() + assert isinstance(p, Path) + assert p.name == "bin" + # Anchored at the runtime module directory. + assert p.parent == Path(rt.__file__).resolve().parent + + +# --------------------------------------------------------------------------- +# _runtime_root — both branches +# --------------------------------------------------------------------------- + + +class TestRuntimeRoot: + def test_without_pilot_home_uses_home_dot_pilot(self, tmp_path, monkeypatch): + # The autouse fixture sets PILOT_HOME — undo it. + monkeypatch.delenv("PILOT_HOME", raising=False) + monkeypatch.setattr(Path, "home", classmethod(lambda cls: tmp_path)) + result = rt._runtime_root() + assert result == tmp_path / ".pilot" + + def test_with_pilot_home_uses_override(self, tmp_path, monkeypatch): + monkeypatch.setenv("PILOT_HOME", str(tmp_path / "override")) + result = rt._runtime_root() + assert result == tmp_path / "override" + + +# --------------------------------------------------------------------------- +# _bundled_version / _runtime_version +# --------------------------------------------------------------------------- + + +class TestBundledVersion: + def test_read_failure_falls_through_to_metadata(self, tmp_path, monkeypatch): + # Create a marker file then make read_text raise. + pkg = tmp_path / "pkg" + pkg.mkdir() + (pkg / ".pilot-version").write_text("9.9.9\n") + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + + # Override read_text to raise OSError → fall through to importlib.metadata. + orig_read = Path.read_text + + def boom(self, *a, **kw): + if self.name == ".pilot-version": + raise OSError("io") + return orig_read(self, *a, **kw) + + monkeypatch.setattr(Path, "read_text", boom) + v = rt._bundled_version() + # importlib.metadata path returns the installed package version, + # which exists in this venv. Just assert non-empty fallback. + assert isinstance(v, str) + + def test_no_marker_no_metadata_returns_empty(self, tmp_path, monkeypatch): + pkg = tmp_path / "pkg-no-marker" + pkg.mkdir() + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + + # Make importlib.metadata.version raise. + import importlib.metadata as md + + def boom(_name): + raise md.PackageNotFoundError("missing") + + monkeypatch.setattr(md, "version", boom) + assert rt._bundled_version() == "" + + def test_runtime_version_read_failure(self, tmp_path, monkeypatch): + rtdir = tmp_path / "rt" + rtdir.mkdir() + (rtdir / ".pilot-version").write_text("1.0.0\n") + + orig = Path.read_text + + def boom(self, *a, **kw): + if self.name == ".pilot-version": + raise OSError("denied") + return orig(self, *a, **kw) + + monkeypatch.setattr(Path, "read_text", boom) + assert rt._runtime_version(rtdir) == "" + + +# --------------------------------------------------------------------------- +# _semver_tuple edge cases (already mostly covered, but a few values left) +# --------------------------------------------------------------------------- + + +class TestSemverTuple: + def test_empty_returns_empty_tuple(self): + assert rt._semver_tuple("") == () + assert rt._semver_tuple(None) == () + + def test_nonnumeric_returns_empty_tuple(self): + assert rt._semver_tuple("not-a-version") == () + + def test_strips_leading_v_and_suffixes(self): + assert rt._semver_tuple("v1.2.3-beta+build") == (1, 2, 3) + + +# --------------------------------------------------------------------------- +# _daemon_running edge paths +# --------------------------------------------------------------------------- + + +class TestDaemonRunning: + def test_no_config_no_socket(self, tmp_path, monkeypatch): + # The autouse fixture stubs _daemon_running. Reach the real function. + monkeypatch.undo() + monkeypatch.setenv("PILOT_HOME", str(tmp_path / ".pilot")) + assert rt._daemon_running() is False + + def test_config_unreadable_uses_default_socket(self, tmp_path, monkeypatch): + monkeypatch.undo() + home = tmp_path / ".pilot" + home.mkdir() + # Bad JSON triggers ValueError branch. + (home / "config.json").write_text("not json {{{") + monkeypatch.setenv("PILOT_HOME", str(home)) + # No socket exists → returns False without error. + assert rt._daemon_running() is False + + def test_socket_present_but_connect_fails(self, tmp_path, monkeypatch): + monkeypatch.undo() + home = tmp_path / ".pilot" + home.mkdir() + sock_path = tmp_path / "pilot.sock" + # Make a regular file that exists at sock_path so Path.exists() is True + # but connect() fails. + sock_path.write_text("not a socket") + (home / "config.json").write_text(json.dumps({"socket": str(sock_path)})) + monkeypatch.setenv("PILOT_HOME", str(home)) + assert rt._daemon_running() is False + + def test_socket_close_exception_swallowed(self, tmp_path, monkeypatch): + """Lines 143-144: if s.close() raises in the finally, swallow it.""" + monkeypatch.undo() + home = tmp_path / ".pilot" + home.mkdir() + monkeypatch.setenv("PILOT_HOME", str(home)) + + # Stub socket.socket so close() raises but exists() / connect() can + # be controlled. The path also needs to not exist so we return early + # — actually we need to reach the finally branch, which requires + # the socket path to exist. Use a real file at sock_path and force + # close() to raise. + sock_path = tmp_path / "fake.sock" + sock_path.write_text("x") # exists() True + (home / "config.json").write_text( + json.dumps({"socket": str(sock_path)}) + ) + + class BadCloseSocket: + def __init__(self, *a, **kw): + pass + def settimeout(self, t): + pass + def connect(self, path): + raise OSError("can't connect to real file") + def close(self): + raise OSError("close blew up") + + monkeypatch.setattr(rt.socket, "socket", BadCloseSocket) + # Should not raise — the finally swallows the close error. + result = rt._daemon_running() + assert result is False + + def test_socket_connect_succeeds(self, tmp_path, monkeypatch): + # macOS AF_UNIX limit is ~104 bytes. Use a short /tmp path instead + # of tmp_path which can exceed it under pytest. + import tempfile + monkeypatch.undo() + home = tmp_path / ".pilot" + home.mkdir() + with tempfile.TemporaryDirectory(prefix="psk-") as short: + sock_path = Path(short) / "p.sock" + srv = socket.socket(socket.AF_UNIX) + try: + srv.bind(str(sock_path)) + srv.listen(1) + (home / "config.json").write_text( + json.dumps({"socket": str(sock_path)}) + ) + monkeypatch.setenv("PILOT_HOME", str(home)) + assert rt._daemon_running() is True + finally: + srv.close() + if sock_path.exists(): + sock_path.unlink() + + +# --------------------------------------------------------------------------- +# _atomic_install +# --------------------------------------------------------------------------- + + +class TestAtomicInstall: + def test_happy_path(self, tmp_path): + src = tmp_path / "src.bin" + src.write_bytes(b"payload") + dst = tmp_path / "dst.bin" + rt._atomic_install(src, dst) + assert dst.read_bytes() == b"payload" + assert dst.stat().st_mode & 0o777 == 0o755 + + def test_preexisting_tmp_is_cleared(self, tmp_path, monkeypatch): + src = tmp_path / "src.bin" + src.write_bytes(b"v1") + dst = tmp_path / "dst.bin" + # Predict the tmp name our function will pick. + import threading + tmp_name = f"dst.bin.tmp.{os.getpid()}.{threading.get_ident()}" + (tmp_path / tmp_name).write_text("stale") + rt._atomic_install(src, dst) + assert dst.read_bytes() == b"v1" + + def test_replace_failure_cleans_tmp(self, tmp_path, monkeypatch): + src = tmp_path / "src.bin" + src.write_bytes(b"x") + dst = tmp_path / "dst.bin" + + # Force os.replace to raise once. + original_replace = os.replace + calls = {"n": 0} + + def boom(a, b): + calls["n"] += 1 + raise OSError(errno.EACCES, "denied") + + monkeypatch.setattr(os, "replace", boom) + with pytest.raises(OSError): + rt._atomic_install(src, dst) + # tmp should be gone (cleanup ran) + import threading + tmp = tmp_path / f"dst.bin.tmp.{os.getpid()}.{threading.get_ident()}" + assert not tmp.exists() + + def test_replace_failure_tmp_already_gone(self, tmp_path, monkeypatch): + """If the cleanup unlink also fails, the original OSError still bubbles.""" + src = tmp_path / "src.bin" + src.write_bytes(b"x") + dst = tmp_path / "dst.bin" + + def boom_replace(a, b): + raise OSError(errno.EACCES, "denied") + + original_unlink = Path.unlink + + def boom_unlink(self, *a, **kw): + raise OSError(errno.ENOENT, "gone") + + monkeypatch.setattr(os, "replace", boom_replace) + monkeypatch.setattr(Path, "unlink", boom_unlink) + with pytest.raises(OSError): + rt._atomic_install(src, dst) + + +# --------------------------------------------------------------------------- +# _ensure_dir_writable +# --------------------------------------------------------------------------- + + +class TestEnsureDirWritable: + def test_creates_missing_dir(self, tmp_path): + p = tmp_path / "deep" / "nested" / "dir" + rt._ensure_dir_writable(p) + assert p.is_dir() + + def test_unwritable_raises(self, tmp_path, monkeypatch): + p = tmp_path / "ro" + p.mkdir() + # Pretend it's not writable. + monkeypatch.setattr(os, "access", lambda path, mode: False) + with pytest.raises(PermissionError, match="not writable"): + rt._ensure_dir_writable(p) + + +# --------------------------------------------------------------------------- +# _ensure_default_config race +# --------------------------------------------------------------------------- + + +class TestEnsureDefaultConfigRace: + def test_handles_race_with_other_writer(self, tmp_path, monkeypatch): + # Simulate another writer winning the os.replace race + # (FileNotFoundError branch). + monkeypatch.setenv("PILOT_HOME", str(tmp_path / ".pilot")) + + original_replace = os.replace + calls = {"n": 0} + + def racing_replace(a, b): + calls["n"] += 1 + raise FileNotFoundError(2, "no such file") + + monkeypatch.setattr(os, "replace", racing_replace) + # Should not raise even though replace failed. + rt._ensure_default_config() + + def test_handles_race_when_tmp_already_unlinked(self, tmp_path, monkeypatch): + """tmp.unlink raises inside cleanup — must not propagate.""" + monkeypatch.setenv("PILOT_HOME", str(tmp_path / ".pilot")) + + def racing_replace(a, b): + raise FileNotFoundError(2, "no such file") + + def cleanup_fails(self): + raise OSError(errno.ENOENT, "double race") + + monkeypatch.setattr(os, "replace", racing_replace) + monkeypatch.setattr(Path, "unlink", cleanup_fails) + # Should still return without raising. + rt._ensure_default_config() + + +# --------------------------------------------------------------------------- +# run_seeder OSError paths +# --------------------------------------------------------------------------- + + +class TestSeederOSErrorPaths: + def test_etxtbsy_during_copy_skips(self, tmp_path, monkeypatch): + # Build a fake pkg/bin + pkg = tmp_path / "pkg" + pkg.mkdir() + names = list(rt._BIN_NAMES) + [rt._platform_lib_name()] + for n in names: + (pkg / n).write_text("stub") + (pkg / ".pilot-version").write_text("1.0.0\n") + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + + # Make _atomic_install raise ETXTBSY for the first name copied. + seen = {"hit": False} + original = rt._atomic_install + + def flaky(src, dst): + if not seen["hit"]: + seen["hit"] = True + raise OSError(errno.ETXTBSY, "busy") + original(src, dst) + + monkeypatch.setattr(rt, "_atomic_install", flaky) + report = rt.run_seeder() + assert seen["hit"] is True + assert len(report.skipped) >= 1 # at least the busy one was skipped + + def test_other_oserror_propagates(self, tmp_path, monkeypatch): + pkg = tmp_path / "pkg" + pkg.mkdir() + for n in list(rt._BIN_NAMES) + [rt._platform_lib_name()]: + (pkg / n).write_text("stub") + (pkg / ".pilot-version").write_text("1.0.0\n") + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + + def boom(src, dst): + raise OSError(errno.EACCES, "permission denied") + + monkeypatch.setattr(rt, "_atomic_install", boom) + with pytest.raises(OSError): + rt.run_seeder() + + +# --------------------------------------------------------------------------- +# runtime_binary / runtime_library fallback paths +# --------------------------------------------------------------------------- + + +class TestRuntimeBinaryFallback: + def test_falls_back_to_wheel_when_rt_missing(self, tmp_path, monkeypatch): + # Build a fake pkg with only one binary present. + pkg = tmp_path / "pkg" + pkg.mkdir() + (pkg / "pilotctl").write_text("from-wheel") + (pkg / "pilotctl").chmod(0o755) + (pkg / ".pilot-version").write_text("1.0.0\n") + # libpilot is required by the platform_lib check inside run_seeder. + # Without it the version-marker comparison can decide to skip copy, + # but we want the SEED action to do nothing for our target name — + # simplest way is to bypass the seeder entirely. + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + monkeypatch.setattr(rt, "ensure_runtime_seeded", lambda: tmp_path / "empty-rt") + (tmp_path / "empty-rt").mkdir() + + p = rt.runtime_binary("pilotctl") + assert p == pkg / "pilotctl" + + def test_missing_in_both_raises(self, tmp_path, monkeypatch): + pkg = tmp_path / "pkg" + pkg.mkdir() + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + monkeypatch.setattr(rt, "ensure_runtime_seeded", lambda: tmp_path / "empty-rt") + (tmp_path / "empty-rt").mkdir() + + with pytest.raises(FileNotFoundError, match="not found"): + rt.runtime_binary("does-not-exist") + + +class TestRuntimeLibraryFallback: + def test_present_in_rt_returns_rt(self, tmp_path, monkeypatch): + rtdir = tmp_path / "rt" + rtdir.mkdir() + libname = rt._platform_lib_name() + (rtdir / libname).write_text("lib") + monkeypatch.setattr(rt, "ensure_runtime_seeded", lambda: rtdir) + # _pkg_bin_dir doesn't matter for this branch. + assert rt.runtime_library() == rtdir / libname + + def test_falls_back_to_wheel(self, tmp_path, monkeypatch): + # RT empty, wheel has it. + rtdir = tmp_path / "rt" + rtdir.mkdir() + pkg = tmp_path / "pkg" + pkg.mkdir() + libname = rt._platform_lib_name() + (pkg / libname).write_text("lib") + monkeypatch.setattr(rt, "ensure_runtime_seeded", lambda: rtdir) + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + assert rt.runtime_library() == pkg / libname + + def test_missing_everywhere_raises(self, tmp_path, monkeypatch): + rtdir = tmp_path / "rt" + rtdir.mkdir() + pkg = tmp_path / "pkg" + pkg.mkdir() + monkeypatch.setattr(rt, "ensure_runtime_seeded", lambda: rtdir) + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + with pytest.raises(FileNotFoundError, match="libpilot"): + rt.runtime_library() + + +# --------------------------------------------------------------------------- +# ensure_runtime_seeded cached path +# --------------------------------------------------------------------------- + + +class TestEnsureRuntimeSeededCache: + def test_second_call_returns_without_re_seeding(self, tmp_path, monkeypatch): + pkg = tmp_path / "pkg" + pkg.mkdir() + for n in list(rt._BIN_NAMES) + [rt._platform_lib_name()]: + (pkg / n).write_text("stub") + (pkg / ".pilot-version").write_text("1.0.0\n") + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + + first = rt.ensure_runtime_seeded() + # Now make run_seeder raise — if cache works, it never runs. + monkeypatch.setattr(rt, "run_seeder", lambda: (_ for _ in ()).throw( + RuntimeError("should not be called"))) + second = rt.ensure_runtime_seeded() + assert second == first + + def test_force_bypasses_cache(self, tmp_path, monkeypatch): + pkg = tmp_path / "pkg" + pkg.mkdir() + for n in list(rt._BIN_NAMES) + [rt._platform_lib_name()]: + (pkg / n).write_text("stub") + (pkg / ".pilot-version").write_text("1.0.0\n") + monkeypatch.setattr(rt, "_pkg_bin_dir", lambda: pkg) + + rt.ensure_runtime_seeded() + calls = {"n": 0} + original = rt.run_seeder + + def counting(): + calls["n"] += 1 + return original() + + monkeypatch.setattr(rt, "run_seeder", counting) + rt.ensure_runtime_seeded(force=True) + assert calls["n"] == 1 diff --git a/tests/test_services.py b/tests/test_services.py new file mode 100644 index 0000000..02457d7 --- /dev/null +++ b/tests/test_services.py @@ -0,0 +1,453 @@ +"""Tests for high-level service helpers on Driver. + +Covers ``send_message`` (data-exchange port 1001), ``send_file`` (TypeFile +framing on port 1001), and ``publish_event`` / ``subscribe_event`` (event +stream port 1002). + +The wire formats are documented in the docstrings on Driver: +- data-exchange frame: ``[4-byte type][4-byte length][payload]`` +- file payload: ``[2-byte name len][name][file data]`` +- event frame: ``[2-byte topic len][topic][4-byte payload len][payload]`` + +We mock ``Driver.dial`` to return a fake Conn that records writes and +serves canned reads, so no libpilot or daemon is required. +""" + +from __future__ import annotations + +import json +import struct +import types +from collections import deque +from unittest import mock + +import pytest + +import pilotprotocol.client as client_mod +from pilotprotocol.client import PilotError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class FakeConn: + """Minimal Conn replacement: capture writes, serve canned reads.""" + + def __init__(self, reads: list[bytes] | None = None) -> None: + self.writes: list[bytes] = [] + self._reads: deque[bytes] = deque(reads or []) + self.closed = False + + def write(self, data: bytes) -> int: + self.writes.append(data) + return len(data) + + def read(self, size: int = 4096) -> bytes: + if not self._reads: + return b"" + chunk = self._reads.popleft() + # If the caller wants a smaller chunk, slice + push back the remainder. + if len(chunk) > size: + self._reads.appendleft(chunk[size:]) + return chunk[:size] + return chunk + + def close(self) -> None: + self.closed = True + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.close() + + +def _ack_frame(payload: str = "ACK TEXT 5 bytes") -> list[bytes]: + """Build the two reads send_message expects: header(8) then payload.""" + body = payload.encode() + header = struct.pack(">II", 1, len(body)) + return [header, body] + + +def _make_driver_with_dial(monkeypatch, conn: FakeConn) -> client_mod.Driver: + """Construct a Driver whose .dial returns the provided FakeConn. + + We don't go through ``__init__`` (it calls into libpilot) — we build the + instance directly. The high-level methods never read ``self._h`` directly, + they go through ``self.dial`` which we replace. + """ + d = object.__new__(client_mod.Driver) + d._h = 1 + d._closed = False + monkeypatch.setattr(d, "dial", lambda addr, **kw: conn) + return d + + +# --------------------------------------------------------------------------- +# send_message +# --------------------------------------------------------------------------- + + +class TestSendMessage: + def test_protocol_address_skips_resolve(self, monkeypatch): + conn = FakeConn(reads=_ack_frame("ACK TEXT 5 bytes")) + d = _make_driver_with_dial(monkeypatch, conn) + # If resolve_hostname were called this would explode (no real lib). + d.resolve_hostname = mock.Mock(side_effect=AssertionError("should skip")) + + result = d.send_message("0:0001.0000.0002", b"hello") + + assert result["sent"] == 5 + assert result["type"] == "text" + assert result["target"] == "0:0001.0000.0002" + assert result["ack"] == "ACK TEXT 5 bytes" + # Frame on the wire: [type=1][length=5]hello + assert conn.writes[0] == struct.pack(">II", 1, 5) + b"hello" + assert conn.closed is True + + def test_hostname_path_calls_resolve(self, monkeypatch): + conn = FakeConn(reads=_ack_frame()) + d = _make_driver_with_dial(monkeypatch, conn) + d.resolve_hostname = mock.Mock(return_value={"address": "0:0001.0000.0042"}) + + result = d.send_message("agent-hostname", b"hi") + + d.resolve_hostname.assert_called_once_with("agent-hostname") + assert result["target"] == "0:0001.0000.0042" + + def test_resolve_returns_empty_address_raises(self, monkeypatch): + conn = FakeConn() + d = _make_driver_with_dial(monkeypatch, conn) + d.resolve_hostname = mock.Mock(return_value={"address": ""}) + + with pytest.raises(PilotError, match="Could not resolve hostname"): + d.send_message("nonexistent", b"x") + + def test_message_type_maps_correctly(self, monkeypatch): + for label, code in (("text", 1), ("binary", 2), ("json", 3), ("file", 4)): + conn = FakeConn(reads=_ack_frame()) + d = _make_driver_with_dial(monkeypatch, conn) + d.send_message("0:0001.0000.0002", b"abc", msg_type=label) + ftype, flen = struct.unpack(">II", conn.writes[0][:8]) + assert ftype == code + assert flen == 3 + + def test_unknown_msg_type_defaults_to_text(self, monkeypatch): + conn = FakeConn(reads=_ack_frame()) + d = _make_driver_with_dial(monkeypatch, conn) + d.send_message("0:0001.0000.0002", b"x", msg_type="weird-type") + ftype, _ = struct.unpack(">II", conn.writes[0][:8]) + assert ftype == 1 # text fallback + + def test_ack_read_failure_still_returns_sent(self, monkeypatch): + # No ACK frame readable → result lacks 'ack' but call succeeds + conn = FakeConn(reads=[]) + d = _make_driver_with_dial(monkeypatch, conn) + result = d.send_message("0:0001.0000.0002", b"hello") + assert result == {"sent": 5, "type": "text", "target": "0:0001.0000.0002"} + assert "ack" not in result + + def test_short_ack_header_falls_through(self, monkeypatch): + # Header is < 8 bytes → ACK branch skipped + conn = FakeConn(reads=[b"\x00\x00\x00"]) # 3 bytes, not 8 + d = _make_driver_with_dial(monkeypatch, conn) + result = d.send_message("0:0001.0000.0002", b"x") + assert "ack" not in result + + def test_ack_read_raises_caught(self, monkeypatch): + # Conn.read raises after the write — caught silently + class BoomConn(FakeConn): + def read(self, size: int = 4096) -> bytes: + raise PilotError("read broke") + + conn = BoomConn() + d = _make_driver_with_dial(monkeypatch, conn) + result = d.send_message("0:0001.0000.0002", b"x") + assert result["sent"] == 1 + assert "ack" not in result + + def test_dial_target_uses_port_1001(self, monkeypatch): + conn = FakeConn(reads=_ack_frame()) + d = object.__new__(client_mod.Driver) + d._h = 1 + d._closed = False + captured = {} + + def fake_dial(addr, **kw): + captured["addr"] = addr + return conn + + d.dial = fake_dial + d.send_message("0:0001.0000.0002", b"x") + assert captured["addr"] == "0:0001.0000.0002:1001" + + +# --------------------------------------------------------------------------- +# send_file +# --------------------------------------------------------------------------- + + +class TestSendFile: + def test_missing_file_raises(self, monkeypatch, tmp_path): + d = _make_driver_with_dial(monkeypatch, FakeConn()) + with pytest.raises(PilotError, match="File not found"): + d.send_file("0:0001.0000.0002", str(tmp_path / "nope.bin")) + + def test_file_frame_layout(self, monkeypatch, tmp_path): + f = tmp_path / "hello.txt" + f.write_bytes(b"contents") + + conn = FakeConn(reads=_ack_frame("ACK FILE 8 bytes")) + d = _make_driver_with_dial(monkeypatch, conn) + result = d.send_file("0:0001.0000.0002", str(f)) + + assert result["sent"] == 8 + assert result["filename"] == "hello.txt" + assert result["ack"] == "ACK FILE 8 bytes" + + # Frame: [type=4][total_len][2-byte name len][name][file data] + ftype, total_len = struct.unpack(">II", conn.writes[0][:8]) + assert ftype == 4 + payload = conn.writes[0][8:] + assert len(payload) == total_len + name_len = struct.unpack(">H", payload[:2])[0] + assert payload[2 : 2 + name_len] == b"hello.txt" + assert payload[2 + name_len :] == b"contents" + + def test_hostname_resolution(self, monkeypatch, tmp_path): + f = tmp_path / "x.bin" + f.write_bytes(b"\x00") + conn = FakeConn(reads=_ack_frame()) + d = _make_driver_with_dial(monkeypatch, conn) + d.resolve_hostname = mock.Mock(return_value={"address": "0:0001.0000.0009"}) + + result = d.send_file("hostname", str(f)) + d.resolve_hostname.assert_called_once_with("hostname") + assert result["target"] == "0:0001.0000.0009" + + def test_resolve_empty_address_raises(self, monkeypatch, tmp_path): + f = tmp_path / "x.bin" + f.write_bytes(b"\x00") + d = _make_driver_with_dial(monkeypatch, FakeConn()) + d.resolve_hostname = mock.Mock(return_value={"address": ""}) + with pytest.raises(PilotError, match="Could not resolve hostname"): + d.send_file("nope", str(f)) + + def test_ack_failure_does_not_raise(self, monkeypatch, tmp_path): + f = tmp_path / "x.bin" + f.write_bytes(b"data") + + class BoomConn(FakeConn): + def read(self, size: int = 4096) -> bytes: + raise PilotError("network died after write") + + conn = BoomConn() + d = _make_driver_with_dial(monkeypatch, conn) + result = d.send_file("0:0001.0000.0002", str(f)) + assert result["sent"] == 4 + assert "ack" not in result + + def test_short_ack_header_falls_through(self, monkeypatch, tmp_path): + f = tmp_path / "x.bin" + f.write_bytes(b"data") + conn = FakeConn(reads=[b""]) # empty read = falsy → branch skipped + d = _make_driver_with_dial(monkeypatch, conn) + result = d.send_file("0:0001.0000.0002", str(f)) + assert result["sent"] == 4 + assert "ack" not in result + + +# --------------------------------------------------------------------------- +# publish_event +# --------------------------------------------------------------------------- + + +class TestPublishEvent: + def test_subscribe_then_publish_frames(self, monkeypatch): + conn = FakeConn() + d = _make_driver_with_dial(monkeypatch, conn) + r = d.publish_event("0:0001.0000.0002", "temp", b"42C") + + assert r == {"status": "published", "topic": "temp", "bytes": 3} + # First write: subscribe (empty payload) + # Wire: [2-byte topic len][topic][4-byte payload len][payload] + topic_len = struct.unpack(">H", conn.writes[0][:2])[0] + assert conn.writes[0][2 : 2 + topic_len] == b"temp" + payload_len = struct.unpack(">I", conn.writes[0][2 + topic_len : 6 + topic_len])[0] + assert payload_len == 0 + # Second write: actual publish + topic_len2 = struct.unpack(">H", conn.writes[1][:2])[0] + assert conn.writes[1][2 : 2 + topic_len2] == b"temp" + plen2 = struct.unpack(">I", conn.writes[1][2 + topic_len2 : 6 + topic_len2])[0] + assert plen2 == 3 + assert conn.writes[1][6 + topic_len2 :] == b"42C" + assert conn.closed is True + + def test_resolves_hostname(self, monkeypatch): + conn = FakeConn() + d = _make_driver_with_dial(monkeypatch, conn) + d.resolve_hostname = mock.Mock(return_value={"address": "0:0001.0000.0042"}) + d.publish_event("agent-host", "topic-A", b"x") + d.resolve_hostname.assert_called_once_with("agent-host") + + def test_resolve_empty_raises(self, monkeypatch): + d = _make_driver_with_dial(monkeypatch, FakeConn()) + d.resolve_hostname = mock.Mock(return_value={"address": ""}) + with pytest.raises(PilotError, match="Could not resolve hostname"): + d.publish_event("nope", "t", b"x") + + def test_dial_uses_port_1002(self, monkeypatch): + captured = {} + + def fake_dial(addr, **kw): + captured["addr"] = addr + return FakeConn() + + d = object.__new__(client_mod.Driver) + d._h = 1 + d._closed = False + d.dial = fake_dial + d.publish_event("0:0001.0000.0002", "t", b"x") + assert captured["addr"] == "0:0001.0000.0002:1002" + + +# --------------------------------------------------------------------------- +# subscribe_event +# --------------------------------------------------------------------------- + + +def _event_bytes(topic: str, payload: bytes) -> list[bytes]: + """Encode an event frame as the four reads subscribe_event performs. + + Reads (in order): 2-byte topic len, topic, 4-byte payload len, payload. + """ + tb = topic.encode() + return [ + struct.pack(">H", len(tb)), + tb, + struct.pack(">I", len(payload)), + payload, + ] + + +class TestSubscribeEvent: + def test_yields_events(self, monkeypatch): + conn = FakeConn( + reads=_event_bytes("foo", b"hello") + _event_bytes("bar", b"world") + ) + d = _make_driver_with_dial(monkeypatch, conn) + gen = d.subscribe_event("0:0001.0000.0002", "foo", timeout=5) + events = list(gen) + assert events == [("foo", b"hello"), ("bar", b"world")] + # First write is the subscription frame with empty payload + topic_len = struct.unpack(">H", conn.writes[0][:2])[0] + assert conn.writes[0][2 : 2 + topic_len] == b"foo" + + def test_callback_invoked_instead_of_yield(self, monkeypatch): + conn = FakeConn(reads=_event_bytes("t", b"p")) + d = _make_driver_with_dial(monkeypatch, conn) + received = [] + gen = d.subscribe_event( + "0:0001.0000.0002", "t", callback=lambda topic, data: received.append((topic, data)), timeout=2, + ) + # When callback is supplied, the generator yields nothing. + assert list(gen) == [] + assert received == [("t", b"p")] + + def test_short_topic_len_returns_none(self, monkeypatch): + # 1 byte instead of 2 — read_event returns None → loop breaks + conn = FakeConn(reads=[b"\x00"]) + d = _make_driver_with_dial(monkeypatch, conn) + events = list(d.subscribe_event("0:0001.0000.0002", "t", timeout=2)) + assert events == [] + assert conn.closed is True + + def test_short_topic_body_returns_none(self, monkeypatch): + # Topic len says 5 but only 2 bytes follow + reads = [struct.pack(">H", 5), b"ab"] + conn = FakeConn(reads=reads) + d = _make_driver_with_dial(monkeypatch, conn) + events = list(d.subscribe_event("0:0001.0000.0002", "t", timeout=2)) + assert events == [] + + def test_short_payload_len_returns_none(self, monkeypatch): + reads = [struct.pack(">H", 3), b"foo", b"\x00\x01"] # 2 bytes, need 4 + conn = FakeConn(reads=reads) + d = _make_driver_with_dial(monkeypatch, conn) + events = list(d.subscribe_event("0:0001.0000.0002", "t", timeout=2)) + assert events == [] + + def test_short_payload_body_returns_none(self, monkeypatch): + reads = [ + struct.pack(">H", 3), + b"foo", + struct.pack(">I", 10), + b"abc", # only 3 bytes, need 10 + ] + conn = FakeConn(reads=reads) + d = _make_driver_with_dial(monkeypatch, conn) + events = list(d.subscribe_event("0:0001.0000.0002", "t", timeout=2)) + assert events == [] + + def test_connection_closed_error_breaks_cleanly(self, monkeypatch): + class BoomConn(FakeConn): + def __init__(self): + super().__init__() + self._sent_subscribe = False + + def read(self, size: int = 4096) -> bytes: + raise PilotError("connection closed by peer") + + conn = BoomConn() + d = _make_driver_with_dial(monkeypatch, conn) + events = list(d.subscribe_event("0:0001.0000.0002", "t", timeout=2)) + assert events == [] + assert conn.closed is True + + def test_eof_error_breaks_cleanly(self, monkeypatch): + class EofConn(FakeConn): + def read(self, size: int = 4096) -> bytes: + raise RuntimeError("unexpected EOF on stream") + + conn = EofConn() + d = _make_driver_with_dial(monkeypatch, conn) + events = list(d.subscribe_event("0:0001.0000.0002", "t", timeout=2)) + assert events == [] + + def test_other_exception_propagates(self, monkeypatch): + # Error string contains neither "connection closed" nor "EOF" — + # should propagate. + class BadConn(FakeConn): + def read(self, size: int = 4096) -> bytes: + raise PilotError("permission denied") + + conn = BadConn() + d = _make_driver_with_dial(monkeypatch, conn) + with pytest.raises(PilotError, match="permission denied"): + list(d.subscribe_event("0:0001.0000.0002", "t", timeout=2)) + # Connection should still be closed by the finally block + assert conn.closed is True + + def test_timeout_terminates_loop(self, monkeypatch): + # With timeout=0 the while loop never enters → no reads, immediate end + conn = FakeConn(reads=_event_bytes("x", b"y")) + d = _make_driver_with_dial(monkeypatch, conn) + events = list(d.subscribe_event("0:0001.0000.0002", "x", timeout=0)) + assert events == [] + # But we should have written the subscription frame before entering the loop + assert len(conn.writes) == 1 + + def test_resolves_hostname(self, monkeypatch): + conn = FakeConn(reads=[]) + d = _make_driver_with_dial(monkeypatch, conn) + d.resolve_hostname = mock.Mock(return_value={"address": "0:0001.0000.0042"}) + list(d.subscribe_event("hostname", "t", timeout=1)) + d.resolve_hostname.assert_called_once_with("hostname") + + def test_resolve_empty_raises(self, monkeypatch): + d = _make_driver_with_dial(monkeypatch, FakeConn()) + d.resolve_hostname = mock.Mock(return_value={"address": ""}) + with pytest.raises(PilotError, match="Could not resolve hostname"): + list(d.subscribe_event("nope", "t", timeout=1))