Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions docs/transports.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,51 @@ actions = client.predict_action(obs)
robot.execute(actions[0]) # first action in the chunk
```

### Secure ZMQ

Production ZMQ deployments should use CURVE authentication/encryption and a
control token for operational endpoints such as `ping` and `kill`.

Generate one server keypair and one client keypair:

```bash
python - <<'PY'
from pathlib import Path
import zmq.auth

out = Path("zmq-certs")
(out / "clients").mkdir(parents=True, exist_ok=True)
zmq.auth.create_certificates(out, "server")
zmq.auth.create_certificates(out / "clients", "robot-1")
PY
```

Start the server with the server secret certificate and the directory of
allowed client public certificates:

```bash
tether serve ./my_export/ \
--transport zmq \
--port 5555 \
--zmq-server-cert ./zmq-certs/server.key_secret \
--zmq-client-cert-dir ./zmq-certs/clients \
--zmq-control-token "$TETHER_ZMQ_CONTROL_TOKEN"
```

Configure the robot-side client with its secret certificate, the server public
certificate, and the same control token:

```python
import os

client = ZmqRuntimeClient(
"tcp://gpu-server:5555",
curve_client_cert="./zmq-certs/clients/robot-1.key_secret",
curve_server_public_key="./zmq-certs/server.key",
auth_token=os.environ["TETHER_ZMQ_CONTROL_TOKEN"],
)
```

## ZMQ Performance

| Metric | HTTP | ZMQ | ZMQ + JPEG |
Expand Down
31 changes: 30 additions & 1 deletion src/tether/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,24 @@ def serve(
"3-camera setups via JPEG-on-wire and 40%+ lower tail jitter. "
"ROS2 reserved for v1.0.",
),
zmq_server_cert: str = typer.Option(
"",
"--zmq-server-cert",
help="Path to a pyzmq CURVE server secret certificate (.key_secret). "
"Requires --transport zmq and --zmq-client-cert-dir.",
),
zmq_client_cert_dir: str = typer.Option(
"",
"--zmq-client-cert-dir",
help="Directory of allowed pyzmq CURVE client public certificates. "
"Requires --transport zmq and --zmq-server-cert.",
),
zmq_control_token: str = typer.Option(
"",
"--zmq-control-token",
help="Token required for ZMQ control endpoints such as ping and kill. "
"Pass the same value to ZmqRuntimeClient(auth_token=...).",
),
device: str = typer.Option("cuda", help="Device: cuda or cpu"),
providers: str = typer.Option(
"",
Expand Down Expand Up @@ -2653,8 +2671,19 @@ def _run_mcp_http():
if transport == "zmq":
console.print("[bold green]Starting ZMQ server...[/bold green]")
from tether.runtime.transports.zmq.factory import create_zmq_server
zmq_server = create_zmq_server(app_instance, host=host, port=port)
zmq_server = create_zmq_server(
app_instance,
host=host,
port=port,
curve_server_cert=zmq_server_cert or None,
curve_client_cert_dir=zmq_client_cert_dir or None,
control_token=zmq_control_token or None,
)
composed.append("[cyan]transport=zmq[/cyan]")
if zmq_server_cert:
composed.append("[cyan]curve=on[/cyan]")
if zmq_control_token:
composed.append("[cyan]control-auth=on[/cyan]")
console.print(f"[dim]Features: {' + '.join(composed)}[/dim]")
zmq_server.run()
elif transport == "http":
Expand Down
80 changes: 70 additions & 10 deletions src/tether/runtime/transports/zmq/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import msgpack
Expand All @@ -32,9 +33,9 @@

from tether.runtime.transports.zmq.serializers import (
SCHEMA_VERSION,
decode_observation,
encode_observation,
)
from tether.runtime.transports.zmq.security import load_curve_key


@dataclass
Expand Down Expand Up @@ -63,10 +64,21 @@ def __init__(
server_url: str = "tcp://localhost:5555",
timeout_ms: int = 5000,
jpeg_quality: int = 85,
*,
curve_client_cert: str | Path | None = None,
curve_client_public_key: str | bytes | None = None,
curve_client_secret_key: str | bytes | None = None,
curve_server_public_key: str | bytes | Path | None = None,
auth_token: str | None = None,
) -> None:
self._server_url = server_url
self._timeout_ms = timeout_ms
self._jpeg_quality = jpeg_quality
self._curve_client_cert = curve_client_cert
self._curve_client_public_key = curve_client_public_key
self._curve_client_secret_key = curve_client_secret_key
self._curve_server_public_key = curve_server_public_key
self._auth_token = auth_token
self._context = zmq.Context()
self._socket: zmq.Socket | None = None
self._connect()
Expand All @@ -78,8 +90,53 @@ def _connect(self) -> None:
if self._timeout_ms > 0:
self._socket.setsockopt(zmq.RCVTIMEO, self._timeout_ms)
self._socket.setsockopt(zmq.SNDTIMEO, self._timeout_ms)
self._configure_curve()
self._socket.connect(self._server_url)

def _configure_curve(self) -> None:
if self._curve_server_public_key is None:
if (
self._curve_client_cert is not None
or self._curve_client_public_key is not None
or self._curve_client_secret_key is not None
):
raise ValueError("CURVE client mode requires curve_server_public_key")
return

if self._socket is None:
raise RuntimeError("ZMQ socket has not been initialized")

if self._curve_client_cert is not None:
if self._curve_client_public_key is not None or self._curve_client_secret_key is not None:
raise ValueError("Pass either curve_client_cert or explicit CURVE keys, not both")
client_public_key = load_curve_key(self._curve_client_cert, secret=False)
client_secret_key = load_curve_key(self._curve_client_cert, secret=True)
else:
if self._curve_client_public_key is None or self._curve_client_secret_key is None:
raise ValueError("CURVE client mode requires both public and secret keys")
client_public_key = load_curve_key(self._curve_client_public_key, secret=False)
client_secret_key = load_curve_key(self._curve_client_secret_key, secret=True)

self._socket.curve_publickey = client_public_key
self._socket.curve_secretkey = client_secret_key
self._socket.curve_serverkey = load_curve_key(
self._curve_server_public_key,
secret=False,
)

def _request(self, endpoint: str, data: dict[str, Any] | None = None) -> dict:
request: dict[str, Any] = {
"endpoint": endpoint,
"schema_version": SCHEMA_VERSION,
}
if data is not None:
request["data"] = data
if self._auth_token is not None:
request["auth_token"] = self._auth_token

self._socket.send(msgpack.packb(request, use_bin_type=True))
return msgpack.unpackb(self._socket.recv(), raw=False)

def predict_action(
self,
obs: dict[str, Any],
Expand Down Expand Up @@ -107,16 +164,19 @@ def predict_action(
# Serialize
t0 = time.perf_counter()
obs_bytes = encode_observation(obs, jpeg_quality=self._jpeg_quality)
request = msgpack.packb({
request = {
"endpoint": "predict_action",
"schema_version": SCHEMA_VERSION,
"data": {"obs_data": obs_bytes},
}, use_bin_type=True)
}
if self._auth_token is not None:
request["auth_token"] = self._auth_token
request_bytes = msgpack.packb(request, use_bin_type=True)
serialize_ms = (time.perf_counter() - t0) * 1000

# ZMQ round-trip
t0 = time.perf_counter()
self._socket.send(request)
self._socket.send(request_bytes)
response_bytes = self._socket.recv()
zmq_roundtrip_ms = (time.perf_counter() - t0) * 1000

Expand Down Expand Up @@ -147,15 +207,15 @@ def predict_action(

def ping(self) -> dict:
"""Health check — returns server status dict."""
msg = msgpack.packb({"endpoint": "ping", "schema_version": SCHEMA_VERSION}, use_bin_type=True)
self._socket.send(msg)
return msgpack.unpackb(self._socket.recv(), raw=False)
return self._request("ping")

def reset(self) -> dict:
"""Signal episode boundary to the server."""
msg = msgpack.packb({"endpoint": "reset", "schema_version": SCHEMA_VERSION}, use_bin_type=True)
self._socket.send(msg)
return msgpack.unpackb(self._socket.recv(), raw=False)
return self._request("reset")

def kill(self) -> dict:
"""Request graceful server shutdown."""
return self._request("kill")

def close(self) -> None:
"""Clean up socket + context."""
Expand Down
15 changes: 14 additions & 1 deletion src/tether/runtime/transports/zmq/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import threading
import time
from pathlib import Path
from typing import Any

import numpy as np
Expand All @@ -32,6 +33,9 @@ def create_zmq_server(
*,
host: str = "*",
port: int = 5555,
curve_server_cert: str | Path | None = None,
curve_client_cert_dir: str | Path | None = None,
control_token: str | None = None,
) -> PolicyServer:
"""Create a ZMQ server that wraps a tether PolicyRuntime.

Expand All @@ -49,6 +53,9 @@ def create_zmq_server(
``predict_action_chunk`` / ``predict_async`` method.
host: Bind address.
port: Bind port.
curve_server_cert: Optional pyzmq CURVE server secret certificate.
curve_client_cert_dir: Directory of allowed client public certificates.
control_token: Optional token required for built-in control endpoints.

Returns:
A configured ``PolicyServer`` ready to ``run()``.
Expand Down Expand Up @@ -125,7 +132,13 @@ def get_status() -> dict:
"avg_infer_time_ms": round(avg * 1000, 2),
}

server = PolicyServer(host=host, port=port)
server = PolicyServer(
host=host,
port=port,
curve_server_cert=curve_server_cert,
curve_client_cert_dir=curve_client_cert_dir,
control_token=control_token,
)
server.register_endpoint("predict_action", predict_action)
server.register_endpoint("reset", reset, requires_input=False)
server.register_endpoint("get_status", get_status, requires_input=False)
Expand Down
Loading
Loading