diff --git a/docs/transports.md b/docs/transports.md index 12225b5..73a2864 100644 --- a/docs/transports.md +++ b/docs/transports.md @@ -48,6 +48,51 @@ actions = client.predict_action(obs) robot.execute(actions[0]) # first action in the chunk ``` +### Secure ZMQ + +Production ZMQ deployments should use CURVE authentication/encryption and a +control token for operational endpoints such as `ping` and `kill`. + +Generate one server keypair and one client keypair: + +```bash +python - <<'PY' +from pathlib import Path +import zmq.auth + +out = Path("zmq-certs") +(out / "clients").mkdir(parents=True, exist_ok=True) +zmq.auth.create_certificates(out, "server") +zmq.auth.create_certificates(out / "clients", "robot-1") +PY +``` + +Start the server with the server secret certificate and the directory of +allowed client public certificates: + +```bash +tether serve ./my_export/ \ + --transport zmq \ + --port 5555 \ + --zmq-server-cert ./zmq-certs/server.key_secret \ + --zmq-client-cert-dir ./zmq-certs/clients \ + --zmq-control-token "$TETHER_ZMQ_CONTROL_TOKEN" +``` + +Configure the robot-side client with its secret certificate, the server public +certificate, and the same control token: + +```python +import os + +client = ZmqRuntimeClient( + "tcp://gpu-server:5555", + curve_client_cert="./zmq-certs/clients/robot-1.key_secret", + curve_server_public_key="./zmq-certs/server.key", + auth_token=os.environ["TETHER_ZMQ_CONTROL_TOKEN"], +) +``` + ## ZMQ Performance | Metric | HTTP | ZMQ | ZMQ + JPEG | diff --git a/src/tether/cli.py b/src/tether/cli.py index dca60bd..fed30f8 100644 --- a/src/tether/cli.py +++ b/src/tether/cli.py @@ -1730,6 +1730,24 @@ def serve( "3-camera setups via JPEG-on-wire and 40%+ lower tail jitter. " "ROS2 reserved for v1.0.", ), + zmq_server_cert: str = typer.Option( + "", + "--zmq-server-cert", + help="Path to a pyzmq CURVE server secret certificate (.key_secret). " + "Requires --transport zmq and --zmq-client-cert-dir.", + ), + zmq_client_cert_dir: str = typer.Option( + "", + "--zmq-client-cert-dir", + help="Directory of allowed pyzmq CURVE client public certificates. " + "Requires --transport zmq and --zmq-server-cert.", + ), + zmq_control_token: str = typer.Option( + "", + "--zmq-control-token", + help="Token required for ZMQ control endpoints such as ping and kill. " + "Pass the same value to ZmqRuntimeClient(auth_token=...).", + ), device: str = typer.Option("cuda", help="Device: cuda or cpu"), providers: str = typer.Option( "", @@ -2653,8 +2671,19 @@ def _run_mcp_http(): if transport == "zmq": console.print("[bold green]Starting ZMQ server...[/bold green]") from tether.runtime.transports.zmq.factory import create_zmq_server - zmq_server = create_zmq_server(app_instance, host=host, port=port) + zmq_server = create_zmq_server( + app_instance, + host=host, + port=port, + curve_server_cert=zmq_server_cert or None, + curve_client_cert_dir=zmq_client_cert_dir or None, + control_token=zmq_control_token or None, + ) composed.append("[cyan]transport=zmq[/cyan]") + if zmq_server_cert: + composed.append("[cyan]curve=on[/cyan]") + if zmq_control_token: + composed.append("[cyan]control-auth=on[/cyan]") console.print(f"[dim]Features: {' + '.join(composed)}[/dim]") zmq_server.run() elif transport == "http": diff --git a/src/tether/runtime/transports/zmq/client.py b/src/tether/runtime/transports/zmq/client.py index a6812d4..eb7c6a8 100644 --- a/src/tether/runtime/transports/zmq/client.py +++ b/src/tether/runtime/transports/zmq/client.py @@ -24,6 +24,7 @@ import io import time from dataclasses import dataclass +from pathlib import Path from typing import Any import msgpack @@ -32,9 +33,9 @@ from tether.runtime.transports.zmq.serializers import ( SCHEMA_VERSION, - decode_observation, encode_observation, ) +from tether.runtime.transports.zmq.security import load_curve_key @dataclass @@ -63,10 +64,21 @@ def __init__( server_url: str = "tcp://localhost:5555", timeout_ms: int = 5000, jpeg_quality: int = 85, + *, + curve_client_cert: str | Path | None = None, + curve_client_public_key: str | bytes | None = None, + curve_client_secret_key: str | bytes | None = None, + curve_server_public_key: str | bytes | Path | None = None, + auth_token: str | None = None, ) -> None: self._server_url = server_url self._timeout_ms = timeout_ms self._jpeg_quality = jpeg_quality + self._curve_client_cert = curve_client_cert + self._curve_client_public_key = curve_client_public_key + self._curve_client_secret_key = curve_client_secret_key + self._curve_server_public_key = curve_server_public_key + self._auth_token = auth_token self._context = zmq.Context() self._socket: zmq.Socket | None = None self._connect() @@ -78,8 +90,53 @@ def _connect(self) -> None: if self._timeout_ms > 0: self._socket.setsockopt(zmq.RCVTIMEO, self._timeout_ms) self._socket.setsockopt(zmq.SNDTIMEO, self._timeout_ms) + self._configure_curve() self._socket.connect(self._server_url) + def _configure_curve(self) -> None: + if self._curve_server_public_key is None: + if ( + self._curve_client_cert is not None + or self._curve_client_public_key is not None + or self._curve_client_secret_key is not None + ): + raise ValueError("CURVE client mode requires curve_server_public_key") + return + + if self._socket is None: + raise RuntimeError("ZMQ socket has not been initialized") + + if self._curve_client_cert is not None: + if self._curve_client_public_key is not None or self._curve_client_secret_key is not None: + raise ValueError("Pass either curve_client_cert or explicit CURVE keys, not both") + client_public_key = load_curve_key(self._curve_client_cert, secret=False) + client_secret_key = load_curve_key(self._curve_client_cert, secret=True) + else: + if self._curve_client_public_key is None or self._curve_client_secret_key is None: + raise ValueError("CURVE client mode requires both public and secret keys") + client_public_key = load_curve_key(self._curve_client_public_key, secret=False) + client_secret_key = load_curve_key(self._curve_client_secret_key, secret=True) + + self._socket.curve_publickey = client_public_key + self._socket.curve_secretkey = client_secret_key + self._socket.curve_serverkey = load_curve_key( + self._curve_server_public_key, + secret=False, + ) + + def _request(self, endpoint: str, data: dict[str, Any] | None = None) -> dict: + request: dict[str, Any] = { + "endpoint": endpoint, + "schema_version": SCHEMA_VERSION, + } + if data is not None: + request["data"] = data + if self._auth_token is not None: + request["auth_token"] = self._auth_token + + self._socket.send(msgpack.packb(request, use_bin_type=True)) + return msgpack.unpackb(self._socket.recv(), raw=False) + def predict_action( self, obs: dict[str, Any], @@ -107,16 +164,19 @@ def predict_action( # Serialize t0 = time.perf_counter() obs_bytes = encode_observation(obs, jpeg_quality=self._jpeg_quality) - request = msgpack.packb({ + request = { "endpoint": "predict_action", "schema_version": SCHEMA_VERSION, "data": {"obs_data": obs_bytes}, - }, use_bin_type=True) + } + if self._auth_token is not None: + request["auth_token"] = self._auth_token + request_bytes = msgpack.packb(request, use_bin_type=True) serialize_ms = (time.perf_counter() - t0) * 1000 # ZMQ round-trip t0 = time.perf_counter() - self._socket.send(request) + self._socket.send(request_bytes) response_bytes = self._socket.recv() zmq_roundtrip_ms = (time.perf_counter() - t0) * 1000 @@ -147,15 +207,15 @@ def predict_action( def ping(self) -> dict: """Health check — returns server status dict.""" - msg = msgpack.packb({"endpoint": "ping", "schema_version": SCHEMA_VERSION}, use_bin_type=True) - self._socket.send(msg) - return msgpack.unpackb(self._socket.recv(), raw=False) + return self._request("ping") def reset(self) -> dict: """Signal episode boundary to the server.""" - msg = msgpack.packb({"endpoint": "reset", "schema_version": SCHEMA_VERSION}, use_bin_type=True) - self._socket.send(msg) - return msgpack.unpackb(self._socket.recv(), raw=False) + return self._request("reset") + + def kill(self) -> dict: + """Request graceful server shutdown.""" + return self._request("kill") def close(self) -> None: """Clean up socket + context.""" diff --git a/src/tether/runtime/transports/zmq/factory.py b/src/tether/runtime/transports/zmq/factory.py index c3f73a6..355d120 100644 --- a/src/tether/runtime/transports/zmq/factory.py +++ b/src/tether/runtime/transports/zmq/factory.py @@ -18,6 +18,7 @@ import logging import threading import time +from pathlib import Path from typing import Any import numpy as np @@ -32,6 +33,9 @@ def create_zmq_server( *, host: str = "*", port: int = 5555, + curve_server_cert: str | Path | None = None, + curve_client_cert_dir: str | Path | None = None, + control_token: str | None = None, ) -> PolicyServer: """Create a ZMQ server that wraps a tether PolicyRuntime. @@ -49,6 +53,9 @@ def create_zmq_server( ``predict_action_chunk`` / ``predict_async`` method. host: Bind address. port: Bind port. + curve_server_cert: Optional pyzmq CURVE server secret certificate. + curve_client_cert_dir: Directory of allowed client public certificates. + control_token: Optional token required for built-in control endpoints. Returns: A configured ``PolicyServer`` ready to ``run()``. @@ -125,7 +132,13 @@ def get_status() -> dict: "avg_infer_time_ms": round(avg * 1000, 2), } - server = PolicyServer(host=host, port=port) + server = PolicyServer( + host=host, + port=port, + curve_server_cert=curve_server_cert, + curve_client_cert_dir=curve_client_cert_dir, + control_token=control_token, + ) server.register_endpoint("predict_action", predict_action) server.register_endpoint("reset", reset, requires_input=False) server.register_endpoint("get_status", get_status, requires_input=False) diff --git a/src/tether/runtime/transports/zmq/policy_server.py b/src/tether/runtime/transports/zmq/policy_server.py index 5fc76fa..c446875 100644 --- a/src/tether/runtime/transports/zmq/policy_server.py +++ b/src/tether/runtime/transports/zmq/policy_server.py @@ -19,10 +19,14 @@ import logging import time from dataclasses import dataclass +from pathlib import Path from typing import Any, Callable import msgpack import zmq +from zmq.auth.thread import ThreadAuthenticator + +from tether.runtime.transports.zmq.security import load_curve_key logger = logging.getLogger(__name__) @@ -47,6 +51,7 @@ def __init__(self, client_version: int, server_version: int) -> None: class _EndpointHandler: handler: Callable requires_input: bool = True + requires_auth: bool = False class PolicyServer: @@ -65,17 +70,45 @@ class PolicyServer: port: TCP port to listen on. 0 = kernel-assigned (for testing). """ - def __init__(self, host: str = "*", port: int = 5555) -> None: + def __init__( + self, + host: str = "*", + port: int = 5555, + *, + curve_server_cert: str | Path | None = None, + curve_server_public_key: str | bytes | None = None, + curve_server_secret_key: str | bytes | None = None, + curve_client_cert_dir: str | Path | None = None, + control_token: str | None = None, + ) -> None: self.running = True self.context = zmq.Context() + self._authenticator: ThreadAuthenticator | None = None + self._control_token = control_token self.socket = self.context.socket(zmq.REP) + self._configure_curve( + curve_server_cert=curve_server_cert, + curve_server_public_key=curve_server_public_key, + curve_server_secret_key=curve_server_secret_key, + curve_client_cert_dir=curve_client_cert_dir, + ) self.socket.bind(f"tcp://{host}:{port}") self._endpoints: dict[str, _EndpointHandler] = {} self._start_time = time.monotonic() self._request_count = 0 - self.register_endpoint("ping", self._handle_ping, requires_input=False) - self.register_endpoint("kill", self._handle_kill, requires_input=False) + self.register_endpoint( + "ping", + self._handle_ping, + requires_input=False, + requires_auth=self._control_token is not None, + ) + self.register_endpoint( + "kill", + self._handle_kill, + requires_input=False, + requires_auth=self._control_token is not None, + ) @property def bound_address(self) -> str: @@ -90,8 +123,53 @@ def register_endpoint( name: str, handler: Callable, requires_input: bool = True, + requires_auth: bool = False, + ) -> None: + self._endpoints[name] = _EndpointHandler(handler, requires_input, requires_auth) + + def _configure_curve( + self, + *, + curve_server_cert: str | Path | None, + curve_server_public_key: str | bytes | None, + curve_server_secret_key: str | bytes | None, + curve_client_cert_dir: str | Path | None, ) -> None: - self._endpoints[name] = _EndpointHandler(handler, requires_input) + if curve_server_cert is None and curve_server_public_key is None and curve_server_secret_key is None: + if curve_client_cert_dir is not None: + raise ValueError("curve_client_cert_dir requires a CURVE server certificate or keypair") + return + + if curve_server_cert is not None: + if curve_server_public_key is not None or curve_server_secret_key is not None: + raise ValueError("Pass either curve_server_cert or explicit CURVE keys, not both") + public_key = load_curve_key(curve_server_cert, secret=False) + secret_key = load_curve_key(curve_server_cert, secret=True) + else: + if curve_server_public_key is None or curve_server_secret_key is None: + raise ValueError("CURVE server mode requires both public and secret keys") + public_key = load_curve_key(curve_server_public_key, secret=False) + secret_key = load_curve_key(curve_server_secret_key, secret=True) + + if curve_client_cert_dir is None: + raise ValueError("CURVE server mode requires curve_client_cert_dir for client authentication") + + client_cert_dir = Path(curve_client_cert_dir).expanduser() + if not client_cert_dir.is_dir(): + raise ValueError(f"CURVE client certificate directory not found: {client_cert_dir}") + + self._authenticator = ThreadAuthenticator(self.context) + self._authenticator.start() + self._authenticator.configure_curve(domain="*", location=client_cert_dir) + self.socket.curve_publickey = public_key + self.socket.curve_secretkey = secret_key + self.socket.curve_server = True + + def _authorize_control_request(self, request: dict[str, Any]) -> None: + if self._control_token is None: + return + if request.get("auth_token") != self._control_token: + raise PermissionError("ZMQ control endpoint requires a valid auth token") def _handle_ping(self) -> dict: return { @@ -147,6 +225,8 @@ def run(self) -> None: raise ValueError(f"Unknown endpoint: {endpoint!r}") handler = self._endpoints[endpoint] + if handler.requires_auth: + self._authorize_control_request(request) if handler.requires_input: result = handler.handler(**request.get("data", {})) else: @@ -166,6 +246,8 @@ def run(self) -> None: # Clean shutdown self.socket.setsockopt(zmq.LINGER, 0) self.socket.close() + if self._authenticator is not None: + self._authenticator.stop() self.context.term() logger.info("ZMQ server shut down cleanly after %d requests", self._request_count) diff --git a/src/tether/runtime/transports/zmq/security.py b/src/tether/runtime/transports/zmq/security.py new file mode 100644 index 0000000..c947ff8 --- /dev/null +++ b/src/tether/runtime/transports/zmq/security.py @@ -0,0 +1,27 @@ +"""Shared ZMQ transport security helpers.""" +from __future__ import annotations + +from pathlib import Path + +import zmq.auth + + +def load_curve_key(value: str | bytes | Path, *, secret: bool) -> bytes: + """Load a Z85 CURVE key from a raw value or pyzmq certificate file.""" + if isinstance(value, bytes): + return value + + raw_value = str(value) + path = Path(raw_value).expanduser() + if path.exists(): + public_key, secret_key = zmq.auth.load_certificate(path) + key = secret_key if secret else public_key + if key is None: + kind = "secret" if secret else "public" + raise ValueError(f"CURVE certificate {path} does not contain a {kind} key") + return key + + return raw_value.encode("ascii") + + +__all__ = ["load_curve_key"] diff --git a/tests/test_zmq_client.py b/tests/test_zmq_client.py index 0a25591..419cfde 100644 --- a/tests/test_zmq_client.py +++ b/tests/test_zmq_client.py @@ -34,6 +34,20 @@ def server_port(): thread.join(timeout=2) +@pytest.fixture +def auth_server_port(): + """Start a ZMQ server that protects control endpoints with a token.""" + runtime = _MockRuntime() + server = create_zmq_server(runtime, port=0, control_token="secret") + port = server.bound_port + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + time.sleep(0.1) + yield port + server.close() + thread.join(timeout=2) + + # ── predict_action ─────────────────────────────────────────────────── @@ -118,6 +132,24 @@ def test_ping(server_port): assert result["status"] == "ok" +def test_ping_sends_auth_token(auth_server_port): + with ZmqRuntimeClient( + f"tcp://127.0.0.1:{auth_server_port}", + auth_token="secret", + ) as client: + result = client.ping() + assert result["status"] == "ok" + + +def test_kill_sends_auth_token(auth_server_port): + with ZmqRuntimeClient( + f"tcp://127.0.0.1:{auth_server_port}", + auth_token="secret", + ) as client: + result = client.kill() + assert result["status"] == "ok" + + def test_reset(server_port): with ZmqRuntimeClient(f"tcp://127.0.0.1:{server_port}") as client: result = client.reset() diff --git a/tests/test_zmq_policy_server.py b/tests/test_zmq_policy_server.py index 849f646..1015296 100644 --- a/tests/test_zmq_policy_server.py +++ b/tests/test_zmq_policy_server.py @@ -9,27 +9,39 @@ import time import msgpack -import pytest import zmq +import zmq.auth from tether.runtime.transports.zmq.policy_server import ( SCHEMA_VERSION, PolicyServer, - WireSchemaMismatchError, ) -def _start_server(port: int = 0) -> tuple[PolicyServer, threading.Thread]: - server = PolicyServer(port=port) +def _start_server(port: int = 0, **kwargs) -> tuple[PolicyServer, threading.Thread]: + server = PolicyServer(port=port, **kwargs) thread = threading.Thread(target=server.run, daemon=True) thread.start() time.sleep(0.1) # let the socket bind return server, thread -def _client_socket(port: int) -> zmq.Socket: +def _client_socket( + port: int, + *, + curve_client_cert: str | None = None, + curve_server_public_key: str | None = None, +) -> zmq.Socket: ctx = zmq.Context() sock = ctx.socket(zmq.REQ) + if curve_client_cert is not None or curve_server_public_key is not None: + assert curve_client_cert is not None + assert curve_server_public_key is not None + client_public_key, client_secret_key = zmq.auth.load_certificate(curve_client_cert) + server_public_key, _ = zmq.auth.load_certificate(curve_server_public_key) + sock.curve_publickey = client_public_key + sock.curve_secretkey = client_secret_key + sock.curve_serverkey = server_public_key sock.connect(f"tcp://127.0.0.1:{port}") return sock @@ -67,6 +79,52 @@ def test_kill_shuts_down(): assert not server.running +def test_control_token_required_for_ping_and_kill(): + server, thread = _start_server(control_token="secret") + port = server.bound_port + sock = _client_socket(port) + + result = _send_recv(sock, {"endpoint": "ping"}) + assert "error" in result + assert "auth token" in result["error"] + + result = _send_recv(sock, {"endpoint": "ping", "auth_token": "secret"}) + assert result["status"] == "ok" + + result = _send_recv(sock, {"endpoint": "kill", "auth_token": "wrong"}) + assert "error" in result + assert server.running + + result = _send_recv(sock, {"endpoint": "kill", "auth_token": "secret"}) + assert result["status"] == "ok" + thread.join(timeout=2) + assert not server.running + + +def test_curve_client_can_connect_with_allowed_certificate(tmp_path): + server_public, server_secret = zmq.auth.create_certificates(tmp_path, "server") + client_cert_dir = tmp_path / "clients" + client_cert_dir.mkdir() + _client_public, client_secret = zmq.auth.create_certificates(client_cert_dir, "robot") + + server, thread = _start_server( + curve_server_cert=server_secret, + curve_client_cert_dir=client_cert_dir, + ) + port = server.bound_port + sock = _client_socket( + port, + curve_client_cert=client_secret, + curve_server_public_key=server_public, + ) + + result = _send_recv(sock, {"endpoint": "ping"}) + assert result["status"] == "ok" + + server.close() + thread.join(timeout=2) + + # ── Custom endpoint ──────────────────────────────────────────────────