diff --git a/.agents/skills/build-from-issue/SKILL.md b/.agents/skills/build-from-issue/SKILL.md index dbb5396cd..1225e6248 100644 --- a/.agents/skills/build-from-issue/SKILL.md +++ b/.agents/skills/build-from-issue/SKILL.md @@ -148,6 +148,7 @@ In the prompt, instruct the reviewer to: - **Medium**: Multiple files/components, some design decisions, but well-scoped - **High**: Cross-cutting changes, architectural decisions needed, significant unknowns 8. Call out risks, unknowns, and decisions that need stakeholder input. +9. Assess **LSM compatibility** — if the change touches process identity, `/proc` filesystem access, binary execution, or inter-process visibility, flag whether it will behave differently on hosts running SELinux (enforcing) or AppArmor. In particular, tests that fork+exec into system binaries will fail on SELinux-enforcing hosts due to cross-label `/proc//exe` access restrictions. ### A2: Post the Plan Comment diff --git a/.agents/skills/create-spike/SKILL.md b/.agents/skills/create-spike/SKILL.md index faa7aca08..f141f82ef 100644 --- a/.agents/skills/create-spike/SKILL.md +++ b/.agents/skills/create-spike/SKILL.md @@ -91,7 +91,9 @@ The prompt to the reviewer **must** instruct it to: 9. **Check architecture docs** in the `architecture/` directory for relevant documentation about the affected subsystems. -10. **Determine the issue type:** `feat`, `fix`, `refactor`, `chore`, `perf`, or `docs`. +10. **Assess Linux Security Module (LSM) impact.** If the change involves process identity, `/proc` filesystem access, file labeling, binary execution, or inter-process visibility, call out whether it will behave differently on hosts running SELinux (enforcing) or AppArmor. For example: reading `/proc//exe` across an SELinux domain boundary returns ENOENT, not EACCES. Tests that fork+exec into system binaries (different SELinux label) will fail on enforcing hosts. Flag any LSM-sensitive code paths and recommend mitigations. + +11. **Determine the issue type:** `feat`, `fix`, `refactor`, `chore`, `perf`, or `docs`. ### What makes a good investigation prompt diff --git a/.agents/skills/debug-openshell-cluster/SKILL.md b/.agents/skills/debug-openshell-cluster/SKILL.md index 408ef85c7..6c8f73bb6 100644 --- a/.agents/skills/debug-openshell-cluster/SKILL.md +++ b/.agents/skills/debug-openshell-cluster/SKILL.md @@ -63,15 +63,46 @@ Use gateway metadata, deployment values, or the user's setup notes to identify t docker info docker ps --filter name=openshell docker logs --tail=200 +docker run --rm --entrypoint /openshell-sandbox "${OPENSHELL_DOCKER_SUPERVISOR_IMAGE:-ghcr.io/nvidia/openshell/supervisor:latest}" --version openshell status ``` +For Docker GPU failures, check CDI support and NVIDIA CDI discovery separately: + +```bash +docker info --format '{{json .CDISpecDirs}}' +docker info --format '{{json .DiscoveredDevices}}' +for dir in /etc/cdi /var/run/cdi; do + if [ -d "$dir" ]; then + find "$dir" -maxdepth 1 -type f \( -name '*.yaml' -o -name '*.json' \) -print + else + echo "$dir missing" + fi +done +systemctl is-enabled nvidia-cdi-refresh.service nvidia-cdi-refresh.path || true +systemctl is-active nvidia-cdi-refresh.service nvidia-cdi-refresh.path || true +systemctl status nvidia-cdi-refresh.service nvidia-cdi-refresh.path --no-pager --lines=50 +journalctl -u nvidia-cdi-refresh.service --no-pager --lines=100 +``` + +When the NVIDIA Container Toolkit CDI refresh units are not enabled or no NVIDIA CDI spec has been generated, enable them and trigger a refresh: + +```bash +sudo systemctl enable --now nvidia-cdi-refresh.path +sudo systemctl enable --now nvidia-cdi-refresh.service +sudo systemctl restart nvidia-cdi-refresh.service +docker info --format '{{json .DiscoveredDevices}}' +``` + Common findings: - Docker daemon unavailable: start Docker Desktop or Docker Engine. - Gateway process stopped: inspect exit status and logs. - Sandbox image missing or pull denied: verify image reference and registry credentials. +- Docker driver cannot initialize because it cannot find `openshell-sandbox`: verify `OPENSHELL_DOCKER_SUPERVISOR_BIN`, the sibling binary next to `openshell-gateway`, or the configured supervisor image contains `/openshell-sandbox`. - Sandbox never registers: check gateway logs and supervisor callback endpoint. +- Supervisor image exits before printing `openshell-sandbox --version`: the image should be the scratch supervisor image from `deploy/docker/Dockerfile.supervisor` and must contain a static executable at `/openshell-sandbox`. +- `mise run e2e:docker:gpu` fails with `docker info --format json did not report any discovered NVIDIA CDI GPU devices`: Docker may report `CDISpecDirs` while still having no generated NVIDIA CDI specs. Verify `.DiscoveredDevices` contains entries such as `nvidia.com/gpu=all`, verify `/etc/cdi` or `/var/run/cdi` contains a generated NVIDIA spec, and check that `nvidia-cdi-refresh.service` and `nvidia-cdi-refresh.path` from NVIDIA Container Toolkit are enabled and healthy. The service is a one-shot unit, so `inactive (dead)` can be normal after a successful run; use `systemctl status` and `journalctl` to distinguish success from a skipped or failed refresh. NVIDIA recommends enabling the path and service units, and restarting `nvidia-cdi-refresh.service` to regenerate missing or stale CDI specs. If specs are generated but Docker still reports no discovered devices, restart Docker or reload the daemon and re-check `docker info`. For source checkout development, restart the local gateway with: @@ -111,12 +142,18 @@ Check required Helm deployment secrets: ```bash kubectl -n openshell get secret \ - openshell-ssh-handshake \ openshell-server-tls \ openshell-server-client-ca \ - openshell-client-tls + openshell-client-tls \ + openshell-jwt-keys ``` +If the gateway exits with `failed to read sandbox JWT signing key from +/etc/openshell-jwt/signing.pem`, verify that `openshell-jwt-keys` contains +`signing.pem`, `public.pem`, and `kid`, and that the StatefulSet mounts the +`sandbox-jwt` secret at `/etc/openshell-jwt`. The sandbox JWT mount is required +even when local Helm values disable TLS. + Check the image references currently used by the gateway deployment: ```bash @@ -124,7 +161,11 @@ kubectl -n openshell get statefulset openshell -o jsonpath="{.spec.template.spec helm -n openshell get values openshell | grep -E 'repository|tag|supervisorImage' ``` -The gateway image and `server.supervisorImage` should use the same build tag in branch and E2E deploys. A stale supervisor image can make sandbox behavior lag behind gateway policy or proto changes. +The gateway image built from `deploy/docker/Dockerfile.gateway` and the scratch supervisor image built from `deploy/docker/Dockerfile.supervisor` should use the same build tag in branch and E2E deploys. A stale supervisor image can make sandbox behavior lag behind gateway policy or proto changes. + +For local/external pull mode (the default local path via `mise run cluster`), local images are tagged to the configured local registry base, pushed to that registry, and pulled by k3s via the `registries.yaml` mirror endpoint. The `cluster` task pushes prebuilt local tags (`openshell/*:dev`, falling back to `localhost:5000/openshell/*:dev` or `127.0.0.1:5000/openshell/*:dev`). + +Gateway image builds stage a partial Rust workspace from `deploy/docker/Dockerfile.images`. If cargo fails with a missing manifest under `/build/crates/...`, or an imported symbol exists locally but is missing in the image build, verify that every current gateway dependency crate, including `openshell-driver-docker`, `openshell-driver-kubernetes`, and `openshell-ocsf`, is copied into the staged workspace there. For plaintext local evaluation, confirm the chart has: @@ -171,6 +212,18 @@ helm -n openshell get values openshell | grep sandboxNamespace Then inspect sandbox resources in that namespace. +Check the configured sandbox service account when TokenReview bootstrap or +sandbox registration fails. Helm creates a dedicated sandbox service account by +default and writes it to `[openshell.drivers.kubernetes].service_account_name`; +the gateway rejects projected tokens from other service accounts. + +```bash +helm -n openshell get values openshell | grep -A3 sandboxServiceAccount +kubectl -n get serviceaccount openshell-sandbox +kubectl -n openshell get configmap openshell-config -o jsonpath='{.data.gateway\.toml}' +kubectl -n get sandbox -o jsonpath='{.spec.template.spec.serviceAccountName}{"\n"}' +``` + ### Step 6: Check VM-Backed Gateways Use the VM driver logs and host diagnostics available in the user's environment. Verify: @@ -194,6 +247,7 @@ openshell logs | `openshell status` fails | Gateway endpoint unreachable or auth mismatch | `openshell gateway info`, gateway logs | | Gateway starts but sandbox create fails | Compute driver cannot reach runtime | Docker/Podman/Kubernetes/VM driver logs | | Docker or Podman sandbox never registers | Wrong callback endpoint or supervisor startup failure | Gateway logs and sandbox container logs | +| Docker GPU e2e fails before GPU sandbox comparison | NVIDIA CDI specs are missing or Docker has not discovered them | `docker info --format '{{json .DiscoveredDevices}}'`, `/etc/cdi`, `/var/run/cdi`, `nvidia-cdi-refresh.service` | | Kubernetes gateway pod pending | PVC unbound, taint, selector, or insufficient resources | `kubectl -n openshell describe pod ` | | Kubernetes gateway pod crash loops | Missing secret, bad DB URL, bad TLS config | `kubectl -n openshell logs statefulset/openshell` | | CLI TLS error | Local mTLS bundle does not match server cert/CA | Check `~/.config/openshell/gateways//mtls/` | diff --git a/.agents/skills/helm-dev-environment/SKILL.md b/.agents/skills/helm-dev-environment/SKILL.md index 18d8c241e..a97395fb1 100644 --- a/.agents/skills/helm-dev-environment/SKILL.md +++ b/.agents/skills/helm-dev-environment/SKILL.md @@ -26,8 +26,9 @@ mise run helm:k3s:create ``` Creates a k3d cluster and merges its kubeconfig into the worktree-local `kubeconfig` file. -Also applies base manifests (`deploy/kube/manifests/agent-sandbox.yaml`). Traefik is -disabled at cluster creation time. +Also applies base manifests (`deploy/kube/manifests/agent-sandbox.yaml`) and preloads the +default community sandbox image into k3d so the first sandbox create does not wait on a +large registry pull. Traefik is disabled at cluster creation time. **Multi-worktree support:** the cluster name is derived from the last component of the current git branch (e.g. branch `kube-support/local-dev/tmutch` → cluster @@ -43,6 +44,8 @@ Port mappings created at cluster time (cannot be changed without recreating): Override with env vars before running `helm:k3s:create`: - `HELM_K3S_LB_HOST_PORT` (default: `8080`) +- `HELM_K3S_PRELOAD_SANDBOX_IMAGE` (default: + `ghcr.io/nvidia/openshell-community/sandboxes/base:latest`; set to an empty value to skip) ### 2. Deploy OpenShell @@ -57,7 +60,8 @@ mise run helm:skaffold:run ``` Both commands build the `gateway` and `supervisor` images and deploy the OpenShell Helm -chart. The `pkiInitJob` hook runs on first install to generate mTLS secrets. Envoy Gateway opt-in; see the Optional Add-ons section below. +chart. The `pkiInitJob` hook (a pre-install Job that runs `openshell-gateway generate-certs`) +generates mTLS secrets on first install. Envoy Gateway opt-in; see the Optional Add-ons section below. The gateway Service uses ClusterIP. Access is via Envoy Gateway (port `8080`) or `kubectl port-forward`. diff --git a/.agents/skills/openshell-cli/SKILL.md b/.agents/skills/openshell-cli/SKILL.md index 7451ea03d..4b7501c1f 100644 --- a/.agents/skills/openshell-cli/SKILL.md +++ b/.agents/skills/openshell-cli/SKILL.md @@ -141,6 +141,7 @@ openshell sandbox create \ Key flags: - `--provider`: Attach one or more providers (repeatable) - `--policy`: Custom policy YAML (otherwise uses built-in default or `OPENSHELL_SANDBOX_POLICY` env var) +- `--cpu`, `--memory`: Set per-sandbox compute sizing. Docker/Podman apply limits; Kubernetes applies matching requests and limits. - `--upload [:]`: Upload local files into the sandbox (default dest: `/sandbox`) - `--no-keep`: Delete the sandbox after the initial command or shell exits - `--forward `: Forward a local port and keep the sandbox alive diff --git a/.agents/skills/openshell-cli/cli-reference.md b/.agents/skills/openshell-cli/cli-reference.md index adfa849dd..799850232 100644 --- a/.agents/skills/openshell-cli/cli-reference.md +++ b/.agents/skills/openshell-cli/cli-reference.md @@ -143,6 +143,8 @@ Create a sandbox through the active gateway, wait for readiness, then connect or | `--no-keep` | Delete sandbox after the initial command or shell exits | | `--provider ` | Provider to attach (repeatable) | | `--policy ` | Path to custom policy YAML | +| `--cpu ` | CPU amount for the sandbox (for example: `500m`, `1`, `2.5`) | +| `--memory ` | Memory amount for the sandbox (for example: `512Mi`, `4Gi`, `8G`) | | `--forward ` | Forward local port to sandbox (keeps the sandbox alive) | | `--tty` | Force pseudo-terminal allocation | | `--no-tty` | Disable pseudo-terminal allocation | diff --git a/.agents/skills/test-release-canary/SKILL.md b/.agents/skills/test-release-canary/SKILL.md new file mode 100644 index 000000000..4bf7d38ae --- /dev/null +++ b/.agents/skills/test-release-canary/SKILL.md @@ -0,0 +1,122 @@ +--- +name: test-release-canary +description: Manually dispatch and iterate on the Release Canary workflow that smoke-tests published OpenShell artifacts (install.sh on macOS/Ubuntu/Fedora, Helm chart on kind) after each Release Dev publish. Use when changing `.github/workflows/release-canary.yml`, validating a release before tagging, debugging a canary failure, or reproducing a canary job locally. Trigger keywords - release canary, release-canary, canary failed, canary dispatch, test release canary, post-release smoke, install.sh canary, helm chart canary, kind canary, dispatch canary. +--- + +# Test Release Canary + +The Release Canary (`.github/workflows/release-canary.yml`) smoke-tests the artifacts a `Release Dev` run just published. It is the last automated checkpoint before tagging a public release: if the canary is red, the published `dev` artifacts do not install on a stock environment. + +## What the canary verifies + +| Job | Runner | Verifies | +|---|---|---| +| `macos` | `macos-latest-xlarge` | `install.sh` resolves the Homebrew formula, brew installs the cask, and `openshell status` reaches the brew-services–backed local gateway with the VM driver. | +| `ubuntu` | `ubuntu-latest` | `install.sh` installs the Debian package, the post-install systemd user service starts, and `openshell status` reaches the local gateway with the Docker driver. | +| `fedora` | `fedora:latest` container | `install.sh` installs the RPM packages, the local gateway starts under Podman, and `openshell status` succeeds. | +| `kubernetes` | `ubuntu-latest` + kind | `helm install oci://ghcr.io/nvidia/openshell/helm-chart --version 0.0.0-dev` succeeds in a kind cluster, the gateway pod becomes Ready, port-forward exposes 8080, and the released CLI registers the in-cluster gateway and runs `openshell status` against it. | + +`install.sh` defaults to the *latest tagged* release — the canary is therefore checking that the most recent public release still installs, not the just-published `dev` build. The `kubernetes` job is the exception: it pins to `0.0.0-dev` chart + `:dev` images. + +## Trigger paths + +The workflow has two triggers: + +```yaml +on: + workflow_dispatch: + workflow_run: + workflows: ["Release Dev"] + types: [completed] +``` + +- **Automatic.** Every successful `Release Dev` run (on `main` or a manual dispatch of Release Dev) fires the canary. Each job gates on `github.event.workflow_run.conclusion == 'success'` so a failed Release Dev does not run the canary. +- **Manual.** `workflow_dispatch` lets you run the canary on demand against any branch's workflow definition. + +When dispatched manually, `github.event.workflow_run.head_sha` is empty and the workflow falls back to `github.sha` (the branch tip) for the `install.sh` URL. + +## Manual dispatch + +Run the canary as-is on the current branch: + +```shell +gh workflow run release-canary.yml --ref "$(git branch --show-current)" +``` + +Watch the run that starts: + +```shell +sleep 5 # let GitHub register the dispatch +gh run list --workflow release-canary.yml --limit 1 +gh run watch "$(gh run list --workflow release-canary.yml --limit 1 --json databaseId --jq '.[0].databaseId')" +``` + +View only failed jobs after completion: + +```shell +gh run view --log-failed +``` + +## Iterating on the canary itself + +When you change `release-canary.yml` on a branch, a manual dispatch on that branch tests *your branch's workflow logic* against *main's published artifacts* (`0.0.0-dev` chart, `:dev` images, latest tagged install.sh assets). This is what you want for iterating on the canary — you're validating that the canary still works against known-good artifacts. + +Note `install.sh` is pulled from `raw.githubusercontent.com/NVIDIA/OpenShell/${head_sha}/install.sh`, so changes to `install.sh` on your branch *are* exercised even though the binaries it downloads are from the latest public tag. + +## Testing artifacts from a specific SHA + +`Release Dev` publishes two chart versions for every dev build (see `.github/actions/release-helm-oci/action.yml:89-102`): + +- `oci://ghcr.io/nvidia/openshell/helm-chart:0.0.0-dev` — floating, overwritten on every main push. +- `oci://ghcr.io/nvidia/openshell/helm-chart:0.0.0-dev.` — immutable, `appVersion` set to the same SHA so it pulls `ghcr.io/nvidia/openshell/gateway:` and `:supervisor:`. + +To smoke-test the chart for a specific dev build, dispatch `Release Dev` on the branch first, then run the kind canary steps locally pointed at the SHA-pinned chart (see "Local kind reproduction" below). The release-canary workflow itself does not currently expose `chart_version` / `image_tag` inputs. + +## Local kind reproduction + +The `kubernetes` job can be reproduced on any machine with Docker and `mise install`-provided `kubectl` + `helm`: + +```shell +kind create cluster --name release-canary-local + +helm install openshell oci://ghcr.io/nvidia/openshell/helm-chart \ + --version 0.0.0-dev \ + --namespace openshell --create-namespace \ + --set server.disableTls=true \ + --wait --timeout 5m + +kubectl wait --namespace openshell \ + --for=condition=Ready pod \ + --selector="app.kubernetes.io/name=openshell,app.kubernetes.io/instance=openshell" \ + --timeout=300s + +kubectl port-forward --namespace openshell svc/openshell 8080:8080 & +openshell gateway add http://127.0.0.1:8080 --local --name kind +openshell status +``` + +Keep `pkiInitJob.enabled=true` (the chart default), even when +`server.disableTls=true`. The hook also generates the sandbox JWT signing +secret that the gateway pod always mounts. + +Swap `0.0.0-dev` for `0.0.0-dev.` to pin to a specific dev build. Tear down with `kind delete cluster --name release-canary-local`. + +Loopback registration auto-derives the gateway name to `openshell` if `--name` is omitted, which collides with the `install.sh`-installed local gateway — always pass `--name kind` (or another distinct name) when registering in addition to a local install. + +## Diagnosing failures + +| Symptom | Likely cause | Where to look | +|---|---|---| +| `macos`/`ubuntu`/`fedora` job fails on `install.sh` | Latest tagged release missing an asset, checksum mismatch, or `install.sh` regression on this branch. | Job log around the `curl … install.sh \| sh` step. | +| `macos`/`ubuntu`/`fedora` job fails on `openshell status` | Local gateway service did not start (systemd/brew/podman). Often a driver issue. | Service logs in the job log; `OPENSHELL_DRIVERS` env in the "Ensure …" step. | +| `kubernetes` job fails on `helm install --wait` | Chart did not deploy in 5 min — usually image pull failure or readiness probe failing. | "Diagnostics on failure" step dumps `helm status`, manifest, pod describe, pod logs. | +| `kubernetes` job fails on `kubectl wait` | Gateway pod stuck `CrashLoopBackOff` or `ImagePullBackOff`. | Diagnostics dump; check `:dev` image existence at `ghcr.io/nvidia/openshell/gateway`. | +| `kubernetes` job fails on `openshell gateway add` or `status` | Port-forward not reachable, or CLI/gateway proto mismatch. | `port-forward.log` and `openshell gateway list` in the diagnostics dump. | + +The `kubernetes` job's diagnostics step (only runs `if: failure()`) emits, in order: helm status, rendered manifest, `kubectl get all`, pod descriptions, pod logs (200 lines per container), port-forward log, gateway list, CLI version. Read it top-to-bottom — most failures fall out by the manifest or pod logs. + +## Related + +- `helm-dev-environment` skill — local k3d-based dev environment (more featureful than the canary's kind cluster, but uses Skaffold-built local images, not published artifacts). +- `watch-github-actions` skill — generic `gh run` workflow monitoring. +- `debug-openshell-cluster` skill — runtime gateway/sandbox diagnostics that pair with the kind job's diagnostics dump. diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 000000000..0005fc2bd --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[env] +# z3-sys bindgen needs the z3 include path. On some distros (e.g. RHEL/Fedora) +# the header lives in /usr/include/z3/ rather than /usr/include/. The extra -I +# is harmless on systems where the path doesn't exist. +BINDGEN_EXTRA_CLANG_ARGS = "-I/usr/include/z3" diff --git a/.claude/agent-memory/arch-doc-writer/MEMORY.md b/.claude/agent-memory/arch-doc-writer/MEMORY.md index 4e5781a59..1fce46001 100644 --- a/.claude/agent-memory/arch-doc-writer/MEMORY.md +++ b/.claude/agent-memory/arch-doc-writer/MEMORY.md @@ -11,7 +11,7 @@ - Proxy: `crates/openshell-sandbox/src/proxy.rs` - Policy crate: `crates/openshell-policy/src/lib.rs` (YAML<->proto conversion, validation, restrictive default) - Server multiplex: `crates/openshell-server/src/multiplex.rs` -- SSH tunnel: `crates/openshell-server/src/ssh_tunnel.rs` +- SSH sessions: `crates/openshell-server/src/ssh_sessions.rs` (session persistence, reaper) - Sandbox SSH server: `crates/openshell-sandbox/src/ssh.rs` - Providers: `crates/openshell-providers/src/providers/` (per-provider modules) - Bootstrap: `crates/openshell-bootstrap/src/lib.rs` (cluster lifecycle) @@ -27,7 +27,7 @@ - OPA baked-in rules: `include_str!("../data/sandbox-policy.rego")` in opa.rs - Policy loading: gRPC mode (OPENSHELL_SANDBOX_ID + OPENSHELL_ENDPOINT) or file mode (--policy-rules + --policy-data) - Env vars: sandbox uses OPENSHELL_* prefix (e.g., OPENSHELL_SANDBOX_ID, OPENSHELL_ENDPOINT, OPENSHELL_POLICY_RULES) -- CLI flag: `--openshell-endpoint` (NOT `--openshell-endpoint`) +- CLI flag: `--gateway-endpoint` (direct URL to gateway); resolution: --gateway-endpoint > --gateway > OPENSHELL_GATEWAY env > active_gateway file - Provider env injection: both entrypoint process (tokio Command) and SSH shell (std Command) - Cluster bootstrap: `sandbox_create_with_bootstrap()` auto-deploys when no cluster exists (main.rs ~line 632) - CLI cluster resolution: --cluster flag > OPENSHELL_CLUSTER env > active cluster file @@ -50,9 +50,9 @@ - Persistence: single `objects` table, protobuf payloads, Store enum dispatches SQLite vs Postgres by URL prefix - Persistence CRUD: upsert ON CONFLICT (id) not (object_type, id); list ORDER BY created_at_ms ASC, name ASC (not id!) - --db-url has no code default; Helm values.yaml sets `sqlite:/var/openshell/openshell.db` -- Object types: "sandbox", "provider", "ssh_session", "inference_route" -- each implements ObjectType/ObjectId/ObjectName +- Object types: "sandbox", "provider", "ssh_session", "inference_route", "service_endpoint", "provider_profile" -- each implements ObjectType/ObjectId/ObjectName - Config: `openshell_core::Config` in `crates/openshell-core/src/config.rs`, all flags have env var fallbacks -- SSH handshake: "NSSH1" preface + HMAC-SHA256, used in both exec proxy (grpc.rs) and tunnel gateway (ssh_tunnel.rs) +- SSH transport: CLI opens ForwardTcp gRPC stream (gated by CreateSshSession short-lived token), gateway relays via DuplexStream to supervisor RelayStream, supervisor connects to sandbox russh server over root-only Unix socket (/run/openshell/ssh.sock); channel uses mTLS when https:// endpoint configured, plaintext when http:// (Podman driver does not yet inject mTLS client materials). NSSH1 appears only in openshell-ocsf examples/tests, not on any live code path. - Phase derivation: transient reasons (ReconcilerError, DependenciesNotReady) -> Provisioning; all others -> Error - Broadcast bus buffer sizes: SandboxWatchBus=128, TracingLogBus=1024, PlatformEventBus=1024 - Sandbox CRD: `agents.x-k8s.io/v1alpha1/Sandbox`, labels: `openshell.ai/sandbox-id`, `openshell.ai/managed-by` @@ -75,13 +75,13 @@ - DNS solution in cluster-entrypoint.sh: iptables DNAT proxy (NOT host-gateway resolv.conf) ## Sandbox Connect Details -- CLI SSH module: `crates/openshell-cli/src/ssh.rs` (sandbox_connect, sandbox_exec, sandbox_rsync, sandbox_ssh_proxy) +- CLI SSH module: `crates/openshell-cli/src/ssh.rs` (sandbox_connect, sandbox_connect_editor, sandbox_forward, sandbox_exec, sandbox_sync_up_files, sandbox_sync_up, sandbox_sync_down, sandbox_ssh_proxy, sandbox_ssh_proxy_by_name) - Re-exported from run.rs: `pub use crate::ssh::{...}` for backward compat - ssh-proxy subcommand: `Commands::SshProxy` in main.rs (~line 139) -- Gateway loopback resolution: `resolve_ssh_gateway()` in ssh.rs -- overrides loopback with cluster endpoint host -- ExecSandbox gRPC: uses single-use TCP proxy + russh client in grpc.rs +- Gateway loopback resolution: `resolve_ssh_gateway()` in `crates/openshell-core/src/forward.rs:439` -- overrides loopback with cluster endpoint host; imported by ssh.rs and tui +- ExecSandbox gRPC: uses single-use TCP proxy + russh client in `grpc/sandbox.rs` (handle_exec_sandbox -> stream_exec_over_relay); operates over a relay DuplexStream through the supervisor session, not a direct TCP connection - PTY I/O: 3 std::threads (writer, reader, exit) with reader_done sync for SSH protocol ordering -- SSH daemon: russh server, ephemeral Ed25519 key, pre_exec: setsid -> TIOCSCTTY -> setns -> drop_privileges -> sandbox::apply +- SSH daemon: russh server, ephemeral Ed25519 key, pre_exec: setsid -> TIOCSCTTY -> setns -> drop_privileges -> harden_child_process -> sandbox::linux::enforce(prepared) [Linux] / sandbox::apply [non-Linux]; sandbox::linux::prepare() runs before fork ## Policy Reload Details - Poll loop: `run_policy_poll_loop()` in lib.rs, spawned after child process, gRPC mode only diff --git a/.claude/agents/principal-engineer-reviewer.md b/.claude/agents/principal-engineer-reviewer.md index ae7e49ea2..8badf491e 100644 --- a/.claude/agents/principal-engineer-reviewer.md +++ b/.claude/agents/principal-engineer-reviewer.md @@ -146,6 +146,36 @@ applies to every PR — use judgment. - **Supply chain:** Do new dependencies introduce known vulnerabilities or unmaintained transitive dependencies? +### Linux Security Module (LSM) compatibility + +OpenShell runs on hosts with SELinux or AppArmor in enforcing mode. +Review changes that interact with the `/proc` filesystem, process +identity, binary execution, or inter-process visibility for +LSM-related issues: + +- **`/proc//exe` across domain boundaries:** On SELinux-enforcing + hosts, readlink on `/proc//exe` returns ENOENT (not EACCES) when + the target process has a different SELinux label than the caller. + This affects any code that resolves binary identity after fork+exec + into a differently-labeled binary (e.g., system binaries under + `bin_t` vs. build artifacts under `user_home_t`). + +- **Tests that fork+exec into system binaries:** Tests that fork a child + and exec into `/bin/sleep`, `/usr/bin/cat`, or similar will fail on + SELinux-enforcing hosts because the child transitions to a different + domain, making its `/proc` entries unreadable to the parent. Flag + these tests and recommend either using a same-label helper binary or + skipping on enforcing hosts with a TODO. + +- **File labeling and Landlock interaction:** New files created in + non-standard paths may inherit unexpected SELinux labels. Verify that + Landlock and SELinux policies do not conflict. + +- **Socket and IPC visibility:** SELinux can restrict `/proc//fd` + and `/proc//net` visibility across domain boundaries. Code that + scans these paths for socket ownership should handle access failures + gracefully. + ## Principles - Don't nitpick style unless it harms readability. Trust `rustfmt` and the diff --git a/.github/actions/pr-gate/action.yml b/.github/actions/pr-gate/action.yml index 5d4b9c183..0c55bf120 100644 --- a/.github/actions/pr-gate/action.yml +++ b/.github/actions/pr-gate/action.yml @@ -16,6 +16,9 @@ outputs: should_run: description: "true if the workflow should proceed, false otherwise" value: ${{ steps.gate.outputs.should_run }} + labels_json: + description: "JSON array of PR label names for push-triggered mirror runs, or [] otherwise" + value: ${{ steps.gate.outputs.labels_json }} runs: using: composite @@ -35,20 +38,23 @@ runs: REQUIRED_LABEL: ${{ inputs.required_label }} run: | if [ "$EVENT_NAME" != "push" ]; then + echo "labels_json=[]" >> "$GITHUB_OUTPUT" echo "should_run=true" >> "$GITHUB_OUTPUT" exit 0 fi if [ "$GET_PR_INFO_OUTCOME" != "success" ]; then + echo "labels_json=[]" >> "$GITHUB_OUTPUT" echo "should_run=false" >> "$GITHUB_OUTPUT" exit 0 fi head_sha="$(jq -r '.head.sha' <<< "$PR_INFO")" + labels_json="$(jq -c '[.labels[].name]' <<< "$PR_INFO")" if [ -z "$REQUIRED_LABEL" ]; then has_label=true else - has_label="$(jq -r --arg L "$REQUIRED_LABEL" '[.labels[].name] | index($L) != null' <<< "$PR_INFO")" + has_label="$(jq -r --arg L "$REQUIRED_LABEL" 'index($L) != null' <<< "$labels_json")" fi # Only trust copied pull-request/* pushes that still match the PR head @@ -59,4 +65,5 @@ runs: should_run=false fi + echo "labels_json=$labels_json" >> "$GITHUB_OUTPUT" echo "should_run=$should_run" >> "$GITHUB_OUTPUT" diff --git a/.github/actions/pr-merge-base/action.yml b/.github/actions/pr-merge-base/action.yml new file mode 100644 index 000000000..a702e9c0b --- /dev/null +++ b/.github/actions/pr-merge-base/action.yml @@ -0,0 +1,37 @@ +name: PR Merge Base +description: Resolve and fetch the merge-base commit needed to diff a copy-pr-bot pull-request/ push against the PR base branch. + +inputs: + gh_token: + description: GitHub token for PR and compare API calls. + required: true + +outputs: + base_sha: + description: Merge-base commit SHA for pull-request/ refs, or empty for other refs. + value: ${{ steps.merge-base.outputs.base_sha }} + +runs: + using: composite + steps: + - id: merge-base + shell: bash + env: + GH_TOKEN: ${{ inputs.gh_token }} + GH_REPO: ${{ github.repository }} + REF_NAME: ${{ github.ref_name }} + GITHUB_SHA_VALUE: ${{ github.sha }} + run: | + set -euo pipefail + + if [[ "$REF_NAME" =~ ^pull-request/([0-9]+)$ ]]; then + pr_number="${BASH_REMATCH[1]}" + base_ref=$(gh pr view "$pr_number" --repo "$GH_REPO" --json baseRefName -q '.baseRefName') + # The mirrored branch is a push ref, so changed-files needs the true merge-base + # to diff the PR head against its base branch. + base_sha=$(gh api "repos/$GH_REPO/compare/$base_ref...$GITHUB_SHA_VALUE" --jq '.merge_base_commit.sha') + git fetch --no-tags --depth=1 origin "$base_sha" + echo "base_sha=$base_sha" >> "$GITHUB_OUTPUT" + else + echo "base_sha=" >> "$GITHUB_OUTPUT" + fi diff --git a/.github/actions/setup-buildx/action.yml b/.github/actions/setup-buildx/action.yml index 41210035f..3e1d54521 100644 --- a/.github/actions/setup-buildx/action.yml +++ b/.github/actions/setup-buildx/action.yml @@ -1,23 +1,11 @@ name: Setup Docker Buildx description: > - Create a Docker Buildx builder. Two modes: - * driver=remote (default) — multi-arch builder against in-cluster BuildKit - pods. Requires EKS connectivity. Behaviour unchanged from prior versions. - * driver=local — single-node buildx on the local docker-container driver. - Pair with cache-to/cache-from=type=gha on build steps for persistence. - Works on nv-gha-runners; no EKS needed. + Create a Docker Buildx builder on the local docker-container driver. Pair + with cache-to/cache-from=type=gha on build steps for persistence. Works on + nv-gha-runners; no EKS BuildKit service is required. Cleanup is automatic when the job finishes (docker/setup-buildx-action default). inputs: - driver: - description: "buildx driver: 'remote' or 'local'" - default: remote - amd64-endpoint: - description: BuildKit endpoint for linux/amd64 (remote driver only) - default: tcp://buildkit-amd64.buildkit:1234 - arm64-endpoint: - description: BuildKit endpoint for linux/arm64 (remote driver only) - default: tcp://buildkit-arm64.buildkit:1234 name: description: Builder instance name default: openshell @@ -34,21 +22,7 @@ inputs: runs: using: composite steps: - - name: Set up Docker Buildx (remote) - if: inputs.driver == 'remote' - uses: docker/setup-buildx-action@v3 - with: - name: ${{ inputs.name }} - driver: remote - endpoint: ${{ inputs.amd64-endpoint }} - platforms: linux/amd64 - append: | - - endpoint: ${{ inputs.arm64-endpoint }} - platforms: linux/arm64 - buildkitd-config: ${{ inputs.buildkitd-config }} - - name: Set up Docker Buildx (local) - if: inputs.driver == 'local' uses: docker/setup-buildx-action@v3 with: name: ${{ inputs.name }} diff --git a/.github/workflows/branch-checks.yml b/.github/workflows/branch-checks.yml index abbcef423..eaa4178a6 100644 --- a/.github/workflows/branch-checks.yml +++ b/.github/workflows/branch-checks.yml @@ -10,7 +10,6 @@ env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: "0" MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - SCCACHE_GHA_ENABLED: "true" permissions: contents: read @@ -30,7 +29,7 @@ jobs: outputs: should_run: ${{ steps.gate.outputs.should_run }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - id: gate uses: ./.github/actions/pr-gate @@ -46,7 +45,7 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Mark workspace as safe for git run: git config --global --add safe.directory "$GITHUB_WORKSPACE" @@ -70,7 +69,7 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install tools run: mise install --locked @@ -88,6 +87,7 @@ jobs: runner: [linux-amd64-cpu8, linux-arm64-cpu8] runs-on: ${{ matrix.runner }} env: + SCCACHE_GHA_ENABLED: "true" SCCACHE_GHA_VERSION: branch-checks-rust-${{ matrix.runner }} container: image: ghcr.io/nvidia/openshell/ci:latest @@ -95,22 +95,23 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 - - - name: Install tools - run: mise install --locked + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Configure GHA sccache backend uses: mozilla-actions/sccache-action@9e7fa8a12102821edf02ca5dbea1acd0f89a2696 # v0.0.10 + - name: Install tools + run: mise install --locked + - name: Cache Rust target and registry uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2 with: - # Separate caches for clippy (check-like) vs test (full build) - # so they don't thrash each other's artifacts + # Keep branch-check caches partitioned by runner architecture; lint + # and test intentionally share the same job-local target directory. shared-key: rust-checks-${{ matrix.runner }} - # Cache the sccache directory too - cache-directories: .cache/sccache + # Preserve compiled artifacts from failed lint/test runs so the next + # push to the same PR branch does not start from a cold cache. + cache-on-failure: "true" - name: Format run: mise run rust:format:check @@ -128,7 +129,7 @@ jobs: stats_bin="${SCCACHE_PATH:-sccache}" "$stats_bin" --show-stats status=$? - if [[ $status -ne 0 ]]; then + if [ "$status" -ne 0 ]; then echo "::warning::sccache stats unavailable (exit $status)" fi exit 0 @@ -148,7 +149,7 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install tools run: mise install --locked @@ -173,7 +174,7 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install tools run: mise install --locked diff --git a/.github/workflows/branch-docs.yml b/.github/workflows/branch-docs.yml index 1368bc775..3b2a4099e 100644 --- a/.github/workflows/branch-docs.yml +++ b/.github/workflows/branch-docs.yml @@ -19,7 +19,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Check Fern preview availability id: fern-preview @@ -34,7 +34,7 @@ jobs: fi - name: Setup Node.js - uses: actions/setup-node@v6 + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6 with: node-version: "24" diff --git a/.github/workflows/branch-e2e.yml b/.github/workflows/branch-e2e.yml index 3d8dd5928..de8bd5551 100644 --- a/.github/workflows/branch-e2e.yml +++ b/.github/workflows/branch-e2e.yml @@ -8,6 +8,10 @@ on: permissions: {} +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: pr_metadata: name: Resolve PR metadata @@ -17,38 +21,64 @@ jobs: pull-requests: read outputs: should_run: ${{ steps.gate.outputs.should_run }} + run_core_e2e: ${{ steps.labels.outputs.run_core_e2e }} + run_gpu_e2e: ${{ steps.labels.outputs.run_gpu_e2e }} + run_any_e2e: ${{ steps.labels.outputs.run_any_e2e }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - id: gate uses: ./.github/actions/pr-gate - with: - required_label: test:e2e + - id: labels + if: steps.gate.outputs.should_run == 'true' + env: + EVENT_NAME: ${{ github.event_name }} + LABELS_JSON: ${{ steps.gate.outputs.labels_json }} + shell: bash + run: | + set -euo pipefail + if [ "$EVENT_NAME" != "push" ]; then + run_core_e2e=true + run_gpu_e2e=true + else + run_core_e2e="$(jq -r 'index("test:e2e") != null' <<< "$LABELS_JSON")" + run_gpu_e2e="$(jq -r 'index("test:e2e-gpu") != null' <<< "$LABELS_JSON")" + fi + if [ "$run_core_e2e" = "true" ] || [ "$run_gpu_e2e" = "true" ]; then + run_any_e2e=true + else + run_any_e2e=false + fi + { + echo "run_core_e2e=$run_core_e2e" + echo "run_gpu_e2e=$run_gpu_e2e" + echo "run_any_e2e=$run_any_e2e" + } >> "$GITHUB_OUTPUT" build-gateway: needs: [pr_metadata] - if: needs.pr_metadata.outputs.should_run == 'true' + if: needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_core_e2e == 'true' permissions: contents: read packages: write uses: ./.github/workflows/docker-build.yml with: component: gateway - platform: linux/arm64 + image-tag: ${{ github.sha }} build-supervisor: needs: [pr_metadata] - if: needs.pr_metadata.outputs.should_run == 'true' + if: needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_any_e2e == 'true' permissions: contents: read packages: write uses: ./.github/workflows/docker-build.yml with: component: supervisor - platform: linux/arm64 + image-tag: ${{ github.sha }} e2e: needs: [pr_metadata, build-gateway, build-supervisor] - if: needs.pr_metadata.outputs.should_run == 'true' + if: needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_core_e2e == 'true' permissions: contents: read packages: read @@ -56,3 +86,77 @@ jobs: with: image-tag: ${{ github.sha }} runner: linux-arm64-cpu8 + + gpu-e2e: + needs: [pr_metadata, build-supervisor] + if: needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_gpu_e2e == 'true' + permissions: + contents: read + packages: read + uses: ./.github/workflows/e2e-gpu-test.yaml + with: + image-tag: ${{ github.sha }} + + kubernetes-e2e: + needs: [pr_metadata, build-gateway, build-supervisor] + if: needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_core_e2e == 'true' + permissions: + contents: read + packages: read + uses: ./.github/workflows/e2e-kubernetes-test.yml + with: + image-tag: ${{ github.sha }} + + core-e2e-result: + name: Core E2E result + needs: [pr_metadata, build-gateway, build-supervisor, e2e, kubernetes-e2e] + if: always() && needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_core_e2e == 'true' + runs-on: ubuntu-latest + steps: + - name: Verify core E2E jobs + env: + BUILD_GATEWAY_RESULT: ${{ needs.build-gateway.result }} + BUILD_SUPERVISOR_RESULT: ${{ needs.build-supervisor.result }} + E2E_RESULT: ${{ needs.e2e.result }} + KUBERNETES_E2E_RESULT: ${{ needs.kubernetes-e2e.result }} + run: | + set -euo pipefail + failed=0 + for item in \ + "build-gateway:$BUILD_GATEWAY_RESULT" \ + "build-supervisor:$BUILD_SUPERVISOR_RESULT" \ + "e2e:$E2E_RESULT" \ + "kubernetes-e2e:$KUBERNETES_E2E_RESULT"; do + name="${item%%:*}" + result="${item#*:}" + if [ "$result" != "success" ]; then + echo "::error::$name concluded $result" + failed=1 + fi + done + exit "$failed" + + gpu-e2e-result: + name: GPU E2E result + needs: [pr_metadata, build-supervisor, gpu-e2e] + if: always() && needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_gpu_e2e == 'true' + runs-on: ubuntu-latest + steps: + - name: Verify GPU E2E jobs + env: + BUILD_SUPERVISOR_RESULT: ${{ needs.build-supervisor.result }} + GPU_E2E_RESULT: ${{ needs.gpu-e2e.result }} + run: | + set -euo pipefail + failed=0 + for item in \ + "build-supervisor:$BUILD_SUPERVISOR_RESULT" \ + "gpu-e2e:$GPU_E2E_RESULT"; do + name="${item%%:*}" + result="${item#*:}" + if [ "$result" != "success" ]; then + echo "::error::$name concluded $result" + failed=1 + fi + done + exit "$failed" diff --git a/.github/workflows/ci-image.yml b/.github/workflows/ci-image.yml index db98022d5..d8d3095f0 100644 --- a/.github/workflows/ci-image.yml +++ b/.github/workflows/ci-image.yml @@ -35,10 +35,10 @@ jobs: runs-on: ${{ matrix.runner }} timeout-minutes: 60 steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Log in to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} @@ -56,7 +56,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx with: - driver: local buildkitd-config: ${{ steps.buildkit.outputs.config }} - name: Build and push CI image @@ -92,7 +91,7 @@ jobs: timeout-minutes: 10 steps: - name: Log in to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} diff --git a/.github/workflows/deb-package.yml b/.github/workflows/deb-package.yml index 72628a23a..47a721de9 100644 --- a/.github/workflows/deb-package.yml +++ b/.github/workflows/deb-package.yml @@ -42,24 +42,24 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs['checkout-ref'] }} - name: Download CLI artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: name: cli-linux-${{ matrix.arch }} path: package-input/ - name: Download gateway artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: name: gateway-binary-linux-${{ matrix.arch }} path: package-input/ - name: Download VM driver artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: name: driver-vm-linux-${{ matrix.arch }} path: package-input/ @@ -85,7 +85,7 @@ jobs: tasks/scripts/package-deb.sh - name: Upload Debian package artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: deb-linux-${{ matrix.arch }} path: artifacts/*.deb diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 6c3807858..3f98e7b6b 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -42,6 +42,11 @@ on: required: false type: string default: "" + publish-manifest: + description: "Push the bare-SHA manifest. Set false for single-arch branch workflows." + required: false + type: boolean + default: true env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -150,11 +155,12 @@ jobs: strategy: fail-fast: false matrix: ${{ fromJSON(needs.resolve.outputs.matrix) }} - uses: ./.github/workflows/shadow-rust-native-build.yml + uses: ./.github/workflows/rust-native-build.yml with: component: ${{ needs.resolve.outputs.binary_component }} arch: ${{ matrix.arch }} cargo-version: ${{ inputs['cargo-version'] }} + image-tag: ${{ needs.resolve.outputs.image_tag_base }} checkout-ref: ${{ inputs['checkout-ref'] }} features: openshell-core/dev-settings artifact-name: ${{ needs.resolve.outputs.artifact_prefix }}-linux-${{ matrix.arch }} @@ -180,12 +186,12 @@ jobs: # inside the container so setup-buildx can read it. - /etc/buildkit:/etc/buildkit:ro env: - IMAGE_TAG: ${{ needs.resolve.outputs.platform_count == '1' && needs.resolve.outputs.image_tag_base || format('{0}-{1}', needs.resolve.outputs.image_tag_base, matrix.arch) }} + IMAGE_TAG: ${{ format('{0}-{1}', needs.resolve.outputs.image_tag_base, matrix.arch) }} IMAGE_REGISTRY: ghcr.io/nvidia/openshell DOCKER_PUSH: ${{ inputs.push && '1' || '0' }} DOCKER_PLATFORM: ${{ matrix.platform }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs['checkout-ref'] || github.sha }} fetch-depth: 0 @@ -203,11 +209,10 @@ jobs: - name: Set up buildx (local driver) uses: ./.github/actions/setup-buildx with: - driver: local buildkitd-config: /etc/buildkit/buildkitd.toml - name: Download Rust binary artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: name: ${{ needs.resolve.outputs.artifact_prefix }}-linux-${{ matrix.arch }} path: prebuilt-rust-binary @@ -238,7 +243,6 @@ jobs: --cache-to "type=gha,mode=max,scope=${{ inputs.component }}-${{ matrix.arch }}" - name: Smoke check ${{ inputs.component }} image - if: ${{ !inputs.push }} run: | set -euo pipefail image="${IMAGE_REGISTRY}/${{ inputs.component }}:${IMAGE_TAG}" @@ -249,7 +253,7 @@ jobs: grep -q '^openshell-gateway ' <<<"$output" ;; supervisor) - output="$(docker run --rm --platform "${{ matrix.platform }}" "$image" --version)" + output="$(docker run --rm --platform "${{ matrix.platform }}" --entrypoint /openshell-sandbox "$image" --version)" echo "$output" grep -q '^openshell-sandbox ' <<<"$output" ;; @@ -258,7 +262,7 @@ jobs: merge: name: Merge ${{ inputs.component }} manifest needs: [resolve, build] - if: ${{ inputs.push && needs.resolve.outputs.platform_count != '1' }} + if: ${{ inputs.push && inputs['publish-manifest'] }} runs-on: linux-amd64-cpu8 timeout-minutes: 10 container: diff --git a/.github/workflows/driver-vm-linux.yml b/.github/workflows/driver-vm-linux.yml index 42632c5d1..6d42217c3 100644 --- a/.github/workflows/driver-vm-linux.yml +++ b/.github/workflows/driver-vm-linux.yml @@ -32,7 +32,7 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs['checkout-ref'] }} @@ -66,7 +66,7 @@ jobs: done - name: Upload runtime artifacts - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: vm-driver-kernel-runtime-tarballs path: runtime-artifacts/vm-runtime-*.tar.zst @@ -81,11 +81,13 @@ jobs: - arch: arm64 runner: linux-arm64-cpu8 target: aarch64-unknown-linux-gnu + zig_target: aarch64-unknown-linux-gnu.2.31 platform: linux-aarch64 guest_arch: aarch64 - arch: amd64 runner: linux-amd64-cpu8 target: x86_64-unknown-linux-gnu + zig_target: x86_64-unknown-linux-gnu.2.31 platform: linux-x86_64 guest_arch: x86_64 runs-on: ${{ matrix.runner }} @@ -100,7 +102,7 @@ jobs: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_IMAGE_TAG: ${{ inputs['image-tag'] }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs['checkout-ref'] }} fetch-depth: 0 @@ -125,7 +127,7 @@ jobs: run: apt-get update && apt-get install -y --no-install-recommends zstd && rm -rf /var/lib/apt/lists/* - name: Download kernel runtime tarball - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: name: vm-driver-kernel-runtime-tarballs path: runtime-download/ @@ -134,24 +136,10 @@ jobs: run: | set -euo pipefail COMPRESSED_DIR="${PWD}/target/vm-runtime-compressed" - mkdir -p "$COMPRESSED_DIR" - - EXTRACT_DIR=$(mktemp -d) - zstd -d "runtime-download/vm-runtime-${{ matrix.platform }}.tar.zst" --stdout \ - | tar -xf - -C "$EXTRACT_DIR" - - echo "Extracted runtime files:" - ls -lah "$EXTRACT_DIR" - - for file in "$EXTRACT_DIR"/*; do - [ -f "$file" ] || continue - name=$(basename "$file") - [ "$name" = "provenance.json" ] && continue - zstd -19 -f -q -T0 -o "${COMPRESSED_DIR}/${name}.zst" "$file" - done - - echo "Staged compressed runtime artifacts:" - ls -lah "$COMPRESSED_DIR" + VM_RUNTIME_TARBALL="${PWD}/runtime-download/vm-runtime-${{ matrix.platform }}.tar.zst" \ + VM_RUNTIME_PLATFORM="${{ matrix.platform }}" \ + OPENSHELL_VM_RUNTIME_COMPRESSED_DIR="$COMPRESSED_DIR" \ + tasks/scripts/vm/compress-vm-runtime.sh - name: Build bundled supervisor run: | @@ -162,7 +150,7 @@ jobs: - name: Verify embedded driver inputs run: | set -euo pipefail - for file in libkrun.so.zst libkrunfw.so.5.zst gvproxy.zst openshell-sandbox.zst; do + for file in libkrun.so.zst libkrunfw.so.5.zst gvproxy.zst umoci.zst openshell-sandbox.zst; do test -s "target/vm-runtime-compressed/${file}" done @@ -177,19 +165,25 @@ jobs: set -euo pipefail sed -i -E '/^\[workspace\.package\]/,/^\[/{s/^version[[:space:]]*=[[:space:]]*".*"/version = "'"${{ inputs['cargo-version'] }}"'"/}' Cargo.toml - - name: Build openshell-driver-vm + - name: Build openshell-driver-vm with glibc 2.31 floor run: | set -euo pipefail + mise x -- rustup target add ${{ matrix.target }} OPENSHELL_VM_RUNTIME_COMPRESSED_DIR="${PWD}/target/vm-runtime-compressed" \ - mise x -- cargo build --release -p openshell-driver-vm + mise x -- cargo zigbuild --release --target ${{ matrix.zig_target }} -p openshell-driver-vm --bin openshell-driver-vm + mkdir -p artifacts/bin + install -m 0755 target/${{ matrix.target }}/release/openshell-driver-vm artifacts/bin/openshell-driver-vm - name: Verify packaged binary run: | set -euo pipefail - OUTPUT="$(target/release/openshell-driver-vm --version)" + OUTPUT="$(artifacts/bin/openshell-driver-vm --version)" echo "$OUTPUT" grep -q '^openshell-driver-vm ' <<<"$OUTPUT" + - name: Verify glibc symbol floor + run: tasks/scripts/verify-glibc-symbols.sh 2.31 artifacts/bin/openshell-driver-vm + - name: sccache stats if: always() run: mise x -- sccache --show-stats @@ -199,10 +193,10 @@ jobs: set -euo pipefail mkdir -p artifacts tar -czf "artifacts/openshell-driver-vm-${{ matrix.target }}.tar.gz" \ - -C target/release openshell-driver-vm + -C artifacts/bin openshell-driver-vm - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: driver-vm-linux-${{ matrix.arch }} path: artifacts/*.tar.gz diff --git a/.github/workflows/driver-vm-macos.yml b/.github/workflows/driver-vm-macos.yml index 5b2bac927..a563c972c 100644 --- a/.github/workflows/driver-vm-macos.yml +++ b/.github/workflows/driver-vm-macos.yml @@ -32,7 +32,7 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs['checkout-ref'] }} @@ -60,7 +60,7 @@ jobs: run: test -f runtime-artifacts/vm-runtime-darwin-aarch64.tar.zst - name: Upload runtime artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: vm-driver-macos-kernel-runtime-tarball path: runtime-artifacts/vm-runtime-darwin-aarch64.tar.zst @@ -79,7 +79,7 @@ jobs: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_IMAGE_TAG: ${{ inputs['image-tag'] }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs['checkout-ref'] }} fetch-depth: 0 @@ -113,7 +113,7 @@ jobs: run: mise x -- sccache --show-stats - name: Upload supervisor bundle - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: driver-vm-supervisor-arm64 path: target/vm-runtime-compressed/openshell-sandbox.zst @@ -135,7 +135,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs['checkout-ref'] }} fetch-depth: 0 @@ -151,14 +151,12 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Install zstd run: apt-get update && apt-get install -y --no-install-recommends zstd && rm -rf /var/lib/apt/lists/* - name: Download kernel runtime tarball - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: name: vm-driver-macos-kernel-runtime-tarball path: runtime-download/ @@ -167,27 +165,13 @@ jobs: run: | set -euo pipefail COMPRESSED_DIR="${PWD}/target/vm-runtime-compressed-macos" - mkdir -p "$COMPRESSED_DIR" - - EXTRACT_DIR=$(mktemp -d) - zstd -d "runtime-download/vm-runtime-darwin-aarch64.tar.zst" --stdout \ - | tar -xf - -C "$EXTRACT_DIR" - - echo "Extracted darwin runtime files:" - ls -lah "$EXTRACT_DIR" - - for file in "$EXTRACT_DIR"/*; do - [ -f "$file" ] || continue - name=$(basename "$file") - [ "$name" = "provenance.json" ] && continue - zstd -19 -f -q -T0 -o "${COMPRESSED_DIR}/${name}.zst" "$file" - done - - echo "Staged macOS compressed runtime artifacts:" - ls -lah "$COMPRESSED_DIR" + VM_RUNTIME_TARBALL="${PWD}/runtime-download/vm-runtime-darwin-aarch64.tar.zst" \ + VM_RUNTIME_PLATFORM="darwin-aarch64" \ + OPENSHELL_VM_RUNTIME_COMPRESSED_DIR="$COMPRESSED_DIR" \ + tasks/scripts/vm/compress-vm-runtime.sh - name: Download bundled supervisor - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: name: driver-vm-supervisor-arm64 path: target/vm-runtime-compressed-macos/ @@ -201,7 +185,7 @@ jobs: - name: Verify embedded driver inputs run: | set -euo pipefail - for file in libkrun.dylib.zst libkrunfw.5.dylib.zst gvproxy.zst openshell-sandbox.zst; do + for file in libkrun.dylib.zst libkrunfw.5.dylib.zst gvproxy.zst umoci.zst openshell-sandbox.zst; do test -s "target/vm-runtime-compressed-macos/${file}" done @@ -230,7 +214,7 @@ jobs: ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: driver-vm-macos path: artifacts/*.tar.gz diff --git a/.github/workflows/e2e-gate-check.yml b/.github/workflows/e2e-gate-check.yml deleted file mode 100644 index 5065663c1..000000000 --- a/.github/workflows/e2e-gate-check.yml +++ /dev/null @@ -1,113 +0,0 @@ -name: E2E Gate Check - -# Reusable gate that enforces: when `required_label` is present on a PR, -# `workflow_file` must have completed successfully for the PR head SHA. -# -# Callers wire their own triggers (`pull_request` + `workflow_run` for the -# workflow this gate guards) and pass in the label and workflow filename. - -on: - workflow_call: - inputs: - required_label: - description: PR label that makes the gated workflow mandatory. - required: true - type: string - workflow_file: - description: Filename of the workflow whose run must have succeeded (e.g. "branch-e2e.yml"). - required: true - type: string - -permissions: {} - -jobs: - check: - name: Enforce ${{ inputs.required_label }} runs when labeled - runs-on: ubuntu-latest - permissions: - contents: read - pull-requests: read - actions: read - steps: - - name: Resolve PR context - id: pr - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - GH_REPO: ${{ github.repository }} - EVENT_NAME: ${{ github.event_name }} - PR_HEAD_SHA_FROM_EVENT: ${{ github.event.pull_request.head.sha }} - PR_LABELS_FROM_EVENT: ${{ toJSON(github.event.pull_request.labels.*.name) }} - WORKFLOW_RUN_HEAD_SHA: ${{ github.event.workflow_run.head_sha }} - shell: bash - run: | - set -euo pipefail - if [ "$EVENT_NAME" = "pull_request" ]; then - head_sha="$PR_HEAD_SHA_FROM_EVENT" - labels_json=$(jq -c . <<< "$PR_LABELS_FROM_EVENT") - else - head_sha="$WORKFLOW_RUN_HEAD_SHA" - pr=$(gh api "repos/$GH_REPO/commits/$head_sha/pulls" --jq '.[0] // empty') - if [ -z "$pr" ]; then - echo "No PR associated with $head_sha; gate is a no-op." - echo "skip=true" >> "$GITHUB_OUTPUT" - exit 0 - fi - labels_json=$(jq -c '[.labels[].name]' <<< "$pr") - fi - echo "head_sha=$head_sha" >> "$GITHUB_OUTPUT" - echo "labels_json=$labels_json" >> "$GITHUB_OUTPUT" - - - name: Evaluate gate - if: steps.pr.outputs.skip != 'true' - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - GH_REPO: ${{ github.repository }} - HEAD_SHA: ${{ steps.pr.outputs.head_sha }} - LABELS_JSON: ${{ steps.pr.outputs.labels_json }} - REQUIRED_LABEL: ${{ inputs.required_label }} - WORKFLOW_FILE: ${{ inputs.workflow_file }} - shell: bash - run: | - set -euo pipefail - - has_label=$(jq -r --arg L "$REQUIRED_LABEL" 'any(.[]; . == $L)' <<< "$LABELS_JSON") - if [ "$has_label" != "true" ]; then - echo "::notice::$REQUIRED_LABEL not applied; gate passes." - exit 0 - fi - - runs=$(gh api "repos/$GH_REPO/actions/workflows/$WORKFLOW_FILE/runs?head_sha=$HEAD_SHA&event=push" --jq '.workflow_runs') - latest=$(jq -c 'sort_by(.created_at) | reverse | .[0] // empty' <<< "$runs") - - if [ -z "$latest" ]; then - echo "::error::$REQUIRED_LABEL is applied but $WORKFLOW_FILE has not run for $HEAD_SHA. Wait for copy-pr-bot to mirror the PR, or re-run the gate once the workflow completes." - exit 1 - fi - - run_id=$(jq -r '.id' <<< "$latest") - status=$(jq -r '.status' <<< "$latest") - conclusion=$(jq -r '.conclusion' <<< "$latest") - - if [ "$status" != "completed" ]; then - echo "::error::$WORKFLOW_FILE is $status for $HEAD_SHA. This gate will re-evaluate on completion." - exit 1 - fi - - if [ "$conclusion" != "success" ]; then - echo "::error::$WORKFLOW_FILE concluded as $conclusion for $HEAD_SHA." - exit 1 - fi - - # Top-level success isn't enough: if `pr_metadata` gated downstream - # jobs out (label wasn't set at run time), only the gate job itself - # concludes `success` and the workflow still reports `success`. - # Require at least one non-gate job to have succeeded as proof the - # label was present when the workflow ran. - real_success=$(gh api "repos/$GH_REPO/actions/runs/$run_id/jobs" --jq '[.jobs[] | select(.conclusion == "success" and .name != "Resolve PR metadata")] | length') - if [ "$real_success" -lt 1 ]; then - echo "::error::$WORKFLOW_FILE run $run_id only ran the metadata gate — $REQUIRED_LABEL was not set when the workflow last executed. Re-run $WORKFLOW_FILE so the gate re-evaluates with the label present." - exit 1 - fi - - echo "$WORKFLOW_FILE run $run_id executed and succeeded for $HEAD_SHA ($real_success non-gate job(s) passed)." - exit 0 diff --git a/.github/workflows/e2e-gate.yml b/.github/workflows/e2e-gate.yml deleted file mode 100644 index 67959fa8d..000000000 --- a/.github/workflows/e2e-gate.yml +++ /dev/null @@ -1,67 +0,0 @@ -name: E2E Gate - -on: - pull_request: - types: [opened, synchronize, reopened, labeled, unlabeled, ready_for_review] - workflow_run: - workflows: ["Branch E2E Checks", "GPU Test"] - types: [completed] - -permissions: {} - -jobs: - # On PR events, actually evaluate the gate — these runs post their check - # result to the PR's head SHA, which is what branch protection sees. - e2e: - name: E2E - if: github.event_name == 'pull_request' - permissions: - contents: read - pull-requests: read - actions: read - uses: ./.github/workflows/e2e-gate-check.yml - with: - required_label: test:e2e - workflow_file: branch-e2e.yml - - gpu: - name: GPU E2E - if: github.event_name == 'pull_request' - permissions: - contents: read - pull-requests: read - actions: read - uses: ./.github/workflows/e2e-gate-check.yml - with: - required_label: test:e2e-gpu - workflow_file: test-gpu.yml - - # When the guarded workflow finishes, GitHub fires `workflow_run` in the - # default-branch context — any check posted from here would land on `main`, - # not on the PR. Instead, find the latest `pull_request`-triggered gate run - # for the same head SHA and ask the API to re-run it. The re-run replays the - # original event (labels, head SHA) so the check posts to the PR again, and - # this time the gate sees the successful upstream run. - rerun-on-completion: - name: Re-run gate in PR context - if: github.event_name == 'workflow_run' - runs-on: ubuntu-latest - permissions: - actions: write - steps: - - name: Rerun latest PR-context gate for this SHA - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - GH_REPO: ${{ github.repository }} - HEAD_SHA: ${{ github.event.workflow_run.head_sha }} - shell: bash - run: | - set -euo pipefail - run_id=$(gh api "repos/$GH_REPO/actions/workflows/e2e-gate.yml/runs?event=pull_request&head_sha=$HEAD_SHA" \ - --jq '.workflow_runs | sort_by(.created_at) | reverse | .[0].id // empty') - if [ -z "$run_id" ]; then - echo "No pull_request-triggered E2E Gate run found for $HEAD_SHA — nothing to re-run." - exit 0 - fi - echo "Re-running E2E Gate run $run_id so it re-evaluates and posts a fresh check on the PR." - gh api --method POST "repos/$GH_REPO/actions/runs/$run_id/rerun" diff --git a/.github/workflows/e2e-gpu-test.yaml b/.github/workflows/e2e-gpu-test.yaml index f61c8c7ae..78cd7e4d1 100644 --- a/.github/workflows/e2e-gpu-test.yaml +++ b/.github/workflows/e2e-gpu-test.yaml @@ -14,7 +14,7 @@ permissions: jobs: e2e-gpu: - name: "E2E GPU (${{ matrix.name }})" + name: "E2E Docker GPU (${{ matrix.name }})" runs-on: ${{ matrix.runner }} continue-on-error: ${{ matrix.experimental }} timeout-minutes: 30 @@ -49,14 +49,21 @@ jobs: OPENSHELL_REGISTRY_USERNAME: ${{ github.actor }} OPENSHELL_REGISTRY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_E2E_DOCKER_GPU: "1" + # NVIDIA-managed Ubuntu base used as the GPU probe target: it has the + # filesystem layout CDI injection expects (ldconfig, populated /usr/bin) + # which the distroless gateway runtime lacks. Consumed by the prereq + # probe below and by the e2e tests in e2e/rust/tests/gpu_device_selection.rs. + OPENSHELL_E2E_GPU_PROBE_IMAGE: "nvcr.io/nvidia/base/ubuntu:noble-20251013" steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Log in to GHCR run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin - - name: Install Python dependencies and generate protobuf stubs - run: uv sync --frozen && mise run --no-deps python:proto + - name: Check Docker GPU prerequisites + run: | + docker info --format '{{json .CDISpecDirs}}' + docker run --rm --device nvidia.com/gpu=all "${OPENSHELL_E2E_GPU_PROBE_IMAGE}" nvidia-smi -L - name: Run tests - run: mise run --no-deps --skip-deps e2e:python:gpu + run: mise run --no-deps --skip-deps e2e:docker:gpu diff --git a/.github/workflows/e2e-kubernetes-test.yml b/.github/workflows/e2e-kubernetes-test.yml new file mode 100644 index 000000000..c3d16a743 --- /dev/null +++ b/.github/workflows/e2e-kubernetes-test.yml @@ -0,0 +1,98 @@ +name: Kubernetes E2E Test + +on: + workflow_call: + inputs: + image-tag: + description: "Image tag to test (typically the commit SHA)" + required: true + type: string + runner: + description: "GitHub Actions runner label" + required: false + type: string + default: "linux-amd64-cpu8" + checkout-ref: + description: "Git ref to check out for test inputs (defaults to the workflow SHA)" + required: false + type: string + default: "" + +permissions: + contents: read + packages: read + +jobs: + e2e-kubernetes: + name: Kubernetes E2E (Rust smoke) + # Bare runner: running kind-in-container hits nested-Docker / kubeconfig + # complications. The runner has Docker; mise installs helm, kubectl, and + # the Rust toolchain. + runs-on: ${{ inputs.runner }} + timeout-minutes: 60 + permissions: + contents: read + packages: read + env: + MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + KIND_CLUSTER_NAME: kube-e2e-${{ github.run_id }} + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + with: + ref: ${{ inputs['checkout-ref'] || github.sha }} + + - name: Install mise + run: | + curl https://mise.run | sh + echo "$HOME/.local/bin" >> "$GITHUB_PATH" + echo "$HOME/.local/share/mise/shims" >> "$GITHUB_PATH" + + - name: Install tools + run: mise install --locked + + # The openshell-policy crate transitively pulls in z3-sys, whose + # build script needs the z3 C/C++ headers and clang/bindgen to + # compile. The bare runner doesn't ship them; the CI container + # image used by other Rust e2e jobs does, but we can't run this job + # there (the runner's container handler injects its own --network + # bridge, which conflicts with the --network host we need so kind's + # API server is reachable from the test process). + - name: Install z3 build deps + run: sudo apt-get update && sudo apt-get install -y --no-install-recommends libz3-dev clang + + - name: Log in to GHCR + run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin + + - name: Create kind cluster + uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # v1.14.0 + with: + cluster_name: ${{ env.KIND_CLUSTER_NAME }} + wait: 120s + + # mise.toml sets KUBECONFIG="{{config_root}}/kubeconfig"; helm/kind-action + # writes to ~/.kube/config. Materialize the kind context at the mise path + # so `mise run e2e:kubernetes` (and the wrapper's `kubectl --context=...`) + # finds the kind cluster. + - name: Export kind kubeconfig to mise path + run: | + set -euo pipefail + kind get kubeconfig --name "$KIND_CLUSTER_NAME" > "$GITHUB_WORKSPACE/kubeconfig" + chmod 600 "$GITHUB_WORKSPACE/kubeconfig" + + - name: Load gateway and supervisor images into kind + run: | + set -euo pipefail + for component in gateway supervisor; do + image="ghcr.io/nvidia/openshell/${component}:${{ inputs.image-tag }}" + archive="${RUNNER_TEMP:-/tmp}/openshell-${component}-linux-amd64.tar" + docker pull --platform linux/amd64 "$image" + docker image save --platform linux/amd64 --output "$archive" "$image" + kind load image-archive "$archive" --name "$KIND_CLUSTER_NAME" + done + + - name: Run Kubernetes E2E (Rust smoke) + env: + OPENSHELL_E2E_KUBE_CONTEXT: kind-${{ env.KIND_CLUSTER_NAME }} + IMAGE_TAG: ${{ inputs.image-tag }} + OPENSHELL_REGISTRY: ghcr.io/nvidia/openshell + run: mise run --no-deps --skip-deps e2e:kubernetes diff --git a/.github/workflows/e2e-label-help.yml b/.github/workflows/e2e-label-help.yml index 2a61660d2..21f4397f7 100644 --- a/.github/workflows/e2e-label-help.yml +++ b/.github/workflows/e2e-label-help.yml @@ -1,6 +1,6 @@ name: E2E Label Help -# When a `test:e2e` / `test:e2e-gpu` label is applied, post a PR comment +# When an E2E label is applied, post a PR comment # telling the maintainer the next manual step. We don't dispatch the workflow # ourselves: a workflow_dispatch-triggered run does not surface in the PR's # Checks tab, so we'd lose in-progress visibility. Instead we point the @@ -37,9 +37,17 @@ jobs: run: | set -euo pipefail + workflow_file=branch-e2e.yml + workflow_name="Branch E2E Checks" case "$LABEL_NAME" in - test:e2e) workflow_file=branch-e2e.yml; workflow_name="Branch E2E Checks" ;; - test:e2e-gpu) workflow_file=test-gpu.yml; workflow_name="GPU Test" ;; + test:e2e) + suite_summary="the standard E2E suite" + build_summary="gateway and supervisor images" + ;; + test:e2e-gpu) + suite_summary="GPU E2E" + build_summary="supervisor image" + ;; *) echo "Unrecognized label $LABEL_NAME"; exit 1 ;; esac @@ -61,7 +69,7 @@ jobs: workflow_link="[$workflow_name](https://github.com/$GH_REPO/actions/workflows/$workflow_file)" instructions="Open $workflow_link, find the run for commit \`$short_pr\`, and click **Re-run all jobs** to execute with the label set." fi - body="Label \`$LABEL_NAME\` applied for \`$short_pr\`. $instructions The \`E2E Gate\` check on this PR will flip green automatically once the run finishes." + body="Label \`$LABEL_NAME\` applied for \`$short_pr\`. $instructions The run will execute $suite_summary after building the required $build_summary once. The matching required CI gate status on this PR will flip green automatically once the run finishes." fi gh pr comment "$PR_NUMBER" --body "$body" diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index db8010d0f..aabddee96 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -58,7 +58,7 @@ jobs: OPENSHELL_REGISTRY_USERNAME: ${{ github.actor }} OPENSHELL_REGISTRY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs['checkout-ref'] || github.sha }} diff --git a/.github/workflows/helm-lint.yml b/.github/workflows/helm-lint.yml index 8b7184133..cf0666bc0 100644 --- a/.github/workflows/helm-lint.yml +++ b/.github/workflows/helm-lint.yml @@ -7,8 +7,6 @@ on: push: branches: - "pull-request/[0-9]+" - paths: - - "deploy/helm/**" workflow_dispatch: env: @@ -32,15 +30,53 @@ jobs: outputs: should_run: ${{ steps.gate.outputs.should_run }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - id: gate uses: ./.github/actions/pr-gate - helm-lint: - name: Helm Lint + helm_changes: + name: Detect Helm changes needs: pr_metadata if: needs.pr_metadata.outputs.should_run == 'true' + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + outputs: + should_run: ${{ steps.default.outputs.should_run || steps.changes.outputs.any_changed }} + steps: + - id: default + if: github.event_name != 'push' + shell: bash + run: echo "should_run=true" >> "$GITHUB_OUTPUT" + + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + if: github.event_name == 'push' + + - id: merge-base + if: github.event_name == 'push' + uses: ./.github/actions/pr-merge-base + with: + gh_token: ${{ secrets.GITHUB_TOKEN }} + + - id: changes + if: github.event_name == 'push' + uses: tj-actions/changed-files@aa08304bd477b800d468db44fe10f6c61f7f7b11 # v42.1.0 + with: + base_sha: ${{ steps.merge-base.outputs.base_sha }} + skip_initial_fetch: ${{ steps.merge-base.outputs.base_sha != '' }} + files: | + deploy/helm/** + mise.toml + mise.lock + tasks/helm.toml + .github/workflows/helm-lint.yml + + helm-lint: + name: Helm Lint + needs: [pr_metadata, helm_changes] + if: needs.pr_metadata.outputs.should_run == 'true' && needs.helm_changes.outputs.should_run == 'true' runs-on: linux-amd64-cpu8 container: image: ghcr.io/nvidia/openshell/ci:latest @@ -48,10 +84,16 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install tools run: mise install --locked - name: Lint Helm chart run: mise run helm:lint + + - name: Check Helm chart README + run: mise run helm:docs:check + + - name: Run Helm chart unit tests + run: mise run helm:test diff --git a/.github/workflows/issue-triage.yml b/.github/workflows/issue-triage.yml index b59d8ba34..4aa3d6697 100644 --- a/.github/workflows/issue-triage.yml +++ b/.github/workflows/issue-triage.yml @@ -14,7 +14,7 @@ jobs: steps: - name: Check contributor permissions id: contributor - uses: actions/github-script@v9 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9 with: result-encoding: string script: | @@ -46,7 +46,7 @@ jobs: - name: Add triage label if: steps.contributor.outputs.result == 'true' - uses: actions/github-script@v9 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9 with: script: | await github.rest.issues.addLabels({ diff --git a/.github/workflows/release-auto-tag.yml b/.github/workflows/release-auto-tag.yml index f89c506d7..2b10a5b6e 100644 --- a/.github/workflows/release-auto-tag.yml +++ b/.github/workflows/release-auto-tag.yml @@ -20,7 +20,7 @@ jobs: create-tag: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 diff --git a/.github/workflows/release-canary.yml b/.github/workflows/release-canary.yml index defe6f32a..9c209d0c9 100644 --- a/.github/workflows/release-canary.yml +++ b/.github/workflows/release-canary.yml @@ -2,293 +2,163 @@ name: Release Canary on: workflow_dispatch: - inputs: - tag: - description: "Release tag to test (e.g. dev, v1.2.3)" - required: true - type: string workflow_run: - workflows: ["Release Dev", "Release Tag"] + workflows: ["Release Dev"] types: [completed] permissions: contents: read - packages: read defaults: run: shell: bash jobs: - # --------------------------------------------------------------------------- - # Verify the default install path (no OPENSHELL_VERSION) resolves to latest - # --------------------------------------------------------------------------- - install-default: - name: Install default (${{ matrix.arch }}) - if: >- - github.event.workflow_run.conclusion == 'success' - && github.event.workflow_run.name == 'Release Tag' - strategy: - fail-fast: false - matrix: - include: - - arch: amd64 - runner: linux-amd64-cpu8 - - arch: arm64 - runner: linux-arm64-cpu8 - runs-on: ${{ matrix.runner }} - timeout-minutes: 10 - container: - image: ghcr.io/nvidia/openshell/ci:latest - credentials: - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + macos: + name: macOS Homebrew + if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }} + runs-on: macos-latest-xlarge + timeout-minutes: 20 steps: - - name: Install CLI (default / latest) + - name: Ensure VM driver run: | - set -euo pipefail - curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/main/install.sh | sh + launchctl setenv OPENSHELL_DRIVERS vm - - name: Verify CLI installation + - name: Install and check status run: | - set -euo pipefail - command -v openshell - ACTUAL="$(openshell --version)" - echo "Installed: $ACTUAL" - # This job only runs after Release Tag, so the triggering tag - # should match the latest release the default installer resolves to. - TAG="${{ github.event.workflow_run.head_branch }}" - if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - EXPECTED="${TAG#v}" - if [[ "$ACTUAL" != *"$EXPECTED"* ]]; then - echo "::error::Version mismatch: expected '$EXPECTED' in '$ACTUAL'" - exit 1 - fi - echo "Version check passed: found $EXPECTED in output" - fi + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install.sh | sh + openshell status - install-dev: - name: Install Debian package (${{ matrix.arch }}) + ubuntu: + name: Ubuntu Docker if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }} - strategy: - fail-fast: false - matrix: - include: - - arch: amd64 - runner: linux-amd64-cpu8 - - arch: arm64 - runner: linux-arm64-cpu8 - runs-on: ${{ matrix.runner }} - timeout-minutes: 10 - container: - image: ghcr.io/nvidia/openshell/ci:latest - credentials: - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + runs-on: ubuntu-latest + timeout-minutes: 20 steps: - - name: Determine release tag - id: release + - name: Ensure Docker run: | - set -euo pipefail - if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then - echo "tag=${{ inputs.tag }}" >> "$GITHUB_OUTPUT" - else - WORKFLOW_NAME="${{ github.event.workflow_run.name }}" - if [ "$WORKFLOW_NAME" = "Release Dev" ]; then - echo "tag=dev" >> "$GITHUB_OUTPUT" - elif [ "$WORKFLOW_NAME" = "Release Tag" ]; then - TAG="${{ github.event.workflow_run.head_branch }}" - if [ -z "$TAG" ]; then - echo "::error::Could not determine release tag from workflow_run" - exit 1 - fi - echo "tag=${TAG}" >> "$GITHUB_OUTPUT" - else - echo "::error::Unexpected triggering workflow: ${WORKFLOW_NAME}" - exit 1 - fi + if ! command -v docker >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y docker.io fi + sudo systemctl start docker || sudo service docker start + mkdir -p "${HOME}/.config/openshell" + printf 'OPENSHELL_DRIVERS=docker\n' > "${HOME}/.config/openshell/gateway.env" + docker info - - name: Install Debian package + - name: Install and check status run: | - set -euo pipefail - curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/main/install-dev.sh \ - | OPENSHELL_VERSION=${{ steps.release.outputs.tag }} sh + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install.sh | sh + openshell status - - name: Verify gateway and VM driver versions + fedora: + name: Fedora RPM + if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }} + runs-on: linux-amd64-cpu8 + timeout-minutes: 20 + container: + image: fedora:latest + options: --privileged + steps: + - name: Ensure Podman run: | - set -euo pipefail - command -v openshell-gateway - test -x /usr/libexec/openshell/openshell-driver-vm - - GATEWAY_ACTUAL="$(openshell-gateway --version)" - DRIVER_ACTUAL="$(/usr/libexec/openshell/openshell-driver-vm --version)" - echo "Gateway: ${GATEWAY_ACTUAL}" - echo "Driver: ${DRIVER_ACTUAL}" + dnf install -y curl podman + mkdir -p "${HOME}/.config/openshell" + printf 'OPENSHELL_DRIVERS=podman\n' > "${HOME}/.config/openshell/gateway.env" + podman info - TAG="${{ steps.release.outputs.tag }}" - if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - EXPECTED="${TAG#v}" - for actual in "$GATEWAY_ACTUAL" "$DRIVER_ACTUAL"; do - if [[ "$actual" != *"$EXPECTED"* ]]; then - echo "::error::Version mismatch: expected '$EXPECTED' in '$actual'" - exit 1 - fi - done - echo "Version check passed: found $EXPECTED in both binaries" - else - echo "Non-release tag ($TAG), skipping version check" - fi + - name: Install and check status + run: | + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install.sh | sh + openshell status - canary: - name: Canary package gateway (${{ matrix.arch }}) + kubernetes: + name: Kubernetes Helm (kind) if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }} - strategy: - fail-fast: false - matrix: - include: - - arch: amd64 - runner: linux-amd64-cpu8 - - arch: arm64 - runner: linux-arm64-cpu8 - runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + runs-on: ubuntu-latest + timeout-minutes: 20 env: - OPENSHELL_REGISTRY_TOKEN: ${{ secrets.GITHUB_TOKEN }} - OPENSHELL_CANARY_PORT: "17670" + KIND_CLUSTER_NAME: release-canary-${{ github.run_id }} + RELEASE_NAME: openshell + RELEASE_NAMESPACE: openshell + KIND_GATEWAY_NAME: kind steps: - - uses: actions/checkout@v6 + - name: Install Helm + uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # v5.0.0 - - name: Determine release tag - id: release - run: | - set -euo pipefail - if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then - echo "tag=${{ inputs.tag }}" >> "$GITHUB_OUTPUT" - else - WORKFLOW_NAME="${{ github.event.workflow_run.name }}" - if [ "$WORKFLOW_NAME" = "Release Dev" ]; then - echo "tag=dev" >> "$GITHUB_OUTPUT" - elif [ "$WORKFLOW_NAME" = "Release Tag" ]; then - TAG="${{ github.event.workflow_run.head_branch }}" - if [ -z "$TAG" ]; then - echo "::error::Could not determine release tag from workflow_run" - exit 1 - fi - echo "tag=${TAG}" >> "$GITHUB_OUTPUT" - else - echo "::error::Unexpected triggering workflow: ${WORKFLOW_NAME}" - exit 1 - fi - fi + - name: Create kind cluster + uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # v1.14.0 + with: + cluster_name: ${{ env.KIND_CLUSTER_NAME }} + wait: 120s - - name: Install Debian package + - name: Install OpenShell Helm chart from GHCR OCI run: | set -euo pipefail - curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/main/install-dev.sh \ - | OPENSHELL_VERSION=${{ steps.release.outputs.tag }} sh + helm install "$RELEASE_NAME" oci://ghcr.io/nvidia/openshell/helm-chart \ + --version 0.0.0-dev \ + --namespace "$RELEASE_NAMESPACE" --create-namespace \ + --set server.disableTls=true \ + --wait --timeout 5m - - name: Verify package binaries + - name: Verify gateway pod is Ready run: | set -euo pipefail - command -v openshell - command -v openshell-gateway - test -x /usr/libexec/openshell/openshell-driver-vm - - CLI_ACTUAL="$(openshell --version)" - GATEWAY_ACTUAL="$(openshell-gateway --version)" - DRIVER_ACTUAL="$(/usr/libexec/openshell/openshell-driver-vm --version)" - echo "CLI: ${CLI_ACTUAL}" - echo "Gateway: ${GATEWAY_ACTUAL}" - echo "Driver: ${DRIVER_ACTUAL}" - - TAG="${{ steps.release.outputs.tag }}" - if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - EXPECTED="${TAG#v}" - for actual in "$CLI_ACTUAL" "$GATEWAY_ACTUAL" "$DRIVER_ACTUAL"; do - if [[ "$actual" != *"$EXPECTED"* ]]; then - echo "::error::Version mismatch: expected '$EXPECTED' in '$actual'" - exit 1 - fi - done - echo "Version check passed: found $EXPECTED in package binaries" - else - echo "Non-release tag ($TAG), skipping version check" - fi + kubectl wait --namespace "$RELEASE_NAMESPACE" \ + --for=condition=Ready pod \ + --selector="app.kubernetes.io/name=openshell,app.kubernetes.io/instance=${RELEASE_NAME}" \ + --timeout=300s - - name: Start packaged gateway + - name: Port-forward gateway service run: | set -euo pipefail - - # The CLI no longer owns gateway lifecycle. In CI we start the - # gateway binary installed by the Debian package directly, using the - # Docker driver so this canary can launch a real sandbox. - systemctl --user stop openshell-gateway >/dev/null 2>&1 || true - - STATE_DIR="$(mktemp -d)" - LOG="${STATE_DIR}/openshell-gateway.log" - echo "GATEWAY_LOG=${LOG}" >> "$GITHUB_ENV" - echo "GATEWAY_STATE_DIR=${STATE_DIR}" >> "$GITHUB_ENV" - - OPENSHELL_BIND_ADDRESS=0.0.0.0 \ - OPENSHELL_SERVER_PORT="${OPENSHELL_CANARY_PORT}" \ - OPENSHELL_DISABLE_TLS=true \ - OPENSHELL_DISABLE_GATEWAY_AUTH=true \ - OPENSHELL_DRIVERS=docker \ - OPENSHELL_DB_URL="sqlite:${STATE_DIR}/openshell.db?mode=rwc" \ - OPENSHELL_GRPC_ENDPOINT="http://host.openshell.internal:${OPENSHELL_CANARY_PORT}" \ - OPENSHELL_SSH_GATEWAY_HOST=127.0.0.1 \ - OPENSHELL_SSH_GATEWAY_PORT="${OPENSHELL_CANARY_PORT}" \ - OPENSHELL_SANDBOX_NAMESPACE="canary-${{ matrix.arch }}-${{ github.run_id }}" \ - nohup openshell-gateway >"${LOG}" 2>&1 & - PID=$! - echo "GATEWAY_PID=${PID}" >> "$GITHUB_ENV" - - for _ in $(seq 1 60); do - if curl -fsS "http://127.0.0.1:${OPENSHELL_CANARY_PORT}/healthz" >/dev/null; then - break - fi - if ! kill -0 "$PID" 2>/dev/null; then - echo "::error::openshell-gateway exited before becoming healthy" - cat "$LOG" - exit 1 + nohup kubectl port-forward --namespace "$RELEASE_NAMESPACE" \ + "svc/${RELEASE_NAME}" 8080:8080 \ + > port-forward.log 2>&1 & + echo $! > port-forward.pid + for _ in $(seq 1 30); do + if (echo > /dev/tcp/127.0.0.1/8080) >/dev/null 2>&1; then + echo "port-forward is reachable" + exit 0 fi sleep 1 done + echo "port-forward did not become reachable" >&2 + cat port-forward.log >&2 + exit 1 - curl -fsS "http://127.0.0.1:${OPENSHELL_CANARY_PORT}/healthz" - openshell gateway remove local >/dev/null 2>&1 || true - openshell gateway add "http://127.0.0.1:${OPENSHELL_CANARY_PORT}" --local --name local - - - name: Run canary test + - name: Install OpenShell CLI run: | set -euo pipefail + mkdir -p "${HOME}/.config/openshell" + printf 'OPENSHELL_DRIVERS=docker\n' > "${HOME}/.config/openshell/gateway.env" + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install.sh | sh - echo "Creating sandbox and running 'echo hello world'..." - OUTPUT=$(openshell sandbox create --no-keep --no-tty -- echo "hello world" 2>&1) || { - EXIT_CODE=$? - echo "::error::openshell sandbox create failed with exit code ${EXIT_CODE}" - echo "$OUTPUT" - exit $EXIT_CODE - } - - echo "$OUTPUT" - - if echo "$OUTPUT" | grep -q "hello world"; then - echo "Canary test passed: 'hello world' found in output" - else - echo "::error::Canary test failed: 'hello world' not found in output" - exit 1 - fi - - - name: Stop packaged gateway - if: always() + - name: Register kind gateway and check status run: | set -euo pipefail - if [ -n "${GATEWAY_PID:-}" ]; then - kill "$GATEWAY_PID" >/dev/null 2>&1 || true - fi - if [ "${{ job.status }}" != "success" ] && [ -n "${GATEWAY_LOG:-}" ] && [ -f "$GATEWAY_LOG" ]; then - echo "Gateway log:" - cat "$GATEWAY_LOG" - fi + openshell gateway add http://127.0.0.1:8080 --local --name "$KIND_GATEWAY_NAME" + openshell status + + - name: Diagnostics on failure + if: failure() + run: | + set +e + echo "--- helm status ---" + helm status "$RELEASE_NAME" --namespace "$RELEASE_NAMESPACE" + echo "--- helm get manifest ---" + helm get manifest "$RELEASE_NAME" --namespace "$RELEASE_NAMESPACE" + echo "--- get all ---" + kubectl get all --namespace "$RELEASE_NAMESPACE" + echo "--- describe pods ---" + kubectl describe pods --namespace "$RELEASE_NAMESPACE" + echo "--- pod logs ---" + kubectl logs --namespace "$RELEASE_NAMESPACE" \ + --selector="app.kubernetes.io/name=openshell,app.kubernetes.io/instance=${RELEASE_NAME}" \ + --tail=200 --all-containers --prefix + echo "--- port-forward log ---" + cat port-forward.log 2>/dev/null + echo "--- openshell gateway list ---" + openshell gateway list 2>/dev/null + echo "--- openshell version ---" + openshell --version 2>/dev/null diff --git a/.github/workflows/release-dev.yml b/.github/workflows/release-dev.yml index 0385930bd..5c8eac435 100644 --- a/.github/workflows/release-dev.yml +++ b/.github/workflows/release-dev.yml @@ -33,7 +33,7 @@ jobs: rpm_version: ${{ steps.v.outputs.rpm_version }} rpm_release: ${{ steps.v.outputs.rpm_release }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 @@ -123,7 +123,7 @@ jobs: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_IMAGE_TAG: dev steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 @@ -147,7 +147,7 @@ jobs: ls -la ${{ matrix.output_path }} - name: Upload wheel artifacts - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: python-wheels-${{ matrix.artifact }} path: ${{ matrix.output_path }} @@ -170,7 +170,7 @@ jobs: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_IMAGE_TAG: dev steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 @@ -179,8 +179,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Mark workspace safe for git run: git config --global --add safe.directory "$GITHUB_WORKSPACE" @@ -195,7 +193,7 @@ jobs: ls -la target/wheels/*.whl - name: Upload wheel artifacts - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: python-wheels-macos path: target/wheels/*.whl @@ -233,7 +231,7 @@ jobs: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_IMAGE_TAG: dev steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 @@ -309,7 +307,7 @@ jobs: ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: cli-linux-${{ matrix.arch }} path: artifacts/*.tar.gz @@ -334,7 +332,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 @@ -349,8 +347,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Build macOS binary via Docker run: | @@ -373,14 +369,14 @@ jobs: ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: cli-macos path: artifacts/*.tar.gz retention-days: 5 # --------------------------------------------------------------------------- - # Build standalone gateway binaries (Linux GNU — native on each arch) + # Build standalone gateway binaries (Linux GNU — glibc 2.31 floor) # --------------------------------------------------------------------------- build-gateway-binary-linux: name: Build Gateway Binary (Linux ${{ matrix.arch }}) @@ -391,9 +387,11 @@ jobs: - arch: amd64 runner: linux-amd64-cpu8 target: x86_64-unknown-linux-gnu + zig_target: x86_64-unknown-linux-gnu.2.31 - arch: arm64 runner: linux-arm64-cpu8 target: aarch64-unknown-linux-gnu + zig_target: aarch64-unknown-linux-gnu.2.31 runs-on: ${{ matrix.runner }} timeout-minutes: 60 container: @@ -405,7 +403,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 @@ -431,18 +429,26 @@ jobs: set -euo pipefail sed -i -E '/^\[workspace\.package\]/,/^\[/{s/^version[[:space:]]*=[[:space:]]*".*"/version = "'"${{ needs.compute-versions.outputs.cargo_version }}"'"/}' Cargo.toml - - name: Build ${{ matrix.target }} + - name: Build ${{ matrix.zig_target }} + env: + OPENSHELL_IMAGE_TAG: ${{ github.sha }} run: | set -euo pipefail - mise x -- cargo build --release --target ${{ matrix.target }} -p openshell-server + mise x -- rustup target add ${{ matrix.target }} + mise x -- cargo zigbuild --release --target ${{ matrix.zig_target }} -p openshell-server --bin openshell-gateway + mkdir -p artifacts/bin + install -m 0755 target/${{ matrix.target }}/release/openshell-gateway artifacts/bin/openshell-gateway - name: Verify packaged binary run: | set -euo pipefail - OUTPUT="$(target/${{ matrix.target }}/release/openshell-gateway --version)" + OUTPUT="$(artifacts/bin/openshell-gateway --version)" echo "$OUTPUT" grep -q '^openshell-gateway ' <<<"$OUTPUT" + - name: Verify glibc symbol floor + run: tasks/scripts/verify-glibc-symbols.sh 2.31 artifacts/bin/openshell-gateway + - name: sccache stats if: always() run: mise x -- sccache --show-stats @@ -452,11 +458,11 @@ jobs: set -euo pipefail mkdir -p artifacts tar -czf artifacts/openshell-gateway-${{ matrix.target }}.tar.gz \ - -C target/${{ matrix.target }}/release openshell-gateway + -C artifacts/bin openshell-gateway ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: gateway-binary-linux-${{ matrix.arch }} path: artifacts/*.tar.gz @@ -481,7 +487,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 @@ -496,8 +502,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Build macOS binary via Docker run: | @@ -505,6 +509,7 @@ jobs: docker buildx build \ --file deploy/docker/Dockerfile.gateway-macos \ --build-arg OPENSHELL_CARGO_VERSION="${{ needs.compute-versions.outputs.cargo_version }}" \ + --build-arg OPENSHELL_IMAGE_TAG=dev \ --build-arg CARGO_TARGET_CACHE_SCOPE="${{ github.sha }}" \ --target binary \ --output type=local,dest=out/ \ @@ -524,7 +529,7 @@ jobs: ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: gateway-binary-macos path: artifacts/*.tar.gz @@ -556,7 +561,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 @@ -607,7 +612,7 @@ jobs: ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: supervisor-binary-linux-${{ matrix.arch }} path: artifacts/*.tar.gz @@ -644,7 +649,7 @@ jobs: build-rpm: name: Build RPM Packages - needs: [compute-versions] + needs: [compute-versions, build-cli-linux, build-gateway-binary-linux] uses: ./.github/workflows/rpm-package.yml with: checkout-ref: ${{ github.sha }} @@ -653,12 +658,46 @@ jobs: cargo-version: ${{ needs.compute-versions.outputs.cargo_version }} secrets: inherit + smoke-linux-dev-artifacts: + name: Smoke Linux Dev Artifacts (${{ matrix.name }}) + needs: [build-gateway-binary-linux, build-driver-vm-linux, build-deb] + timeout-minutes: 20 + strategy: + fail-fast: false + matrix: + include: + - name: ubuntu-22.04-deb-amd64 + runner: linux-amd64-cpu8 + image: ubuntu:22.04 + artifact_arch: amd64 + - name: ubuntu-22.04-deb-arm64 + runner: linux-arm64-cpu8 + image: ubuntu:22.04 + artifact_arch: arm64 + runs-on: ${{ matrix.runner }} + container: + image: ${{ matrix.image }} + steps: + - name: Download Debian package artifact + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: deb-linux-${{ matrix.artifact_arch }} + path: package-input/ + + - name: Smoke Debian package on Ubuntu 22.04 + run: | + set -euo pipefail + apt-get update + apt-get install -y --no-install-recommends ./package-input/*.deb + openshell-gateway --version + /usr/libexec/openshell/openshell-driver-vm --version + # --------------------------------------------------------------------------- # Create / update the dev GitHub Release with CLI, gateway, driver, and wheels # --------------------------------------------------------------------------- release-dev: name: Release Dev - needs: [compute-versions, build-cli-linux, build-cli-macos, build-gateway-binary-linux, build-gateway-binary-macos, build-supervisor-binary-linux, build-python-wheels-linux, build-python-wheel-macos, build-driver-vm-linux, build-driver-vm-macos, build-deb, build-rpm] + needs: [compute-versions, build-cli-linux, build-cli-macos, build-gateway-binary-linux, build-gateway-binary-macos, build-supervisor-binary-linux, build-python-wheels-linux, build-python-wheel-macos, build-driver-vm-linux, build-driver-vm-macos, build-deb, build-rpm, smoke-linux-dev-artifacts] runs-on: linux-amd64-cpu8 timeout-minutes: 10 permissions: @@ -669,52 +708,52 @@ jobs: outputs: wheel_filenames: ${{ steps.wheel_filenames.outputs.wheel_filenames }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Download all CLI artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: cli-* path: release/ merge-multiple: true - name: Download gateway binary artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: gateway-binary-* path: release/ merge-multiple: true - name: Download supervisor binary artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: supervisor-binary-* path: release/ merge-multiple: true - name: Download VM driver artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: driver-vm-* path: release/ merge-multiple: true - name: Download wheel artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: python-wheels-* path: release/ merge-multiple: true - name: Download Debian package artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: deb-linux-* path: release/ merge-multiple: true - name: Download RPM package artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: rpm-linux-* path: release/ @@ -788,7 +827,7 @@ jobs: cat release/openshell.rb - name: Attest VM driver artifacts - uses: actions/attest@v4 + uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4 with: subject-path: | release/openshell-driver-vm-x86_64-unknown-linux-gnu.tar.gz @@ -796,7 +835,7 @@ jobs: release/openshell-driver-vm-aarch64-apple-darwin.tar.gz - name: Prune managed assets from dev release - uses: actions/github-script@v9 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9 with: script: | const [owner, repo] = process.env.GITHUB_REPOSITORY.split('/'); @@ -851,7 +890,7 @@ jobs: git push --force origin dev - name: Create / update GitHub Release - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3.0.0 with: name: OpenShell Development Build prerelease: true @@ -897,7 +936,7 @@ jobs: permissions: packages: write steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - uses: ./.github/actions/release-helm-oci with: diff --git a/.github/workflows/release-tag.yml b/.github/workflows/release-tag.yml index 97c8422a2..fc0f47480 100644 --- a/.github/workflows/release-tag.yml +++ b/.github/workflows/release-tag.yml @@ -48,7 +48,7 @@ jobs: # Commit resolved from RELEASE_TAG, used for image tags and downstream metadata source_sha: ${{ steps.v.outputs.source_sha }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} fetch-depth: 0 @@ -152,7 +152,7 @@ jobs: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_IMAGE_TAG: ${{ needs.compute-versions.outputs.semver }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} fetch-depth: 0 @@ -177,7 +177,7 @@ jobs: ls -la ${{ matrix.output_path }} - name: Upload wheel artifacts - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: python-wheels-${{ matrix.artifact }} path: ${{ matrix.output_path }} @@ -200,7 +200,7 @@ jobs: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_IMAGE_TAG: ${{ needs.compute-versions.outputs.semver }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} fetch-depth: 0 @@ -210,8 +210,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Mark workspace safe for git run: git config --global --add safe.directory "$GITHUB_WORKSPACE" @@ -226,7 +224,7 @@ jobs: ls -la target/wheels/*.whl - name: Upload wheel artifacts - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: python-wheels-macos path: target/wheels/*.whl @@ -264,7 +262,7 @@ jobs: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_IMAGE_TAG: ${{ needs.compute-versions.outputs.semver }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} fetch-depth: 0 @@ -341,7 +339,7 @@ jobs: ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: cli-linux-${{ matrix.arch }} path: artifacts/*.tar.gz @@ -366,7 +364,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} fetch-depth: 0 @@ -382,8 +380,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Build macOS binary via Docker run: | @@ -406,14 +402,14 @@ jobs: ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: cli-macos path: artifacts/*.tar.gz retention-days: 5 # --------------------------------------------------------------------------- - # Build standalone gateway binaries (Linux GNU — native on each arch) + # Build standalone gateway binaries (Linux GNU — glibc 2.31 floor) # --------------------------------------------------------------------------- build-gateway-binary-linux: name: Build Gateway Binary (Linux ${{ matrix.arch }}) @@ -424,9 +420,11 @@ jobs: - arch: amd64 runner: linux-amd64-cpu8 target: x86_64-unknown-linux-gnu + zig_target: x86_64-unknown-linux-gnu.2.31 - arch: arm64 runner: linux-arm64-cpu8 target: aarch64-unknown-linux-gnu + zig_target: aarch64-unknown-linux-gnu.2.31 runs-on: ${{ matrix.runner }} timeout-minutes: 60 container: @@ -438,7 +436,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} fetch-depth: 0 @@ -465,18 +463,26 @@ jobs: set -euo pipefail sed -i -E '/^\[workspace\.package\]/,/^\[/{s/^version[[:space:]]*=[[:space:]]*".*"/version = "'"${{ needs.compute-versions.outputs.cargo_version }}"'"/}' Cargo.toml - - name: Build ${{ matrix.target }} + - name: Build ${{ matrix.zig_target }} + env: + OPENSHELL_IMAGE_TAG: ${{ needs.compute-versions.outputs.source_sha }} run: | set -euo pipefail - mise x -- cargo build --release --target ${{ matrix.target }} -p openshell-server + mise x -- rustup target add ${{ matrix.target }} + mise x -- cargo zigbuild --release --target ${{ matrix.zig_target }} -p openshell-server --bin openshell-gateway + mkdir -p artifacts/bin + install -m 0755 target/${{ matrix.target }}/release/openshell-gateway artifacts/bin/openshell-gateway - name: Verify packaged binary run: | set -euo pipefail - OUTPUT="$(target/${{ matrix.target }}/release/openshell-gateway --version)" + OUTPUT="$(artifacts/bin/openshell-gateway --version)" echo "$OUTPUT" grep -q '^openshell-gateway ' <<<"$OUTPUT" + - name: Verify glibc symbol floor + run: tasks/scripts/verify-glibc-symbols.sh 2.31 artifacts/bin/openshell-gateway + - name: sccache stats if: always() run: mise x -- sccache --show-stats @@ -486,11 +492,11 @@ jobs: set -euo pipefail mkdir -p artifacts tar -czf artifacts/openshell-gateway-${{ matrix.target }}.tar.gz \ - -C target/${{ matrix.target }}/release openshell-gateway + -C artifacts/bin openshell-gateway ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: gateway-binary-linux-${{ matrix.arch }} path: artifacts/*.tar.gz @@ -522,7 +528,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} fetch-depth: 0 @@ -574,7 +580,7 @@ jobs: ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: supervisor-binary-linux-${{ matrix.arch }} path: artifacts/*.tar.gz @@ -599,7 +605,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} fetch-depth: 0 @@ -615,8 +621,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Build macOS binary via Docker run: | @@ -624,6 +628,7 @@ jobs: docker buildx build \ --file deploy/docker/Dockerfile.gateway-macos \ --build-arg OPENSHELL_CARGO_VERSION="${{ needs.compute-versions.outputs.cargo_version }}" \ + --build-arg OPENSHELL_IMAGE_TAG="${{ needs.compute-versions.outputs.semver }}" \ --build-arg CARGO_TARGET_CACHE_SCOPE="${{ github.sha }}" \ --target binary \ --output type=local,dest=out/ \ @@ -643,7 +648,7 @@ jobs: ls -lh artifacts/ - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: gateway-binary-macos path: artifacts/*.tar.gz @@ -680,7 +685,7 @@ jobs: build-rpm: name: Build RPM Packages - needs: [compute-versions] + needs: [compute-versions, build-cli-linux, build-gateway-binary-linux] uses: ./.github/workflows/rpm-package.yml with: checkout-ref: ${{ inputs.tag || github.ref }} @@ -689,12 +694,120 @@ jobs: cargo-version: ${{ needs.compute-versions.outputs.cargo_version }} secrets: inherit + smoke-linux-release-artifacts: + name: Smoke Linux Release Artifacts (${{ matrix.name }}) + needs: [build-gateway-binary-linux, build-driver-vm-linux, build-deb, build-rpm] + timeout-minutes: 20 + strategy: + fail-fast: false + matrix: + include: + - name: ubuntu-20.04-binaries + runner: linux-amd64-cpu8 + image: ubuntu:20.04 + kind: binary + artifact_arch: amd64 + rpm_arch: x86_64 + target: x86_64-unknown-linux-gnu + - name: ubuntu-20.04-binaries-arm64 + runner: linux-arm64-cpu8 + image: ubuntu:20.04 + kind: binary + artifact_arch: arm64 + rpm_arch: aarch64 + target: aarch64-unknown-linux-gnu + - name: ubuntu-22.04-deb + runner: linux-amd64-cpu8 + image: ubuntu:22.04 + kind: deb + artifact_arch: amd64 + rpm_arch: x86_64 + target: x86_64-unknown-linux-gnu + - name: ubuntu-22.04-deb-arm64 + runner: linux-arm64-cpu8 + image: ubuntu:22.04 + kind: deb + artifact_arch: arm64 + rpm_arch: aarch64 + target: aarch64-unknown-linux-gnu + - name: fedora-rpm + runner: linux-amd64-cpu8 + image: fedora:latest + kind: rpm + artifact_arch: amd64 + rpm_arch: x86_64 + target: x86_64-unknown-linux-gnu + - name: fedora-rpm-aarch64 + runner: linux-arm64-cpu8 + image: fedora:latest + kind: rpm + artifact_arch: arm64 + rpm_arch: aarch64 + target: aarch64-unknown-linux-gnu + runs-on: ${{ matrix.runner }} + container: + image: ${{ matrix.image }} + steps: + - name: Download gateway binary artifact + if: matrix.kind == 'binary' + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: gateway-binary-linux-${{ matrix.artifact_arch }} + path: smoke-input/ + + - name: Download VM driver binary artifact + if: matrix.kind == 'binary' + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: driver-vm-linux-${{ matrix.artifact_arch }} + path: smoke-input/ + + - name: Smoke binary artifacts + if: matrix.kind == 'binary' + run: | + set -euo pipefail + mkdir -p smoke-bin + tar -xzf smoke-input/openshell-gateway-${{ matrix.target }}.tar.gz -C smoke-bin + tar -xzf smoke-input/openshell-driver-vm-${{ matrix.target }}.tar.gz -C smoke-bin + smoke-bin/openshell-gateway --version + smoke-bin/openshell-driver-vm --version + + - name: Download Debian package artifact + if: matrix.kind == 'deb' + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: deb-linux-${{ matrix.artifact_arch }} + path: package-input/ + + - name: Smoke Debian package + if: matrix.kind == 'deb' + run: | + set -euo pipefail + apt-get update + apt-get install -y --no-install-recommends ./package-input/*.deb + openshell-gateway --version + /usr/libexec/openshell/openshell-driver-vm --version + + - name: Download RPM package artifacts + if: matrix.kind == 'rpm' + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: rpm-linux-${{ matrix.rpm_arch }} + path: package-input/ + + - name: Smoke RPM packages + if: matrix.kind == 'rpm' + run: | + set -euo pipefail + dnf install -y ./package-input/openshell-[0-9]*.rpm ./package-input/openshell-gateway-*.rpm + openshell-gateway --version + # --------------------------------------------------------------------------- # Create a tagged GitHub Release with CLI, gateway, driver, and wheels # --------------------------------------------------------------------------- release: name: Release - needs: [compute-versions, build-cli-linux, build-cli-macos, build-gateway-binary-linux, build-gateway-binary-macos, build-supervisor-binary-linux, build-python-wheels-linux, build-python-wheel-macos, tag-ghcr-release, build-driver-vm-linux, build-driver-vm-macos, build-deb, build-rpm] + needs: [compute-versions, build-cli-linux, build-cli-macos, build-gateway-binary-linux, build-gateway-binary-macos, build-supervisor-binary-linux, build-python-wheels-linux, build-python-wheel-macos, tag-ghcr-release, build-driver-vm-linux, build-driver-vm-macos, build-deb, build-rpm, smoke-linux-release-artifacts] runs-on: linux-amd64-cpu8 timeout-minutes: 10 permissions: @@ -705,54 +818,54 @@ jobs: outputs: wheel_filenames: ${{ steps.wheel_filenames.outputs.wheel_filenames }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} - name: Download all CLI artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: cli-* path: release/ merge-multiple: true - name: Download gateway binary artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: gateway-binary-* path: release/ merge-multiple: true - name: Download supervisor binary artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: supervisor-binary-* path: release/ merge-multiple: true - name: Download VM driver artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: driver-vm-* path: release/ merge-multiple: true - name: Download wheel artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: python-wheels-* path: release/ merge-multiple: true - name: Download Debian package artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: deb-linux-* path: release/ merge-multiple: true - name: Download RPM package artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: rpm-linux-* path: release/ @@ -801,15 +914,16 @@ jobs: cat release/openshell.rb - name: Attest VM driver artifacts - uses: actions/attest@v4 + uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4 with: subject-path: | - release/openshell-driver-vm-x86_64-unknown-linux-gnu.tar.gz - release/openshell-driver-vm-aarch64-unknown-linux-gnu.tar.gz - release/openshell-driver-vm-aarch64-apple-darwin.tar.gz + release/*.tar.gz + release/*.deb + release/*.rpm + release/*.whl - name: Prune removed VM checksum asset - uses: actions/github-script@v9 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9 with: script: | const [owner, repo] = process.env.GITHUB_REPOSITORY.split('/'); @@ -831,7 +945,7 @@ jobs: } - name: Create GitHub Release - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3.0.0 with: name: OpenShell ${{ env.RELEASE_TAG }} prerelease: false @@ -872,12 +986,12 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 15 steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} - name: Setup Node.js - uses: actions/setup-node@v6 + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6 with: node-version: "24" @@ -900,7 +1014,7 @@ jobs: permissions: packages: write steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.tag || github.ref }} diff --git a/.github/workflows/release-vm-kernel.yml b/.github/workflows/release-vm-kernel.yml index 5216a79c7..da4ec132d 100644 --- a/.github/workflows/release-vm-kernel.yml +++ b/.github/workflows/release-vm-kernel.yml @@ -47,7 +47,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Mark workspace safe for git run: git config --global --add safe.directory "$GITHUB_WORKSPACE" @@ -63,7 +63,7 @@ jobs: --output artifacts/vm-runtime-linux-aarch64.tar.zst - name: Upload runtime artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: vm-runtime-linux-arm64 path: artifacts/vm-runtime-linux-aarch64.tar.zst @@ -73,7 +73,7 @@ jobs: # the aarch64 Linux kernel as a byte array — it is OS-agnostic and can # be compiled into a .dylib by Apple's cc without rebuilding the kernel. - name: Upload kernel.c for macOS build - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: kernel-c-arm64 path: | @@ -97,7 +97,7 @@ jobs: env: MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Mark workspace safe for git run: git config --global --add safe.directory "$GITHUB_WORKSPACE" @@ -113,7 +113,7 @@ jobs: --output artifacts/vm-runtime-linux-x86_64.tar.zst - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: vm-runtime-linux-amd64 path: artifacts/vm-runtime-linux-x86_64.tar.zst @@ -130,7 +130,7 @@ jobs: env: RUSTC_WRAPPER: "" steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install dependencies run: | @@ -140,7 +140,7 @@ jobs: brew install lld dtc xz - name: Download pre-built kernel.c - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: name: kernel-c-arm64 path: target/kernel-artifact @@ -156,7 +156,7 @@ jobs: --output artifacts/vm-runtime-darwin-aarch64.tar.zst - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: vm-runtime-macos-arm64 path: artifacts/vm-runtime-darwin-aarch64.tar.zst @@ -176,17 +176,17 @@ jobs: attestations: write artifact-metadata: write steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Download all runtime artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: vm-runtime-* path: release/ merge-multiple: true - name: Attest VM runtime artifacts - uses: actions/attest@v4 + uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4 with: subject-path: | release/vm-runtime-linux-aarch64.tar.zst @@ -201,7 +201,7 @@ jobs: git push --force origin vm-runtime - name: Prune stale runtime assets from vm-runtime release - uses: actions/github-script@v9 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9 with: script: | const [owner, repo] = process.env.GITHUB_REPOSITORY.split('/'); @@ -224,7 +224,7 @@ jobs: } - name: Create / update vm-runtime GitHub Release - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3.0.0 with: name: OpenShell VM Runtime prerelease: true @@ -237,9 +237,9 @@ jobs: ### Kernel Runtime Artifacts - Pre-built kernel runtime (libkrunfw + libkrun + gvproxy) for embedding into - the `openshell-driver-vm` binary. These are rebuilt on demand when the kernel - config or pinned dependency versions change. + Pre-built kernel runtime (libkrunfw + libkrun + gvproxy + umoci) for embedding + into the `openshell-driver-vm` binary. These are rebuilt on demand when the + kernel config or pinned dependency versions change. | Platform | Artifact | |----------|----------| diff --git a/.github/workflows/required-ci-gates.yml b/.github/workflows/required-ci-gates.yml new file mode 100644 index 000000000..ca068cf5c --- /dev/null +++ b/.github/workflows/required-ci-gates.yml @@ -0,0 +1,233 @@ +name: Required CI Gates + +on: + pull_request_target: + types: [opened, synchronize, reopened, ready_for_review, labeled, unlabeled] + workflow_run: + workflows: + - Branch Checks + - Branch E2E Checks + - Helm Lint + types: [completed] + +permissions: + actions: read + contents: read + pull-requests: read + statuses: write + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.event.workflow_run.head_sha || github.run_id }} + cancel-in-progress: true + +jobs: + publish: + name: Publish required CI gate statuses + runs-on: ubuntu-latest + steps: + - name: Evaluate required CI gates + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GH_REPO: ${{ github.repository }} + EVENT_NAME: ${{ github.event_name }} + PR_NUMBER_FROM_EVENT: ${{ github.event.pull_request.number }} + PR_HEAD_SHA_FROM_EVENT: ${{ github.event.pull_request.head.sha }} + PR_LABELS_FROM_EVENT: ${{ toJSON(github.event.pull_request.labels.*.name) }} + WORKFLOW_RUN_HEAD_SHA: ${{ github.event.workflow_run.head_sha }} + WORKFLOW_RUN_HEAD_BRANCH: ${{ github.event.workflow_run.head_branch }} + WORKFLOW_RUN_EVENT: ${{ github.event.workflow_run.event }} + shell: bash + run: | + set -euo pipefail + + post_status() { + local context="$1" + local state="$2" + local description="$3" + local target_url="${4:-}" + + args=( + --method POST + "repos/$GH_REPO/statuses/$HEAD_SHA" + -f "state=$state" + -f "context=$context" + -f "description=$description" + ) + if [ -n "$target_url" ]; then + args+=(-f "target_url=$target_url") + fi + + echo "$context: $state - $description" + gh api "${args[@]}" >/dev/null + } + + has_label() { + local label="$1" + jq -e --arg label "$label" 'index($label) != null' <<< "$LABELS_JSON" >/dev/null + } + + resolve_pull_request_event() { + PR_NUMBER="$PR_NUMBER_FROM_EVENT" + HEAD_SHA="$PR_HEAD_SHA_FROM_EVENT" + LABELS_JSON=$(jq -c . <<< "$PR_LABELS_FROM_EVENT") + } + + load_pr_context() { + PR_NUMBER="$1" + + local pr state + pr=$(gh api "repos/$GH_REPO/pulls/$PR_NUMBER") + state=$(jq -r '.state' <<< "$pr") + if [ "$state" != "open" ]; then + echo "PR #$PR_NUMBER is $state; nothing to publish." + exit 0 + fi + + HEAD_SHA=$(jq -r '.head.sha' <<< "$pr") + LABELS_JSON=$(gh api "repos/$GH_REPO/issues/$PR_NUMBER" --jq '[.labels[].name]') + } + + resolve_workflow_run_event() { + if [ "$WORKFLOW_RUN_EVENT" != "push" ]; then + echo "Ignoring workflow_run from event '$WORKFLOW_RUN_EVENT'." + exit 0 + fi + + if [[ "$WORKFLOW_RUN_HEAD_BRANCH" =~ ^pull-request/([0-9]+)$ ]]; then + load_pr_context "${BASH_REMATCH[1]}" + return + fi + + local associated_prs pr + associated_prs=$(gh api "repos/$GH_REPO/commits/$WORKFLOW_RUN_HEAD_SHA/pulls") + pr=$(jq -c 'map(select(.state == "open"))[0] // empty' <<< "$associated_prs") + if [ -z "$pr" ]; then + echo "No open PR associated with $WORKFLOW_RUN_HEAD_SHA; nothing to publish." + exit 0 + fi + + load_pr_context "$(jq -r '.number' <<< "$pr")" + } + + resolve_context() { + if [ "$EVENT_NAME" = "pull_request_target" ]; then + resolve_pull_request_event + elif [ "$EVENT_NAME" = "workflow_run" ]; then + resolve_workflow_run_event + else + echo "Unsupported event '$EVENT_NAME'." + exit 1 + fi + + PR_URL="https://github.com/$GH_REPO/pull/$PR_NUMBER" + MIRROR_REF="pull-request/$PR_NUMBER" + } + + verify_mirror() { + local context="$1" + local mirror_sha + + mirror_sha=$(gh api "repos/$GH_REPO/branches/$MIRROR_REF" --jq '.commit.sha' 2>/dev/null || true) + if [ -z "$mirror_sha" ]; then + post_status "$context" pending "Waiting for /ok to test mirror" "$PR_URL" + return 1 + fi + + if [ "$mirror_sha" != "$HEAD_SHA" ]; then + post_status "$context" pending "Waiting for /ok to test mirror" "$PR_URL" + return 1 + fi + + return 0 + } + + evaluate_workflow() { + local context="$1" + local workflow_file="$2" + local workflow_name="$3" + local required_label="${4:-}" + local required_job_name="${5:-}" + local workflow_url="https://github.com/$GH_REPO/actions/workflows/$workflow_file" + + if [ -n "$required_label" ] && ! has_label "$required_label"; then + post_status "$context" success "$required_label not applied" "$PR_URL" + return 0 + fi + + if ! verify_mirror "$context"; then + return 0 + fi + + local runs latest run_id status conclusion run_url real_success + runs=$(gh api "repos/$GH_REPO/actions/workflows/$workflow_file/runs?head_sha=$HEAD_SHA&event=push" --jq '.workflow_runs') + latest=$(jq -c --arg branch "$MIRROR_REF" '[.[] | select(.head_branch == $branch)] | sort_by(.created_at) | reverse | .[0] // empty' <<< "$runs") + + if [ -z "$latest" ]; then + post_status "$context" pending "Waiting for $workflow_name" "$workflow_url" + return 0 + fi + + run_id=$(jq -r '.id' <<< "$latest") + status=$(jq -r '.status' <<< "$latest") + conclusion=$(jq -r '.conclusion' <<< "$latest") + run_url=$(jq -r '.html_url' <<< "$latest") + + if [ "$status" != "completed" ]; then + post_status "$context" pending "$workflow_name is $status" "$run_url" + return 0 + fi + + if [ -n "$required_job_name" ]; then + local jobs required_job job_status job_conclusion + jobs=$(gh api "repos/$GH_REPO/actions/runs/$run_id/jobs?per_page=100" --jq '.jobs') + required_job=$(jq -c --arg name "$required_job_name" '[.[] | select(.name == $name)] | .[0] // empty' <<< "$jobs") + + if [ -z "$required_job" ]; then + if [ "$conclusion" = "success" ]; then + post_status "$context" pending "Waiting for $required_job_name" "$run_url" + else + post_status "$context" failure "$required_job_name did not run" "$run_url" + fi + return 0 + fi + + job_status=$(jq -r '.status' <<< "$required_job") + job_conclusion=$(jq -r '.conclusion' <<< "$required_job") + + if [ "$job_status" != "completed" ]; then + post_status "$context" pending "$required_job_name is $job_status" "$run_url" + return 0 + fi + + if [ "$job_conclusion" = "success" ]; then + post_status "$context" success "$required_job_name passed" "$run_url" + elif [ "$job_conclusion" = "skipped" ] && [ "$conclusion" = "success" ]; then + post_status "$context" pending "Waiting for $required_job_name" "$run_url" + else + post_status "$context" failure "$required_job_name concluded $job_conclusion" "$run_url" + fi + return 0 + fi + + if [ "$conclusion" != "success" ]; then + post_status "$context" failure "$workflow_name concluded $conclusion" "$run_url" + return 0 + fi + + real_success=$(gh api "repos/$GH_REPO/actions/runs/$run_id/jobs?per_page=100" \ + --jq '[.jobs[] | select(.conclusion == "success" and .name != "Resolve PR metadata")] | length') + + if [ "$real_success" -lt 1 ]; then + post_status "$context" failure "No real CI jobs ran" "$run_url" + return 0 + fi + + post_status "$context" success "$workflow_name passed" "$run_url" + } + + resolve_context + + evaluate_workflow "OpenShell / Branch Checks" "branch-checks.yml" "Branch Checks" + evaluate_workflow "OpenShell / E2E" "branch-e2e.yml" "Branch E2E Checks" "test:e2e" "Core E2E result" + evaluate_workflow "OpenShell / GPU E2E" "branch-e2e.yml" "Branch E2E Checks" "test:e2e-gpu" "GPU E2E result" + evaluate_workflow "OpenShell / Helm Lint" "helm-lint.yml" "Helm Lint" diff --git a/.github/workflows/rpm-package.yml b/.github/workflows/rpm-package.yml index e0607c3ff..078563c4e 100644 --- a/.github/workflows/rpm-package.yml +++ b/.github/workflows/rpm-package.yml @@ -37,9 +37,15 @@ jobs: matrix: include: - arch: x86_64 + artifact_arch: amd64 runner: linux-amd64-cpu8 + cli_target: x86_64-unknown-linux-musl + gnu_target: x86_64-unknown-linux-gnu - arch: aarch64 + artifact_arch: arm64 runner: linux-arm64-cpu8 + cli_target: aarch64-unknown-linux-musl + gnu_target: aarch64-unknown-linux-gnu runs-on: ${{ matrix.runner }} timeout-minutes: 60 container: @@ -54,11 +60,31 @@ jobs: pandoc python3-devel git-core \ cargo-rpm-macros - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs.checkout-ref }} fetch-depth: 0 + - name: Download CLI artifact + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: cli-linux-${{ matrix.artifact_arch }} + path: package-input/ + + - name: Download gateway artifact + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: gateway-binary-linux-${{ matrix.artifact_arch }} + path: package-input/ + + - name: Extract package inputs + run: | + set -euo pipefail + mkdir -p package-binaries + tar -xzf "package-input/openshell-${{ matrix.cli_target }}.tar.gz" -C package-binaries + tar -xzf "package-input/openshell-gateway-${{ matrix.gnu_target }}.tar.gz" -C package-binaries + ls -lah package-binaries + - name: Mark workspace safe for git run: git config --global --add safe.directory "$GITHUB_WORKSPACE" @@ -70,6 +96,7 @@ jobs: OPENSHELL_RPM_VERSION: ${{ inputs['rpm-version'] }} OPENSHELL_RPM_RELEASE: ${{ inputs['rpm-release'] }} OPENSHELL_CARGO_VERSION: ${{ inputs['cargo-version'] }} + OPENSHELL_PREBUILT_BINARIES_DIR: ${{ github.workspace }}/package-binaries run: packit build locally - name: Collect RPM artifacts @@ -87,7 +114,7 @@ jobs: ls -lah artifacts/ - name: Upload RPM artifacts - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: rpm-linux-${{ matrix.arch }} path: artifacts/*.rpm diff --git a/.github/workflows/rust-cache-seed.yml b/.github/workflows/rust-cache-seed.yml new file mode 100644 index 000000000..3b369c611 --- /dev/null +++ b/.github/workflows/rust-cache-seed.yml @@ -0,0 +1,72 @@ +name: Rust Cache Seed + +on: + push: + branches: + - main + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + CARGO_INCREMENTAL: "0" + MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + SCCACHE_GHA_ENABLED: "true" + +permissions: + contents: read + packages: read + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + rust: + name: Rust (${{ matrix.runner }}) + strategy: + fail-fast: false + matrix: + runner: [linux-amd64-cpu8, linux-arm64-cpu8] + runs-on: ${{ matrix.runner }} + env: + SCCACHE_GHA_VERSION: branch-checks-rust-${{ matrix.runner }} + container: + image: ghcr.io/nvidia/openshell/ci:latest + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + + - name: Install tools + run: mise install --locked + + - name: Configure GHA sccache backend + uses: mozilla-actions/sccache-action@9e7fa8a12102821edf02ca5dbea1acd0f89a2696 # v0.0.10 + + - name: Cache Rust target and registry + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2 + with: + shared-key: rust-checks-${{ matrix.runner }} + cache-on-failure: true + + - name: Format + run: mise run rust:format:check + + - name: Lint + run: mise run rust:lint + + - name: Test + run: mise run test:rust + + - name: sccache stats + if: always() + run: | + set +e + stats_bin="${SCCACHE_PATH:-sccache}" + "$stats_bin" --show-stats + status=$? + if [ "$status" -ne 0 ]; then + echo "::warning::sccache stats unavailable (exit $status)" + fi + exit 0 diff --git a/.github/workflows/shadow-rust-native-build.yml b/.github/workflows/rust-native-build.yml similarity index 64% rename from .github/workflows/shadow-rust-native-build.yml rename to .github/workflows/rust-native-build.yml index b943a1ddb..682d5eb88 100644 --- a/.github/workflows/shadow-rust-native-build.yml +++ b/.github/workflows/rust-native-build.yml @@ -1,10 +1,14 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -name: Rust Native Build (openshell-gateway / openshell-sandbox) +name: Rust Image Binary Build (openshell-gateway / openshell-sandbox) -# OS-128 Phase 4: build Rust binaries natively per Linux architecture before -# the Docker image build consumes them as prebuilt artifacts. +# Build Rust binaries per Linux architecture before the Docker image build +# consumes them as prebuilt artifacts. Gateway images use GNU-linked binaries +# for the NVIDIA distroless C/C++ runtime; supervisor images use musl/static +# binaries so the final image can remain scratch. Gateway GNU binaries are +# built with an explicit glibc 2.31 floor so image, package, and tarball +# artifacts share the same host portability contract. on: workflow_call: @@ -42,50 +46,11 @@ on: required: false type: string default: "" - workflow_dispatch: - inputs: - component: - description: "Binary component to build" - required: true - type: choice - default: gateway - options: - - gateway - - sandbox - arch: - description: "Linux architecture to build" - required: true - type: choice - default: amd64 - options: - - amd64 - - arm64 - cargo-version: - description: "Cargo version override" - required: false - type: string - default: "" - features: - description: "Cargo features to enable" - required: false - type: string - default: "openshell-core/dev-settings" - retention-days: - description: "Artifact retention period" - required: false - type: number - default: 5 - artifact-name: - description: "Artifact name override" - required: false - type: string - default: "" - checkout-ref: - description: "Git ref to check out for build inputs (defaults to the workflow SHA)" + image-tag: + description: "Supervisor image tag to bake into gateway binaries" required: false type: string default: "" - permissions: contents: read packages: read @@ -113,7 +78,7 @@ jobs: FEATURES: ${{ inputs.features }} # Partition the GHA sccache cache per (component, arch). Without this, # concurrent jobs collide on the same cache key and later-starting - # writers hit 409 Conflict (PR #961 fix for shadow-shared-cpu-spike). + # writers hit 409 Conflict. SCCACHE_GHA_VERSION: ${{ inputs.component }}-${{ inputs.arch }} container: image: ghcr.io/nvidia/openshell/ci:latest @@ -121,7 +86,7 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ inputs['checkout-ref'] || github.sha }} fetch-depth: 0 @@ -132,6 +97,11 @@ jobs: - name: Fetch tags run: git fetch --tags --force + - name: Configure GHA sccache backend + # Exposes ACTIONS_CACHE_URL / ACTIONS_RUNTIME_TOKEN before `mise install` + # compiles cargo-installed tools through RUSTC_WRAPPER=sccache. + uses: mozilla-actions/sccache-action@9e7fa8a12102821edf02ca5dbea1acd0f89a2696 # v0.0.10 + - name: Install tools run: mise install --locked @@ -144,10 +114,12 @@ jobs: gateway) crate=openshell-server binary=openshell-gateway + zig_target= ;; sandbox) crate=openshell-sandbox binary=openshell-sandbox + zig_target= ;; *) echo "unsupported component: $COMPONENT" >&2 @@ -157,10 +129,22 @@ jobs: case "$ARCH" in amd64) - target=x86_64-unknown-linux-gnu + if [[ "$COMPONENT" == "sandbox" ]]; then + target=x86_64-unknown-linux-musl + zig_target=x86_64-linux-musl + else + target=x86_64-unknown-linux-gnu + zig_target=x86_64-unknown-linux-gnu.2.31 + fi ;; arm64) - target=aarch64-unknown-linux-gnu + if [[ "$COMPONENT" == "sandbox" ]]; then + target=aarch64-unknown-linux-musl + zig_target=aarch64-linux-musl + else + target=aarch64-unknown-linux-gnu + zig_target=aarch64-unknown-linux-gnu.2.31 + fi ;; *) echo "unsupported arch: $ARCH" >&2 @@ -172,13 +156,9 @@ jobs: echo "crate=$crate" echo "binary=$binary" echo "target=$target" + echo "zig_target=$zig_target" } >> "$GITHUB_OUTPUT" - - name: Configure GHA sccache backend - # Exposes ACTIONS_CACHE_URL / ACTIONS_RUNTIME_TOKEN so sccache (wrapped - # around rustc via mise's RUSTC_WRAPPER) can initialize the GHA cache. - uses: mozilla-actions/sccache-action@9e7fa8a12102821edf02ca5dbea1acd0f89a2696 # v0.0.10 - - name: Cache Rust target and registry uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2 with: @@ -202,16 +182,48 @@ jobs: set -euo pipefail sed -i -E '/^\[workspace\.package\]/,/^\[/{s/^version[[:space:]]*=[[:space:]]*".*"/version = "'"${{ steps.version.outputs.cargo_version }}"'"/}' Cargo.toml - - name: Build ${{ steps.target.outputs.binary }} (${{ steps.target.outputs.target }}) + - name: Set up zig musl wrappers + if: contains(steps.target.outputs.target, 'musl') + run: | + set -euo pipefail + ZIG="$(mise which zig)" + ZIG_TARGET="${{ steps.target.outputs.zig_target }}" + mkdir -p /tmp/zig-musl + + # cc-rs injects --target=, which zig does not parse. + # Strip caller-provided --target and use the wrapper's zig target. + for tool in cc c++; do + printf '#!/bin/bash\nargs=()\nfor arg in "$@"; do\n case "$arg" in\n --target=*) ;;\n *) args+=("$arg") ;;\n esac\ndone\nexec "%s" %s --target=%s "${args[@]}"\n' \ + "$ZIG" "$tool" "$ZIG_TARGET" > "/tmp/zig-musl/${tool}" + chmod +x "/tmp/zig-musl/${tool}" + done + + TARGET_ENV=$(echo "${{ steps.target.outputs.target }}" | tr '-' '_') + TARGET_ENV_UPPER=${TARGET_ENV^^} + + echo "CC_${TARGET_ENV}=/tmp/zig-musl/cc" >> "$GITHUB_ENV" + echo "CXX_${TARGET_ENV}=/tmp/zig-musl/c++" >> "$GITHUB_ENV" + echo "CARGO_TARGET_${TARGET_ENV_UPPER}_LINKER=/tmp/zig-musl/cc" >> "$GITHUB_ENV" + echo "CARGO_TARGET_${TARGET_ENV_UPPER}_RUSTFLAGS=-Clink-self-contained=no" >> "$GITHUB_ENV" + + - name: Build ${{ steps.target.outputs.binary }} (${{ steps.target.outputs.zig_target || steps.target.outputs.target }}) env: # Preserve the release-codegen setting used by the old Dockerfile # Rust build path so image artifacts keep the same release profile. CARGO_PROFILE_RELEASE_CODEGEN_UNITS: "1" + OPENSHELL_IMAGE_TAG: ${{ inputs['image-tag'] }} run: | set -euo pipefail + mise x -- rustup target add "${{ steps.target.outputs.target }}" + cargo_cmd=(cargo build) + build_target="${{ steps.target.outputs.target }}" + if [[ "${{ inputs.component }}" == "gateway" ]]; then + cargo_cmd=(cargo zigbuild) + build_target="${{ steps.target.outputs.zig_target }}" + fi args=( --release - --target "${{ steps.target.outputs.target }}" + --target "$build_target" -p "${{ steps.target.outputs.crate }}" --bin "${{ steps.target.outputs.binary }}" ) @@ -221,7 +233,7 @@ jobs: if [[ -n "${{ steps.version.outputs.cargo_version }}" ]]; then export GIT_DIR=/nonexistent fi - mise x -- cargo build "${args[@]}" + mise x -- "${cargo_cmd[@]}" "${args[@]}" - name: Verify packaged binary run: | @@ -230,11 +242,17 @@ jobs: OUTPUT="$("$BIN" --version)" echo "$OUTPUT" grep -q "^${{ steps.target.outputs.binary }} " <<<"$OUTPUT" - # Record glibc linkage so drift from the Ubuntu noble runtime base - # image is visible in logs. + # Record linkage so image runtime drift is visible in logs. ldd --version ldd "$BIN" || true + - name: Verify glibc symbol floor + if: inputs.component == 'gateway' + run: | + set -euo pipefail + BIN="target/${{ steps.target.outputs.target }}/release/${{ steps.target.outputs.binary }}" + tasks/scripts/verify-glibc-symbols.sh 2.31 "$BIN" + - name: Stage binary for prebuilt layout run: | set -euo pipefail @@ -246,7 +264,7 @@ jobs: ls -lh "$STAGE/" - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 with: name: ${{ inputs['artifact-name'] != '' && inputs['artifact-name'] || format('rust-binary-{0}-linux-{1}', inputs.component, inputs.arch) }} path: prebuilt-binaries/${{ inputs.arch }}/${{ steps.target.outputs.binary }} diff --git a/.github/workflows/shadow-docker-build.yml b/.github/workflows/shadow-docker-build.yml deleted file mode 100644 index 3c7642ab3..000000000 --- a/.github/workflows/shadow-docker-build.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Shadow Docker Build - -# OS-128 Phase 4: manual non-publishing exercise of the production Docker -# image workflow. This stays off main's push surface because the image path is -# not a required signal while the prebuilt-binary rollout is being measured. - -on: - workflow_dispatch: - inputs: - platform: - description: "Target platform(s)" - required: false - type: string - default: "linux/amd64,linux/arm64" - -permissions: - contents: read - packages: write - -jobs: - gateway: - uses: ./.github/workflows/docker-build.yml - with: - component: gateway - platform: ${{ inputs.platform }} - push: false - secrets: inherit - - supervisor: - uses: ./.github/workflows/docker-build.yml - with: - component: supervisor - platform: ${{ inputs.platform }} - push: false - secrets: inherit diff --git a/.github/workflows/shadow-shared-cpu-spike.yml b/.github/workflows/shadow-shared-cpu-spike.yml deleted file mode 100644 index 5a072c8e1..000000000 --- a/.github/workflows/shadow-shared-cpu-spike.yml +++ /dev/null @@ -1,73 +0,0 @@ -name: Shadow — Shared CPU Spike - -# OS-49 Phase 2 / PR 1 — non-blocking spike. Runs `cargo check` on the -# nv-gha-runners shared CPU pool (`linux-{amd64,arm64}-cpu8`) with a -# GHA-backed sccache. -# -# Plan, decision thresholds, and results live in the Linear doc attached -# to OS-126 ("OS-126 — Shared CPU spike plan & results"). Dispatch this -# workflow 4–5 times after merge and record numbers there. - -on: - workflow_dispatch: - -permissions: - contents: read - packages: read - -env: - CARGO_TERM_COLOR: always - CARGO_INCREMENTAL: "0" - MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # Wire sccache (already RUSTC_WRAPPER in mise.toml) to the GHA cache - # backend instead of the EKS memcached used by ARC. - SCCACHE_GHA_ENABLED: "true" - -jobs: - rust-check: - name: cargo check (${{ matrix.runner }}) - strategy: - fail-fast: false - matrix: - runner: [linux-amd64-cpu8, linux-arm64-cpu8] - runs-on: ${{ matrix.runner }} - env: - # Partition the GHA sccache cache per-arch. Without this, matrix jobs - # collide on the same cache key, and the later-starting job's writes - # fail with 409 Conflict while the earlier one's writes land — leaving - # subsequent runs with asymmetric and mostly-empty caches. - # (Run 1 on 2026-04-24 showed amd64 with 0/1062 writes succeeding - # while arm64 got 575/1064.) - SCCACHE_GHA_VERSION: ${{ matrix.runner }} - container: - image: ghcr.io/nvidia/openshell/ci:latest - credentials: - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - timeout-minutes: 60 - steps: - - uses: actions/checkout@v6 - - - name: Install tools - run: mise install - - - name: Configure GHA sccache backend - # Exposes ACTIONS_CACHE_URL / ACTIONS_RUNTIME_TOKEN to subsequent steps - # so sccache (wrapped around rustc via RUSTC_WRAPPER in mise.toml) can - # initialize the GHA cache. Without this, sccache fails at startup with - # "cache url for ghac not found". The action also installs its own - # sccache binary; harmless since mise's sccache remains on PATH. - uses: mozilla-actions/sccache-action@9e7fa8a12102821edf02ca5dbea1acd0f89a2696 # v0.0.10 - - - name: Cache Rust target and registry - uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2 - with: - shared-key: shadow-shared-cpu-spike-${{ matrix.runner }} - cache-directories: .cache/sccache - - - name: cargo check - run: mise x -- cargo check --workspace --all-targets - - - name: sccache stats - if: always() - run: mise x -- sccache --show-stats diff --git a/.github/workflows/test-gpu.yml b/.github/workflows/test-gpu.yml deleted file mode 100644 index 37fdcbb94..000000000 --- a/.github/workflows/test-gpu.yml +++ /dev/null @@ -1,46 +0,0 @@ -name: GPU Test - -on: - push: - branches: - - "pull-request/[0-9]+" - workflow_dispatch: {} - # Add `schedule:` here when we want nightly coverage from the same workflow. - -permissions: {} - -jobs: - pr_metadata: - name: Resolve PR metadata - runs-on: ubuntu-latest - permissions: - contents: read - pull-requests: read - outputs: - should_run: ${{ steps.gate.outputs.should_run }} - steps: - - uses: actions/checkout@v6 - - id: gate - uses: ./.github/actions/pr-gate - with: - required_label: test:e2e-gpu - - build-supervisor: - needs: [pr_metadata] - if: needs.pr_metadata.outputs.should_run == 'true' - permissions: - contents: read - packages: write - uses: ./.github/workflows/docker-build.yml - with: - component: supervisor - - e2e-gpu: - needs: [pr_metadata, build-supervisor] - if: needs.pr_metadata.outputs.should_run == 'true' - permissions: - contents: read - packages: read - uses: ./.github/workflows/e2e-gpu-test.yaml - with: - image-tag: ${{ github.sha }} diff --git a/.github/workflows/test-install.yml b/.github/workflows/test-install.yml deleted file mode 100644 index 06b1e007f..000000000 --- a/.github/workflows/test-install.yml +++ /dev/null @@ -1,51 +0,0 @@ -name: Test Install Script - -on: - pull_request: - paths: - - 'install.sh' - - 'e2e/install/**' - - '.github/workflows/test-install.yml' - push: - branches: [main] - paths: - - 'install.sh' - - 'e2e/install/**' - - '.github/workflows/test-install.yml' - workflow_dispatch: - -permissions: - contents: read - -jobs: - test-install: - name: install.sh (${{ matrix.shell }}) - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - include: - - shell: sh - test: e2e/install/sh_test.sh - run: sh e2e/install/sh_test.sh - - shell: bash - test: e2e/install/bash_test.sh - run: bash e2e/install/bash_test.sh - - shell: zsh - test: e2e/install/zsh_test.sh - run: zsh e2e/install/zsh_test.sh - install: zsh - - shell: fish - test: e2e/install/fish_test.fish - run: fish e2e/install/fish_test.fish - install: fish - - steps: - - uses: actions/checkout@v6 - - - name: Install ${{ matrix.shell }} - if: matrix.install - run: sudo apt-get update && sudo apt-get install -y ${{ matrix.install }} - - - name: Run tests (${{ matrix.shell }}) - run: ${{ matrix.run }} diff --git a/.github/workflows/vouch-check.yml b/.github/workflows/vouch-check.yml index db7a540eb..2eeeb949f 100644 --- a/.github/workflows/vouch-check.yml +++ b/.github/workflows/vouch-check.yml @@ -18,7 +18,7 @@ jobs: - name: Check org membership id: org-check if: env.ORG_READ_TOKEN != '' - uses: actions/github-script@v9 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9 with: github-token: ${{ secrets.ORG_READ_TOKEN }} result-encoding: string @@ -42,7 +42,7 @@ jobs: - name: Check if contributor is vouched if: steps.org-check.outputs.result != 'skip' - uses: actions/github-script@v9 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9 with: script: | const author = context.payload.pull_request.user.login; diff --git a/.github/workflows/vouch-command.yml b/.github/workflows/vouch-command.yml index 309a4ae36..366dd6a0e 100644 --- a/.github/workflows/vouch-command.yml +++ b/.github/workflows/vouch-command.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Process /vouch command - uses: actions/github-script@v9 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9 with: script: | const commenter = context.payload.comment.user.login; diff --git a/.gitignore b/.gitignore index 1b37bfd49..24a77fce2 100644 --- a/.gitignore +++ b/.gitignore @@ -196,6 +196,9 @@ artifacts/ # Local mise settings mise.local.toml +# Local Codex app state +.codex/ + # Ignore plans for now architecture/plans @@ -212,5 +215,9 @@ rfc.md *.tar.xz *.tar.bz2 +# Snap build artifacts +*.snap +*.comp + # Markdown/mermaid lint tooling deps scripts/lint-mermaid/node_modules/ diff --git a/.markdownlint-cli2.jsonc b/.markdownlint-cli2.jsonc index 4c7f68e5a..125df0f81 100644 --- a/.markdownlint-cli2.jsonc +++ b/.markdownlint-cli2.jsonc @@ -1,14 +1,25 @@ { "globs": [ - "**/*.md", - "**/*.mdx" + "*.md", + "architecture/**/*.md", + "crates/**/*.md", + "deploy/**/*.md", + "docs/**/*.md", + "docs/**/*.mdx", + "e2e/**/*.md", + "examples/**/*.md", + "rfc/**/*.md" ], - "gitignore": true, + "gitignore": false, "ignores": [ ".agents/**", ".claude/**", ".opencode/**", ".github/**", + "architecture/plans/**", + "**/node_modules/**", + "target/**", + ".pytest_cache/**", "THIRD-PARTY-NOTICES/**", "CLAUDE.md", // Man page sources use pandoc markdown with multiple H1 sections diff --git a/.packit.yaml b/.packit.yaml index 3a608111b..5d1f65063 100644 --- a/.packit.yaml +++ b/.packit.yaml @@ -35,12 +35,10 @@ actions: - 'bash -c "echo openshell-${PACKIT_PROJECT_VERSION}.tar.gz"' fix-spec-file: - # Update Source0 to the generated tarball name - - 'bash -c "sed -i \"s|^Source0:.*|Source0: openshell-${PACKIT_PROJECT_VERSION}.tar.gz|\" openshell.spec"' - # Update Source1 to the generated vendor tarball name - - 'bash -c "sed -i \"s|^Source1:.*|Source1: openshell-${PACKIT_PROJECT_VERSION}-vendor.tar.xz|\" openshell.spec"' - # Update Version - - 'bash -c "sed -i -r \"s/^Version:(\\s*)\\S+/Version:\\1${PACKIT_RPMSPEC_VERSION}/\" openshell.spec"' + # Update the canonical version macro. Version:, Source0:, Source1:, and all + # other version references expand from %{openshell_version} so only this + # one line needs updating. + - 'bash -c "sed -i -r \"s/^%global openshell_version .*/%global openshell_version ${PACKIT_RPMSPEC_VERSION}/\" openshell.spec"' # Update Release - 'bash -c "RELEASE=${OPENSHELL_RPM_RELEASE:-${PACKIT_RPMSPEC_RELEASE}} && sed -i -r \"s/^Release:(\\s*)\\S+/Release:\\1${RELEASE}%{?dist}/\" openshell.spec"' # Keep embedded binary metadata aligned with the release workflow. Python diff --git a/CI.md b/CI.md index 57e6627ed..a7ca79c9d 100644 --- a/CI.md +++ b/CI.md @@ -8,14 +8,17 @@ For local test commands see [TESTING.md](TESTING.md). For PR conventions see [CO PR CI that runs on NVIDIA self-hosted runners uses NVIDIA's copy-pr-bot. The bot mirrors trusted PR commits to internal `pull-request/` branches in this repository. The gated workflows trigger on pushes to those branches, not on the original PR. -`Branch Checks` run automatically after copy-pr-bot mirrors the PR. E2E suites are opt-in because they are more expensive and publish temporary images. +`Branch Checks` run automatically after copy-pr-bot mirrors the PR. `Required CI Gates` posts PR-head statuses that verify the mirror exists, is current, and ran the expected push-based workflows. E2E suites are opt-in because they are more expensive and publish temporary images. -Two opt-in labels enable the suites: +Two opt-in labels enable the long-running E2E suites: -- `test:e2e` runs `Branch E2E Checks` (non-GPU E2E) -- `test:e2e-gpu` runs `GPU Test` +- `test:e2e` runs the standard E2E suite in `Branch E2E Checks` +- `test:e2e-gpu` runs GPU E2E in `Branch E2E Checks` -Both are required to merge once the corresponding `E2E Gate` checks are marked required in branch protection. +When both labels are present, `Branch E2E Checks` builds the shared gateway and supervisor images once and fans out all enabled suites in parallel. +The `OpenShell / E2E` and `OpenShell / GPU E2E` required statuses are evaluated from separate suite result jobs inside that workflow, so the expensive GPU suite stays independently gated. + +The GitHub ruleset should require the `OpenShell / ...` statuses published by `Required CI Gates`, not the push-triggered workflow jobs directly. ## Commit signing @@ -65,11 +68,11 @@ Prerequisites: Flow: 1. Open the PR. copy-pr-bot mirrors it to `pull-request/` automatically. -2. The mirror push runs `Branch Checks` automatically. The first `Branch E2E Checks` / `GPU Test` run only resolves metadata and skips expensive jobs unless the matching label is already set. +2. The mirror push runs `Branch Checks` automatically. `Required CI Gates` keeps the PR blocked until the mirror exists, matches the PR head SHA, and the required push-based workflow succeeds. The first `Branch E2E Checks` run only resolves metadata and skips expensive jobs unless an E2E label is already set. 3. A maintainer applies `test:e2e` and/or `test:e2e-gpu`. `E2E Label Help` posts a comment with a link to the existing gated workflow run. 4. The maintainer opens that link and clicks **Re-run all jobs**. This time `pr_metadata` sees the label and the build/E2E jobs run. -5. When the run finishes, the `E2E Gate` check on the PR flips to green automatically. -6. New commits push to the mirror automatically and re-trigger `Branch Checks` plus any labeled E2E/GPU workflows. +5. When the run finishes, the matching `OpenShell / ...` gate status flips to green automatically. +6. New commits push to the mirror automatically and re-trigger `Branch Checks` plus any labeled E2E jobs in `Branch E2E Checks`. ### Forked PR @@ -82,9 +85,9 @@ Flow: 1. Open the PR. The vouch check confirms you are vouched (otherwise the PR is auto-closed). 2. copy-pr-bot does not mirror forks automatically. A maintainer reviews the diff and comments `/ok to test ` with your latest commit SHA. -3. After `/ok to test`, copy-pr-bot mirrors to `pull-request/`. From here the flow is identical to internal PRs: maintainer applies the label, follows the comment from `E2E Label Help`, and re-runs the workflow. +3. After `/ok to test`, copy-pr-bot mirrors to `pull-request/`. From here the flow is identical to internal PRs: `Required CI Gates` verifies the mirror and required push workflows, and maintainers apply the E2E label when the extra suites are needed. -Important: every new commit you push requires another `/ok to test ` from a maintainer before E2E will run on it. If a label is applied while the mirror is stale, `E2E Label Help` will post a comment explaining what's needed. +Important: every new commit you push requires another `/ok to test ` from a maintainer before push-based CI will run on it. If a label is applied while the mirror is stale, `E2E Label Help` will post a comment explaining what's needed. ## copy-pr-bot @@ -105,9 +108,30 @@ The bot's full administrator documentation is internal to NVIDIA. The only comma | File | Role | |---|---| | `.github/workflows/branch-checks.yml` | Required non-E2E PR checks. Triggers on `push: pull-request/[0-9]+`. | -| `.github/workflows/branch-e2e.yml` | Non-GPU E2E. Triggers on `push: pull-request/[0-9]+`. | -| `.github/workflows/test-gpu.yml` | GPU E2E. Triggers on `push: pull-request/[0-9]+`. | +| `.github/workflows/branch-e2e.yml` | Opt-in standard and GPU E2E. Triggers on `push: pull-request/[0-9]+` and runs jobs selected by `test:e2e` / `test:e2e-gpu`. | +| `.github/workflows/helm-lint.yml` | Helm chart validation. Triggers on `push: pull-request/[0-9]+` and skips lint jobs unless Helm inputs changed. | | `.github/actions/pr-gate/action.yml` | Composite action that resolves PR metadata and verifies the required label is set. | -| `.github/workflows/e2e-gate.yml` | Posts the required `E2E Gate` check on the PR. Re-evaluates after the gated workflow completes. | -| `.github/workflows/e2e-gate-check.yml` | Reusable gate logic shared by E2E and GPU E2E. | +| `.github/actions/pr-merge-base/action.yml` | Composite action that resolves and fetches the merge-base commit for `pull-request/` push workflows. | +| `.github/workflows/required-ci-gates.yml` | Posts required PR-head statuses for push-based CI workflows. This is what branch protection should require. | | `.github/workflows/e2e-label-help.yml` | When a `test:e2e*` label is applied, posts a PR comment telling the maintainer the next manual step (re-run an existing workflow run, or `/ok to test ` to refresh the mirror). | + +## Release workflows + +These workflows run after merge to publish dev/tagged artifacts and verify them. They are not PR-gated. + +| File | Role | +|---|---| +| `.github/workflows/release-dev.yml` | Publishes the rolling `dev` build on every push to `main`. Builds gateway/supervisor images and binaries, packages, wheels, and pushes the Helm chart as `oci://ghcr.io/nvidia/openshell/helm-chart:0.0.0-dev` (plus an immutable `0.0.0-dev.` pin). Also dispatchable manually. | +| `.github/workflows/release-tag.yml` | Publishes a tagged public release. | +| `.github/workflows/release-canary.yml` | Smoke-tests published artifacts on `macos`, `ubuntu`, `fedora`, and `kubernetes` (kind + Helm) runners. Triggers automatically when `Release Dev` succeeds, and via `workflow_dispatch` on any branch (`gh workflow run release-canary.yml --ref `). The `kubernetes` job pins to `0.0.0-dev` artifacts; the other jobs install the latest tagged release via `install.sh`. See the `test-release-canary` skill for the manual-dispatch playbook and local kind reproduction. | + +## Required status contexts + +Require these statuses in the branch ruleset for push-based CI: + +- `OpenShell / Branch Checks` +- `OpenShell / E2E` +- `OpenShell / GPU E2E` +- `OpenShell / Helm Lint` + +Do not require the underlying push workflow jobs directly. Those jobs only appear after copy-pr-bot mirrors trusted code, so they cannot independently prove that an untrusted or stale PR head was tested. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5c091c6c4..0f42d3469 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -75,6 +75,7 @@ Skills live in `.agents/skills/`. Your agent's harness can discover and load the | Reviewing | `review-github-pr` | Summarize PR diffs and key design decisions | | Reviewing | `review-security-issue` | Assess security issues for severity and remediation | | Reviewing | `watch-github-actions` | Monitor CI pipeline status and logs | +| Reviewing | `test-release-canary` | Dispatch and iterate on the Release Canary workflow that smoke-tests published artifacts | | Triage | `triage-issue` | Assess, classify, and route community-filed issues | | Platform | `generate-sandbox-policy` | Generate YAML sandbox policies from requirements or API docs | | Platform | `tui-development` | Development guide for the ratatui-based terminal UI | @@ -125,6 +126,25 @@ Project requirements: - Docker (running) - Z3 solver library (for the policy prover crate) +### macOS build tools + +Install Apple Command Line Tools before building locally: + +```bash +xcode-select --install +``` + +If Cargo fails while building `protobuf-src` with an error such as +`fatal error: 'utility' file not found`, `fatal error: 'cstdlib' file not +found`, or `A compiler with support for C++11 language features is required`, +your Command Line Tools install may not expose the libc++ headers on the +compiler's default include path. Reinstall Command Line Tools to correct the error: + +```bash +sudo rm -rf /Library/Developer/CommandLineTools +xcode-select --install +``` + ### Z3 installation The `openshell-prover` crate links against the system Z3 library via pkg-config. @@ -174,15 +194,16 @@ openshell sandbox create -- codex These are the primary `mise` tasks for day-to-day development: -| Task | Purpose | -| ------------------ | ------------------------------------------------------- | -| `mise run gateway` | Run a standalone gateway for local development | -| `mise run sandbox` | Create or reconnect to the dev sandbox | -| `mise run test` | Default test suite | -| `mise run e2e` | Default end-to-end test lane | -| `mise run ci` | Full local CI checks (lint, compile/type checks, tests) | -| `mise run docs` | Validate Fern docs locally | -| `mise run clean` | Clean build artifacts | +| Task | Purpose | +| -------------------- | ------------------------------------------------------- | +| `mise run gateway` | Run a standalone gateway for local development | +| `mise run sandbox` | Create or reconnect to the dev sandbox | +| `mise run test` | Default test suite | +| `mise run e2e` | Default end-to-end test lane | +| `mise run ci` | Full local CI checks (lint, compile/type checks, tests) | +| `mise run docs` | Validate Fern docs locally | +| `mise run helm:docs` | Regenerate the Helm chart README | +| `mise run clean` | Clean build artifacts | ## Project Structure @@ -281,4 +302,4 @@ DCO sign-off is separate from cryptographic commit signing. CI requires signing ## CI -How E2E runs in CI, the `test:e2e` / `test:e2e-gpu` labels, copy-pr-bot, and commit-signing setup are documented in [CI.md](CI.md). +How PR CI runs, the `test:e2e` / `test:e2e-gpu` labels, copy-pr-bot, and commit-signing setup are documented in [CI.md](CI.md). diff --git a/Cargo.lock b/Cargo.lock index a1c390646..08e1d052f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -168,6 +168,45 @@ dependencies = [ "password-hash", ] +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", + "synstructure", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -1209,6 +1248,20 @@ dependencies = [ "zeroize", ] +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "deranged" version = "0.5.8" @@ -3292,6 +3345,15 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs", +] + [[package]] name = "olpc-cjson" version = "0.1.4" @@ -3333,6 +3395,7 @@ dependencies = [ "rcgen", "serde", "serde_json", + "sha2 0.10.9", "tar", "tempfile", "tokio", @@ -3346,6 +3409,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "bytes", + "chrono", "clap", "clap_complete", "crossterm 0.28.1", @@ -3373,6 +3437,7 @@ dependencies = [ "rustls-pemfile", "serde", "serde_json", + "serde_yml", "tar", "temp-env", "tempfile", @@ -3414,6 +3479,7 @@ dependencies = [ "bytes", "futures", "openshell-core", + "serde", "tar", "tempfile", "tokio", @@ -3436,6 +3502,7 @@ dependencies = [ "openshell-core", "prost", "prost-types", + "serde", "serde_json", "thiserror 2.0.18", "tokio", @@ -3457,6 +3524,7 @@ dependencies = [ "miette", "nix", "openshell-core", + "rustix 1.1.4", "serde", "serde_json", "temp-env", @@ -3484,7 +3552,9 @@ dependencies = [ "openshell-core", "openshell-vfio", "polling", + "prost", "prost-types", + "rustix 1.1.4", "serde", "serde_json", "sha2 0.10.9", @@ -3517,6 +3587,7 @@ dependencies = [ "miette", "openshell-core", "serde", + "serde_json", "serde_yml", ] @@ -3571,6 +3642,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "flate2", "futures", "glob", "hex", @@ -3588,12 +3660,14 @@ dependencies = [ "rcgen", "regorus", "russh", + "rustix 1.1.4", "rustls", "rustls-pemfile", "seccompiler", "serde", "serde_json", "serde_yml", + "sha1 0.10.6", "sha2 0.10.9", "temp-env", "tempfile", @@ -3615,6 +3689,7 @@ name = "openshell-server" version = "0.0.0" dependencies = [ "anyhow", + "async-trait", "axum 0.8.9", "bytes", "clap", @@ -3630,9 +3705,12 @@ dependencies = [ "hyper-util", "ipnet", "jsonwebtoken 9.3.1", + "k8s-openapi", + "kube", "metrics", "metrics-exporter-prometheus", "miette", + "openshell-bootstrap", "openshell-core", "openshell-driver-docker", "openshell-driver-kubernetes", @@ -3649,6 +3727,7 @@ dependencies = [ "rcgen", "reqwest 0.12.28", "russh", + "rustix 1.1.4", "rustls", "rustls-pemfile", "serde", @@ -3661,13 +3740,16 @@ dependencies = [ "tokio-rustls", "tokio-stream", "tokio-tungstenite 0.26.2", + "toml", "tonic", "tower 0.5.3", "tower-http 0.6.8", "tracing", "tracing-subscriber", + "url", "uuid", "wiremock", + "x509-parser", ] [[package]] @@ -4535,6 +4617,7 @@ dependencies = [ "pin-project-lite", "quinn", "rustls", + "rustls-native-certs", "rustls-pki-types", "serde", "serde_json", @@ -4782,6 +4865,15 @@ dependencies = [ "semver", ] +[[package]] +name = "rusticata-macros" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" +dependencies = [ + "nom", +] + [[package]] name = "rustix" version = "0.38.44" @@ -4884,9 +4976,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.12" +version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8279bb85272c9f10811ae6a6c547ff594d6a7f3c6c6b02ee9726d1d0dcfcdd06" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ "aws-lc-rs", "ring", @@ -5122,6 +5214,15 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -6059,6 +6160,47 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap 2.14.0", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "tonic" version = "0.12.3" @@ -7148,6 +7290,15 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winnow" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + [[package]] name = "wiremock" version = "0.6.5" @@ -7271,6 +7422,23 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs", + "data-encoding", + "der-parser", + "lazy_static", + "nom", + "oid-registry", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "xattr" version = "1.6.1" diff --git a/Cargo.toml b/Cargo.toml index 9bc3f9ea2..079e1e172 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,15 +64,17 @@ metrics-exporter-prometheus = { version = "0.18", default-features = false, feat # Unix/Process nix = { version = "0.29", features = ["signal", "process", "user", "fs", "term"] } +rustix = { version = "1.1", features = ["process"] } # Serialization serde = { version = "1", features = ["derive"] } serde_json = "1" serde_yml = "0.0.12" +toml = "0.8" apollo-parser = "0.8.5" # HTTP client -reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls-native-roots"] } # WebSocket tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots"] } diff --git a/README.md b/README.md index 02447e421..e16a57190 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ The sandbox container includes the following tools by default: | Category | Tools | | ---------- | -------------------------------------------------------- | | Agent | `claude`, `opencode`, `codex`, `copilot` | -| Language | `python` (3.13), `node` (22) | +| Language | `python` (3.14), `node` (22) | | Developer | `gh`, `git`, `vim`, `nano` | | Networking | `ping`, `dig`, `nslookup`, `nc`, `traceroute`, `netstat` | @@ -153,8 +153,9 @@ Docker-backed GPU sandboxes auto-select CDI when available and otherwise fall ba | [OpenCode](https://opencode.ai/) | [`base`](https://github.com/NVIDIA/OpenShell-Community/tree/main/sandboxes/base) | Works out of the box. Provider uses `OPENAI_API_KEY` or `OPENROUTER_API_KEY`. | | [Codex](https://developers.openai.com/codex) | [`base`](https://github.com/NVIDIA/OpenShell-Community/tree/main/sandboxes/base) | Works out of the box. Provider uses `OPENAI_API_KEY`. | | [GitHub Copilot CLI](https://docs.github.com/en/copilot/github-copilot-in-the-cli) | [`base`](https://github.com/NVIDIA/OpenShell-Community/tree/main/sandboxes/base) | Works out of the box. Provider uses `GITHUB_TOKEN` or `COPILOT_GITHUB_TOKEN`. | -| [OpenClaw](https://openclaw.ai/) | [Community](https://github.com/NVIDIA/OpenShell-Community) | Launch with `openshell sandbox create --from openclaw`. | +| [OpenClaw](https://openclaw.ai/) | [NemoClaw](https://github.com/NVIDIA/NemoClaw) | Run OpenClaw more securely inside NVIDIA OpenShell with managed inference using NemoClaw. | | [Ollama](https://ollama.com/) | [Community](https://github.com/NVIDIA/OpenShell-Community) | Launch with `openshell sandbox create --from ollama`. | +| [Pi](https://pi.dev/) | [Community](https://github.com/NVIDIA/OpenShell-Community) | Launch with `openshell sandbox create --from pi`. | ## Key Commands @@ -191,7 +192,7 @@ The TUI gives you a live, keyboard-driven view of your gateway and sandboxes. Na Use `--from` to create sandboxes from the [OpenShell Community](https://github.com/NVIDIA/OpenShell-Community) catalog, a local directory, or a container image: ```bash -openshell sandbox create --from openclaw # community catalog +openshell sandbox create --from gemini # community catalog openshell sandbox create --from ./my-sandbox-dir # local Dockerfile openshell sandbox create --from registry.io/img:v1 # container image ``` @@ -234,6 +235,7 @@ All implementation work is human-gated — agents propose plans, humans approve, - [Quickstart](https://docs.nvidia.com/openshell/latest/get-started/quickstart) — detailed install and first sandbox walkthrough - [GitHub Sandbox Tutorial](https://docs.nvidia.com/openshell/latest/tutorials/github-sandbox) — end-to-end scoped GitHub repo access - [Architecture](https://github.com/NVIDIA/OpenShell/tree/main/architecture) — detailed architecture docs and design decisions +- [Roadmap](https://github.com/orgs/NVIDIA/projects/233) — planned work and project priorities - [Support Matrix](https://docs.nvidia.com/openshell/latest/reference/support-matrix) — platforms, versions, and kernel requirements - [Brev Launchable](https://brev.nvidia.com/launchable/deploy/now?launchableID=env-3Ap3tL55zq4a8kew1AuW0FpSLsg) — try OpenShell on cloud compute without local setup - [Agent Instructions](AGENTS.md) — system prompt and workflow documentation for agent contributors diff --git a/TESTING.md b/TESTING.md index 49c9b781a..7bcf2d203 100644 --- a/TESTING.md +++ b/TESTING.md @@ -151,6 +151,14 @@ Suites: - Docker suite (`--features e2e-docker`) - common suite plus Docker-only coverage such as Dockerfile image builds, Docker preflight checks, and managed Docker gateway resume. - Docker GPU suite (`--features e2e-docker-gpu`) - Docker suite plus GPU sandbox smoke coverage. +GPU device-selection tests compare OpenShell sandboxes against a plain Docker or +Podman container that requests `--device nvidia.com/gpu=all`. The probe image +defaults to the image used by the `gateway` stage in +`deploy/docker/Dockerfile.images`; set `OPENSHELL_E2E_GPU_PROBE_IMAGE` to +override it. Per-device checks run only for NVIDIA CDI device IDs reported by +the runtime's discovered devices list, so WSL2 hosts that expose only +`nvidia.com/gpu=all` skip the index-based cases. + Run the Docker-backed Rust CLI e2e suite: ```shell diff --git a/architecture/build.md b/architecture/build.md index baf44eba9..200be8b1e 100644 --- a/architecture/build.md +++ b/architecture/build.md @@ -12,35 +12,78 @@ OpenShell builds these main artifacts: |---|---| | Gateway binary | `crates/openshell-server` | | CLI package and Python SDK | `python/openshell` plus Rust binaries where packaged | -| Gateway container image | `deploy/docker/Dockerfile.images` | +| Gateway container image | `deploy/docker/Dockerfile.gateway` | +| Supervisor container image | `deploy/docker/Dockerfile.supervisor` | | Helm chart | `deploy/helm/openshell` | | VM driver/runtime assets | `crates/openshell-driver-vm` | | Published docs site | `docs/` rendered by Fern config in `fern/` | Sandbox community images are built outside this repository. +## Linux Runtime Environments + +OpenShell uses different Linux libc environments for different host artifacts. +The standalone `openshell` CLI is built as a static musl binary so it can run on +a wide range of Linux distributions without depending on the host's glibc. Host +runtime binaries that use the GNU/Linux runtime environment, including +`openshell-gateway` and `openshell-driver-vm`, are GNU-linked and built with a +glibc 2.31 floor. + ## Container Builds -The Docker image pipeline stages prebuilt Rust binaries, then builds container -images from `deploy/docker/Dockerfile.images`. CI builds native artifacts on the -target architecture, stages them under `deploy/docker/.build/`, and then uses -Buildx to publish per-architecture images and multi-architecture tags. +The Docker image pipeline is a two-step flow: build the Rust binary natively +for the target architecture, then assemble the container image from the +prebuilt binary. The gateway image is built from `deploy/docker/Dockerfile.gateway` +and the supervisor image from `deploy/docker/Dockerfile.supervisor`. Neither +Dockerfile compiles Rust — both copy a staged binary out of +`deploy/docker/.build/prebuilt-binaries//` into the final image. + +Binary staging is driven by `tasks/scripts/stage-prebuilt-binaries.sh`. Gateway +binaries use `cargo zigbuild` with GNU targets pinned to glibc 2.31, including +native-architecture builds, so the gateway image, standalone tarballs, and Linux +packages share the same host portability floor. Supervisor binaries remain +static musl. Local Docker image tasks infer the target architecture from +`DOCKER_PLATFORM` when set, otherwise from the container engine host metadata +with the kernel architecture as the fallback. CI invokes the same staging step +via the `rust-native-build.yml` workflow (per-architecture, per-component) and +uploads the result as an artifact that the image build job downloads back into +the staging directory before running Buildx. + +Runtime layout: + +- **Gateway**: `gcr.io/distroless/cc-debian13:nonroot` base, GNU-linked binary at + `/usr/local/bin/openshell-gateway`, runs as UID/GID `1000:1000`. Linux GNU + gateway and VM driver binaries must not reference `GLIBC_*` symbols newer than + `GLIBC_2.31`; release workflows verify this before publishing artifacts. +- **Supervisor**: `scratch` base, static musl binary at `/openshell-sandbox`. + Static linkage is required because the image is mounted/extracted into + sandbox environments (Docker extraction, Podman image volumes, Kubernetes + init-container copy-self) and cannot rely on a dynamic loader. + +Gateway image builds bake the corresponding supervisor image tag into the +gateway binary so Docker sandboxes do not depend on `:latest` by default. +Package formulas also pin Docker supervisor extraction to the matching release +image tag so standalone gateway binaries do not infer image tags from package +versions. +The Homebrew service keeps gateway TLS under the Homebrew state directory but +mirrors Docker sandbox client TLS into `$HOME/.local/state/openshell/homebrew/tls` +at service start, because Docker Desktop bind mounts must use paths visible to +the macOS user's shared home directory. Local image work should use `mise` tasks rather than direct Docker commands so the same staging and tagging assumptions are used locally and in CI. ## CI and E2E -Required checks run on GitHub Actions. E2E and GPU workflows use NVIDIA -self-hosted runners, so trusted PRs are mirrored by copy-pr-bot into -`pull-request/` branches before those workflows run. +Required checks run on GitHub Actions. Workflows that use NVIDIA self-hosted runners trigger from copy-pr-bot mirror branches, so trusted PRs are mirrored into `pull-request/` branches before those workflows run. The high-level CI model: -1. Standard branch checks run on normal PR activity. -2. Label-gated E2E and GPU checks run from trusted mirror branches. -3. Gate jobs verify that the expected non-gate workflow actually ran. -4. Release workflows rebuild and publish binaries, wheels, images, and docs. +1. PR-context gate jobs publish required statuses for the PR head commit. +2. Standard branch checks run from trusted mirror branches. +3. Label-gated E2E, GPU, and Kubernetes checks run from trusted mirror branches. +4. Gate jobs verify that the mirror branch matches the PR head and that the expected non-gate workflow actually ran. +5. Release workflows rebuild and publish binaries, wheels, images, and docs. See `CI.md` for the contributor workflow and labels. diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index 095b7d020..d79b7366d 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -16,6 +16,12 @@ Each runtime receives a sandbox spec from the gateway and is responsible for: - Reporting lifecycle and platform events back to the gateway. - Cleaning up runtime-owned resources. +Drivers own runtime-specific platform event interpretation. When an event should +drive client provisioning UI, the driver attaches the shared +`openshell.progress.*` metadata defined in `openshell-core` instead of requiring +clients to parse Kubernetes reasons, VM cache states, or other driver-local +reason strings. + ## Runtime Summary | Runtime | Best fit | Sandbox boundary | Notes | @@ -23,7 +29,18 @@ Each runtime receives a sandbox spec from the gateway and is responsible for: | Docker | Local development with Docker available. | Container plus nested sandbox namespace. | Uses host networking so loopback gateway endpoints work from the supervisor. | | Podman | Rootless or single-machine deployments. | Container plus nested sandbox namespace. | Uses the Podman REST API, OCI image volumes, and CDI GPU devices when available. | | Kubernetes | Cluster deployment through Helm. | Pod plus nested sandbox namespace. | Uses Kubernetes API objects, service accounts, secrets, PVC-backed workspace storage, and GPU resources. | -| VM | Experimental microVM isolation. | Per-sandbox libkrun VM. | Gateway spawns `openshell-driver-vm` as a subprocess over a Unix socket. | +| VM | Experimental microVM isolation. | Per-sandbox libkrun VM. | Gateway spawns `openshell-driver-vm` as a subprocess over a private, state-local Unix socket. The VM driver boots a cached bootstrap `rootfs.ext4`, prepares requested OCI images inside a bootstrap VM with `umoci`, attaches the prepared image disk read-only, and gives each sandbox a writable `overlay.ext4` for merged-root changes and runtime material. The driver persists each accepted launch request beside the overlay and restarts those VMs on driver startup without recreating the overlay. | + +Per-sandbox CPU and memory values currently enter the driver layer through +template resource limits. Docker and Podman apply them as runtime limits. +Kubernetes mirrors each limit into the matching request. VM accepts the fields +but currently ignores them. + +VM runtime state paths are derived only from driver-validated sandbox IDs +matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a +private `run/` directory plus Unix peer UID/PID checks. Standalone +unauthenticated TCP mode is disabled unless explicitly enabled for local +development. Runtime-specific implementation notes belong in the driver crate README: @@ -38,7 +55,7 @@ The supervisor must be available inside each sandbox workload: | Runtime | Delivery model | |---|---| -| Docker | Bind-mounted or extracted supervisor binary configured by the gateway. | +| Docker | Bind-mounted local supervisor binary, or a binary extracted from the configured supervisor image. | | Podman | Read-only OCI image volume containing the supervisor binary. | | Kubernetes | Sandbox pod image or pod template configuration. | | VM | Embedded in the guest rootfs bundle. | diff --git a/architecture/gateway.md b/architecture/gateway.md index f36878cf1..a7e8f1a00 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -9,11 +9,12 @@ workloads. - Authenticate clients and sandbox callbacks. - Serve gRPC APIs for sandbox lifecycle, provider management, policy updates, - settings, inference configuration, logs, and watch streams. -- Serve HTTP endpoints for health, SSH tunnel upgrades, and edge-auth flows. + settings, inference configuration, logs, watch streams, and relay forwarding. +- Serve HTTP endpoints for health, WebSocket tunnels, and edge-auth flows. - Persist domain objects in SQLite or Postgres. - Resolve provider credentials and inference bundles for sandbox supervisors. -- Coordinate supervisor relay sessions for connect, exec, and file sync. +- Coordinate supervisor relay sessions for connect, exec, file sync, and + service forwarding. The gateway does not enforce agent network policy at request time. That happens inside each sandbox, where the supervisor and proxy can observe local process @@ -22,21 +23,62 @@ identity. ## Protocol and Auth The gateway listens on one service port and multiplexes gRPC and HTTP traffic. -The default deployment mode is mTLS: clients and sandbox workloads present a -certificate signed by the deployment CA before reaching application handlers. +The default local single-user deployment mode is mTLS user authentication: +clients present a certificate signed by the local deployment CA, and the +gateway maps the verified certificate subject to a user principal. Kubernetes +deployments use mTLS for transport only and require OIDC or a trusted access +proxy for user authentication unless the explicit unsafe local-development +`allow_unauthenticated_users` switch is enabled. +When that service port is bound to loopback, the listener can also accept +plaintext HTTP on the same port for sandbox service subdomains only. That local +browser path is enabled by default and disabled with +`--enable-loopback-service-http=false`; it never serves gateway APIs, auth, +health, metrics, or tunnel routes. The plaintext service router also rejects +browser requests whose Fetch Metadata, Origin, or Referer headers indicate a +cross-origin or sibling-subdomain request. Supported auth modes: | Mode | Use | |---|---| -| mTLS | Default direct gateway access for CLI, SDK, TUI, and sandbox callbacks. | +| mTLS user auth | Local single-user Docker, Podman, and VM gateway access. | | Plaintext | Local development or a trusted reverse proxy boundary. | +| Unauthenticated local users | Trusted Kubernetes dev or fully trusted proxy deployments only. | | Cloudflare JWT | Edge-authenticated deployments where Cloudflare Access supplies identity. | | OIDC | Bearer-token auth for users, with browser PKCE or client credentials login. | -Sandbox supervisor RPCs authenticate with either mTLS material or a sandbox -secret depending on the runtime and deployment mode. User-facing mutations are -authorized by role policy when OIDC or edge identity is enabled. +Sandbox supervisor RPCs authenticate with gateway-minted sandbox JWTs when that +authenticator is configured; mTLS does not grant sandbox identity. User-facing +mutations are authorized by role policy when OIDC or edge identity is enabled. + +Sandbox secrets are gateway-signed JWTs bound to a single sandbox ID. Docker, +Podman, and VM drivers deliver the initial token through supervisor-only +runtime material; Kubernetes supervisors exchange a projected ServiceAccount +token through `IssueSandboxToken`. The gateway validates that projected token +with Kubernetes `TokenReview`, requires the configured sandbox service account, +checks the returned pod binding against the live pod UID, and verifies the pod's +controlling `Sandbox` ownerReference against the live Sandbox CR UID and +sandbox-id label before minting the gateway JWT. Supervisors renew gateway JWTs +in memory before expiry only while the sandbox record still exists. Older tokens +are not server-revoked; deployments bound replay exposure with short +`gateway_jwt.ttl_secs` lifetimes. + +Gateway JWT signing-key rotation is currently an offline operator action. The +runtime loads one active signing key and one matching public verification key +from the configured secret at startup. To rotate that key material today, +operators must delete or replace the JWT key secret, let certgen recreate it, +and restart the gateway pods. This invalidates outstanding supervisor tokens; +running supervisors recover by re-running their bootstrap path where available +or by reconnecting after sandbox restart. Online rotation with multiple +verification keys keyed by `kid` is tracked separately. + +Sandbox JWTs are not user credentials. The gRPC router accepts +`Principal::Sandbox` only on the supervisor-to-gateway RPC allowlist +(`ConnectSupervisor`, `RelayStream`, token renewal, config sync, policy status, +log push, and policy-analysis callbacks). Handlers then compare the +authenticated sandbox ID with any sandbox ID or name resolved from the request. +Supervisor control and relay streams require a matching sandbox principal before +the gateway registers the session or bridges relay bytes. ## API Surface @@ -44,7 +86,7 @@ The gateway API is organized around platform objects and operational streams: | Area | Examples | |---|---| -| Sandbox lifecycle | Create, list, delete, watch, exec, SSH session bootstrap. | +| Sandbox lifecycle | Create, list, delete, watch, exec, SSH session bootstrap, ForwardTcp service forwarding. | | Providers | Store provider records, discover credentials, resolve runtime environment. | | Policy and settings | Get effective sandbox config, update sandbox policy, manage global settings. | | Inference | Set gateway-level model/provider config and resolve sandbox route bundles. | @@ -75,6 +117,7 @@ The storage schema is intentionally narrow: | `version` | Optional monotonically increasing version for scoped records. | | `status` | Optional workflow state for records such as policy revisions or draft policy chunks. | | `dedup_key` and `hit_count` | Optional policy-advisor fields for coalescing repeated observations. | +| `resource_version` | Monotonically increasing counter for optimistic concurrency control. Incremented atomically on each update. | | `payload` | Prost-encoded protobuf payload for the full domain object. | | `created_at_ms` and `updated_at_ms` | Gateway timestamps used for ordering and list output. | | `labels` | JSON object carrying Kubernetes-style object labels for filtering and organization. | @@ -96,9 +139,101 @@ This keeps the gateway data model portable across storage backends and leaves room for future stores that can provide the same object, label, version, and scope semantics. +The SQLite adapter tightens the on-disk database file to mode `0o600` on every +connect so that provider API keys, SSH session tokens, and sandbox metadata are +not readable by other local users on shared hosts. The same restriction is +reapplied to the `-wal` and `-shm` sidecars (created by SQLite's +default WAL journal mode), which mirror the same sensitive contents. + Persisted state includes sandboxes, providers, SSH sessions, policy revisions, settings, inference configuration, and deployment records. +### Optimistic Concurrency (CAS) + +Every object row carries a `resource_version` that the database increments +atomically on each write. Concurrent mutations use compare-and-swap (CAS): the +writer reads the current version, applies changes, and writes back with a +`WHERE resource_version = ` guard. If another writer updated the row +in between, the guard fails and the caller receives a `Conflict` error. + +This matters for HA deployments where multiple gateway replicas share the same +Postgres database, and for single-node deployments where concurrent gRPC +handlers or the reconciler mutate the same sandbox. + +**Compile-time enforcement.** The unconditional write methods `put` and +`put_message` are gated behind `#[cfg(test)]`. Production code must use +`put_if` with an explicit `WriteCondition` or `update_message_cas`. The +compiler rejects any other write path, making non-CAS writes structurally +impossible outside of tests. + +Every write goes through one of three conditions: + +- `MustCreate` -- insert-only. The database rejects the write with a + `UniqueViolation` error if a row with that ID already exists. Handlers match + on the structured `PersistenceError::UniqueViolation { .. }` variant to + distinguish creation conflicts from other failures. +- `MatchResourceVersion(v)` -- update-only. The database rejects the write + with a `Conflict` error if the current version differs from `v`. +- `Unconditional` -- test-only; not reachable in production builds. + +**Creates.** All create paths use `MustCreate` and hydrate the response +directly from the `WriteResult` returned by `put_if`, which carries the +assigned `resource_version`, `created_at_ms`, and `updated_at_ms`. This +eliminates a read-after-write round trip and the race window that would come +with it. + +**Updates.** The `update_message_cas` helper makes a single CAS attempt: it +fetches the current object, applies a mutation closure, and writes with a +`MatchResourceVersion` condition. On conflict the persistence layer returns a +`Conflict` error, which gRPC handlers map to `ABORTED` status so the client +(or the next watch/reconcile event) can retry with fresh state. There is no +automatic retry loop. + +The helper accepts an `expected_version` parameter that selects between two +modes: + +- **Server-driven** (`expected_version = 0`): the helper uses the version it + just read from the database. Internal operations (reconciler, policy status + reports, compute phase transitions) use this mode because the caller does + not track versions. +- **Client-driven** (`expected_version != 0`): the helper validates that the + caller's version matches the current database version before applying the + mutation. If they diverge it returns `Conflict` without attempting the + write. Client-facing operations that carry an `expected_resource_version` + field use this mode: `AttachSandboxProvider`, `DetachSandboxProvider`, + `UpdateProvider`, and `UpdateConfig` (policy backfill path). + +**Lists.** The `list_messages` and `list_messages_with_selector` helpers decode +protobuf payloads from list results and hydrate `resource_version` from the +authoritative database column into each decoded message, mirroring the +`get_message` pattern. This ensures list responses carry correct versions +without requiring callers to manually hydrate each record. + +**Deletes.** Delete operations are not yet CAS-protected -- the delete request +protos do not carry `expected_resource_version`. A `delete_if` primitive exists +in the persistence layer but is not wired into gRPC handlers. + +**Coverage.** All `ObjectMeta`-bearing message types have write-condition +coverage: + +| Type | Create | Update | List | +|---|---|---|---| +| Sandbox | `MustCreate` | `update_message_cas` | `list_messages` | +| Provider | `MustCreate` | `update_message_cas` | `list_messages` | +| ProviderProfile | `MustCreate` | (immutable) | `list_messages` | +| InferenceRoute | `MustCreate` | `update_message_cas` | `list_messages` | +| SandboxPolicy | scoped versioning | scoped versioning | scoped query | +| Settings | `Mutex`-guarded | `Mutex`-guarded | single-row | + +Global settings updates use a Tokio `Mutex` to serialize multi-step +validation within a single gateway process, with CAS on the underlying +persistence write as defense in depth. In an HA deployment with multiple +gateways, the Mutex alone would be insufficient. Sandbox-scoped settings +rely entirely on CAS without a Mutex. + +The `resource_version` is surfaced to clients through `ObjectMeta` in proto +responses. Database migrations backfill existing rows with version 1. + Policy and runtime settings are delivered together through the effective sandbox config path. A gateway-global policy can override sandbox-scoped policy. The sandbox supervisor polls for config revisions and hot-reloads dynamic policy @@ -115,22 +250,107 @@ sequenceDiagram participant CLI participant GW as Gateway participant SUP as Sandbox supervisor - participant SSH as Sandbox SSH socket + participant Target as Sandbox target SUP->>GW: ConnectSupervisor stream - CLI->>GW: connect / exec / sync request - GW->>SUP: RelayOpen(channel) + CLI->>GW: ForwardTcp / exec / sync request + GW->>SUP: RelayOpen(channel, target) + SUP->>Target: Dial SSH socket or loopback service SUP->>GW: RelayStream(channel) - SUP->>SSH: Bridge bytes to Unix socket CLI->>GW: Client bytes GW-->>CLI: Client bytes GW->>SUP: Relay bytes SUP-->>GW: Relay bytes ``` -The same relay pattern backs interactive SSH, command execution, and file sync. -The gateway tracks live sessions in memory and persists session records so -tokens can expire or be revoked. +The same relay pattern backs interactive SSH, command execution, file sync, and +local service forwarding. The gateway tracks live sessions in memory and +persists session records so tokens can expire or be revoked. + +`ForwardTcp` is the client-facing byte stream for SSH and service forwarding. +The first frame is a `TcpForwardInit` that carries the sandbox ID, an +authorization token from `CreateSshSession`, and an explicit target: +`target.ssh` for the sandbox SSH socket or `target.tcp` for a loopback service +inside the sandbox. The gateway validates the token and sandbox readiness, +sends a targeted `RelayOpen` to the supervisor, then bridges +`TcpForwardFrame::Data` to `RelayFrame::Data` until either side closes. + +Browser service URLs use the same supervisor relay path after host-based +routing resolves `sandbox--service.` to a stored +service endpoint. Accepted service routing domains are derived from wildcard +DNS SANs configured on the gateway server certificate, with +`openshell.localhost` available by default for loopback gateways. TLS-enabled +loopback gateways print `http://` URLs when loopback plaintext service HTTP is +enabled; non-loopback TLS gateways continue to print `https://` URLs. + +For `target.tcp`, the gateway only accepts loopback destinations such as +`localhost`, `127.0.0.0/8`, or `::1`. The gateway never needs to know or dial a +sandbox pod IP; supervisors connect outbound and bridge only the explicit target +requested for that relay. + +## PKI Bootstrap + +`openshell-gateway generate-certs` is the one place mTLS materials are +created. Both deployment paths use it: + +| Output mode | Selector | Layout | +|---|---|---| +| Kubernetes Secrets | (default) `--namespace`, `--server-secret-name`, `--client-secret-name` | Two `kubernetes.io/tls` Secrets with `tls.crt` / `tls.key` / `ca.crt`. | +| Filesystem | `--output-dir ` | `/{ca.crt, ca.key, server/tls.{crt,key}, client/tls.{crt,key}}`. Also copies client materials to `$XDG_CONFIG_HOME/openshell/gateways/openshell/mtls/` for CLI auto-discovery. | + +On Kubernetes, the Helm chart runs the command via a pre-install/pre-upgrade +hook Job using the gateway image itself -- no separate cert-generation image, +no extra mirror burden in air-gapped environments. On the RPM gateway, the +same command runs from the systemd unit's `ExecStartPre` to bootstrap PKI +into the user's state directory on first start. + +Both modes share the same idempotency contract: all targets present -> skip; +partial state -> fail with a recovery hint; nothing present -> generate and +write. This guards mTLS continuity across restarts and upgrades while still +recovering cleanly if an operator deletes everything and starts over. + +Operators who manage PKI externally (cert-manager, an enterprise CA, or +pre-created Secrets) disable the Helm hook via `pkiInitJob.enabled=false`. +The chart also ships a `certManager.*` path that produces equivalent Secrets +through cert-manager `Issuer`/`Certificate` resources. + +## Configuration + +The gateway reads its configuration from three sources, merged in this +precedence (highest first): + +``` +Gateway CLI flag > gateway OPENSHELL_* env var > TOML file > built-in default +``` + +The TOML file is opt-in via `--config ` / `OPENSHELL_GATEWAY_CONFIG`. +Driver implementation settings live in the TOML driver tables. See +`docs/reference/gateway-config.mdx` for worked per-driver examples and RFC +0003 for the full schema. + +`database_url` is env-only and rejected when present in the file +(`OPENSHELL_DB_URL` / `--db-url`). + +### Driver inheritance + +`[openshell.gateway]` carries a small set of values (`sandbox_namespace`, +`default_image`, +`supervisor_image`, `guest_tls_ca/cert/key`, `client_tls_secret_name`, +`host_gateway_ip`, `enable_user_namespaces`) that are inherited into each +driver's `[openshell.drivers.]` table when the driver-specific table +does not override them. The allowlist is per-driver so a gateway-wide +default cannot land in a driver that does not understand it (e.g. +`client_tls_secret_name` is K8s-only). + +`image_pull_policy` is intentionally **not** inheritable: Kubernetes uses +`Always | IfNotPresent | Never` (passed verbatim to the K8s API) while +Podman uses the lowercase enum `always | missing | never | newer`. No +value means the same thing in both, so the key lives only under each +driver's own table. + +Driver-specific values that are not part of the inheritance allowlist +(e.g. Podman `socket_path`, VM `vcpus`) only come from the driver's own +table. ## Operational Constraints @@ -138,6 +358,9 @@ tokens can expire or be revoked. by the operator or packaging layer. - Compute runtimes own the mechanics of starting workloads and injecting callback configuration. +- Docker-backed local gateways use Docker's `host-gateway` callback alias on + macOS and Docker Desktop-style runtimes. Native Linux Docker may expose an + additional bridge-gateway listener because the host can bind that bridge IP. - Gateway restarts recover persisted objects from storage, but live relay streams must be re-established by supervisors. - User-facing behavior changes must update published docs in `docs/`; this file diff --git a/architecture/sandbox.md b/architecture/sandbox.md index 71dd35227..4bc6803eb 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -89,6 +89,21 @@ Sandbox logs are emitted locally and can also be pushed back to the gateway. Security-relevant sandbox behavior uses OCSF structured events; internal diagnostics use ordinary tracing. +## Policy Proposals + +When an L4 CONNECT is denied, the proxy emits a `DenialEvent`. The denial +aggregator batches these events and flushes summaries to the gateway every 10 +seconds (configurable via `OPENSHELL_DENIAL_FLUSH_INTERVAL_SECS`). The gateway +runs them through the mechanistic mapper, which generates a pending +`NetworkPolicyRule` proposal visible under `openshell rule get --status pending`. + +L7 denials (HTTP 403 from method/path rules) are intentionally excluded from +mechanistic mapping. L4 denials carry only `host:port`, which a deterministic mapper can handle. +L7 denials carry method, path, query, and body context. The agent loop reads +the structured 403 and authors the narrowest rule. Mechanistically mapping L7 +would either over-broaden rules or require path-templating logic that rots +quickly. + ## Failure Behavior - If gateway config polling fails, the sandbox keeps its last-known-good policy. diff --git a/architecture/security-policy.md b/architecture/security-policy.md index e5f179dc1..bc7b0c7a8 100644 --- a/architecture/security-policy.md +++ b/architecture/security-policy.md @@ -36,6 +36,27 @@ Ordinary network traffic follows this order: Explicit deny and hardening checks win over allow rules. If no rule matches, the request is denied. +## Host Wildcards + +Network endpoint `host` patterns accept a `*` wildcard inside the first DNS +label only. The OPA runtime matches with a `.` label boundary, so a wildcard +never spans dots. The validator enforces the same boundary so that policy load +fails fast instead of silently mismatching at the proxy. + +| Pattern | Accepted | Example match | Notes | +|---|---|---|---| +| `*.example.com` | Yes | `api.example.com` | Single first label of any value. | +| `**.example.com` | Yes | `a.b.example.com` | Recursive wildcard as the entire first label. | +| `*-aiplatform.googleapis.com` | Yes | `us-central1-aiplatform.googleapis.com` | Intra-label wildcard inside the first DNS label. | +| `*` or `**` | No | — | Matches every host. | +| `*.com`, `**.com` | No | — | TLD wildcards (`labels <= 2`). | +| `foo.*.example.com` | No | — | Wildcard outside the first DNS label. | +| `foo**.example.com` | No | — | Recursive `**` mixed inside a label; allowed only as the entire first label. | + +Validation rejects the disallowed patterns at policy load time with a message +that names the offending host. Exact hosts and IP addresses do not use this +path. + ## TLS and L7 Inspection For HTTP endpoints that need request-level controls, the proxy can terminate TLS @@ -43,9 +64,13 @@ with the sandbox's ephemeral CA and inspect method/path or protocol-specific metadata before forwarding. The proxy also supports credential injection on terminated HTTP streams when policy allows the endpoint. -Raw streams, HTTP upgrades, and long-lived response bodies are connection -scoped. Policy reloads affect the next connection or the next parsed HTTP -request; they do not rewrite bytes already being relayed. +Raw streams and long-lived response bodies are connection scoped. Policy +reloads affect the next connection or the next parsed HTTP request; they do not +rewrite bytes already being relayed. HTTP upgrades switch to raw relay by +default. A `protocol: rest` endpoint can opt in to +`websocket_credential_rewrite` for client-to-server WebSocket text messages +after an allowed `101` upgrade; server-to-client traffic and all other upgraded +protocols remain raw passthrough. ## Live Updates @@ -72,6 +97,11 @@ recommendations: 4. A human or admin workflow approves or rejects drafts. 5. Approved drafts merge into the target sandbox policy. +Proposals intentionally omit `allowed_ips`. If a proposed rule targets a host +that resolves to a private IP, the proxy's runtime SSRF classification blocks +the connection. The operator must then add an explicit `allowed_ips` entry to +permit it — a two-step flow that keeps SSRF protection on by default. + The advisor should propose narrow additions and preserve explicit-deny behavior. It is a workflow aid, not an automatic permission grant. diff --git a/crates/openshell-bootstrap/Cargo.toml b/crates/openshell-bootstrap/Cargo.toml index c0fb7e9f4..578d59e65 100644 --- a/crates/openshell-bootstrap/Cargo.toml +++ b/crates/openshell-bootstrap/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } futures = { workspace = true } miette = { workspace = true } rcgen = { workspace = true } +sha2 = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tar = "0.4" diff --git a/crates/openshell-bootstrap/src/jwt.rs b/crates/openshell-bootstrap/src/jwt.rs new file mode 100644 index 000000000..cf8ab0dc1 --- /dev/null +++ b/crates/openshell-bootstrap/src/jwt.rs @@ -0,0 +1,112 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Gateway-minted JWT signing-key generation. +//! +//! The gateway mints per-sandbox identity tokens (see PR 2 of the +//! per-sandbox identity series, issue #1354) signed with an Ed25519 +//! keypair generated once at gateway init and persisted alongside the +//! existing PKI bundle. The signing key never leaves the gateway; the +//! public key plus a stable `kid` are consumed by the gateway's own +//! validator and any future external verifiers. + +use miette::{IntoDiagnostic, Result, WrapErr}; +use rcgen::{KeyPair, PKCS_ED25519}; +use sha2::{Digest, Sha256}; + +/// All PEM-encoded material needed to mint and validate sandbox JWTs. +/// +/// The signing key stays in the gateway process. The public key is shared +/// across gateway replicas (so any replica can validate a JWT minted by +/// any other replica). The `kid` is published in every minted JWT's +/// header so the validator can pick the right key after a future rotation. +pub struct JwtKeyMaterial { + /// PKCS#8 PEM-encoded Ed25519 private key. + pub signing_key_pem: String, + /// `SubjectPublicKeyInfo` PEM-encoded Ed25519 public key. + pub public_key_pem: String, + /// Stable identifier derived from the public key (SHA-256 hex prefix). + /// Embedded in every minted JWT's `kid` header so future rotation can + /// be performed in-place by adding a second key without breaking + /// in-flight tokens. + pub kid: String, +} + +/// Generate a fresh Ed25519 JWT signing key. +/// +/// Output PEM is in the formats `jsonwebtoken` consumes via +/// `EncodingKey::from_ed_pem` (signing) and `DecodingKey::from_ed_pem` +/// (validation), so the gateway can round-trip its own tokens with no +/// further conversion. +pub fn generate_jwt_key() -> Result { + let keypair = KeyPair::generate_for(&PKCS_ED25519) + .into_diagnostic() + .wrap_err("failed to generate Ed25519 JWT signing key")?; + let signing_key_pem = keypair.serialize_pem(); + let public_key_pem = keypair.public_key_pem(); + let kid = kid_from_public_key_der(&keypair.public_key_der()); + Ok(JwtKeyMaterial { + signing_key_pem, + public_key_pem, + kid, + }) +} + +/// Stable `kid` derived from the SHA-256 of the public-key DER. +/// +/// First 16 bytes hex-encoded — collision-resistant for the small N of +/// signing keys a single deployment ever has, while staying short enough +/// to keep JWT headers compact. +fn kid_from_public_key_der(public_key_der: &[u8]) -> String { + let digest = Sha256::digest(public_key_der); + hex_encode_prefix(&digest, 16) +} + +fn hex_encode_prefix(bytes: &[u8], n: usize) -> String { + use std::fmt::Write as _; + let mut out = String::with_capacity(n * 2); + for byte in bytes.iter().take(n) { + let _ = write!(out, "{byte:02x}"); + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generate_jwt_key_produces_parseable_pem() { + let material = generate_jwt_key().expect("generate_jwt_key"); + assert!(material.signing_key_pem.contains("BEGIN PRIVATE KEY")); + assert!(material.public_key_pem.contains("BEGIN PUBLIC KEY")); + assert_eq!(material.kid.len(), 32, "kid is 16 bytes hex-encoded"); + assert!(material.kid.chars().all(|c| c.is_ascii_hexdigit())); + } + + #[test] + fn kid_is_stable_for_identical_public_keys() { + // Same input -> same kid. Hash of a fixed byte string. + let kid_a = kid_from_public_key_der(b"abc"); + let kid_b = kid_from_public_key_der(b"abc"); + assert_eq!(kid_a, kid_b); + } + + #[test] + fn kid_differs_for_different_public_keys() { + let kid_a = kid_from_public_key_der(b"first"); + let kid_b = kid_from_public_key_der(b"second"); + assert_ne!(kid_a, kid_b); + } + + #[test] + fn generated_keys_are_unique() { + let a = generate_jwt_key().expect("generate_jwt_key"); + let b = generate_jwt_key().expect("generate_jwt_key"); + assert_ne!( + a.kid, b.kid, + "fresh keypairs must produce distinct public keys" + ); + assert_ne!(a.signing_key_pem, b.signing_key_pem); + } +} diff --git a/crates/openshell-bootstrap/src/lib.rs b/crates/openshell-bootstrap/src/lib.rs index 0988c4b6b..8845f0392 100644 --- a/crates/openshell-bootstrap/src/lib.rs +++ b/crates/openshell-bootstrap/src/lib.rs @@ -3,6 +3,7 @@ pub mod build; pub mod edge_token; +pub mod jwt; pub mod oidc_token; mod metadata; diff --git a/crates/openshell-bootstrap/src/metadata.rs b/crates/openshell-bootstrap/src/metadata.rs index abe51335e..108a99b8a 100644 --- a/crates/openshell-bootstrap/src/metadata.rs +++ b/crates/openshell-bootstrap/src/metadata.rs @@ -25,7 +25,7 @@ pub struct GatewayMetadata { #[serde(skip_serializing_if = "Option::is_none", default)] pub resolved_host: Option, - /// Auth mode: `None` or `"mtls"` = mTLS (default), `"plaintext"` = direct HTTP, + /// Auth mode: `None` or `"mtls"` = mTLS, `"plaintext"` = direct HTTP, /// `"cloudflare_jwt"` = CF JWT. #[serde(default, skip_serializing_if = "Option::is_none")] pub auth_mode: Option, diff --git a/crates/openshell-bootstrap/src/pki.rs b/crates/openshell-bootstrap/src/pki.rs index f3a30211e..adc2c48f1 100644 --- a/crates/openshell-bootstrap/src/pki.rs +++ b/crates/openshell-bootstrap/src/pki.rs @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use crate::jwt::{JwtKeyMaterial, generate_jwt_key}; use miette::{IntoDiagnostic, Result, WrapErr}; use rcgen::{BasicConstraints, CertificateParams, DnType, Ia5String, IsCa, KeyPair, SanType}; use std::net::IpAddr; @@ -15,15 +16,28 @@ pub struct PkiBundle { pub server_key_pem: String, pub client_cert_pem: String, pub client_key_pem: String, + /// PKCS#8 PEM Ed25519 private key for minting per-sandbox JWTs. + pub jwt_signing_key_pem: String, + /// SPKI PEM Ed25519 public key, paired with `jwt_signing_key_pem`. + pub jwt_public_key_pem: String, + /// Stable identifier embedded in the `kid` header of every minted JWT. + pub jwt_key_id: String, } /// Default SANs always included on the server certificate. -const DEFAULT_SERVER_SANS: &[&str] = &[ +/// +/// Covers the host aliases used by every supported runtime: Kubernetes service DNS, +/// `host.docker.internal` for Docker Desktop and rootless Docker on Linux, +/// and `host.containers.internal` for Podman containers reaching their host. +pub const DEFAULT_SERVER_SANS: &[&str] = &[ "openshell", "openshell.openshell.svc", "openshell.openshell.svc.cluster.local", "localhost", + "openshell.localhost", + "*.openshell.localhost", "host.docker.internal", + "host.containers.internal", "127.0.0.1", ]; @@ -84,12 +98,22 @@ pub fn generate_pki(extra_sans: &[String]) -> Result { client_params .distinguished_name .push(DnType::CommonName, "openshell-client"); + client_params + .distinguished_name + .push(DnType::OrganizationalUnitName, "openshell-user"); let client_cert = client_params .signed_by(&client_key, &ca_cert, &ca_key) .into_diagnostic() .wrap_err("failed to sign client certificate")?; + // --- JWT signing key (Ed25519, used to mint per-sandbox identity tokens) --- + let JwtKeyMaterial { + signing_key_pem: jwt_signing_key_pem, + public_key_pem: jwt_public_key_pem, + kid: jwt_key_id, + } = generate_jwt_key().wrap_err("failed to generate JWT signing key")?; + Ok(PkiBundle { ca_cert_pem: ca_cert.pem(), ca_key_pem: ca_key.serialize_pem(), @@ -97,6 +121,9 @@ pub fn generate_pki(extra_sans: &[String]) -> Result { server_key_pem: server_key.serialize_pem(), client_cert_pem: client_cert.pem(), client_key_pem: client_key.serialize_pem(), + jwt_signing_key_pem, + jwt_public_key_pem, + jwt_key_id, }) } @@ -139,6 +166,9 @@ mod tests { assert!(bundle.server_key_pem.contains("BEGIN PRIVATE KEY")); assert!(bundle.client_cert_pem.contains("BEGIN CERTIFICATE")); assert!(bundle.client_key_pem.contains("BEGIN PRIVATE KEY")); + assert!(bundle.jwt_signing_key_pem.contains("BEGIN PRIVATE KEY")); + assert!(bundle.jwt_public_key_pem.contains("BEGIN PUBLIC KEY")); + assert_eq!(bundle.jwt_key_id.len(), 32, "kid is 16 bytes hex-encoded"); } #[test] @@ -155,4 +185,10 @@ mod tests { // Should have all default SANs + 2 extras assert_eq!(sans.len(), DEFAULT_SERVER_SANS.len() + 2); } + + #[test] + fn default_server_sans_include_local_container_hostnames() { + assert!(DEFAULT_SERVER_SANS.contains(&"host.docker.internal")); + assert!(DEFAULT_SERVER_SANS.contains(&"host.containers.internal")); + } } diff --git a/crates/openshell-cli/Cargo.toml b/crates/openshell-cli/Cargo.toml index 8b86544b7..b69a9629b 100644 --- a/crates/openshell-cli/Cargo.toml +++ b/crates/openshell-cli/Cargo.toml @@ -23,6 +23,7 @@ openshell-prover = { path = "../openshell-prover" } openshell-tui = { path = "../openshell-tui" } serde = { workspace = true } serde_json = { workspace = true } +serde_yml = { workspace = true } prost-types = { workspace = true } # Async runtime @@ -32,6 +33,7 @@ tokio = { workspace = true } tonic = { workspace = true, features = ["tls", "tls-native-roots"] } # CLI +chrono = "0.4" clap = { workspace = true } clap_complete = { workspace = true } crossterm = { workspace = true } @@ -58,6 +60,7 @@ anyhow = { workspace = true } # File archiving (tar-over-SSH sync) tar = "0.4" +tempfile = "3" # OIDC/Auth oauth2 = "5" @@ -68,6 +71,7 @@ tokio-tungstenite = { workspace = true } # Streams futures = { workspace = true } +tokio-stream = { workspace = true } nix = { workspace = true } # URL parsing @@ -90,6 +94,5 @@ rcgen = { version = "0.13", features = ["crypto", "pem"] } reqwest = { workspace = true } serde_json = { workspace = true } temp-env = "0.3" -tempfile = "3" tokio-stream = { workspace = true } url = { workspace = true } diff --git a/crates/openshell-cli/src/completers.rs b/crates/openshell-cli/src/completers.rs index d5d9a0a88..a421b418a 100644 --- a/crates/openshell-cli/src/completers.rs +++ b/crates/openshell-cli/src/completers.rs @@ -3,16 +3,20 @@ use std::ffi::OsStr; use std::future::Future; -use std::time::Duration; use clap_complete::engine::CompletionCandidate; +use openshell_bootstrap::edge_token::load_edge_token; +use openshell_bootstrap::oidc_token::{is_token_expired, load_oidc_token, store_oidc_token}; use openshell_bootstrap::{list_gateways, load_active_gateway, load_gateway_metadata}; use openshell_core::ObjectName; +use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::proto::open_shell_client::OpenShellClient; use openshell_core::proto::{ListProvidersRequest, ListSandboxesRequest}; -use tonic::transport::{Channel, Endpoint}; +use tonic::service::interceptor::InterceptedService; +use tonic::transport::Channel; -use crate::tls::{TlsOptions, build_tonic_tls_config, require_tls_materials}; +use crate::oidc_auth::oidc_refresh_token; +use crate::tls::{TlsOptions, build_channel}; /// Complete gateway names from local metadata files (no network call). pub fn complete_gateway_names(_prefix: &OsStr) -> Vec { @@ -84,17 +88,46 @@ fn resolve_active_gateway() -> Option<(String, String)> { async fn completion_grpc_client( server: &str, gateway_name: &str, -) -> Option> { - let tls_opts = TlsOptions::default().with_gateway_name(gateway_name); - let materials = require_tls_materials(server, &tls_opts).ok()?; - let tls_config = build_tonic_tls_config(&materials); - let endpoint = Endpoint::from_shared(server.to_string()) - .ok()? - .connect_timeout(Duration::from_secs(2)) - .tls_config(tls_config) - .ok()?; - let channel = endpoint.connect().await.ok()?; - Some(OpenShellClient::new(channel)) +) -> Option>> { + let mut tls_opts = TlsOptions::default().with_gateway_name(gateway_name); + tls_opts.gateway_insecure = std::env::var("OPENSHELL_GATEWAY_INSECURE") + .is_ok_and(|v| !v.is_empty() && v != "0" && v != "false"); + + if let Ok(meta) = load_gateway_metadata(gateway_name) { + match meta.auth_mode.as_deref() { + Some("oidc") => { + if let Some(bundle) = load_oidc_token(gateway_name) { + if is_token_expired(&bundle) { + match oidc_refresh_token(&bundle, tls_opts.gateway_insecure).await { + Ok(refreshed) => { + let _ = store_oidc_token(gateway_name, &refreshed); + tls_opts.oidc_token = Some(refreshed.access_token); + } + Err(_) => { + tls_opts.oidc_token = Some(bundle.access_token); + } + } + } else { + tls_opts.oidc_token = Some(bundle.access_token); + } + } + } + Some("cloudflare_jwt") => { + if let Some(token) = load_edge_token(gateway_name) { + tls_opts.edge_token = Some(token); + } + } + _ => {} + } + } + + let channel = build_channel(server, &tls_opts).await.ok()?; + let interceptor = EdgeAuthInterceptor::new( + tls_opts.oidc_token.as_deref(), + tls_opts.edge_token.as_deref(), + ) + .ok()?; + Some(OpenShellClient::with_interceptor(channel, interceptor)) } /// Run an async future on a dedicated thread to avoid nested tokio runtime panics. diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index cd14568ef..25c677986 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -8,6 +8,7 @@ use clap_complete::engine::ArgValueCompleter; use clap_complete::env::CompleteEnv; use miette::Result; use owo_colors::OwoColorize; +use std::collections::HashMap; use std::io::Write; use std::path::PathBuf; @@ -121,9 +122,9 @@ fn resolve_gateway_name(gateway_flag: &Option) -> Option { /// Apply authentication token from local storage based on gateway auth mode. /// -/// Handles both Cloudflare Access (`edge_token`) and OIDC (`oidc_token`) -/// auth modes by loading the stored token and setting it on `TlsOptions`. -/// For OIDC, automatically refreshes the token if it's near expiry. +/// Handles Cloudflare Access and OIDC auth modes by loading the stored token +/// and setting it on `TlsOptions`. For OIDC, automatically refreshes the token +/// if it's near expiry. fn apply_auth(tls: &mut TlsOptions, gateway_name: &str) { let Some(meta) = get_gateway_metadata(gateway_name) else { return; @@ -140,11 +141,14 @@ fn apply_auth(tls: &mut TlsOptions, gateway_name: &str) { return; }; if openshell_bootstrap::oidc_token::is_token_expired(&bundle) { + let insecure = std::env::var("OPENSHELL_GATEWAY_INSECURE") + .is_ok_and(|v| !v.is_empty() && v != "0" && v != "false"); // Try to refresh the token in-place using block_in_place // so the async refresh can run within the sync apply_auth call. match tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(openshell_cli::oidc_auth::oidc_refresh_token(&bundle)) + tokio::runtime::Handle::current().block_on( + openshell_cli::oidc_auth::oidc_refresh_token(&bundle, insecure), + ) }) { Ok(refreshed) => { let _ = openshell_bootstrap::oidc_token::store_oidc_token( @@ -180,7 +184,7 @@ fn resolve_sandbox_name(name: Option, gateway: &str) -> Result { let last = load_last_sandbox(gateway).ok_or_else(|| { miette::miette!( "No sandbox name provided and no last-used sandbox.\n\ - Specify a sandbox name or connect to one first: nav sandbox connect " + Specify a sandbox name or connect to one first: openshell sandbox connect " ) })?; eprintln!("{} Using sandbox '{}' (last used)", "→".bold(), last.bold()); @@ -197,6 +201,7 @@ const HELP_TEMPLATE: &str = "\ \x1b[1mSANDBOX COMMANDS\x1b[0m sandbox: Manage sandboxes + service: Expose sandbox services forward: Manage port forwarding to a sandbox logs: View sandbox logs policy: Manage sandbox policy @@ -266,10 +271,23 @@ const FORWARD_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m \x1b[1mEXAMPLES\x1b[0m $ openshell forward start 8080 $ openshell forward start 3000 my-sandbox + $ openshell forward service my-sandbox --target-port 8000 --local 8000 $ openshell forward stop 8080 $ openshell forward list "; +const SERVICE_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m + svc + +\x1b[1mEXAMPLES\x1b[0m + $ openshell service expose my-sandbox 8080 + $ openshell service expose my-sandbox 8080 web + $ openshell service list + $ openshell service list my-sandbox + $ openshell service get my-sandbox web + $ openshell service delete my-sandbox web +"; + const LOGS_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m lg @@ -287,6 +305,7 @@ const POLICY_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m $ openshell policy get my-sandbox $ openshell policy set my-sandbox --policy policy.yaml $ openshell policy update my-sandbox --add-endpoint api.github.com:443:read-only:rest:enforce + $ openshell policy update my-sandbox --add-endpoint realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite,allowed-ip=10.0.0.0/8 $ openshell policy update my-sandbox --add-allow 'api.github.com:443:GET:/repos/**' $ openshell policy set --global --policy policy.yaml $ openshell policy delete --global @@ -407,6 +426,13 @@ enum Commands { command: Option, }, + /// Manage sandbox services. + #[command(alias = "svc", after_help = SERVICE_EXAMPLES, help_template = SUBCOMMAND_HELP_TEMPLATE)] + Service { + #[command(subcommand)] + command: Option, + }, + /// View sandbox logs. #[command(alias = "lg", after_help = LOGS_EXAMPLES, help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] Logs { @@ -624,13 +650,30 @@ fn normalize_completion_script(output: Vec, executable: &std::path::Path) -> } #[derive(Clone, Debug, ValueEnum)] -enum ProviderProfileOutput { +enum OutputFormat { Table, Yaml, Json, } -impl ProviderProfileOutput { +#[derive(Clone, Debug, ValueEnum)] +enum CliProviderRefreshStrategy { + Oauth2RefreshToken, + Oauth2ClientCredentials, + GoogleServiceAccountJwt, +} + +impl CliProviderRefreshStrategy { + fn as_str(&self) -> &'static str { + match self { + Self::Oauth2RefreshToken => "oauth2_refresh_token", + Self::Oauth2ClientCredentials => "oauth2_client_credentials", + Self::GoogleServiceAccountJwt => "google_service_account_jwt", + } + } +} + +impl OutputFormat { fn as_str(&self) -> &'static str { match self { Self::Table => "table", @@ -640,6 +683,21 @@ impl ProviderProfileOutput { } } +#[derive(Clone, Debug, ValueEnum)] +enum PolicyGetOutput { + Table, + Json, +} + +impl PolicyGetOutput { + fn as_str(&self) -> &'static str { + match self { + Self::Table => "table", + Self::Json => "json", + } + } +} + #[derive(Clone, Debug, ValueEnum)] enum CliEditor { Vscode, @@ -685,6 +743,10 @@ enum ProviderCommands { config: Vec, }, + /// Manage provider credential refresh. + #[command(subcommand, help_template = SUBCOMMAND_HELP_TEMPLATE)] + Refresh(ProviderRefreshCommands), + /// Fetch a provider by name. #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] Get { @@ -713,8 +775,8 @@ enum ProviderCommands { #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] ListProfiles { /// Output format. - #[arg(short = 'o', long = "output", value_enum, default_value_t = ProviderProfileOutput::Table)] - output: ProviderProfileOutput, + #[arg(short = 'o', long = "output", value_enum, default_value_t = OutputFormat::Table)] + output: OutputFormat, }, /// Manage provider profiles. @@ -743,6 +805,10 @@ enum ProviderCommands { /// Provider config key/value pair. #[arg(long = "config", value_name = "KEY=VALUE")] config: Vec, + + /// Credential expiry (`KEY=TIMESTAMP`). Accepts epoch milliseconds or RFC3339. A zero timestamp clears expiry. + #[arg(long = "credential-expires-at", value_name = "KEY=TIMESTAMP")] + credential_expires_at: Vec, }, /// Delete providers by name. @@ -754,6 +820,77 @@ enum ProviderCommands { }, } +#[derive(Subcommand, Debug)] +enum ProviderRefreshCommands { + /// Show provider credential refresh status. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Status { + /// Provider name. + #[arg(add = ArgValueCompleter::new(completers::complete_provider_names))] + name: String, + + /// Optional credential key to filter by. + #[arg(long = "credential-key")] + credential_key: Option, + }, + + /// Configure refresh metadata for a provider credential. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Configure { + /// Provider name. + #[arg(add = ArgValueCompleter::new(completers::complete_provider_names))] + name: String, + + /// Injectable credential key, for example `MS_GRAPH_ACCESS_TOKEN`. + #[arg(long = "credential-key")] + credential_key: String, + + /// Refresh strategy. + #[arg(long, value_enum)] + strategy: CliProviderRefreshStrategy, + + /// Non-injectable refresh material (`KEY=VALUE`). + #[arg(long = "material", value_name = "KEY=VALUE")] + material: Vec, + + /// Material keys that are secret and must not be exposed. + #[arg(long = "secret-material-key", value_name = "KEY")] + secret_material_keys: Vec, + + /// Expiry for the current credential. Accepts epoch milliseconds or RFC3339. + #[arg( + long = "credential-expires-at", + value_name = "TIMESTAMP", + value_parser = run::parse_credential_expiry_cli_value + )] + credential_expires_at: Option, + }, + + /// Record a gateway-owned credential rotation request. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Rotate { + /// Provider name. + #[arg(add = ArgValueCompleter::new(completers::complete_provider_names))] + name: String, + + /// Injectable credential key, for example `MS_GRAPH_ACCESS_TOKEN`. + #[arg(long = "credential-key")] + credential_key: String, + }, + + /// Delete refresh metadata for a provider credential. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Delete { + /// Provider name. + #[arg(add = ArgValueCompleter::new(completers::complete_provider_names))] + name: String, + + /// Injectable credential key, for example `MS_GRAPH_ACCESS_TOKEN`. + #[arg(long = "credential-key")] + credential_key: String, + }, +} + #[derive(Subcommand, Debug)] enum ProviderProfileCommands { /// Export a provider profile. @@ -763,8 +900,8 @@ enum ProviderProfileCommands { id: String, /// Output format. - #[arg(short = 'o', long = "output", value_enum, default_value_t = ProviderProfileOutput::Yaml)] - output: ProviderProfileOutput, + #[arg(short = 'o', long = "output", value_enum, default_value_t = OutputFormat::Yaml)] + output: OutputFormat, }, /// Import provider profiles from a file or directory. @@ -920,7 +1057,11 @@ enum GatewayCommands { /// Prints a table of all registered gateways with their endpoint, type, /// and authentication mode. The active gateway is marked with `*`. #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] - List, + List { + /// Output format. + #[arg(short = 'o', long = "output", value_enum, default_value_t = OutputFormat::Table)] + output: OutputFormat, + }, } // ----------------------------------------------------------------------- @@ -1015,7 +1156,7 @@ enum SandboxCommands { #[arg(long, add = ArgValueCompleter::new(completers::complete_sandbox_names))] name: Option, - /// Sandbox source: a community sandbox name (e.g., `openclaw`), a path + /// Sandbox source: a community sandbox name (e.g., `ollama`), a path /// to a Dockerfile or directory containing one, or a full container /// image reference (e.g., `myregistry.com/img:tag`). /// @@ -1061,11 +1202,20 @@ enum SandboxCommands { #[arg(long)] gpu: bool, - /// Target a specific GPU by PCI address (e.g. "0000:2d:00.0") or index (e.g. "0", "1"). - /// Only valid with --gpu. When omitted with --gpu, the first available GPU is assigned. + /// Target a driver-specific GPU device. Docker and Podman use CDI device IDs + /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. + /// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection. #[arg(long, requires = "gpu")] gpu_device: Option, + /// CPU limit for the sandbox (for example: 500m, 1, 2.5). + #[arg(long)] + cpu: Option, + + /// Memory limit for the sandbox (for example: 512Mi, 4Gi, 8G). + #[arg(long)] + memory: Option, + /// Provider names to attach to this sandbox. #[arg(long = "provider")] providers: Vec, @@ -1135,16 +1285,20 @@ enum SandboxCommands { offset: u32, /// Print only sandbox ids (one per line). - #[arg(long, conflicts_with = "names")] + #[arg(long, conflicts_with_all = ["names", "output"])] ids: bool, /// Print only sandbox names (one per line). - #[arg(long, conflicts_with = "ids")] + #[arg(long, conflicts_with_all = ["ids", "output"])] names: bool, /// Filter sandboxes by label selector (key1=value1,key2=value2). #[arg(long)] selector: Option, + + /// Output format. + #[arg(short = 'o', long = "output", value_enum, default_value_t = OutputFormat::Table, conflicts_with_all = ["ids", "names"])] + output: OutputFormat, }, /// Delete a sandbox by name. @@ -1260,6 +1414,45 @@ enum SandboxCommands { #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] name: Option, }, + + /// Manage providers attached to a sandbox. + #[command(subcommand)] + Provider(SandboxProviderCommands), +} + +#[derive(Subcommand, Debug)] +enum SandboxProviderCommands { + /// List providers attached to a sandbox. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + List { + /// Sandbox name (defaults to last-used sandbox). + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: Option, + }, + + /// Attach a provider to a sandbox. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Attach { + /// Sandbox name. + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: String, + + /// Provider name to attach. + #[arg(add = ArgValueCompleter::new(completers::complete_provider_names))] + provider: String, + }, + + /// Detach a provider from a sandbox. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Detach { + /// Sandbox name. + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: String, + + /// Provider name to detach. + #[arg(add = ArgValueCompleter::new(completers::complete_provider_names))] + provider: String, + }, } #[derive(Subcommand, Debug)] @@ -1364,7 +1557,7 @@ enum PolicyCommands { #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] name: Option, - /// Add or merge an endpoint: host:port[:access[:protocol[:enforcement]]]. + /// Add or merge an endpoint: host:port[:access[:protocol[:enforcement[:options]]]]. #[arg(long = "add-endpoint")] add_endpoints: Vec, @@ -1372,11 +1565,11 @@ enum PolicyCommands { #[arg(long = "remove-endpoint")] remove_endpoints: Vec, - /// Add a REST allow rule: `host:port:METHOD:path_glob`. + /// Add a REST or WebSocket method/path allow rule: `host:port:METHOD:path_glob`. #[arg(long = "add-allow")] add_allow: Vec, - /// Add a REST deny rule: `host:port:METHOD:path_glob`. + /// Add a REST or WebSocket method/path deny rule: `host:port:METHOD:path_glob`. #[arg(long = "add-deny")] add_deny: Vec, @@ -1416,10 +1609,14 @@ enum PolicyCommands { #[arg(long = "rev", default_value_t = 0)] rev: u32, - /// Print the full policy as YAML. + /// Include the full policy payload. #[arg(long)] full: bool, + /// Output format. + #[arg(short = 'o', long = "output", value_enum, default_value_t = PolicyGetOutput::Table)] + output: PolicyGetOutput, + /// Show the global policy revision. #[arg(long)] global: bool, @@ -1572,6 +1769,82 @@ enum ForwardCommands { /// List active port forwards. #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] List, + + /// Forward a local TCP port to a loopback service inside a sandbox over gRPC. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Service { + /// Sandbox name (defaults to last-used sandbox). + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: Option, + + /// Target service port inside the sandbox. + #[arg(long)] + target_port: u16, + + /// Target service host inside the sandbox. Phase 1 accepts loopback only. + #[arg(long, default_value = "127.0.0.1")] + target_host: String, + + /// Local bind address and port: `[bind_address:]port`. Defaults to the target port. Use port 0 for dynamic assignment. + #[arg(long)] + local: Option, + }, +} + +#[derive(Subcommand, Debug)] +enum ServiceCommands { + /// Expose an HTTP service running inside a sandbox. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Expose { + /// Sandbox name. + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + sandbox: String, + + /// Loopback TCP port inside the sandbox. + #[arg(value_name = "TARGET-PORT")] + target_port: u16, + + /// Service name. + service: Option, + }, + + /// List exposed sandbox service endpoints. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + List { + /// Sandbox name. + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + sandbox: Option, + + /// Maximum number of endpoints to return. + #[arg(long, default_value_t = 100)] + limit: u32, + + /// Number of endpoints to skip. + #[arg(long, default_value_t = 0)] + offset: u32, + }, + + /// Show one exposed sandbox service endpoint. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Get { + /// Sandbox name. + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + sandbox: String, + + /// Service name. Omit for the unnamed endpoint. + service: Option, + }, + + /// Delete one exposed sandbox service endpoint. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Delete { + /// Sandbox name. + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + sandbox: String, + + /// Service name. Omit for the unnamed endpoint. + service: Option, + }, } #[tokio::main] @@ -1647,6 +1920,7 @@ async fn main() -> Result<()> { &oidc_client_id, oidc_audience.as_deref(), oidc_scopes.as_deref(), + cli.gateway_insecure, ) .await?; } @@ -1672,7 +1946,7 @@ async fn main() -> Result<()> { Or set one with: openshell gateway select " ) })?; - run::gateway_login(&name).await?; + run::gateway_login(&name, cli.gateway_insecure).await?; } GatewayCommands::Logout { name } => { let name = name @@ -1695,8 +1969,8 @@ async fn main() -> Result<()> { .unwrap_or_else(|| "openshell".to_string()); run::gateway_admin_info(&name)?; } - GatewayCommands::List => { - run::gateway_list(&cli.gateway)?; + GatewayCommands::List { output } => { + run::gateway_list(&cli.gateway, output.as_str())?; } }, @@ -1814,6 +2088,27 @@ async fn main() -> Result<()> { } } } + ForwardCommands::Service { + name, + target_port, + target_host, + local, + } => { + let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; + let mut tls = tls.with_gateway_name(&ctx.name); + apply_auth(&mut tls, &ctx.name); + let name = resolve_sandbox_name(name, &ctx.name)?; + let local = local.unwrap_or_else(|| target_port.to_string()); + run::service_forward_tcp( + &ctx.endpoint, + &name, + Some(&local), + &target_host, + target_port, + &tls, + ) + .await?; + } ForwardCommands::Start { port, name, @@ -1837,6 +2132,43 @@ async fn main() -> Result<()> { } }, + // ----------------------------------------------------------- + // Service exposure + // ----------------------------------------------------------- + Some(Commands::Service { + command: Some(command), + }) => { + let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; + let mut tls = tls.with_gateway_name(&ctx.name); + apply_auth(&mut tls, &ctx.name); + match command { + ServiceCommands::Expose { + sandbox, + service, + target_port, + } => { + let service = service.unwrap_or_default(); + run::service_expose(&ctx.endpoint, &sandbox, &service, target_port, &tls) + .await?; + } + ServiceCommands::List { + sandbox, + limit, + offset, + } => { + run::service_list(&ctx.endpoint, sandbox.as_deref(), limit, offset, &tls) + .await?; + } + ServiceCommands::Get { sandbox, service } => { + let service = service.unwrap_or_default(); + run::service_get(&ctx.endpoint, &sandbox, &service, &tls).await?; + } + ServiceCommands::Delete { sandbox, service } => { + let service = service.unwrap_or_default(); + run::service_delete(&ctx.endpoint, &sandbox, &service, &tls).await?; + } + } + } // ----------------------------------------------------------- // Top-level logs (was `sandbox logs`) // ----------------------------------------------------------- @@ -1962,13 +2294,29 @@ async fn main() -> Result<()> { name, rev, full, + output, global, } => { if global { - run::sandbox_policy_get_global(&ctx.endpoint, rev, full, &tls).await?; + run::sandbox_policy_get_global( + &ctx.endpoint, + rev, + full, + output.as_str(), + &tls, + ) + .await?; } else { let name = resolve_sandbox_name(name, &ctx.name)?; - run::sandbox_policy_get(&ctx.endpoint, &name, rev, full, &tls).await?; + run::sandbox_policy_get( + &ctx.endpoint, + &name, + rev, + full, + output.as_str(), + &tls, + ) + .await?; } } PolicyCommands::List { @@ -2168,6 +2516,8 @@ async fn main() -> Result<()> { editor, gpu, gpu_device, + cpu, + memory, providers, policy, forward, @@ -2197,7 +2547,7 @@ async fn main() -> Result<()> { }; // Parse --label flags into a HashMap. - let mut labels_map = std::collections::HashMap::new(); + let mut labels_map = HashMap::new(); for label_str in &labels { let parts: Vec<&str> = label_str.splitn(2, '=').collect(); if parts.len() != 2 { @@ -2234,6 +2584,8 @@ async fn main() -> Result<()> { keep, gpu, gpu_device.as_deref(), + cpu.as_deref(), + memory.as_deref(), editor, &providers, policy.as_deref(), @@ -2291,12 +2643,8 @@ async fn main() -> Result<()> { let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; let mut tls = tls.with_gateway_name(&ctx.name); apply_auth(&mut tls, &ctx.name); - let local_dest = std::path::Path::new(dest.as_deref().unwrap_or(".")); - eprintln!( - "Downloading sandbox:{} -> {}", - sandbox_path, - local_dest.display() - ); + let local_dest = dest.as_deref().unwrap_or("."); + eprintln!("Downloading sandbox:{sandbox_path} -> {local_dest}"); run::sandbox_sync_down(&ctx.endpoint, &name, &sandbox_path, local_dest, &tls) .await?; eprintln!("{} Download complete", "✓".green().bold()); @@ -2322,6 +2670,7 @@ async fn main() -> Result<()> { ids, names, selector, + output, } => { run::sandbox_list( endpoint, @@ -2330,6 +2679,7 @@ async fn main() -> Result<()> { ids, names, selector.as_deref(), + output.as_str(), &tls, ) .await?; @@ -2385,6 +2735,20 @@ async fn main() -> Result<()> { let name = resolve_sandbox_name(name, &ctx.name)?; run::print_ssh_config(&ctx.name, &name); } + SandboxCommands::Provider(command) => match command { + SandboxProviderCommands::List { name } => { + let name = resolve_sandbox_name(name, &ctx.name)?; + run::sandbox_provider_list(endpoint, &name, &tls).await?; + } + SandboxProviderCommands::Attach { name, provider } => { + run::sandbox_provider_attach(endpoint, &name, &provider, &tls) + .await?; + } + SandboxProviderCommands::Detach { name, provider } => { + run::sandbox_provider_detach(endpoint, &name, &provider, &tls) + .await?; + } + }, } } } @@ -2416,6 +2780,55 @@ async fn main() -> Result<()> { ) .await?; } + ProviderCommands::Refresh(command) => match command { + ProviderRefreshCommands::Status { + name, + credential_key, + } => { + run::provider_refresh_status( + endpoint, + &name, + credential_key.as_deref(), + &tls, + ) + .await?; + } + ProviderRefreshCommands::Configure { + name, + credential_key, + strategy, + material, + secret_material_keys, + credential_expires_at, + } => { + run::provider_refresh_config( + endpoint, + run::ProviderRefreshConfigInput { + name: &name, + credential_key: &credential_key, + strategy: strategy.as_str(), + material: &material, + secret_material_keys: &secret_material_keys, + credential_expires_at_ms: credential_expires_at, + }, + &tls, + ) + .await?; + } + ProviderRefreshCommands::Rotate { + name, + credential_key, + } => { + run::provider_rotate(endpoint, &name, &credential_key, &tls).await?; + } + ProviderRefreshCommands::Delete { + name, + credential_key, + } => { + run::provider_refresh_delete(endpoint, &name, &credential_key, &tls) + .await?; + } + }, ProviderCommands::Get { name } => { run::provider_get(endpoint, &name, &tls).await?; } @@ -2460,6 +2873,7 @@ async fn main() -> Result<()> { from_existing, credentials, config, + credential_expires_at, } => { run::provider_update( endpoint, @@ -2467,6 +2881,7 @@ async fn main() -> Result<()> { from_existing, &credentials, &config, + &credential_expires_at, &tls, ) .await?; @@ -2481,7 +2896,11 @@ async fn main() -> Result<()> { let mut tls = tls.with_gateway_name(&ctx.name); apply_auth(&mut tls, &ctx.name); let channel = openshell_cli::tls::build_channel(&ctx.endpoint, &tls).await?; - openshell_tui::run(channel, &ctx.name, &ctx.endpoint, theme).await?; + let interceptor = openshell_core::auth::EdgeAuthInterceptor::new( + tls.oidc_token.as_deref(), + tls.edge_token.as_deref(), + )?; + openshell_tui::run(channel, interceptor, &ctx.name, &ctx.endpoint, theme).await?; } Some(Commands::Completions { shell }) => { let exe = std::env::current_exe() @@ -2560,6 +2979,13 @@ async fn main() -> Result<()> { .print_help() .expect("Failed to print help"); } + Some(Commands::Service { command: None }) => { + Cli::command() + .find_subcommand_mut("service") + .expect("service subcommand exists") + .print_help() + .expect("Failed to print help"); + } Some(Commands::Policy { command: None }) => { Cli::command() .find_subcommand_mut("policy") @@ -2721,6 +3147,30 @@ mod tests { ); } + #[test] + fn sandbox_provider_subcommands_parse() { + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "provider", + "attach", + "work-sandbox", + "work-github", + ]) + .expect("sandbox provider attach should parse"); + + let Some(Commands::Sandbox { + command: + Some(SandboxCommands::Provider(SandboxProviderCommands::Attach { name, provider })), + }) = cli.command + else { + panic!("expected sandbox provider attach command"); + }; + + assert_eq!(name, "work-sandbox"); + assert_eq!(provider, "work-github"); + } + #[test] fn completions_policy_flag_falls_back_to_file_paths() { let temp = tempfile::tempdir().expect("failed to create tempdir"); @@ -3032,7 +3482,7 @@ mod tests { let err = resolve_sandbox_name(None, "unknown-gateway").unwrap_err(); let msg = err.to_string(); assert!( - msg.contains("nav sandbox connect"), + msg.contains("openshell sandbox connect"), "expected helpful hint in error, got: {msg}" ); }); @@ -3144,7 +3594,7 @@ mod tests { cli.command, Some(Commands::Provider { command: Some(ProviderCommands::ListProfiles { - output: ProviderProfileOutput::Table + output: OutputFormat::Table }) }) )); @@ -3159,7 +3609,7 @@ mod tests { cli.command, Some(Commands::Provider { command: Some(ProviderCommands::ListProfiles { - output: ProviderProfileOutput::Json + output: OutputFormat::Json }) }) )); @@ -3182,7 +3632,7 @@ mod tests { Some(Commands::Provider { command: Some(ProviderCommands::Profile(ProviderProfileCommands::Export { id, - output: ProviderProfileOutput::Yaml + output: OutputFormat::Yaml })) }) if id == "custom-api" )); @@ -3219,6 +3669,111 @@ mod tests { )); } + #[test] + fn sandbox_list_default_output_is_table() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "list"]) + .expect("sandbox list should parse"); + + assert!(matches!( + cli.command, + Some(Commands::Sandbox { + command: Some(SandboxCommands::List { + output: OutputFormat::Table, + .. + }) + }) + )); + } + + #[test] + fn sandbox_list_accepts_output_json() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "list", "-o", "json"]) + .expect("sandbox list -o json should parse"); + + assert!(matches!( + cli.command, + Some(Commands::Sandbox { + command: Some(SandboxCommands::List { + output: OutputFormat::Json, + .. + }) + }) + )); + } + + #[test] + fn sandbox_list_accepts_output_yaml() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "list", "-o", "yaml"]) + .expect("sandbox list -o yaml should parse"); + + assert!(matches!( + cli.command, + Some(Commands::Sandbox { + command: Some(SandboxCommands::List { + output: OutputFormat::Yaml, + .. + }) + }) + )); + } + + #[test] + fn sandbox_list_json_conflicts_with_ids() { + let result = Cli::try_parse_from(["openshell", "sandbox", "list", "-o", "json", "--ids"]); + assert!(result.is_err(), "--ids and -o json should conflict"); + } + + #[test] + fn sandbox_list_json_conflicts_with_names() { + let result = Cli::try_parse_from(["openshell", "sandbox", "list", "-o", "json", "--names"]); + assert!(result.is_err(), "--names and -o json should conflict"); + } + + #[test] + fn gateway_list_default_output_is_table() { + let cli = Cli::try_parse_from(["openshell", "gateway", "list"]) + .expect("gateway list should parse"); + + assert!(matches!( + cli.command, + Some(Commands::Gateway { + command: Some(GatewayCommands::List { + output: OutputFormat::Table, + }) + }) + )); + } + + #[test] + fn gateway_list_accepts_output_json() { + let cli = Cli::try_parse_from(["openshell", "gateway", "list", "-o", "json"]) + .expect("gateway list -o json should parse"); + + assert!(matches!( + cli.command, + Some(Commands::Gateway { + command: Some(GatewayCommands::List { + output: OutputFormat::Json, + }) + }) + )); + } + + #[test] + fn gateway_list_accepts_output_yaml() { + let cli = Cli::try_parse_from(["openshell", "gateway", "list", "-o", "yaml"]) + .expect("gateway list -o yaml should parse"); + + assert!(matches!( + cli.command, + Some(Commands::Gateway { + command: Some(GatewayCommands::List { + output: OutputFormat::Yaml, + }) + }) + )); + } + #[test] fn provider_create_accepts_custom_profile_type_ids() { let cli = Cli::try_parse_from([ @@ -3252,6 +3807,155 @@ mod tests { } } + #[test] + fn provider_refresh_commands_parse() { + let status = Cli::try_parse_from([ + "openshell", + "provider", + "refresh", + "status", + "my-graph", + "--credential-key", + "MS_GRAPH_ACCESS_TOKEN", + ]) + .expect("provider refresh status should parse"); + assert!(matches!( + status.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Refresh(ProviderRefreshCommands::Status { + name, + credential_key: Some(key) + })) + }) if name == "my-graph" && key == "MS_GRAPH_ACCESS_TOKEN" + )); + + let config = Cli::try_parse_from([ + "openshell", + "provider", + "refresh", + "configure", + "my-graph", + "--credential-key", + "MS_GRAPH_ACCESS_TOKEN", + "--strategy", + "oauth2-client-credentials", + "--material", + "tenant_id=abc", + "--secret-material-key", + "client_secret", + "--credential-expires-at", + "1767225600000", + ]) + .expect("provider refresh configure should parse"); + assert!(matches!( + config.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Refresh( + ProviderRefreshCommands::Configure { + strategy: CliProviderRefreshStrategy::Oauth2ClientCredentials, + credential_expires_at: Some(1_767_225_600_000), + .. + } + )) + }) + )); + + let rotate = Cli::try_parse_from([ + "openshell", + "provider", + "refresh", + "rotate", + "my-graph", + "--credential-key", + "MS_GRAPH_ACCESS_TOKEN", + ]) + .expect("provider refresh rotate should parse"); + assert!(matches!( + rotate.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Refresh(ProviderRefreshCommands::Rotate { + name, + credential_key + })) + }) if name == "my-graph" && credential_key == "MS_GRAPH_ACCESS_TOKEN" + )); + + let delete = Cli::try_parse_from([ + "openshell", + "provider", + "refresh", + "delete", + "my-graph", + "--credential-key", + "MS_GRAPH_ACCESS_TOKEN", + ]) + .expect("provider refresh delete should parse"); + assert!(matches!( + delete.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Refresh(ProviderRefreshCommands::Delete { + name, + credential_key + })) + }) if name == "my-graph" && credential_key == "MS_GRAPH_ACCESS_TOKEN" + )); + } + + #[test] + fn provider_update_accepts_credential_expiry() { + let cli = Cli::try_parse_from([ + "openshell", + "provider", + "update", + "my-graph", + "--credential", + "MS_GRAPH_ACCESS_TOKEN=abc", + "--credential-expires-at", + "MS_GRAPH_ACCESS_TOKEN=1767225600000", + ]) + .expect("provider update should parse credential expiry"); + + assert!(matches!( + cli.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Update { + credential_expires_at, + .. + }) + }) if credential_expires_at == vec!["MS_GRAPH_ACCESS_TOKEN=1767225600000"] + )); + } + + #[test] + fn provider_refresh_config_accepts_rfc3339_credential_expiry() { + let cli = Cli::try_parse_from([ + "openshell", + "provider", + "refresh", + "configure", + "my-graph", + "--credential-key", + "MS_GRAPH_ACCESS_TOKEN", + "--strategy", + "oauth2-client-credentials", + "--credential-expires-at", + "2026-01-01T00:00:00Z", + ]) + .expect("provider refresh configure should parse RFC3339 credential expiry"); + + assert!(matches!( + cli.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Refresh( + ProviderRefreshCommands::Configure { + credential_expires_at: Some(1_767_225_600_000), + .. + } + )) + }) + )); + } + #[test] fn settings_set_global_parses_yes_flag() { let cli = Cli::try_parse_from([ @@ -3303,6 +4007,34 @@ mod tests { } } + #[test] + fn policy_get_json_output_parses() { + let cli = Cli::try_parse_from([ + "openshell", + "policy", + "get", + "my-sandbox", + "--full", + "-o", + "json", + ]) + .expect("policy get -o json should parse"); + + match cli.command { + Some(Commands::Policy { + command: + Some(PolicyCommands::Get { + name, full, output, .. + }), + }) => { + assert_eq!(name.as_deref(), Some("my-sandbox")); + assert!(full); + assert!(matches!(output, PolicyGetOutput::Json)); + } + other => panic!("expected policy get command, got: {other:?}"), + } + } + #[test] fn policy_delete_global_parses() { let cli = Cli::try_parse_from(["openshell", "policy", "delete", "--global", "--yes"]) @@ -3398,4 +4130,219 @@ mod tests { } } } + + #[test] + fn sandbox_create_resource_flags_parse() { + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--cpu", + "500m", + "--memory", + "2Gi", + "--", + "claude", + ]) + .expect("sandbox create resource flags should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: + Some(SandboxCommands::Create { + cpu, + memory, + command, + .. + }), + .. + }) => { + assert_eq!(cpu.as_deref(), Some("500m")); + assert_eq!(memory.as_deref(), Some("2Gi")); + assert_eq!(command, vec!["claude".to_string()]); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn service_expose_accepts_positional_target_port_and_service() { + let cli = Cli::try_parse_from([ + "openshell", + "service", + "expose", + "my-sandbox", + "8080", + "api", + ]) + .expect("service expose positional target port should parse"); + + match cli.command { + Some(Commands::Service { + command: + Some(ServiceCommands::Expose { + sandbox, + target_port, + service, + }), + }) => { + assert_eq!(sandbox, "my-sandbox"); + assert_eq!(target_port, 8080); + assert_eq!(service.as_deref(), Some("api")); + } + other => panic!("expected service expose command, got: {other:?}"), + } + } + + #[test] + fn service_expose_allows_omitted_service_name() { + let cli = Cli::try_parse_from(["openshell", "service", "expose", "my-sandbox", "8080"]) + .expect("service expose should allow omitting the service name"); + + match cli.command { + Some(Commands::Service { + command: + Some(ServiceCommands::Expose { + sandbox, + target_port, + service, + }), + }) => { + assert_eq!(sandbox, "my-sandbox"); + assert_eq!(target_port, 8080); + assert_eq!(service, None); + } + other => panic!("expected service expose command, got: {other:?}"), + } + } + + #[test] + fn service_alias_parses_service_commands() { + let cli = Cli::try_parse_from(["openshell", "svc", "expose", "my-sandbox", "8080"]) + .expect("svc alias should parse service commands"); + + match cli.command { + Some(Commands::Service { + command: + Some(ServiceCommands::Expose { + sandbox, + target_port, + service, + }), + }) => { + assert_eq!(sandbox, "my-sandbox"); + assert_eq!(target_port, 8080); + assert_eq!(service, None); + } + other => panic!("expected service expose command, got: {other:?}"), + } + } + + #[test] + fn service_list_accepts_optional_sandbox_and_paging() { + let cli = Cli::try_parse_from([ + "openshell", + "service", + "list", + "my-sandbox", + "--limit", + "10", + "--offset", + "2", + ]) + .expect("service list should parse optional sandbox and paging"); + + match cli.command { + Some(Commands::Service { + command: + Some(ServiceCommands::List { + sandbox, + limit, + offset, + }), + }) => { + assert_eq!(sandbox.as_deref(), Some("my-sandbox")); + assert_eq!(limit, 10); + assert_eq!(offset, 2); + } + other => panic!("expected service list command, got: {other:?}"), + } + + let cli = Cli::try_parse_from(["openshell", "service", "list"]) + .expect("service list should allow omitting sandbox"); + + match cli.command { + Some(Commands::Service { + command: + Some(ServiceCommands::List { + sandbox, + limit, + offset, + }), + }) => { + assert_eq!(sandbox, None); + assert_eq!(limit, 100); + assert_eq!(offset, 0); + } + other => panic!("expected service list command, got: {other:?}"), + } + } + + #[test] + fn service_get_accepts_optional_service_name() { + let cli = Cli::try_parse_from(["openshell", "service", "get", "my-sandbox", "api"]) + .expect("service get should parse service name"); + + match cli.command { + Some(Commands::Service { + command: Some(ServiceCommands::Get { sandbox, service }), + }) => { + assert_eq!(sandbox, "my-sandbox"); + assert_eq!(service.as_deref(), Some("api")); + } + other => panic!("expected service get command, got: {other:?}"), + } + + let cli = Cli::try_parse_from(["openshell", "service", "get", "my-sandbox"]) + .expect("service get should allow omitting service name"); + + match cli.command { + Some(Commands::Service { + command: Some(ServiceCommands::Get { sandbox, service }), + }) => { + assert_eq!(sandbox, "my-sandbox"); + assert_eq!(service, None); + } + other => panic!("expected service get command, got: {other:?}"), + } + } + + #[test] + fn service_delete_accepts_optional_service_name() { + let cli = Cli::try_parse_from(["openshell", "service", "delete", "my-sandbox", "api"]) + .expect("service delete should parse service name"); + + match cli.command { + Some(Commands::Service { + command: Some(ServiceCommands::Delete { sandbox, service }), + }) => { + assert_eq!(sandbox, "my-sandbox"); + assert_eq!(service.as_deref(), Some("api")); + } + other => panic!("expected service delete command, got: {other:?}"), + } + + let cli = Cli::try_parse_from(["openshell", "service", "delete", "my-sandbox"]) + .expect("service delete should allow omitting service name"); + + match cli.command { + Some(Commands::Service { + command: Some(ServiceCommands::Delete { sandbox, service }), + }) => { + assert_eq!(sandbox, "my-sandbox"); + assert_eq!(service, None); + } + other => panic!("expected service delete command, got: {other:?}"), + } + } } diff --git a/crates/openshell-cli/src/oidc_auth.rs b/crates/openshell-cli/src/oidc_auth.rs index fd742a418..379a53112 100644 --- a/crates/openshell-cli/src/oidc_auth.rs +++ b/crates/openshell-cli/src/oidc_auth.rs @@ -42,10 +42,13 @@ struct OidcDiscovery { /// /// Validates that the discovery document's `issuer` field matches the /// configured issuer URL to prevent SSRF or misdirection. -async fn discover(issuer: &str) -> Result { +async fn discover(issuer: &str, insecure: bool) -> Result { let normalized_issuer = issuer.trim_end_matches('/'); let url = format!("{normalized_issuer}/.well-known/openid-configuration"); - let resp: OidcDiscovery = reqwest::get(&url) + let client = http_client(insecure); + let resp: OidcDiscovery = client + .get(&url) + .send() .await .into_diagnostic()? .json() @@ -63,11 +66,12 @@ async fn discover(issuer: &str) -> Result { Ok(resp) } -fn http_client() -> reqwest::Client { - reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .build() - .expect("failed to build HTTP client") +fn http_client(insecure: bool) -> reqwest::Client { + let mut builder = reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()); + if insecure { + builder = builder.danger_accept_invalid_certs(true); + } + builder.build().expect("failed to build HTTP client") } fn build_scopes(scopes: Option<&str>) -> Vec { @@ -100,8 +104,9 @@ pub async fn oidc_browser_auth_flow( client_id: &str, audience: Option<&str>, scopes: Option<&str>, + insecure: bool, ) -> Result { - let discovery = discover(issuer).await?; + let discovery = discover(issuer, insecure).await?; let listener = TcpListener::bind("127.0.0.1:0").await.into_diagnostic()?; let port = listener.local_addr().into_diagnostic()?.port(); @@ -161,7 +166,7 @@ pub async fn oidc_browser_auth_flow( server_handle.abort(); - let http = http_client(); + let http = http_client(insecure); let token_response = client .exchange_code(AuthorizationCode::new(code)) .set_pkce_verifier(pkce_verifier) @@ -184,6 +189,7 @@ pub async fn oidc_client_credentials_flow( client_id: &str, audience: Option<&str>, scopes: Option<&str>, + insecure: bool, ) -> Result { let client_secret = std::env::var("OPENSHELL_OIDC_CLIENT_SECRET").map_err(|_| { miette::miette!( @@ -191,7 +197,7 @@ pub async fn oidc_client_credentials_flow( ) })?; - let discovery = discover(issuer).await?; + let discovery = discover(issuer, insecure).await?; let client = BasicClient::new(ClientId::new(client_id.to_string())) .set_client_secret(ClientSecret::new(client_secret)) @@ -206,7 +212,7 @@ pub async fn oidc_client_credentials_flow( request = request.add_extra_param("audience", aud); } - let http = http_client(); + let http = http_client(insecure); let token_response = request .request_async(&http) .await @@ -223,19 +229,22 @@ pub async fn oidc_client_credentials_flow( /// /// Preserves the existing refresh token if the server does not return a new /// one (per OAuth 2.0 spec, the refresh response may omit `refresh_token`). -pub async fn oidc_refresh_token(bundle: &OidcTokenBundle) -> Result { +pub async fn oidc_refresh_token( + bundle: &OidcTokenBundle, + insecure: bool, +) -> Result { let refresh_token = bundle.refresh_token.as_deref().ok_or_else(|| { miette::miette!( "no refresh token available — re-authenticate with: openshell gateway login" ) })?; - let discovery = discover(&bundle.issuer).await?; + let discovery = discover(&bundle.issuer, insecure).await?; let client = BasicClient::new(ClientId::new(bundle.client_id.clone())) .set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?); - let http = http_client(); + let http = http_client(insecure); let token_response = client .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string())) .request_async(&http) @@ -253,7 +262,7 @@ pub async fn oidc_refresh_token(bundle: &OidcTokenBundle) -> Result Result { +pub async fn ensure_valid_oidc_token(gateway_name: &str, insecure: bool) -> Result { let bundle = openshell_bootstrap::oidc_token::load_oidc_token(gateway_name).ok_or_else(|| { miette::miette!( @@ -270,7 +279,7 @@ pub async fn ensure_valid_oidc_token(gateway_name: &str) -> Result { gateway = gateway_name, "OIDC token expired, attempting refresh" ); - let refreshed = oidc_refresh_token(&bundle).await?; + let refreshed = oidc_refresh_token(&bundle, insecure).await?; openshell_bootstrap::oidc_token::store_oidc_token(gateway_name, &refreshed)?; Ok(refreshed.access_token) } @@ -436,3 +445,90 @@ fn html_response(status: StatusCode, message: &str) -> Response> { .body(Full::new(Bytes::from(body))) .expect("response") } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn http_client_secure_rejects_self_signed() { + let client = http_client(false); + let rt = tokio::runtime::Runtime::new().unwrap(); + // A real self-signed server isn't available in unit tests, but we can + // verify the client is constructed and makes requests. The secure client + // should exist and function for valid endpoints. + let result = rt.block_on(async { client.get("https://127.0.0.1:1").send().await }); + assert!(result.is_err(), "connection to closed port should fail"); + } + + #[test] + fn http_client_insecure_builds_without_panic() { + let client = http_client(true); + // Verify the client is usable (doesn't panic on construction). + let rt = tokio::runtime::Runtime::new().unwrap(); + let result = rt.block_on(async { client.get("https://127.0.0.1:1").send().await }); + assert!(result.is_err(), "connection to closed port should fail"); + } + + #[test] + fn discover_validates_issuer_mismatch() { + let rt = tokio::runtime::Runtime::new().unwrap(); + // Discovery against a non-existent issuer should fail with a + // connection error, not silently succeed. + let result = rt.block_on(discover("http://127.0.0.1:1/realms/test", false)); + assert!(result.is_err()); + } + + #[test] + fn discover_insecure_passes_flag_through() { + let rt = tokio::runtime::Runtime::new().unwrap(); + // Same as above but with insecure=true. Should still fail on + // connection (no server) but must not panic. + let result = rt.block_on(discover("https://127.0.0.1:1/realms/test", true)); + assert!(result.is_err()); + } + + #[test] + fn percent_decode_basic() { + assert_eq!(percent_decode("hello%20world"), "hello world"); + assert_eq!(percent_decode("a%2Fb"), "a/b"); + assert_eq!(percent_decode("no+encoding+here"), "no encoding here"); + } + + #[test] + fn build_scopes_always_includes_openid() { + let scopes = build_scopes(None); + assert_eq!(scopes.len(), 1); + + let scopes = build_scopes(Some("profile email")); + assert_eq!(scopes.len(), 3); + } + + #[test] + fn build_scopes_deduplicates_openid() { + let scopes = build_scopes(Some("openid profile")); + assert_eq!(scopes.len(), 2); + } + + #[test] + fn build_ci_scopes_empty_on_none() { + let scopes = build_ci_scopes(None); + assert!(scopes.is_empty()); + } + + #[test] + fn bundle_from_response_sets_fields() { + use oauth2::basic::BasicTokenResponse; + + let token_response: BasicTokenResponse = serde_json::from_str( + r#"{"access_token":"test-access","token_type":"bearer","expires_in":300,"refresh_token":"test-refresh"}"#, + ) + .unwrap(); + let bundle = bundle_from_oauth2_response(&token_response, "https://issuer", "my-client"); + assert_eq!(bundle.access_token, "test-access"); + assert_eq!(bundle.refresh_token.as_deref(), Some("test-refresh")); + assert_eq!(bundle.issuer, "https://issuer"); + assert_eq!(bundle.client_id, "my-client"); + assert!(bundle.expires_at.is_some()); + } +} diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 322a28df6..57656b878 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -18,6 +18,7 @@ pub struct PolicyUpdatePlan { pub preview_operations: Vec, } +#[allow(clippy::too_many_arguments)] pub fn build_policy_update_plan( add_endpoints: &[String], remove_endpoints: &[String], @@ -41,7 +42,6 @@ pub fn build_policy_update_plan( "--rule-name is only supported when exactly one --add-endpoint is provided" )); } - let mut merge_operations = Vec::new(); let mut preview_operations = Vec::new(); @@ -155,6 +155,40 @@ pub fn build_policy_update_plan( }) } +fn ensure_websocket_credential_rewrite_protocol( + spec: &str, + endpoint: &NetworkEndpoint, +) -> Result<()> { + if matches!(endpoint.protocol.as_str(), "rest" | "websocket") { + return Ok(()); + } + let protocol = if endpoint.protocol.is_empty() { + "" + } else { + endpoint.protocol.as_str() + }; + Err(miette!( + "websocket-credential-rewrite endpoint option requires --add-endpoint protocol segment to be 'rest' or 'websocket'; got '{protocol}' in '{spec}'" + )) +} + +fn ensure_request_body_credential_rewrite_protocol( + spec: &str, + endpoint: &NetworkEndpoint, +) -> Result<()> { + if endpoint.protocol == "rest" { + return Ok(()); + } + let protocol = if endpoint.protocol.is_empty() { + "" + } else { + endpoint.protocol.as_str() + }; + Err(miette!( + "request-body-credential-rewrite endpoint option requires --add-endpoint protocol segment to be 'rest'; got '{protocol}' in '{spec}'" + )) +} + fn group_allow_rules(specs: &[String]) -> Result>> { let mut grouped = BTreeMap::new(); for spec in specs { @@ -257,9 +291,9 @@ fn parse_remove_endpoint_spec(spec: &str) -> Result<(String, u32)> { fn parse_add_endpoint_spec(spec: &str) -> Result { let parts = spec.split(':').collect::>(); - if !(2..=5).contains(&parts.len()) { + if !(2..=6).contains(&parts.len()) { return Err(miette!( - "--add-endpoint expects host:port[:access[:protocol[:enforcement]]], got '{spec}'" + "--add-endpoint expects host:port[:access[:protocol[:enforcement[:options]]]], got '{spec}'" )); } @@ -269,12 +303,18 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { let access = parts.get(2).copied().unwrap_or("").trim(); let protocol = parts.get(3).copied().unwrap_or("").trim(); let enforcement = parts.get(4).copied().unwrap_or("").trim(); + let options = parts.get(5).copied().unwrap_or("").trim(); if parts.len() == 3 && access.is_empty() { return Err(miette!( "--add-endpoint has an empty access segment in '{spec}'; omit it entirely if you do not need access or protocol fields" )); } + if parts.len() == 6 && options.is_empty() { + return Err(miette!( + "--add-endpoint has an empty options segment in '{spec}'; omit it entirely if you do not need endpoint options" + )); + } if !enforcement.is_empty() && protocol.is_empty() { return Err(miette!( "--add-endpoint cannot set enforcement without protocol in '{spec}'" @@ -285,9 +325,9 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { "--add-endpoint access segment must be one of read-only, read-write, or full; got '{access}' in '{spec}'" )); } - if !protocol.is_empty() && !matches!(protocol, "rest" | "sql") { + if !protocol.is_empty() && !matches!(protocol, "rest" | "websocket" | "sql") { return Err(miette!( - "--add-endpoint protocol segment must be 'rest' or 'sql'; got '{protocol}' in '{spec}'" + "--add-endpoint protocol segment must be 'rest', 'websocket', or 'sql'; got '{protocol}' in '{spec}'" )); } if !enforcement.is_empty() && !matches!(enforcement, "enforce" | "audit") { @@ -296,7 +336,7 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { )); } - Ok(NetworkEndpoint { + let mut endpoint = NetworkEndpoint { host, port, ports: vec![port], @@ -304,7 +344,65 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { enforcement: enforcement.to_string(), access: access.to_string(), ..Default::default() - }) + }; + apply_add_endpoint_options(spec, &mut endpoint, options)?; + Ok(endpoint) +} + +fn apply_add_endpoint_options( + spec: &str, + endpoint: &mut NetworkEndpoint, + options: &str, +) -> Result<()> { + if options.is_empty() { + return Ok(()); + } + + for option in options.split(',') { + let option = option.trim(); + if option.is_empty() { + return Err(miette!( + "--add-endpoint options segment must not contain empty options in '{spec}'" + )); + } + match option { + "websocket-credential-rewrite" => { + ensure_websocket_credential_rewrite_protocol(spec, endpoint)?; + endpoint.websocket_credential_rewrite = true; + } + "request-body-credential-rewrite" => { + ensure_request_body_credential_rewrite_protocol(spec, endpoint)?; + endpoint.request_body_credential_rewrite = true; + } + _ => { + let Some(allowed_ip) = option.strip_prefix("allowed-ip=") else { + return Err(miette!( + "--add-endpoint options segment supports only 'websocket-credential-rewrite', 'request-body-credential-rewrite', and 'allowed-ip='; got '{option}' in '{spec}'" + )); + }; + let allowed_ip = allowed_ip.trim(); + if allowed_ip.is_empty() { + return Err(miette!( + "--add-endpoint allowed-ip option must include a CIDR or IP value in '{spec}'" + )); + } + if allowed_ip.contains(char::is_whitespace) { + return Err(miette!( + "--add-endpoint allowed-ip option must not contain whitespace in '{spec}'" + )); + } + if !endpoint + .allowed_ips + .iter() + .any(|existing| existing == allowed_ip) + { + endpoint.allowed_ips.push(allowed_ip.to_string()); + } + } + } + } + + Ok(()) } fn parse_host(flag: &str, spec: &str, host: &str) -> Result { @@ -352,7 +450,30 @@ fn dedup_strings(values: &[String]) -> Vec { #[cfg(test)] mod tests { - use super::build_policy_update_plan; + use super::{ + PolicyUpdatePlan, build_policy_update_plan as build_policy_update_plan_with_options, + }; + use openshell_policy::PolicyMergeOp; + + fn build_policy_update_plan( + add_endpoints: &[String], + remove_endpoints: &[String], + add_deny: &[String], + add_allow: &[String], + remove_rules: &[String], + binaries: &[String], + rule_name: Option<&str>, + ) -> miette::Result { + build_policy_update_plan_with_options( + add_endpoints, + remove_endpoints, + add_deny, + add_allow, + remove_rules, + binaries, + rule_name, + ) + } #[test] fn parse_add_endpoint_basic_l4() { @@ -392,6 +513,229 @@ mod tests { .expect("plan should build"); } + #[test] + fn parse_add_endpoint_accepts_websocket_protocol() { + let plan = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert_eq!(endpoint.host, "realtime.example.com"); + assert_eq!(endpoint.protocol, "websocket"); + assert_eq!(endpoint.access, "read-write"); + assert_eq!(endpoint.enforcement, "enforce"); + } + + #[test] + fn parse_add_endpoint_enables_websocket_credential_rewrite() { + let plan = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite" + .to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert!(rule.endpoints[0].websocket_credential_rewrite); + } + + #[test] + fn parse_add_endpoint_enables_websocket_credential_rewrite_on_rest_compat_endpoint() { + let plan = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:rest:enforce:websocket-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert!(rule.endpoints[0].websocket_credential_rewrite); + } + + #[test] + fn parse_add_endpoint_enables_request_body_credential_rewrite_on_rest_endpoint() { + let plan = build_policy_update_plan( + &[ + "api.example.com:443:read-write:rest:enforce:request-body-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert_eq!(endpoint.protocol, "rest"); + assert!(endpoint.request_body_credential_rewrite); + } + + #[test] + fn parse_add_endpoint_merges_allowed_ips_with_websocket_options() { + let plan = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite,allowed-ip=10.0.0.0/8,allowed-ip=172.16.0.0/12,allowed-ip=10.0.0.0/8" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert!(endpoint.websocket_credential_rewrite); + assert_eq!( + endpoint.allowed_ips, + vec!["10.0.0.0/8".to_string(), "172.16.0.0/12".to_string()] + ); + } + + #[test] + fn parse_add_endpoint_accepts_allowed_ip_on_rest_endpoint() { + let plan = build_policy_update_plan( + &["api.example.com:443:read-write:rest:enforce:allowed-ip=192.168.0.0/16".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert_eq!(rule.endpoints[0].allowed_ips, vec!["192.168.0.0/16"]); + } + + #[test] + fn parse_add_endpoint_rejects_empty_allowed_ip() { + let error = build_policy_update_plan( + &["api.example.com:443:read-write:rest:enforce:allowed-ip=".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("allowed-ip option")); + } + + #[test] + fn websocket_credential_rewrite_rejects_l4_endpoint() { + let error = build_policy_update_plan( + &["realtime.example.com:443::::websocket-credential-rewrite".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("protocol segment")); + } + + #[test] + fn request_body_credential_rewrite_rejects_non_rest_endpoint() { + let error = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:websocket:enforce:request-body-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + + assert!(error.to_string().contains("protocol segment")); + assert!(error.to_string().contains("'rest'")); + } + + #[test] + fn parse_add_endpoint_rejects_unknown_options() { + let error = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce:future-option".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("options segment")); + } + + #[test] + fn parse_add_allow_accepts_websocket_text_method() { + let plan = build_policy_update_plan( + &[], + &[], + &[], + &["realtime.example.com:443:websocket_text:/v1/messages/**".to_string()], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddAllowRules { host, port, rules } = &plan.preview_operations[0] else { + panic!("expected add-allow preview"); + }; + assert_eq!(host, "realtime.example.com"); + assert_eq!(*port, 443); + let allow = rules[0].allow.as_ref().expect("allow rule"); + assert_eq!(allow.method, "WEBSOCKET_TEXT"); + assert_eq!(allow.path, "/v1/messages/**"); + } + #[test] fn parse_add_deny_rejects_empty_method() { let error = build_policy_update_plan( diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 102bc87ab..f1a44ad31 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -9,6 +9,7 @@ use crate::tls::{ grpc_inference_client, require_tls_materials, }; use bytes::Bytes; +use chrono::DateTime; use dialoguer::{Confirm, Select, theme::ColorfulTheme}; use futures::StreamExt; use http_body_util::Full; @@ -23,27 +24,39 @@ use openshell_bootstrap::{ remove_gateway_metadata, resolve_ssh_hostname, save_active_gateway, save_last_sandbox, store_gateway_metadata, }; +use openshell_core::progress::{ + PROGRESS_ACTIVE_DETAIL_KEY, PROGRESS_ACTIVE_STEP_KEY, PROGRESS_COMPLETE_LABEL_KEY, + PROGRESS_COMPLETE_STEP_KEY, PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, + PROGRESS_STEP_STARTING_SANDBOX, +}; use openshell_core::proto::ProviderProfileCategory; use openshell_core::proto::{ - ApproveAllDraftChunksRequest, ApproveDraftChunkRequest, ClearDraftChunksRequest, - CreateProviderRequest, CreateSandboxRequest, DeleteProviderProfileRequest, - DeleteProviderRequest, DeleteSandboxRequest, ExecSandboxRequest, GetClusterInferenceRequest, - GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, - GetProviderProfileRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, ImportProviderProfilesRequest, - LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, - ListSandboxPoliciesRequest, ListSandboxesRequest, PolicySource, PolicyStatus, Provider, - ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, - Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, SetClusterInferenceRequest, - SettingScope, SettingValue, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, - exec_sandbox_event, setting_value, + ApproveAllDraftChunksRequest, ApproveDraftChunkRequest, AttachSandboxProviderRequest, + ClearDraftChunksRequest, ConfigureProviderRefreshRequest, CreateProviderRequest, + CreateSandboxRequest, CreateSshSessionRequest, DeleteProviderProfileRequest, + DeleteProviderRefreshRequest, DeleteProviderRequest, DeleteSandboxRequest, + DeleteServiceRequest, DetachSandboxProviderRequest, ExecSandboxRequest, ExposeServiceRequest, + GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, + GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest, + GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, + GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, HealthRequest, + ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, + ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, + ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider, + ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile, + ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, + RevokeSshSessionRequest, RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, + SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, + SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, + UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, + setting_value, tcp_forward_init, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; use openshell_providers::{ - ProviderRegistry, ProviderTypeProfile, detect_provider_from_command, normalize_provider_type, - parse_profile_json, parse_profile_yaml, profile_to_json, profile_to_yaml, profiles_to_json, - profiles_to_yaml, + ProviderRegistry, ProviderTypeProfile, RealDiscoveryContext, detect_provider_from_command, + discover_from_profile, normalize_provider_type, parse_profile_json, parse_profile_yaml, + profile_to_json, profile_to_yaml, profiles_to_json, profiles_to_yaml, }; use owo_colors::OwoColorize; use std::collections::{HashMap, HashSet}; @@ -192,26 +205,6 @@ impl ProvisioningStep { } } -/// Kubernetes event reason codes we care about. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum KubeEventReason { - Scheduled, - Pulling, - Pulled, - Started, -} - -/// Map a Kubernetes event reason string to an enum. -fn parse_kube_event_reason(reason: &str) -> Option { - match reason { - "Scheduled" => Some(KubeEventReason::Scheduled), - "Pulling" => Some(KubeEventReason::Pulling), - "Pulled" => Some(KubeEventReason::Pulled), - "Started" => Some(KubeEventReason::Started), - _ => None, - } -} - /// Live-updating display showing a provisioning step checklist with spinner. /// /// Completed steps are printed as static `✓ Step` lines. The current @@ -268,14 +261,6 @@ impl ProvisioningDisplay { } } - /// Record a completed provisioning step. - /// - /// The step is printed as a static `✓` line and the spinner advances - /// to the next expected state. - fn complete_step(&mut self, step: ProvisioningStep) { - self.complete_step_with_label(step, step.completed_label()); - } - /// Record a completed provisioning step with a custom label. fn complete_step_with_label(&mut self, step: ProvisioningStep, label: &str) { // Don't duplicate steps we've already printed. @@ -378,34 +363,106 @@ fn format_timestamp(d: Duration) -> String { format!("[{secs:.1}s]") } -/// Extract image size in bytes from a Kubernetes Pulled event message. -/// Example: "Successfully pulled image ... Image size: 620405524 bytes." -fn extract_image_size(message: &str) -> Option { - let size_prefix = "Image size: "; - let start = message.find(size_prefix)? + size_prefix.len(); - let rest = &message[start..]; - let end = rest.find(' ')?; - rest[..end].parse().ok() -} - -/// Format bytes as a human-readable string (e.g., "620 MB"). -fn format_bytes(bytes: u64) -> String { - const KB: u64 = 1024; - const MB: u64 = 1024 * KB; - const GB: u64 = 1024 * MB; - - if bytes >= GB { - // GB-scale precision loss is acceptable for a human-readable label. - #[allow(clippy::cast_precision_loss)] - let gb = bytes as f64 / GB as f64; - format!("{gb:.1} GB") - } else if bytes >= MB { - format!("{} MB", bytes / MB) - } else if bytes >= KB { - format!("{} KB", bytes / KB) - } else { - format!("{bytes} B") +fn progress_step_from_metadata(value: &str) -> Option { + match value { + PROGRESS_STEP_REQUESTING_SANDBOX => Some(ProvisioningStep::RequestingSandbox), + PROGRESS_STEP_PULLING_IMAGE => Some(ProvisioningStep::PullingSandboxImage), + PROGRESS_STEP_STARTING_SANDBOX => Some(ProvisioningStep::StartingSandbox), + _ => None, + } +} + +fn noninteractive_active_label(step: ProvisioningStep) -> String { + step.active_label().trim_end_matches('.').to_string() +} + +fn handle_platform_progress_event( + event: &PlatformEvent, + display: &mut Option, + provision_start: Instant, +) -> bool { + let completed_step = event + .metadata + .get(PROGRESS_COMPLETE_STEP_KEY) + .and_then(|step| progress_step_from_metadata(step)); + let active_step = event + .metadata + .get(PROGRESS_ACTIVE_STEP_KEY) + .and_then(|step| progress_step_from_metadata(step)); + let active_detail = event + .metadata + .get(PROGRESS_ACTIVE_DETAIL_KEY) + .filter(|detail| !detail.is_empty()); + + let handled = completed_step.is_some() || active_step.is_some() || active_detail.is_some(); + if !handled { + return false; + } + + if let Some(step) = completed_step { + let label = event + .metadata + .get(PROGRESS_COMPLETE_LABEL_KEY) + .map_or_else(|| step.completed_label(), String::as_str); + if let Some(d) = display.as_mut() { + d.complete_step_with_label(step, label); + } else { + let ts = format_timestamp(provision_start.elapsed()); + println!("{} {}", ts.dimmed(), label); + } + } + + if let Some(step) = active_step + && let Some(d) = display.as_mut() + { + d.set_active_step(step); + } + + if let Some(detail) = active_detail { + if let Some(d) = display.as_mut() { + d.set_active_detail(detail); + } else { + let ts = format_timestamp(provision_start.elapsed()); + if let Some(step) = active_step { + println!( + "{} {} {}", + ts.dimmed(), + noninteractive_active_label(step), + detail + ); + } else { + println!("{} {}", ts.dimmed(), detail); + } + } + } + + true +} + +fn is_provisioning_progress_event(event: &PlatformEvent) -> bool { + if event.metadata.contains_key(PROGRESS_COMPLETE_STEP_KEY) + || event.metadata.contains_key(PROGRESS_ACTIVE_STEP_KEY) + || event.metadata.contains_key(PROGRESS_ACTIVE_DETAIL_KEY) + { + return true; } + + event.source == "vm" + && matches!( + event.reason.as_str(), + "PullingLayer" + | "ResolvingImage" + | "AuthenticatingRegistry" + | "FetchingManifest" + | "CacheHit" + | "CacheMiss" + | "WaitingForImageCacheLock" + | "ExportingRootfs" + | "PreparingRootfs" + | "CreatingRootDisk" + | "PreparingOverlay" + | "Started" + ) } fn print_sandbox_header(sandbox: &Sandbox, display: Option<&ProvisioningDisplay>) { @@ -622,8 +679,8 @@ fn is_loopback_gateway_endpoint(endpoint: &str) -> bool { /// would serve this endpoint. /// /// Loopback endpoints (`localhost`, `127.0.0.1`, `::1`) resolve to the -/// `"openshell"` gateway name, matching the convention used by -/// `init-pki.sh` and the TLS cert resolver in `tls.rs`. +/// `"openshell"` gateway name, matching the convention used by local +/// `openshell-gateway generate-certs` and the TLS cert resolver in `tls.rs`. fn mtls_certs_exist_for_endpoint(name: &str, endpoint: &str) -> bool { let cert_name = if is_loopback_gateway_endpoint(endpoint) { "openshell" @@ -642,6 +699,65 @@ fn mtls_certs_exist_for_endpoint(name: &str, endpoint: &str) -> bool { }) } +fn package_managed_tls_dirs() -> Vec { + if let Some(path) = std::env::var_os("OPENSHELL_LOCAL_TLS_DIR") { + return vec![PathBuf::from(path)]; + } + + let mut dirs = Vec::new(); + + if cfg!(target_os = "macos") { + dirs.push(PathBuf::from("/opt/homebrew/var/openshell/tls")); + dirs.push(PathBuf::from("/usr/local/var/openshell/tls")); + } + + let state_dir = std::env::var_os("XDG_STATE_HOME") + .map(PathBuf::from) + .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".local/state"))); + if let Some(state_dir) = state_dir { + dirs.push(state_dir.join("openshell/tls")); + } + + dirs +} + +fn import_local_package_mtls_bundle(name: &str) -> Result> { + for dir in package_managed_tls_dirs() { + let ca = dir.join("ca.crt"); + let cert = dir.join("client/tls.crt"); + let key = dir.join("client/tls.key"); + if !(ca.is_file() && cert.is_file() && key.is_file()) { + continue; + } + + let bundle = openshell_bootstrap::pki::PkiBundle { + ca_cert_pem: std::fs::read_to_string(&ca) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read {}", ca.display()))?, + ca_key_pem: String::new(), + server_cert_pem: String::new(), + server_key_pem: String::new(), + client_cert_pem: std::fs::read_to_string(&cert) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read {}", cert.display()))?, + client_key_pem: std::fs::read_to_string(&key) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read {}", key.display()))?, + // CLI never holds the gateway's JWT signing material — only the + // gateway needs it. Fill the JWT fields with placeholders. + jwt_signing_key_pem: String::new(), + jwt_public_key_pem: String::new(), + jwt_key_id: String::new(), + }; + openshell_bootstrap::mtls::store_pki_bundle(name, &bundle) + .wrap_err_with(|| format!("failed to store mTLS bundle for gateway '{name}'"))?; + + return Ok(Some(dir)); + } + + Ok(None) +} + fn plaintext_gateway_is_remote(endpoint: &str, remote: Option<&str>, local: bool) -> bool { if local { return false; @@ -691,7 +807,7 @@ where let gateways = list_gateways()?; if gateways.is_empty() || !interactive { - gateway_list(gateway_flag)?; + gateway_list(gateway_flag, "table")?; if !gateways.is_empty() { eprintln!(); eprintln!( @@ -742,6 +858,7 @@ pub async fn gateway_add( oidc_client_id: &str, oidc_audience: Option<&str>, oidc_scopes: Option<&str>, + gateway_insecure: bool, ) -> Result<()> { // If the endpoint starts with ssh://, parse it into an SSH destination // and a gateway endpoint automatically. The host is resolved via @@ -794,7 +911,7 @@ pub async fn gateway_add( // Derive a gateway name from the hostname when none is provided. // Loopback endpoints use the canonical "openshell" name, matching the - // convention in init-pki.sh and default_tls_dir. + // convention in local cert generation and default_tls_dir. let derived_name; let name = if let Some(n) = name { n @@ -855,6 +972,7 @@ pub async fn gateway_add( oidc_client_id, oidc_audience, oidc_scopes, + gateway_insecure, ) .await { @@ -875,6 +993,7 @@ pub async fn gateway_add( oidc_client_id, oidc_audience, oidc_scopes, + gateway_insecure, ) .await { @@ -923,16 +1042,13 @@ pub async fn gateway_add( // Verify the gateway is reachable. let tls = TlsOptions::default(); - match http_health_check(&endpoint, &tls).await { - Ok(Some(status)) if status.is_success() => {} - _ => { - eprintln!( - "{} Gateway is not reachable at {endpoint}", - "⚠".yellow().bold(), - ); - if !has_mtls_certs { - eprintln!(" Verify the gateway is running and the endpoint is correct."); - } + if !gateway_reachable(&endpoint, &tls).await { + eprintln!( + "{} Gateway is not reachable at {endpoint}", + "⚠".yellow().bold(), + ); + if !has_mtls_certs { + eprintln!(" Verify the gateway is running and the endpoint is correct."); } } @@ -950,7 +1066,13 @@ pub async fn gateway_add( if remote.is_some() || local { // mTLS gateway (remote or local). - let certs_on_disk = mtls_certs_exist_for_endpoint(name, &endpoint); + let imported_mtls_dir = if local { + import_local_package_mtls_bundle(name)? + } else { + None + }; + let certs_on_disk = + imported_mtls_dir.is_some() || mtls_certs_exist_for_endpoint(name, &endpoint); if !certs_on_disk { return Err(miette::miette!( "mTLS certificates for gateway '{name}' were not found.\n\ @@ -983,14 +1105,11 @@ pub async fn gateway_add( // Verify the gateway is reachable over mTLS. let tls = TlsOptions::default().with_gateway_name(name); - match http_health_check(&endpoint, &tls).await { - Ok(Some(status)) if status.is_success() => {} - _ => { - eprintln!( - "{} Gateway is not reachable at {endpoint}. Verify the gateway is running.", - "⚠".yellow().bold(), - ); - } + if !gateway_reachable(&endpoint, &tls).await { + eprintln!( + "{} Gateway is not reachable at {endpoint}. Verify the gateway is running.", + "⚠".yellow().bold(), + ); } eprintln!( @@ -1048,7 +1167,7 @@ pub async fn gateway_add( /// Re-authenticate with an edge-authenticated or OIDC gateway. /// /// Dispatches to the appropriate auth flow based on `auth_mode`. -pub async fn gateway_login(name: &str) -> Result<()> { +pub async fn gateway_login(name: &str, gateway_insecure: bool) -> Result<()> { let metadata = openshell_bootstrap::load_gateway_metadata(name).map_err(|_| { miette::miette!( "Unknown gateway '{name}'.\n\ @@ -1074,11 +1193,23 @@ pub async fn gateway_login(name: &str) -> Result<()> { let scopes = metadata.oidc_scopes.as_deref(); let bundle = if std::env::var("OPENSHELL_OIDC_CLIENT_SECRET").is_ok() { - crate::oidc_auth::oidc_client_credentials_flow(issuer, client_id, audience, scopes) - .await? + crate::oidc_auth::oidc_client_credentials_flow( + issuer, + client_id, + audience, + scopes, + gateway_insecure, + ) + .await? } else { - crate::oidc_auth::oidc_browser_auth_flow(issuer, client_id, audience, scopes) - .await? + crate::oidc_auth::oidc_browser_auth_flow( + issuer, + client_id, + audience, + scopes, + gateway_insecure, + ) + .await? }; let username = jwt_preferred_username(&bundle.access_token); @@ -1145,10 +1276,34 @@ pub fn gateway_logout(name: &str) -> Result<()> { } /// List all registered gateways. -pub fn gateway_list(gateway_flag: &Option) -> Result<()> { +pub fn gateway_list(gateway_flag: &Option, output: &str) -> Result<()> { let gateways = list_gateways()?; let active = gateway_flag.clone().or_else(load_active_gateway); + match output { + "json" => { + let items: Vec = gateways + .iter() + .map(|g| gateway_to_json(g, &active)) + .collect(); + println!( + "{}", + serde_json::to_string_pretty(&items).into_diagnostic()? + ); + return Ok(()); + } + "yaml" => { + let items: Vec = gateways + .iter() + .map(|g| gateway_to_json(g, &active)) + .collect(); + print!("{}", serde_yml::to_string(&items).into_diagnostic()?); + return Ok(()); + } + "table" => {} + _ => return Err(miette!("unsupported output format: {output}")), + } + if gateways.is_empty() { println!("No gateways found."); println!(); @@ -1208,6 +1363,16 @@ pub fn gateway_list(gateway_flag: &Option) -> Result<()> { Ok(()) } +fn gateway_to_json(gateway: &GatewayMetadata, active: &Option) -> serde_json::Value { + serde_json::json!({ + "name": gateway.name, + "endpoint": gateway.gateway_endpoint, + "type": gateway_type_label(gateway), + "auth": gateway_auth_label(gateway), + "active": active.as_deref() == Some(&gateway.name), + }) +} + async fn http_health_check(server: &str, tls: &TlsOptions) -> Result> { let base = server.trim_end_matches('/'); let uri: hyper::Uri = format!("{base}/healthz").parse().into_diagnostic()?; @@ -1251,6 +1416,16 @@ async fn http_health_check(server: &str, tls: &TlsOptions) -> Result bool { + if let Ok(mut client) = grpc_client(server, tls).await + && client.health(HealthRequest {}).await.is_ok() + { + return true; + } + + matches!(http_health_check(server, tls).await, Ok(Some(status)) if status.is_success()) +} + fn remove_gateway_registration(name: &str) { if let Err(err) = openshell_bootstrap::edge_token::remove_edge_token(name) { tracing::debug!("failed to remove edge token: {err}"); @@ -1366,6 +1541,101 @@ fn sandbox_should_persist( keep || forward.is_some() } +fn build_sandbox_resource_limits( + cpu: Option<&str>, + memory: Option<&str>, +) -> Result> { + use prost_types::{Struct, Value, value::Kind}; + + fn string_value(value: String) -> Value { + Value { + kind: Some(Kind::StringValue(value)), + } + } + + let mut limits = std::collections::BTreeMap::new(); + if let Some(cpu) = cpu { + limits.insert("cpu".to_string(), string_value(validate_cpu_quantity(cpu)?)); + } + if let Some(memory) = memory { + limits.insert( + "memory".to_string(), + string_value(validate_memory_quantity(memory)?), + ); + } + + if limits.is_empty() { + return Ok(None); + } + + let mut fields = std::collections::BTreeMap::new(); + fields.insert( + "limits".to_string(), + Value { + kind: Some(Kind::StructValue(Struct { fields: limits })), + }, + ); + Ok(Some(Struct { fields })) +} + +fn validate_cpu_quantity(value: &str) -> Result { + let value = value.trim(); + if value.is_empty() { + return Err(miette!("--cpu must not be empty")); + } + + if let Some(millicores) = value.strip_suffix('m') { + if millicores.is_empty() || !millicores.bytes().all(|b| b.is_ascii_digit()) { + return Err(miette!( + "invalid --cpu value '{value}': expected positive cores or millicores, for example 2, 0.5, or 500m" + )); + } + let millicores = millicores.parse::().into_diagnostic()?; + if millicores == 0 { + return Err(miette!("--cpu must be greater than zero")); + } + return Ok(value.to_string()); + } + + let cores = value.parse::().map_err(|_| { + miette!( + "invalid --cpu value '{value}': expected positive cores or millicores, for example 2, 0.5, or 500m" + ) + })?; + if !cores.is_finite() || cores <= 0.0 { + return Err(miette!("--cpu must be greater than zero")); + } + Ok(value.to_string()) +} + +fn validate_memory_quantity(value: &str) -> Result { + let value = value.trim(); + if value.is_empty() { + return Err(miette!("--memory must not be empty")); + } + + let number_end = value + .find(|ch: char| !ch.is_ascii_digit()) + .unwrap_or(value.len()); + let (number, suffix) = value.split_at(number_end); + if number.is_empty() + || !matches!( + suffix, + "" | "Ki" | "Mi" | "Gi" | "Ti" | "Pi" | "Ei" | "K" | "M" | "G" | "T" | "P" | "E" + ) + { + return Err(miette!( + "invalid --memory value '{value}': expected positive bytes or a quantity such as 512Mi, 4Gi, or 8G" + )); + } + + let amount = number.parse::().into_diagnostic()?; + if amount == 0 { + return Err(miette!("--memory must be greater than zero")); + } + Ok(value.to_string()) +} + async fn finalize_sandbox_create_session( server: &str, sandbox_name: &str, @@ -1400,6 +1670,8 @@ pub async fn sandbox_create( keep: bool, gpu: bool, gpu_device: Option<&str>, + cpu: Option<&str>, + memory: Option<&str>, editor: Option, providers: &[String], policy: Option<&str>, @@ -1452,7 +1724,12 @@ pub async fn sandbox_create( }; let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu); - let inferred_types: Vec = inferred_provider_type(command).into_iter().collect(); + let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?; + let inferred_types: Vec = if providers_v2_enabled { + Vec::new() + } else { + inferred_provider_type(command).into_iter().collect() + }; let configured_providers = ensure_required_providers( &mut client, providers, @@ -1462,11 +1739,17 @@ pub async fn sandbox_create( .await?; let policy = load_sandbox_policy(policy)?; + let resource_limits = build_sandbox_resource_limits(cpu, memory)?; - let template = image.map(|img| SandboxTemplate { - image: img, - ..SandboxTemplate::default() - }); + let template = if image.is_some() || resource_limits.is_some() { + Some(SandboxTemplate { + image: image.unwrap_or_default(), + resources: resource_limits, + ..SandboxTemplate::default() + }) + } else { + None + }; let request = CreateSandboxRequest { spec: Some(SandboxSpec { @@ -1489,7 +1772,7 @@ pub async fn sandbox_create( status.message() )); } - Err(status) => return Err(status).into_diagnostic(), + Err(status) => return Err(miette::miette!(status.to_string())), }; let sandbox = response .into_inner() @@ -1549,7 +1832,7 @@ pub async fn sandbox_create( follow_logs: true, follow_events: true, log_tail_lines: 200, - event_tail: 0, + event_tail: 50, stop_on_terminal: false, log_since_ms: 0, log_sources: vec!["gateway".to_string()], @@ -1564,20 +1847,35 @@ pub async fn sandbox_create( let mut last_condition_message = ready_false_condition_message(sandbox.status.as_ref()); // Track whether we have seen a non-Ready phase during the watch. let mut saw_non_ready = SandboxPhase::try_from(sandbox.phase) != Ok(SandboxPhase::Ready); - let start_time = Instant::now(); let provision_timeout = Duration::from_secs( std::env::var("OPENSHELL_PROVISION_TIMEOUT") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(300), ); + let mut provisioning_idle_deadline = Instant::now() + provision_timeout; // Track whether we saw the gateway become ready (from log messages). let mut saw_gateway_ready = false; loop { - // Compute remaining time so the timeout fires even when the stream - // produces no events (e.g. server-side producer died). - let remaining = provision_timeout.saturating_sub(start_time.elapsed()); + // Timeout only when provisioning goes idle. VM first-create can spend + // longer than the default timeout pulling and preparing large images, + // but only recognized progress events extend the idle deadline. Logs + // and generic status churn must not keep a stuck sandbox alive forever. + let remaining = provisioning_idle_deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + let timeout_message = provisioning_timeout_message( + provision_timeout.as_secs(), + requested_gpu, + last_condition_message.as_deref(), + ); + if let Some(d) = display.as_mut() { + d.finish_error(&timeout_message); + } + println!(); + return Err(miette::miette!(timeout_message)); + } + let maybe_item = tokio::time::timeout(remaining, stream.next()).await; let item = match maybe_item { @@ -1624,6 +1922,7 @@ pub async fn sandbox_create( format!("{}: {}", condition.reason, condition.message); } } + break; } // Only accept Ready as terminal after we've observed a @@ -1642,76 +1941,18 @@ pub async fn sandbox_create( } } Some(openshell_core::proto::sandbox_stream_event::Payload::Event(ev)) => { - // Map Kubernetes events to provisioning steps. - // We simplify the display to: Sandbox allocated -> Pulling image -> Ready - if let Some(reason) = parse_kube_event_reason(&ev.reason) { - match reason { - KubeEventReason::Scheduled => { - if let Some(d) = display.as_mut() { - d.complete_step_with_label( - ProvisioningStep::RequestingSandbox, - "Sandbox allocated", - ); - d.set_active_step(ProvisioningStep::PullingSandboxImage); - } else { - let ts = format_timestamp(provision_start.elapsed()); - println!("{} Sandbox allocated", ts.dimmed()); - } - } - KubeEventReason::Pulling => { - // Extract image name from the event message. - let image_name = ev - .message - .strip_prefix("Pulling image ") - .map_or("", |s| s.trim_matches('"')); - if let Some(d) = display.as_mut() { - d.set_active("Pulling image..."); - if !image_name.is_empty() { - d.set_active_detail(image_name); - } - } else { - let ts = format_timestamp(provision_start.elapsed()); - if image_name.is_empty() { - println!("{} Pulling image...", ts.dimmed()); - } else { - println!("{} Pulling image {image_name}", ts.dimmed()); - } - } - } - KubeEventReason::Pulled => { - // Extract image size from message like: - // "Successfully pulled image ... Image size: 620405524 bytes." - let size_label = extract_image_size(&ev.message) - .map(format_bytes) - .unwrap_or_default(); - let label = if size_label.is_empty() { - "Image pulled".to_string() - } else { - format!("Image pulled ({size_label})") - }; - if let Some(d) = display.as_mut() { - d.complete_step_with_label( - ProvisioningStep::PullingSandboxImage, - &label, - ); - d.set_active_step(ProvisioningStep::StartingSandbox); - } else { - let ts = format_timestamp(provision_start.elapsed()); - println!("{} {}", ts.dimmed(), label); - } - } - KubeEventReason::Started => { - // Only complete StartingSandbox if we've already completed - // PullingSandboxImage (meaning the container is starting). - if let Some(d) = display.as_mut() - && d.completed_steps - .contains(&ProvisioningStep::PullingSandboxImage) - { - d.complete_step(ProvisioningStep::StartingSandbox); - } - } + let extends_timeout = is_provisioning_progress_event(&ev); + if handle_platform_progress_event(&ev, &mut display, provision_start) { + if extends_timeout { + provisioning_idle_deadline = Instant::now() + provision_timeout; } - } else if let Some(d) = display.as_mut() { + continue; + } + if extends_timeout { + provisioning_idle_deadline = Instant::now() + provision_timeout; + } + + if let Some(d) = display.as_mut() { // Unknown events: show as detail on the current spinner. if !ev.message.is_empty() { d.set_active_detail(&ev.message); @@ -1792,7 +2033,7 @@ pub async fn sandbox_create( // If --forward was requested, start the background port forward // *before* running the command so that long-running processes - // (e.g. `openclaw gateway`) are reachable immediately. + // (e.g. a web gateway) are reachable immediately. if let Some(ref spec) = forward { sandbox_forward( &effective_server, @@ -2134,12 +2375,8 @@ pub async fn sandbox_sync_command( eprintln!("{} Sync complete", "✓".green().bold()); } (None, Some(sandbox_path)) => { - let local_dest = Path::new(dest.unwrap_or(".")); - eprintln!( - "Syncing sandbox:{} -> {}", - sandbox_path, - local_dest.display() - ); + let local_dest = dest.unwrap_or("."); + eprintln!("Syncing sandbox:{sandbox_path} -> {local_dest}"); sandbox_sync_down(server, name, sandbox_path, local_dest, tls).await?; eprintln!("{} Sync complete", "✓".green().bold()); } @@ -2215,6 +2452,11 @@ pub async fn sandbox_get( println!(" {} {}", "Id:".dimmed(), id); println!(" {} {}", "Name:".dimmed(), name); println!(" {} {}", "Phase:".dimmed(), phase_name(sandbox.phase)); + println!( + " {} {}", + "Resource version:".dimmed(), + sandbox.metadata.as_ref().map_or(0, |m| m.resource_version) + ); // Display labels if present if let Some(metadata) = &sandbox.metadata @@ -2330,6 +2572,11 @@ pub async fn sandbox_exec_grpc( let tty = tty_override .unwrap_or_else(|| std::io::stdin().is_terminal() && std::io::stdout().is_terminal()); + if tty_override == Some(true) && std::io::stdin().is_terminal() { + return sandbox_exec_interactive_grpc(client, &sandbox, command, workdir, timeout_seconds) + .await; + } + // Make the streaming gRPC call. let mut stream = client .exec_sandbox(ExecSandboxRequest { @@ -2340,6 +2587,7 @@ pub async fn sandbox_exec_grpc( timeout_seconds, stdin: stdin_payload, tty, + ..Default::default() }) .await .into_diagnostic()? @@ -2373,74 +2621,510 @@ pub async fn sandbox_exec_grpc( Ok(exit_code) } -/// Print a single YAML line with dimmed keys and regular values. -fn print_yaml_line(line: &str) { - // Find leading whitespace - let trimmed = line.trim_start(); - let indent = &line[..line.len() - trimmed.len()]; - - // Handle list items - if let Some(rest) = trimmed.strip_prefix("- ") { - print!("{indent}"); - print!("{}", "- ".dimmed()); - print!("{rest}"); - println!(); - return; - } +pub async fn service_forward_tcp( + server: &str, + name: &str, + local: Option<&str>, + target_host: &str, + target_port: u16, + tls: &TlsOptions, +) -> Result<()> { + let (bind_addr, bind_port) = parse_tcp_forward_spec(local, target_port)?; + let mut client = grpc_client(server, tls).await?; - // Handle key: value pairs - if let Some(colon_pos) = trimmed.find(':') { - let key = &trimmed[..colon_pos]; - let after_colon = &trimmed[colon_pos + 1..]; + let sandbox = fetch_ready_sandbox_for_forward(&mut client, name).await?; - print!("{indent}"); - print!("{}", key.dimmed()); - print!("{}", ":".dimmed()); + let listener = tokio::net::TcpListener::bind((bind_addr.as_str(), bind_port)) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to bind local forward on {bind_addr}:{bind_port}"))?; + let local_addr = listener + .local_addr() + .into_diagnostic() + .wrap_err("failed to read local forward address")?; + eprintln!( + "{} Forwarding {} -> {}:{} in sandbox {} via gRPC", + "✓".green().bold(), + local_addr, + target_host, + target_port, + name, + ); - if after_colon.is_empty() { - // Key with nested content (no value on this line) - } else if let Some(value) = after_colon.strip_prefix(' ') { - // Key: value - print!(" {value}"); - } else { - // Shouldn't happen in valid YAML, but handle it - print!("{after_colon}"); - } - println!(); - return; - } + let sandbox_id = sandbox.object_id().to_string(); + let (fatal_tx, mut fatal_rx) = tokio::sync::mpsc::channel::(1); + let mut health_check = tokio::time::interval(Duration::from_secs(2)); + health_check.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + loop { + tokio::select! { + Some(reason) = fatal_rx.recv() => { + return Err(miette::miette!("service forward stopped: {reason}")); + } - // Plain line (shouldn't happen often in YAML) - println!("{line}"); -} + _ = health_check.tick() => { + fetch_ready_sandbox_for_forward(&mut client, name).await?; + } -/// Print sandbox policy as YAML with dimmed keys. -fn print_sandbox_policy(policy: &SandboxPolicy) { - println!("{}", "Policy:".cyan().bold()); - println!(); - if let Ok(yaml_str) = openshell_policy::serialize_sandbox_policy(policy) { - // Indent the YAML output and skip the initial "---" line - for line in yaml_str.lines() { - if line == "---" { - continue; + accepted = listener.accept() => { + let (socket, peer) = accepted + .into_diagnostic() + .wrap_err("failed to accept local forward connection")?; + let mut client = client.clone(); + let sandbox_id = sandbox_id.clone(); + let target_host = target_host.to_string(); + let service_id = format!("service-forward:{name}:{target_host}:{target_port}"); + let fatal_tx = fatal_tx.clone(); + tokio::spawn(async move { + let token = match create_forward_session_token(&mut client, &sandbox_id).await { + Ok(token) => token, + Err(err) => { + tracing::warn!(peer = %peer, error = %err, "service forward session creation failed"); + if err.fatal { + let _ = fatal_tx.send(err.message).await; + } + return; + } + }; + if let Err(err) = forward_one_tcp_connection( + &mut client, + socket, + sandbox_id, + target_host, + target_port, + service_id, + token.clone(), + ) + .await + { + tracing::warn!(peer = %peer, error = %err, "service forward connection failed"); + if err.fatal { + let _ = fatal_tx.send(err.message).await; + } + } + let _ = client + .revoke_ssh_session(RevokeSshSessionRequest { token }) + .await; + }); } - print!(" "); - print_yaml_line(line); } } } -/// List sandboxes. -pub async fn sandbox_list( - server: &str, - limit: u32, - offset: u32, - ids_only: bool, - names_only: bool, - label_selector: Option<&str>, - tls: &TlsOptions, -) -> Result<()> { - let mut client = grpc_client(server, tls).await?; +async fn create_forward_session_token( + client: &mut crate::tls::GrpcClient, + sandbox_id: &str, +) -> std::result::Result { + let response = client + .create_ssh_session(CreateSshSessionRequest { + sandbox_id: sandbox_id.to_string(), + }) + .await + .map_err(ForwardTcpConnectionError::from_status)?; + Ok(response.into_inner().token) +} + +async fn fetch_ready_sandbox_for_forward( + client: &mut crate::tls::GrpcClient, + name: &str, +) -> Result { + let response = match client + .get_sandbox(GetSandboxRequest { + name: name.to_string(), + }) + .await + { + Ok(response) => response, + Err(status) if status.code() == Code::NotFound => { + return Err(miette::miette!( + "sandbox '{name}' no longer exists; stopping service forward" + )); + } + Err(status) => return Err(status).into_diagnostic(), + }; + + let sandbox = response + .into_inner() + .sandbox + .ok_or_else(|| miette::miette!("sandbox '{name}' not found"))?; + + if SandboxPhase::try_from(sandbox.phase) != Ok(SandboxPhase::Ready) { + return Err(miette::miette!( + "sandbox '{}' is no longer ready (phase: {}); stopping service forward", + name, + phase_name(sandbox.phase) + )); + } + + Ok(sandbox) +} + +#[derive(Debug)] +struct ForwardTcpConnectionError { + message: String, + fatal: bool, +} + +impl ForwardTcpConnectionError { + fn transient(message: impl Into) -> Self { + Self { + message: message.into(), + fatal: false, + } + } + + fn from_status(status: Status) -> Self { + let fatal = matches!(status.code(), Code::NotFound | Code::FailedPrecondition); + Self { + message: status.to_string(), + fatal, + } + } +} + +impl std::fmt::Display for ForwardTcpConnectionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for ForwardTcpConnectionError {} + +fn parse_tcp_forward_spec(local: Option<&str>, default_port: u16) -> Result<(String, u16)> { + let Some(spec) = local else { + return Ok(("127.0.0.1".to_string(), default_port)); + }; + + if let Some(pos) = spec.rfind(':') { + let addr = &spec[..pos]; + let port_str = &spec[pos + 1..]; + if let Ok(port) = port_str.parse::() { + if addr.is_empty() { + return Err(miette::miette!("bind address is required before ':'")); + } + return Ok((addr.to_string(), port)); + } + } + + let port: u16 = spec.parse().map_err(|_| { + miette::miette!("invalid local forward spec '{spec}': expected [bind_address:]port") + })?; + Ok(("127.0.0.1".to_string(), port)) +} + +async fn forward_one_tcp_connection( + client: &mut crate::tls::GrpcClient, + socket: tokio::net::TcpStream, + sandbox_id: String, + target_host: String, + target_port: u16, + service_id: String, + authorization_token: String, +) -> std::result::Result<(), ForwardTcpConnectionError> { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::wrappers::ReceiverStream; + + let (tx, rx) = tokio::sync::mpsc::channel::(16); + tx.send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Init( + TcpForwardInit { + sandbox_id, + service_id, + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: target_host, + port: u32::from(target_port), + })), + authorization_token, + }, + )), + }) + .await + .map_err(|_| ForwardTcpConnectionError::transient("failed to initialize forward stream"))?; + + let mut response = match client.forward_tcp(ReceiverStream::new(rx)).await { + Ok(response) => response.into_inner(), + Err(status) => { + let err = ForwardTcpConnectionError::from_status(status); + drain_and_shutdown_local_socket(socket).await; + return Err(err); + } + }; + + let (mut local_read, mut local_write) = socket.into_split(); + + let to_gateway = tokio::spawn(async move { + let mut buf = vec![0u8; 64 * 1024]; + loop { + let n = local_read.read(&mut buf).await?; + if n == 0 { + break; + } + if tx + .send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Data( + buf[..n].to_vec(), + )), + }) + .await + .is_err() + { + break; + } + } + Ok::<(), std::io::Error>(()) + }); + + while let Some(frame) = response + .message() + .await + .map_err(ForwardTcpConnectionError::from_status)? + { + let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = frame.payload + else { + continue; + }; + if data.is_empty() { + continue; + } + local_write + .write_all(&data) + .await + .map_err(|err| ForwardTcpConnectionError::transient(err.to_string()))?; + } + + let _ = local_write.shutdown().await; + to_gateway.abort(); + Ok(()) +} + +async fn drain_and_shutdown_local_socket(mut socket: tokio::net::TcpStream) { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let mut buf = [0u8; 4096]; + while matches!( + tokio::time::timeout(Duration::from_millis(25), socket.read(&mut buf)).await, + Ok(Ok(n)) if n != 0 + ) {} + let _ = socket.shutdown().await; +} + +struct RawModeGuard; + +impl Drop for RawModeGuard { + fn drop(&mut self) { + let _ = crossterm::terminal::disable_raw_mode(); + } +} + +struct TaskGuard(tokio::task::JoinHandle<()>); + +impl Drop for TaskGuard { + fn drop(&mut self) { + self.0.abort(); + } +} + +async fn sandbox_exec_interactive_grpc( + mut client: crate::tls::GrpcClient, + sandbox: &Sandbox, + command: &[String], + workdir: Option<&str>, + timeout_seconds: u32, +) -> Result { + use openshell_core::proto::{ExecSandboxInput, ExecSandboxWindowResize, exec_sandbox_input}; + use tokio_stream::wrappers::ReceiverStream; + + let (cols, rows) = crossterm::terminal::size().unwrap_or((80, 24)); + + let (input_tx, input_rx) = tokio::sync::mpsc::channel::(4096); + + // Send the start message with exec metadata. + input_tx + .send(ExecSandboxInput { + payload: Some(exec_sandbox_input::Payload::Start(ExecSandboxRequest { + sandbox_id: sandbox.object_id().to_string(), + command: command.to_vec(), + workdir: workdir.unwrap_or_default().to_string(), + environment: HashMap::new(), + timeout_seconds, + stdin: Vec::new(), + tty: true, + cols: u32::from(cols), + rows: u32::from(rows), + })), + }) + .await + .into_diagnostic()?; + + let mut stream = client + .exec_sandbox_interactive(ReceiverStream::new(input_rx)) + .await + .into_diagnostic()? + .into_inner(); + + // Enable raw mode so keystrokes are forwarded immediately. + crossterm::terminal::enable_raw_mode().into_diagnostic()?; + let raw_guard = RawModeGuard; + + // Stdin reader on a detached OS thread. Using std::thread (not + // spawn_blocking) so the tokio runtime shutdown doesn't wait for a + // thread blocked on stdin.read(). The thread exits when the channel + // closes (blocking_send returns Err) or stdin hits EOF. + #[cfg(unix)] + { + let stdin_tx = input_tx.clone(); + std::thread::spawn(move || { + let mut stdin = std::io::stdin().lock(); + let mut buf = [0u8; 4096]; + loop { + match stdin.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if stdin_tx + .blocking_send(ExecSandboxInput { + payload: Some(exec_sandbox_input::Payload::Stdin( + buf[..n].to_vec(), + )), + }) + .is_err() + { + break; + } + } + } + } + }); + } + + // SIGWINCH handler: forward terminal resize events. + #[cfg(unix)] + let resize_task = { + let resize_tx = input_tx.clone(); + tokio::spawn(async move { + let mut sig = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::window_change()) + .expect("failed to register SIGWINCH handler"); + while sig.recv().await.is_some() { + if let Ok((c, r)) = crossterm::terminal::size() { + let msg = ExecSandboxInput { + payload: Some(exec_sandbox_input::Payload::Resize( + ExecSandboxWindowResize { + cols: u32::from(c), + rows: u32::from(r), + }, + )), + }; + if resize_tx.send(msg).await.is_err() { + break; + } + } + } + }) + }; + #[cfg(unix)] + let _resize_guard = TaskGuard(resize_task); + + let mut exit_code = 0i32; + let stdout = std::io::stdout(); + + while let Some(event) = stream.next().await { + let event = event.into_diagnostic()?; + match event.payload { + Some(exec_sandbox_event::Payload::Stdout(out)) => { + let mut handle = stdout.lock(); + handle.write_all(&out.data).into_diagnostic()?; + handle.flush().into_diagnostic()?; + } + Some(exec_sandbox_event::Payload::Stderr(err)) => { + let mut handle = stdout.lock(); + handle.write_all(&err.data).into_diagnostic()?; + handle.flush().into_diagnostic()?; + } + Some(exec_sandbox_event::Payload::Exit(exit)) => { + exit_code = exit.exit_code; + break; + } + None => {} + } + } + + drop(input_tx); + + // Drop the raw mode guard to restore the terminal before returning. + drop(raw_guard); + + Ok(exit_code) +} + +/// Print a single YAML line with dimmed keys and regular values. +fn print_yaml_line(line: &str) { + // Find leading whitespace + let trimmed = line.trim_start(); + let indent = &line[..line.len() - trimmed.len()]; + + // Handle list items + if let Some(rest) = trimmed.strip_prefix("- ") { + print!("{indent}"); + print!("{}", "- ".dimmed()); + print!("{rest}"); + println!(); + return; + } + + // Handle key: value pairs + if let Some(colon_pos) = trimmed.find(':') { + let key = &trimmed[..colon_pos]; + let after_colon = &trimmed[colon_pos + 1..]; + + print!("{indent}"); + print!("{}", key.dimmed()); + print!("{}", ":".dimmed()); + + if after_colon.is_empty() { + // Key with nested content (no value on this line) + } else if let Some(value) = after_colon.strip_prefix(' ') { + // Key: value + print!(" {value}"); + } else { + // Shouldn't happen in valid YAML, but handle it + print!("{after_colon}"); + } + println!(); + return; + } + + // Plain line (shouldn't happen often in YAML) + println!("{line}"); +} + +/// Print sandbox policy as YAML with dimmed keys. +fn print_sandbox_policy(policy: &SandboxPolicy) { + println!("{}", "Policy:".cyan().bold()); + println!(); + if let Ok(yaml_str) = openshell_policy::serialize_sandbox_policy(policy) { + // Indent the YAML output and skip the initial "---" line + for line in yaml_str.lines() { + if line == "---" { + continue; + } + print!(" "); + print_yaml_line(line); + } + } +} + +/// List sandboxes. +#[allow(clippy::too_many_arguments)] +pub async fn sandbox_list( + server: &str, + limit: u32, + offset: u32, + ids_only: bool, + names_only: bool, + label_selector: Option<&str>, + output: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; let response = client .list_sandboxes(ListSandboxesRequest { @@ -2452,6 +3136,25 @@ pub async fn sandbox_list( .into_diagnostic()?; let sandboxes = response.into_inner().sandboxes; + + match output { + "json" => { + let items: Vec = sandboxes.iter().map(sandbox_to_json).collect(); + println!( + "{}", + serde_json::to_string_pretty(&items).into_diagnostic()? + ); + return Ok(()); + } + "yaml" => { + let items: Vec = sandboxes.iter().map(sandbox_to_json).collect(); + print!("{}", serde_yml::to_string(&items).into_diagnostic()?); + return Ok(()); + } + "table" => {} + _ => return Err(miette!("unsupported output format: {output}")), + } + if sandboxes.is_empty() { if !ids_only && !names_only { println!("No sandboxes found."); @@ -2512,6 +3215,205 @@ pub async fn sandbox_list( Ok(()) } +fn sandbox_to_json(sandbox: &Sandbox) -> serde_json::Value { + let meta = sandbox.metadata.as_ref(); + let labels = meta.map_or_else(|| serde_json::json!({}), |m| serde_json::json!(m.labels)); + serde_json::json!({ + "id": sandbox.object_id(), + "name": sandbox.object_name(), + "labels": labels, + "resource_version": meta.map_or(0, |m| m.resource_version), + "created_at": format_epoch_ms(meta.map_or(0, |m| m.created_at_ms)), + "phase": phase_name(sandbox.phase), + "current_policy_version": sandbox.current_policy_version, + }) +} + +pub async fn sandbox_provider_list(server: &str, name: &str, tls: &TlsOptions) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .list_sandbox_providers(ListSandboxProvidersRequest { + sandbox_name: name.to_string(), + }) + .await + .into_diagnostic()?; + let providers = response.into_inner().providers; + + if providers.is_empty() { + println!("No providers attached to sandbox {name}."); + return Ok(()); + } + + print_provider_attachment_table(&providers); + Ok(()) +} + +pub async fn sandbox_provider_attach( + server: &str, + name: &str, + provider: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + + // Fetch current sandbox to get resource_version for CAS + let sandbox = client + .get_sandbox(GetSandboxRequest { + name: name.to_string(), + }) + .await + .into_diagnostic()? + .into_inner() + .sandbox + .ok_or_else(|| miette::miette!("sandbox not found"))?; + + let resource_version = sandbox.metadata.as_ref().map_or(0, |m| m.resource_version); + + let response = match client + .attach_sandbox_provider(AttachSandboxProviderRequest { + sandbox_name: name.to_string(), + provider_name: provider.to_string(), + expected_resource_version: resource_version, + }) + .await + { + Ok(response) => response.into_inner(), + Err(status) if status.code() == Code::Aborted => { + return Err(miette::miette!( + "Failed to attach provider: sandbox was modified by another operation.\n\ + Please retry the command." + ) + .with_source_code(status.message().to_string())); + } + Err(e) => return Err(e).into_diagnostic(), + }; + + if response.attached { + println!( + "{} Attached provider {} to sandbox {}", + "✓".green().bold(), + provider, + name + ); + } else { + println!("Provider {provider} is already attached to sandbox {name}."); + } + Ok(()) +} + +pub async fn sandbox_provider_detach( + server: &str, + name: &str, + provider: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + + // Fetch current sandbox to get resource_version for CAS + let sandbox = client + .get_sandbox(GetSandboxRequest { + name: name.to_string(), + }) + .await + .into_diagnostic()? + .into_inner() + .sandbox + .ok_or_else(|| miette::miette!("sandbox not found"))?; + + let resource_version = sandbox.metadata.as_ref().map_or(0, |m| m.resource_version); + + let response = match client + .detach_sandbox_provider(DetachSandboxProviderRequest { + sandbox_name: name.to_string(), + provider_name: provider.to_string(), + expected_resource_version: resource_version, + }) + .await + { + Ok(response) => response.into_inner(), + Err(status) if status.code() == Code::Aborted => { + return Err(miette::miette!( + "Failed to detach provider: sandbox was modified by another operation.\n\ + Please retry the command." + ) + .with_source_code(status.message().to_string())); + } + Err(e) => return Err(e).into_diagnostic(), + }; + + if response.detached { + println!( + "{} Detached provider {} from sandbox {}", + "✓".green().bold(), + provider, + name + ); + } else { + println!("Provider {provider} was not attached to sandbox {name}."); + } + Ok(()) +} + +fn print_provider_attachment_table(providers: &[Provider]) { + print!("{}", format_provider_attachment_table(providers, true)); +} + +fn format_provider_attachment_table(providers: &[Provider], color: bool) -> String { + use std::fmt::Write as _; + + let name_width = providers + .iter() + .map(|provider| provider.object_name().len()) + .max() + .unwrap_or(4) + .max(4); + let type_width = providers + .iter() + .map(|provider| provider.r#type.len()) + .max() + .unwrap_or(4) + .max(4); + + let name_header = if color { + "NAME".bold().to_string() + } else { + "NAME".to_string() + }; + let type_header = if color { + "TYPE".bold().to_string() + } else { + "TYPE".to_string() + }; + let credential_keys_header = if color { + "CREDENTIAL_KEYS".bold().to_string() + } else { + "CREDENTIAL_KEYS".to_string() + }; + let config_keys_header = if color { + "CONFIG_KEYS".bold().to_string() + } else { + "CONFIG_KEYS".to_string() + }; + + let mut output = String::new(); + let _ = writeln!( + output, + "{name_header: Option { /// passed through directly; the server validates they exist at sandbox creation. /// /// `inferred_types` are provider **types** inferred from the trailing command -/// (e.g. `claude` → type `"claude"`). These are resolved to provider names via +/// (e.g. `claude` -> type `"claude-code"`). These are resolved to provider names via /// a type→name lookup, and missing types may be auto-created interactively. /// /// Returns a deduplicated list of provider **names** suitable for @@ -2749,9 +3651,8 @@ async fn auto_create_provider( return Ok(()); } - let registry = ProviderRegistry::new(); - let discovered = registry - .discover_existing(provider_type) + let discovered = discover_existing_provider_data(client, provider_type) + .await .map_err(|err| miette::miette!("failed to discover provider '{provider_type}': {err}"))?; let Some(discovered) = discovered else { eprintln!( @@ -2772,10 +3673,12 @@ async fn auto_create_provider( name: exact_name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: discovered.credentials.clone(), config: discovered.config.clone(), + credential_expires_at_ms: HashMap::new(), }), }; @@ -2812,10 +3715,12 @@ async fn auto_create_provider( name: name.clone(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: discovered.credentials.clone(), config: discovered.config.clone(), + credential_expires_at_ms: HashMap::new(), }), }; @@ -2853,63 +3758,421 @@ async fn auto_create_provider( } } - eprintln!(); + eprintln!(); + Ok(()) +} + +fn parse_key_value_pairs(items: &[String], flag: &str) -> Result> { + let mut map = HashMap::new(); + + for item in items { + let Some((key, value)) = item.split_once('=') else { + return Err(miette::miette!("{flag} expects KEY=VALUE, got '{item}'")); + }; + + let key = key.trim(); + if key.is_empty() { + return Err(miette::miette!("{flag} key cannot be empty")); + } + + map.insert(key.to_string(), value.to_string()); + } + + Ok(map) +} + +fn parse_credential_pairs(items: &[String]) -> Result> { + let mut map = HashMap::new(); + + for item in items { + if let Some((key, value)) = item.split_once('=') { + let key = key.trim(); + if key.is_empty() { + return Err(miette::miette!("--credential key cannot be empty")); + } + map.insert(key.to_string(), value.to_string()); + continue; + } + + let key = item.trim(); + if key.is_empty() { + return Err(miette::miette!("--credential key cannot be empty")); + } + + let value = std::env::var(key).map_err(|_| { + miette::miette!( + "--credential {key} requires local env var '{key}' to be set to a non-empty value" + ) + })?; + + if value.trim().is_empty() { + return Err(miette::miette!( + "--credential {key} requires local env var '{key}' to be set to a non-empty value" + )); + } + + map.insert(key.to_string(), value); + } + + Ok(map) +} + +pub fn parse_credential_expiry_cli_value(value: &str) -> std::result::Result { + parse_credential_expiry_value(value, None).map_err(|err| err.to_string()) +} + +fn credential_expiry_value_error(key: Option<&str>, detail: &str) -> miette::Report { + key.map_or_else( + || miette::miette!("--credential-expires-at value {detail}"), + |key| miette::miette!("--credential-expires-at value for '{key}' {detail}"), + ) +} + +fn parse_credential_expiry_value(value: &str, key: Option<&str>) -> Result { + let value = value.trim(); + if value.is_empty() { + return Err(credential_expiry_value_error(key, "cannot be empty")); + } + + if let Ok(value_ms) = value.parse::() { + if value_ms < 0 { + return Err(credential_expiry_value_error( + key, + "must be greater than or equal to 0", + )); + } + return Ok(value_ms); + } + + let parsed = DateTime::parse_from_rfc3339(value).map_err(|_| { + credential_expiry_value_error( + key, + "must be a Unix epoch millisecond timestamp or RFC3339 timestamp", + ) + })?; + let value_ms = parsed.timestamp_millis(); + if value_ms < 0 { + return Err(credential_expiry_value_error( + key, + "must be greater than or equal to 0", + )); + } + + Ok(value_ms) +} + +fn parse_credential_expiry_pairs(items: &[String]) -> Result> { + let mut map = HashMap::new(); + + for item in items { + let Some((key, value)) = item.split_once('=') else { + return Err(miette::miette!( + "--credential-expires-at expects KEY=TIMESTAMP, got '{item}'" + )); + }; + let key = key.trim(); + if key.is_empty() { + return Err(miette::miette!( + "--credential-expires-at key cannot be empty" + )); + } + let value = parse_credential_expiry_value(value, Some(key))?; + map.insert(key.to_string(), value); + } + + Ok(map) +} + +pub async fn service_expose( + server: &str, + sandbox: &str, + service: &str, + target_port: u16, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .expose_service(ExposeServiceRequest { + sandbox: sandbox.to_string(), + service: service.to_string(), + target_port: u32::from(target_port), + domain: true, + }) + .await + .map_err(service_expose_status_error)? + .into_inner(); + + if service.is_empty() { + println!( + "{} Exposed sandbox {} -> 127.0.0.1:{}", + "✓".green().bold(), + sandbox.bold(), + target_port, + ); + } else { + println!( + "{} Exposed service {} on sandbox {} -> 127.0.0.1:{}", + "✓".green().bold(), + service.bold(), + sandbox.bold(), + target_port, + ); + } + if !response.url.is_empty() { + let url = service_url_for_gateway(&response.url, server); + println!(" URL: {}", url.cyan()); + } + Ok(()) +} + +fn service_expose_status_error(status: Status) -> miette::Report { + service_status_error("expose service", "sandbox:write", status) +} + +pub async fn service_list( + server: &str, + sandbox: Option<&str>, + limit: u32, + offset: u32, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .list_services(ListServicesRequest { + sandbox: sandbox.unwrap_or_default().to_string(), + limit, + offset, + }) + .await + .map_err(|status| service_status_error("list services", "sandbox:read", status))? + .into_inner(); + + if response.services.is_empty() { + if let Some(sandbox) = sandbox { + println!("No services exposed for sandbox {sandbox}."); + } else { + println!("No services exposed."); + } + return Ok(()); + } + + print_service_endpoint_table(&response.services, server); Ok(()) } -fn parse_key_value_pairs(items: &[String], flag: &str) -> Result> { - let mut map = HashMap::new(); +pub async fn service_get( + server: &str, + sandbox: &str, + service: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .get_service(GetServiceRequest { + sandbox: sandbox.to_string(), + service: service.to_string(), + }) + .await + .map_err(|status| service_status_error("get service", "sandbox:read", status))? + .into_inner(); - for item in items { - let Some((key, value)) = item.split_once('=') else { - return Err(miette::miette!("{flag} expects KEY=VALUE, got '{item}'")); - }; + print_service_endpoint_table(&[response], server); + Ok(()) +} - let key = key.trim(); - if key.is_empty() { - return Err(miette::miette!("{flag} key cannot be empty")); +pub async fn service_delete( + server: &str, + sandbox: &str, + service: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .delete_service(DeleteServiceRequest { + sandbox: sandbox.to_string(), + service: service.to_string(), + }) + .await + .map_err(|status| service_status_error("delete service", "sandbox:write", status))? + .into_inner(); + + if !response.deleted { + return Err(miette!("delete service failed: service endpoint not found")); + } + + if service.is_empty() { + println!( + "{} Deleted exposed sandbox {}", + "✓".green().bold(), + sandbox.bold(), + ); + } else { + println!( + "{} Deleted service {} on sandbox {}", + "✓".green().bold(), + service.bold(), + sandbox.bold(), + ); + } + Ok(()) +} + +fn service_status_error(action: &str, required_scope: &str, status: Status) -> miette::Report { + let message = status.message(); + match status.code() { + Code::PermissionDenied => { + miette!("{action} failed: permission denied (requires {required_scope})") + } + Code::Unauthenticated => miette!("{action} failed: authentication required"), + Code::NotFound if message == "sandbox not found" => { + miette!("{action} failed: sandbox not found") + } + Code::NotFound if message == "service endpoint not found" => { + miette!("{action} failed: service endpoint not found") } + Code::InvalidArgument if !message.is_empty() => { + miette!("{action} failed: invalid request: {message}") + } + _ => miette!("{action} failed: {status}"), + } +} - map.insert(key.to_string(), value.to_string()); +fn print_service_endpoint_table(services: &[ServiceEndpointResponse], gateway_endpoint: &str) { + let rows = services + .iter() + .filter_map(|response| { + let endpoint = response.endpoint.as_ref()?; + let service = service_display_name(&endpoint.service_name).to_string(); + let target = format!("127.0.0.1:{}", endpoint.target_port); + let url = if response.url.is_empty() { + String::new() + } else { + service_url_for_gateway(&response.url, gateway_endpoint) + }; + Some((endpoint.sandbox_name.clone(), service, target, url)) + }) + .collect::>(); + + if rows.is_empty() { + return; } - Ok(map) + let sandbox_width = rows + .iter() + .map(|(sandbox, _, _, _)| sandbox.len()) + .max() + .unwrap_or(7) + .max(7); + let service_width = rows + .iter() + .map(|(_, service, _, _)| service.len()) + .max() + .unwrap_or(7) + .max(7); + let target_width = rows + .iter() + .map(|(_, _, target, _)| target.len()) + .max() + .unwrap_or(6) + .max(6); + + println!( + "{: Result> { - let mut map = HashMap::new(); +fn service_display_name(service: &str) -> &str { + if service.is_empty() { "-" } else { service } +} - for item in items { - if let Some((key, value)) = item.split_once('=') { - let key = key.trim(); - if key.is_empty() { - return Err(miette::miette!("--credential key cannot be empty")); - } - map.insert(key.to_string(), value.to_string()); - continue; - } +fn service_url_for_gateway(service_url: &str, gateway_endpoint: &str) -> String { + let (Ok(mut service_url), Ok(gateway_endpoint)) = ( + url::Url::parse(service_url), + url::Url::parse(gateway_endpoint), + ) else { + return service_url.to_string(); + }; - let key = item.trim(); - if key.is_empty() { - return Err(miette::miette!("--credential key cannot be empty")); - } + if service_url + .set_port(gateway_endpoint.port_or_known_default()) + .is_err() + { + return service_url.to_string(); + } - let value = std::env::var(key).map_err(|_| { - miette::miette!( - "--credential {key} requires local env var '{key}' to be set to a non-empty value" - ) + service_url.to_string() +} + +async fn gateway_providers_v2_enabled(client: &mut crate::tls::GrpcClient) -> Result { + let response = client + .get_gateway_config(GetGatewayConfigRequest {}) + .await + .into_diagnostic()? + .into_inner(); + let Some(setting) = response.settings.get(settings::PROVIDERS_V2_ENABLED_KEY) else { + return Ok(false); + }; + match setting.value.as_ref() { + Some(setting_value::Value::BoolValue(enabled)) => Ok(*enabled), + None => Ok(false), + Some(_) => Err(miette::miette!( + "gateway setting '{}' has invalid value type; expected bool", + settings::PROVIDERS_V2_ENABLED_KEY + )), + } +} + +async fn fetch_provider_profile( + client: &mut crate::tls::GrpcClient, + provider_type: &str, +) -> Result { + let response = client + .get_provider_profile(GetProviderProfileRequest { + id: provider_type.to_string(), + }) + .await + .map_err(|status| { + if status.code() == Code::NotFound { + miette::miette!( + "provider profile '{provider_type}' not found; providers v2 discovery requires a provider profile" + ) + } else { + miette::miette!(status.to_string()) + } })?; - if value.trim().is_empty() { - return Err(miette::miette!( - "--credential {key} requires local env var '{key}' to be set to a non-empty value" - )); - } + response + .into_inner() + .profile + .ok_or_else(|| miette::miette!("provider profile '{provider_type}' missing from response")) +} - map.insert(key.to_string(), value); +async fn discover_existing_provider_data( + client: &mut crate::tls::GrpcClient, + provider_type: &str, +) -> Result> { + if gateway_providers_v2_enabled(client).await? { + let profile = fetch_provider_profile(client, provider_type).await?; + let profile = ProviderTypeProfile::from_proto(&profile); + discover_from_profile(&profile, &RealDiscoveryContext).map_err(|err| { + miette::miette!("failed to discover existing provider data from profile: {err}") + }) + } else { + let registry = ProviderRegistry::new(); + registry + .discover_existing(provider_type) + .map_err(|err| miette::miette!("failed to discover existing provider data: {err}")) } - - Ok(map) } pub async fn provider_create( @@ -2961,10 +4224,7 @@ pub async fn provider_create( let mut config_map = parse_key_value_pairs(config, "--config")?; if from_existing { - let registry = ProviderRegistry::new(); - let discovered = registry - .discover_existing(&provider_type) - .map_err(|err| miette::miette!("failed to discover existing provider data: {err}"))?; + let discovered = discover_existing_provider_data(&mut client, &provider_type).await?; let Some(discovered) = discovered else { return Err(miette::miette!( "no existing local credentials/config found for provider type '{provider_type}'" @@ -2980,10 +4240,16 @@ pub async fn provider_create( } if credential_map.is_empty() { - return Err(miette::miette!( - "no credentials resolved for provider type '{provider_type}'. \ - Use --credential KEY[=VALUE] or --from-existing with the appropriate env vars set." - )); + let allows_refresh_bootstrap = fetch_provider_profile(&mut client, &provider_type) + .await + .ok() + .is_some_and(|profile| provider_profile_allows_refresh_bootstrap(&profile)); + if !allows_refresh_bootstrap { + return Err(miette::miette!( + "no credentials resolved for provider type '{provider_type}'. \ + Use --credential KEY[=VALUE] or --from-existing with the appropriate env vars set." + )); + } } let response = client @@ -2994,10 +4260,12 @@ pub async fn provider_create( name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.clone(), credentials: credential_map, config: config_map, + credential_expires_at_ms: HashMap::new(), }), }) .await @@ -3016,6 +4284,30 @@ pub async fn provider_create( Ok(()) } +fn provider_profile_allows_refresh_bootstrap(profile: &ProviderProfile) -> bool { + let required_credentials = profile + .credentials + .iter() + .filter(|credential| credential.required) + .collect::>(); + !required_credentials.is_empty() + && required_credentials.iter().all(|credential| { + credential + .refresh + .as_ref() + .is_some_and(|refresh| is_gateway_mintable_refresh_strategy(refresh.strategy)) + }) +} + +fn is_gateway_mintable_refresh_strategy(strategy: i32) -> bool { + matches!( + ProviderCredentialRefreshStrategy::try_from(strategy), + Ok(ProviderCredentialRefreshStrategy::Oauth2RefreshToken + | ProviderCredentialRefreshStrategy::Oauth2ClientCredentials + | ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt) + ) +} + pub async fn provider_get(server: &str, name: &str, tls: &TlsOptions) -> Result<()> { let mut client = grpc_client(server, tls).await?; let response = client @@ -3038,6 +4330,11 @@ pub async fn provider_get(server: &str, name: &str, tls: &TlsOptions) -> Result< println!(" {} {}", "Id:".dimmed(), provider.object_id()); println!(" {} {}", "Name:".dimmed(), provider.object_name()); println!(" {} {}", "Type:".dimmed(), provider.r#type); + println!( + " {} {}", + "Resource version:".dimmed(), + provider.metadata.as_ref().map_or(0, |m| m.resource_version) + ); println!( " {} {}", "Credential keys:".dimmed(), @@ -3271,23 +4568,241 @@ pub async fn provider_profile_lint( return Err(miette!("provider profile lint failed")); } - println!("Provider profile lint passed."); - Ok(()) + println!("Provider profile lint passed."); + Ok(()) +} + +pub async fn provider_profile_delete(server: &str, id: &str, tls: &TlsOptions) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .delete_provider_profile(DeleteProviderProfileRequest { id: id.to_string() }) + .await + .into_diagnostic()? + .into_inner(); + if response.deleted { + println!("Deleted provider profile '{id}'."); + } else { + println!("Provider profile '{id}' was not deleted."); + } + Ok(()) +} + +pub async fn provider_refresh_status( + server: &str, + name: &str, + credential_key: Option<&str>, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .get_provider_refresh_status(GetProviderRefreshStatusRequest { + provider: name.to_string(), + credential_key: credential_key.unwrap_or_default().to_string(), + }) + .await + .into_diagnostic()? + .into_inner(); + + if response.credentials.is_empty() { + if let Some(credential_key) = credential_key { + println!( + "No refresh configuration found for provider '{name}' credential '{credential_key}'." + ); + } else { + println!("No refresh configurations found for provider '{name}'."); + } + return Ok(()); + } + + println!("{}", refresh_status_header()); + for status in response.credentials { + print_refresh_status_row(&status); + } + Ok(()) +} + +fn refresh_status_header() -> String { + format!( + "{:<24} {:<28} {:<28} {:<18} {:<20} {:<20} {:<20} {}", + "PROVIDER".bold(), + "CREDENTIAL_KEY".bold(), + "STRATEGY".bold(), + "STATUS".bold(), + "EXPIRES_AT".bold(), + "NEXT_REFRESH".bold(), + "LAST_REFRESH".bold(), + "LAST_ERROR".bold(), + ) +} + +pub struct ProviderRefreshConfigInput<'a> { + pub name: &'a str, + pub credential_key: &'a str, + pub strategy: &'a str, + pub material: &'a [String], + pub secret_material_keys: &'a [String], + pub credential_expires_at_ms: Option, +} + +pub async fn provider_refresh_config( + server: &str, + input: ProviderRefreshConfigInput<'_>, + tls: &TlsOptions, +) -> Result<()> { + let strategy = provider_refresh_strategy(input.strategy)?; + let material = parse_key_value_pairs(input.material, "--material")?; + let mut client = grpc_client(server, tls).await?; + let status = client + .configure_provider_refresh(ConfigureProviderRefreshRequest { + provider: input.name.to_string(), + credential_key: input.credential_key.to_string(), + strategy: strategy as i32, + material, + secret_material_keys: input.secret_material_keys.to_vec(), + expires_at_ms: input.credential_expires_at_ms, + }) + .await + .into_diagnostic()? + .into_inner() + .status + .ok_or_else(|| miette!("provider refresh status missing from response"))?; + + println!( + "{} Configured refresh for {} {}", + "✓".green().bold(), + status.provider_name, + status.credential_key + ); + Ok(()) +} + +pub async fn provider_rotate( + server: &str, + name: &str, + credential_key: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let status = client + .rotate_provider_credential(RotateProviderCredentialRequest { + provider: name.to_string(), + credential_key: credential_key.to_string(), + }) + .await + .into_diagnostic()? + .into_inner() + .status + .ok_or_else(|| miette!("provider refresh status missing from response"))?; + + if status.last_error.is_empty() { + println!( + "{} Rotation requested for {} {} ({})", + "✓".green().bold(), + status.provider_name, + status.credential_key, + status.status + ); + } else { + println!( + "Rotation request recorded for {} {} ({}): {}", + status.provider_name, status.credential_key, status.status, status.last_error + ); + } + Ok(()) +} + +pub async fn provider_refresh_delete( + server: &str, + name: &str, + credential_key: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .delete_provider_refresh(DeleteProviderRefreshRequest { + provider: name.to_string(), + credential_key: credential_key.to_string(), + }) + .await + .into_diagnostic()? + .into_inner(); + + if response.deleted { + println!( + "{} Deleted refresh config for {} {}", + "✓".green().bold(), + name, + credential_key + ); + } else { + println!("No refresh config found for provider '{name}' credential '{credential_key}'."); + } + Ok(()) +} + +fn provider_refresh_strategy(strategy: &str) -> Result { + match strategy { + "oauth2_refresh_token" => Ok(ProviderCredentialRefreshStrategy::Oauth2RefreshToken), + "oauth2_client_credentials" => { + Ok(ProviderCredentialRefreshStrategy::Oauth2ClientCredentials) + } + "google_service_account_jwt" => { + Ok(ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt) + } + _ => Err(miette!("unsupported provider refresh strategy: {strategy}")), + } +} + +fn print_refresh_status_row(status: &ProviderCredentialRefreshStatus) { + println!("{}", refresh_status_row(status)); } -pub async fn provider_profile_delete(server: &str, id: &str, tls: &TlsOptions) -> Result<()> { - let mut client = grpc_client(server, tls).await?; - let response = client - .delete_provider_profile(DeleteProviderProfileRequest { id: id.to_string() }) - .await - .into_diagnostic()? - .into_inner(); - if response.deleted { - println!("Deleted provider profile '{id}'."); +fn refresh_status_row(status: &ProviderCredentialRefreshStatus) -> String { + let strategy = ProviderCredentialRefreshStrategy::try_from(status.strategy) + .unwrap_or(ProviderCredentialRefreshStrategy::Unspecified); + format!( + "{:<24} {:<28} {:<28} {:<18} {:<20} {:<20} {:<20} {}", + status.provider_name, + status.credential_key, + provider_refresh_strategy_name(strategy), + status.status, + format_optional_epoch_ms(status.expires_at_ms), + format_optional_epoch_ms(status.next_refresh_at_ms), + format_optional_epoch_ms(status.last_refresh_at_ms), + truncate_status_field(&status.last_error, 72), + ) +} + +fn format_optional_epoch_ms(ms: i64) -> String { + if ms > 0 { + format_epoch_ms(ms) } else { - println!("Provider profile '{id}' was not deleted."); + "-".to_string() + } +} + +fn truncate_status_field(value: &str, max_chars: usize) -> String { + if value.is_empty() { + return "-".to_string(); + } + let mut chars = value.chars(); + let truncated = chars.by_ref().take(max_chars).collect::(); + if chars.next().is_some() { + format!("{truncated}...") + } else { + truncated + } +} + +fn provider_refresh_strategy_name(strategy: ProviderCredentialRefreshStrategy) -> &'static str { + match strategy { + ProviderCredentialRefreshStrategy::Static => "static", + ProviderCredentialRefreshStrategy::External => "external", + ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token", + ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials", + ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt", + ProviderCredentialRefreshStrategy::Unspecified => "unspecified", } - Ok(()) } fn load_profile_import_items( @@ -3442,6 +4957,7 @@ pub async fn provider_update( from_existing: bool, credentials: &[String], config: &[String], + credential_expires_at: &[String], tls: &TlsOptions, ) -> Result<()> { if from_existing && !credentials.is_empty() { @@ -3454,6 +4970,7 @@ pub async fn provider_update( let mut credential_map = parse_credential_pairs(credentials)?; let mut config_map = parse_key_value_pairs(config, "--config")?; + let credential_expires_at_ms = parse_credential_expiry_pairs(credential_expires_at)?; if from_existing { // Fetch the existing provider to discover its type for credential lookup. @@ -3468,10 +4985,7 @@ pub async fn provider_update( .ok_or_else(|| miette::miette!("provider '{name}' not found"))?; let provider_type = existing.r#type; - let registry = ProviderRegistry::new(); - let discovered = registry - .discover_existing(&provider_type) - .map_err(|err| miette::miette!("failed to discover existing provider data: {err}"))?; + let discovered = discover_existing_provider_data(&mut client, &provider_type).await?; let Some(discovered) = discovered else { return Err(miette::miette!( "no existing local credentials/config found for provider type '{provider_type}'" @@ -3494,11 +5008,14 @@ pub async fn provider_update( name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: credential_map, config: config_map, + credential_expires_at_ms: HashMap::new(), }), + credential_expires_at_ms, }) .await .into_diagnostic()?; @@ -4048,6 +5565,7 @@ pub async fn sandbox_policy_set_global( delete_setting: false, global: true, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -4246,6 +5764,7 @@ pub async fn gateway_setting_set( delete_setting: false, global: true, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -4280,6 +5799,7 @@ pub async fn sandbox_setting_set( delete_setting: false, global: false, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -4314,6 +5834,7 @@ pub async fn gateway_setting_delete( delete_setting: true, global: true, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -4348,6 +5869,7 @@ pub async fn sandbox_setting_delete( delete_setting: true, global: false, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -4406,6 +5928,7 @@ pub async fn sandbox_policy_set( delete_setting: false, global: false, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()?; @@ -4580,6 +6103,7 @@ pub async fn sandbox_policy_update( delete_setting: false, global: false, merge_operations: plan.merge_operations, + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -4672,8 +6196,49 @@ pub async fn sandbox_policy_get( name: &str, version: u32, full: bool, + output: &str, tls: &TlsOptions, ) -> Result<()> { + let mut stdout = Vec::new(); + let mut stderr = Vec::new(); + sandbox_policy_get_to_writer( + server, + name, + version, + full, + output, + tls, + (&mut stdout, &mut stderr), + ) + .await?; + + { + let mut terminal_stdout = std::io::stdout().lock(); + terminal_stdout.write_all(&stdout).into_diagnostic()?; + } + { + let mut terminal_stderr = std::io::stderr().lock(); + terminal_stderr.write_all(&stderr).into_diagnostic()?; + } + + Ok(()) +} + +#[doc(hidden)] +pub async fn sandbox_policy_get_to_writer( + server: &str, + name: &str, + version: u32, + full: bool, + output: &str, + tls: &TlsOptions, + writers: (&mut W, &mut E), +) -> Result<()> +where + W: Write + Send, + E: Write + Send, +{ + let (stdout, stderr) = writers; let mut client = grpc_client(server, tls).await?; let status_resp = client @@ -4688,32 +6253,55 @@ pub async fn sandbox_policy_get( let inner = status_resp.into_inner(); if let Some(rev) = inner.revision { let status = PolicyStatus::try_from(rev.status).unwrap_or(PolicyStatus::Unspecified); - println!("Version: {}", rev.version); - println!("Hash: {}", rev.policy_hash); - println!("Status: {status:?}"); - println!("Active: {}", inner.active_version); + match output { + "json" => { + let obj = policy_revision_to_json( + "sandbox", + Some(name), + Some(inner.active_version), + &rev, + status, + full, + )?; + writeln!( + stdout, + "{}", + serde_json::to_string_pretty(&obj).into_diagnostic()? + ) + .into_diagnostic()?; + return Ok(()); + } + "table" => {} + _ => return Err(miette!("unsupported output format: {output}")), + } + + writeln!(stdout, "Version: {}", rev.version).into_diagnostic()?; + writeln!(stdout, "Hash: {}", rev.policy_hash).into_diagnostic()?; + writeln!(stdout, "Status: {status:?}").into_diagnostic()?; + writeln!(stdout, "Active: {}", inner.active_version).into_diagnostic()?; if rev.created_at_ms > 0 { - println!("Created: {} ms", rev.created_at_ms); + writeln!(stdout, "Created: {} ms", rev.created_at_ms).into_diagnostic()?; } if rev.loaded_at_ms > 0 { - println!("Loaded: {} ms", rev.loaded_at_ms); + writeln!(stdout, "Loaded: {} ms", rev.loaded_at_ms).into_diagnostic()?; } if !rev.load_error.is_empty() { - println!("Error: {}", rev.load_error); + writeln!(stdout, "Error: {}", rev.load_error).into_diagnostic()?; } if full { if let Some(ref policy) = rev.policy { - println!("---"); + writeln!(stdout, "---").into_diagnostic()?; let yaml_str = openshell_policy::serialize_sandbox_policy(policy) .wrap_err("failed to serialize policy to YAML")?; - print!("{yaml_str}"); + write!(stdout, "{yaml_str}").into_diagnostic()?; } else { - eprintln!("Policy payload not available for this version"); + writeln!(stderr, "Policy payload not available for this version") + .into_diagnostic()?; } } } else { - eprintln!("No policy history found for sandbox '{name}'"); + writeln!(stderr, "No policy history found for sandbox '{name}'").into_diagnostic()?; } Ok(()) @@ -4723,6 +6311,7 @@ pub async fn sandbox_policy_get_global( server: &str, version: u32, full: bool, + output: &str, tls: &TlsOptions, ) -> Result<()> { let mut client = grpc_client(server, tls).await?; @@ -4739,6 +6328,16 @@ pub async fn sandbox_policy_get_global( let inner = status_resp.into_inner(); if let Some(rev) = inner.revision { let status = PolicyStatus::try_from(rev.status).unwrap_or(PolicyStatus::Unspecified); + match output { + "json" => { + let obj = policy_revision_to_json("global", None, None, &rev, status, full)?; + println!("{}", serde_json::to_string_pretty(&obj).into_diagnostic()?); + return Ok(()); + } + "table" => {} + _ => return Err(miette!("unsupported output format: {output}")), + } + println!("Scope: global"); println!("Version: {}", rev.version); println!("Hash: {}", rev.policy_hash); @@ -4767,6 +6366,66 @@ pub async fn sandbox_policy_get_global( Ok(()) } +fn policy_status_json_name(status: PolicyStatus) -> &'static str { + match status { + PolicyStatus::Unspecified => "unspecified", + PolicyStatus::Pending => "pending", + PolicyStatus::Loaded => "loaded", + PolicyStatus::Failed => "failed", + PolicyStatus::Superseded => "superseded", + } +} + +fn policy_revision_to_json( + scope: &str, + sandbox: Option<&str>, + active_version: Option, + rev: &openshell_core::proto::SandboxPolicyRevision, + status: PolicyStatus, + full: bool, +) -> Result { + let mut obj = serde_json::Map::new(); + obj.insert("scope".to_string(), serde_json::json!(scope)); + if let Some(sandbox) = sandbox { + obj.insert("sandbox".to_string(), serde_json::json!(sandbox)); + } + obj.insert("version".to_string(), serde_json::json!(rev.version)); + obj.insert("hash".to_string(), serde_json::json!(rev.policy_hash)); + obj.insert( + "status".to_string(), + serde_json::json!(policy_status_json_name(status)), + ); + if let Some(active_version) = active_version { + obj.insert( + "active_version".to_string(), + serde_json::json!(active_version), + ); + } + if rev.created_at_ms > 0 { + obj.insert( + "created_at_ms".to_string(), + serde_json::json!(rev.created_at_ms), + ); + } + if rev.loaded_at_ms > 0 { + obj.insert( + "loaded_at_ms".to_string(), + serde_json::json!(rev.loaded_at_ms), + ); + } + if !rev.load_error.is_empty() { + obj.insert("load_error".to_string(), serde_json::json!(rev.load_error)); + } + if full { + let policy = match rev.policy.as_ref() { + Some(policy) => openshell_policy::sandbox_policy_to_json_value(policy)?, + None => serde_json::Value::Null, + }; + obj.insert("policy".to_string(), policy); + } + Ok(serde_json::Value::Object(obj)) +} + pub async fn sandbox_policy_list( server: &str, name: &str, @@ -5220,17 +6879,57 @@ pub async fn sandbox_draft_history(server: &str, name: &str, tls: &TlsOptions) - fn format_endpoints(rule: &openshell_core::proto::NetworkPolicyRule) -> String { rule.endpoints .iter() - .map(|e| { - if e.port > 0 { - format!("{}:{}", e.host, e.port) - } else { - e.host.clone() - } - }) + .map(format_endpoint) .collect::>() .join(", ") } +/// Render an endpoint as `host:port [layer, …allows…, …denies…]` so a reader +/// can tell L4-only access apart from a method/path-scoped L7 grant. The L7 +/// fields (`protocol: rest`, `rules`, `access`) materially change what gets +/// allowed; surfacing them in the default text output is what makes +/// `openshell rule get` useful for approval review. +fn format_endpoint(endpoint: &openshell_core::proto::NetworkEndpoint) -> String { + let host_port = if endpoint.port > 0 { + format!("{}:{}", endpoint.host, endpoint.port) + } else { + endpoint.host.clone() + }; + + let mut tags: Vec = Vec::new(); + let layer_tag = if endpoint.protocol.eq_ignore_ascii_case("rest") { + "L7 rest" + } else if endpoint.protocol.is_empty() { + "L4" + } else { + endpoint.protocol.as_str() + }; + tags.push(layer_tag.to_string()); + + if !endpoint.access.is_empty() { + tags.push(format!("access={}", endpoint.access)); + } + + for r in &endpoint.rules { + if let Some(allow) = &r.allow { + let method = non_empty_or(&allow.method, "*"); + let path = non_empty_or(&allow.path, "*"); + tags.push(format!("allow {method} {path}")); + } + } + for r in &endpoint.deny_rules { + let method = non_empty_or(&r.method, "*"); + let path = non_empty_or(&r.path, "*"); + tags.push(format!("deny {method} {path}")); + } + + format!("{host_port} [{}]", tags.join(", ")) +} + +fn non_empty_or<'a>(value: &'a str, fallback: &'a str) -> &'a str { + if value.is_empty() { fallback } else { value } +} + /// Format a millisecond timestamp into a readable string. fn format_timestamp_ms(ms: i64) -> String { if ms <= 0 { @@ -5250,12 +6949,17 @@ fn format_timestamp_ms(ms: i64) -> String { #[cfg(test)] mod tests { use super::{ - TlsOptions, dockerfile_sources_supported_for_gateway, format_gateway_select_header, - format_gateway_select_items, gateway_add, gateway_auth_label, gateway_env_override_warning, - gateway_select_with, gateway_type_label, git_sync_files, http_health_check, - image_requests_gpu, inferred_provider_type, parse_cli_setting_value, - parse_credential_pairs, plaintext_gateway_is_remote, provisioning_timeout_message, - ready_false_condition_message, resolve_from, sandbox_should_persist, + ProvisioningStep, TlsOptions, build_sandbox_resource_limits, + dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header, + format_gateway_select_items, format_provider_attachment_table, gateway_add, + gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, + git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, + inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, + parse_credential_expiry_cli_value, parse_credential_expiry_pairs, parse_credential_pairs, + plaintext_gateway_is_remote, progress_step_from_metadata, + provider_profile_allows_refresh_bootstrap, provisioning_timeout_message, + ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from, + sandbox_should_persist, service_expose_status_error, service_url_for_gateway, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -5263,12 +6967,21 @@ mod tests { use std::fs; use std::io::{Read, Write}; use std::net::TcpListener; - use std::path::Path; + use std::path::{Path, PathBuf}; use std::process::Command; use std::thread; + use tonic::Status; use openshell_bootstrap::GatewayMetadata; - use openshell_core::proto::{SandboxCondition, SandboxStatus}; + use openshell_core::progress::{ + PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, + PROGRESS_STEP_STARTING_SANDBOX, + }; + use openshell_core::proto::{ + Provider, ProviderCredentialRefresh, ProviderCredentialRefreshStatus, + ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileCredential, + SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta, + }; struct EnvVarGuard { key: &'static str, @@ -5371,6 +7084,187 @@ mod tests { )); } + #[test] + fn parse_credential_expiry_pairs_accepts_epoch_millis_and_rfc3339() { + let parsed = parse_credential_expiry_pairs(&[ + "API_TOKEN=1767225600000".to_string(), + "MS_GRAPH_ACCESS_TOKEN=2026-01-01T00:00:00Z".to_string(), + ]) + .expect("parse"); + + assert_eq!(parsed.get("API_TOKEN"), Some(&1_767_225_600_000)); + assert_eq!( + parsed.get("MS_GRAPH_ACCESS_TOKEN"), + Some(&1_767_225_600_000) + ); + } + + #[test] + fn parse_credential_expiry_pairs_accepts_zero_to_clear_expiry() { + let parsed = + parse_credential_expiry_pairs(&["API_TOKEN=0".to_string()]).expect("parse zero"); + + assert_eq!(parsed.get("API_TOKEN"), Some(&0)); + } + + #[test] + fn parse_credential_expiry_rejects_invalid_timestamp() { + let err = parse_credential_expiry_pairs(&["API_TOKEN=next-week".to_string()]) + .expect_err("invalid timestamp should error"); + + assert!( + err.to_string() + .contains("must be a Unix epoch millisecond timestamp or RFC3339 timestamp") + ); + } + + #[test] + fn parse_credential_expiry_cli_value_accepts_rfc3339_offsets() { + let parsed = parse_credential_expiry_cli_value("2026-01-01T01:00:00+01:00") + .expect("parse RFC3339 with offset"); + + assert_eq!(parsed, 1_767_225_600_000); + } + + #[test] + fn provider_attachment_table_formats_provider_counts() { + let output = format_provider_attachment_table( + &[Provider { + metadata: Some(ObjectMeta { + name: "work-custom".to_string(), + ..Default::default() + }), + r#type: "custom-api".to_string(), + credentials: [ + ("CUSTOM_API_KEY".to_string(), "REDACTED".to_string()), + ("CUSTOM_API_SECRET".to_string(), "REDACTED".to_string()), + ] + .into_iter() + .collect(), + config: std::iter::once(( + "BASE_URL".to_string(), + "https://api.custom.example".to_string(), + )) + .collect(), + credential_expires_at_ms: std::collections::HashMap::new(), + }], + false, + ); + + assert!(output.contains("NAME")); + assert!(output.contains("TYPE")); + assert!(output.contains("CREDENTIAL_KEYS")); + assert!(output.contains("CONFIG_KEYS")); + assert!(output.contains("work-custom")); + assert!(output.contains("custom-api")); + assert!(output.contains('2')); + assert!(output.contains('1')); + } + + #[test] + fn progress_step_metadata_values_map_to_cli_steps() { + assert_eq!( + progress_step_from_metadata(PROGRESS_STEP_REQUESTING_SANDBOX), + Some(ProvisioningStep::RequestingSandbox) + ); + assert_eq!( + progress_step_from_metadata(PROGRESS_STEP_PULLING_IMAGE), + Some(ProvisioningStep::PullingSandboxImage) + ); + assert_eq!( + progress_step_from_metadata(PROGRESS_STEP_STARTING_SANDBOX), + Some(ProvisioningStep::StartingSandbox) + ); + assert_eq!(progress_step_from_metadata("driver-private-step"), None); + } + + #[test] + fn refresh_status_table_includes_operational_fields() { + let header = refresh_status_header(); + assert!(header.contains("NEXT_REFRESH")); + assert!(header.contains("LAST_REFRESH")); + assert!(header.contains("LAST_ERROR")); + + let row = refresh_status_row(&ProviderCredentialRefreshStatus { + provider_name: "my-graph".to_string(), + provider_id: "provider-id".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + status: "error".to_string(), + expires_at_ms: 1_767_225_600_000, + next_refresh_at_ms: 1_767_225_660_000, + last_refresh_at_ms: 1_767_225_000_000, + last_error: "token endpoint returned a very long error message that should be truncated for table readability" + .to_string(), + }); + + assert!(row.contains("my-graph")); + assert!(row.contains("MS_GRAPH_ACCESS_TOKEN")); + assert!(row.contains("oauth2_client_credentials")); + assert!(row.contains("error")); + assert!(row.contains("2026-01-01 00:00:00")); + assert!(row.contains("...")); + } + + #[test] + fn refresh_bootstrap_requires_all_required_credentials_to_be_gateway_mintable() { + let refresh_token_profile = ProviderProfile { + credentials: vec![ProviderProfileCredential { + name: "MS_GRAPH_ACCESS_TOKEN".to_string(), + required: true, + refresh: Some(ProviderCredentialRefresh { + strategy: ProviderCredentialRefreshStrategy::Oauth2RefreshToken as i32, + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }; + assert!(provider_profile_allows_refresh_bootstrap( + &refresh_token_profile + )); + + let mixed_static_profile = ProviderProfile { + credentials: vec![ + ProviderProfileCredential { + name: "ACCESS_TOKEN".to_string(), + required: true, + refresh: Some(ProviderCredentialRefresh { + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + ..Default::default() + }), + ..Default::default() + }, + ProviderProfileCredential { + name: "STATIC_API_KEY".to_string(), + required: true, + refresh: None, + ..Default::default() + }, + ], + ..Default::default() + }; + assert!(!provider_profile_allows_refresh_bootstrap( + &mixed_static_profile + )); + + let optional_refresh_profile = ProviderProfile { + credentials: vec![ProviderProfileCredential { + name: "OPTIONAL_TOKEN".to_string(), + required: false, + refresh: Some(ProviderCredentialRefresh { + strategy: ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt as i32, + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }; + assert!(!provider_profile_allows_refresh_bootstrap( + &optional_refresh_profile + )); + } + #[cfg(feature = "dev-settings")] #[test] fn parse_cli_setting_value_parses_bool_aliases() { @@ -5414,10 +7308,59 @@ mod tests { assert!(err.to_string().contains("unknown setting key")); } + #[test] + fn build_sandbox_resource_limits_sets_limits_only() { + let resources = build_sandbox_resource_limits(Some("500m"), Some("2Gi")) + .expect("resource limits should parse") + .expect("resource limits should be present"); + + let limits = resources + .fields + .get("limits") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StructValue(inner) => Some(inner), + _ => None, + }) + .expect("limits should be a struct"); + + assert_eq!( + limits + .fields + .get("cpu") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StringValue(value) => Some(value.as_str()), + _ => None, + }), + Some("500m") + ); + assert_eq!( + limits + .fields + .get("memory") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StringValue(value) => Some(value.as_str()), + _ => None, + }), + Some("2Gi") + ); + assert!(!resources.fields.contains_key("requests")); + } + + #[test] + fn build_sandbox_resource_limits_rejects_invalid_quantities() { + assert!(build_sandbox_resource_limits(Some("0"), None).is_err()); + assert!(build_sandbox_resource_limits(Some("half"), None).is_err()); + assert!(build_sandbox_resource_limits(None, Some("0Gi")).is_err()); + assert!(build_sandbox_resource_limits(None, Some("1.5Gi")).is_err()); + } + #[test] fn inferred_provider_type_returns_type_for_known_command() { let result = inferred_provider_type(&["claude".to_string(), "--help".to_string()]); - assert_eq!(result, Some("claude".to_string())); + assert_eq!(result, Some("claude-code".to_string())); } #[test] @@ -5446,7 +7389,7 @@ mod tests { #[test] fn inferred_provider_type_handles_full_path() { let result = inferred_provider_type(&["/usr/local/bin/claude".to_string()]); - assert_eq!(result, Some("claude".to_string())); + assert_eq!(result, Some("claude-code".to_string())); } #[test] @@ -5484,7 +7427,7 @@ mod tests { for image in [ "ghcr.io/nvidia/openshell-community/sandboxes/base:latest", "registry.example.com/gpu/team/base:latest", - "registry.example.com/team/openclaw:latest", + "registry.example.com/team/notebook:latest", "cuda-toolkit:latest", "registry.example.com/team/graphics:latest", ] { @@ -5592,6 +7535,62 @@ mod tests { assert!(dockerfile_sources_supported_for_gateway(None)); } + #[test] + fn service_url_for_gateway_uses_external_gateway_port() { + assert_eq!( + service_url_for_gateway( + "https://quiet-flamingo--notebook.navigator.openshell.localhost:8080/", + "https://127.0.0.1:31886" + ), + "https://quiet-flamingo--notebook.navigator.openshell.localhost:31886/" + ); + } + + #[test] + fn service_url_for_gateway_omits_default_external_port() { + assert_eq!( + service_url_for_gateway( + "https://quiet-flamingo--notebook.navigator.openshell.localhost:8080/", + "https://gateway.example.com" + ), + "https://quiet-flamingo--notebook.navigator.openshell.localhost/" + ); + } + + #[test] + fn service_url_for_gateway_preserves_service_scheme() { + assert_eq!( + service_url_for_gateway( + "http://quiet-flamingo--notebook.navigator.openshell.localhost:8080/", + "https://127.0.0.1:31886" + ), + "http://quiet-flamingo--notebook.navigator.openshell.localhost:31886/" + ); + } + + #[test] + fn service_url_for_gateway_uses_gateway_default_port() { + assert_eq!( + service_url_for_gateway( + "http://quiet-flamingo--notebook.navigator.openshell.localhost:8080/", + "https://gateway.example.com" + ), + "http://quiet-flamingo--notebook.navigator.openshell.localhost:443/" + ); + } + + #[test] + fn service_expose_status_error_mentions_required_scope() { + let report = service_expose_status_error(Status::permission_denied( + "scope 'sandbox:write' required", + )); + + assert_eq!( + report.to_string(), + "expose service failed: permission denied (requires sandbox:write)" + ); + } + #[test] fn ready_false_condition_message_prefers_reason_and_message() { let status = SandboxStatus { @@ -5848,6 +7847,52 @@ mod tests { assert_eq!(gateway_auth_label(&gateway), "mtls"); } + #[test] + fn package_managed_tls_dirs_respects_override() { + let _guard = TEST_ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _tls_dir = EnvVarGuard::set("OPENSHELL_LOCAL_TLS_DIR", "/tmp/openshell-test-tls"); + + assert_eq!( + package_managed_tls_dirs(), + vec![PathBuf::from("/tmp/openshell-test-tls")], + ); + } + + #[test] + fn import_local_package_mtls_bundle_copies_client_materials() { + let tmpdir = tempfile::tempdir().expect("create tmpdir"); + let package_tls = tmpdir.path().join("package-tls"); + fs::create_dir_all(package_tls.join("client")).expect("create package tls dir"); + fs::write(package_tls.join("ca.crt"), "ca").expect("write ca"); + fs::write(package_tls.join("client/tls.crt"), "client cert").expect("write cert"); + fs::write(package_tls.join("client/tls.key"), "client key").expect("write key"); + + with_tmp_xdg(tmpdir.path(), || { + let _tls_dir = EnvVarGuard::set( + "OPENSHELL_LOCAL_TLS_DIR", + package_tls.to_str().expect("temp path should be utf-8"), + ); + + let imported = + import_local_package_mtls_bundle("openshell").expect("import local bundle"); + + assert_eq!(imported.as_deref(), Some(package_tls.as_path())); + + let mtls = tmpdir.path().join("openshell/gateways/openshell/mtls"); + assert_eq!(fs::read_to_string(mtls.join("ca.crt")).unwrap(), "ca"); + assert_eq!( + fs::read_to_string(mtls.join("tls.crt")).unwrap(), + "client cert", + ); + assert_eq!( + fs::read_to_string(mtls.join("tls.key")).unwrap(), + "client key", + ); + }); + } + #[test] fn plaintext_gateway_locality_infers_loopback_endpoints_as_local() { assert!(!plaintext_gateway_is_remote( @@ -5897,13 +7942,14 @@ mod tests { "openshell-cli", None, None, + false, ) .await .expect("register plaintext gateway"); }); // Loopback endpoints derive the canonical "openshell" gateway - // name, matching init-pki.sh and default_tls_dir conventions. + // name, matching local cert generation and default_tls_dir conventions. let metadata = load_gateway_metadata("openshell").expect("load stored gateway"); assert_eq!(metadata.auth_mode.as_deref(), Some("plaintext")); assert!(!metadata.is_remote); @@ -5928,6 +7974,7 @@ mod tests { "openshell-cli", None, None, + false, ) .await .expect("register plaintext gateway"); @@ -5969,4 +8016,50 @@ mod tests { server.join().expect("server thread"); assert_eq!(status, Some(StatusCode::OK)); } + #[test] + fn format_endpoint_distinguishes_l4_from_l7_rest() { + use openshell_core::proto::{L7Allow, L7DenyRule, L7Rule, NetworkEndpoint}; + + let l4 = NetworkEndpoint { + host: "host.example.test".to_string(), + port: 443, + ..Default::default() + }; + assert_eq!(format_endpoint(&l4), "host.example.test:443 [L4]"); + + let l7_readonly = NetworkEndpoint { + host: "host.example.test".to_string(), + port: 443, + protocol: "rest".to_string(), + access: "read-only".to_string(), + ..Default::default() + }; + assert_eq!( + format_endpoint(&l7_readonly), + "host.example.test:443 [L7 rest, access=read-only]" + ); + + let l7_scoped = NetworkEndpoint { + host: "host.example.test".to_string(), + port: 443, + protocol: "rest".to_string(), + rules: vec![L7Rule { + allow: Some(L7Allow { + method: "PUT".to_string(), + path: "/v1/example/resource".to_string(), + ..Default::default() + }), + }], + deny_rules: vec![L7DenyRule { + method: "DELETE".to_string(), + path: "/v1/example/resource".to_string(), + ..Default::default() + }], + ..Default::default() + }; + assert_eq!( + format_endpoint(&l7_scoped), + "host.example.test:443 [L7 rest, allow PUT /v1/example/resource, deny DELETE /v1/example/resource]" + ); + } } diff --git a/crates/openshell-cli/src/ssh.rs b/crates/openshell-cli/src/ssh.rs index 89e9071e1..204128d34 100644 --- a/crates/openshell-cli/src/ssh.rs +++ b/crates/openshell-cli/src/ssh.rs @@ -3,30 +3,30 @@ //! SSH connection and proxy utilities. -use crate::tls::{TlsOptions, build_rustls_config, grpc_client, require_tls_materials}; +use crate::tls::{TlsOptions, grpc_client}; use miette::{IntoDiagnostic, Result, WrapErr}; #[cfg(unix)] use nix::sys::signal::{SaFlags, SigAction, SigHandler, SigSet, Signal, sigaction}; use openshell_core::ObjectId; use openshell_core::forward::{ - build_proxy_command, find_ssh_forward_pid, resolve_ssh_gateway, shell_escape, - validate_ssh_session_response, write_forward_pid, + build_proxy_command, find_ssh_forward_pid, format_gateway_url, resolve_ssh_gateway, + shell_escape, validate_ssh_session_response, write_forward_pid, +}; +use openshell_core::proto::{ + CreateSshSessionRequest, GetSandboxRequest, SshRelayTarget, TcpForwardFrame, TcpForwardInit, + tcp_forward_init, }; -use openshell_core::proto::{CreateSshSessionRequest, GetSandboxRequest}; use owo_colors::OwoColorize; -use rustls::pki_types::ServerName; use std::fs; use std::io::{IsTerminal, Write}; #[cfg(unix)] use std::os::unix::process::CommandExt; use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; -use std::sync::Arc; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; -use tokio::net::TcpStream; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::process::Command as TokioCommand; -use tokio_rustls::TlsConnector; +use tokio_stream::wrappers::ReceiverStream; const FOREGROUND_FORWARD_STARTUP_GRACE_PERIOD: Duration = Duration::from_secs(2); @@ -100,8 +100,7 @@ async fn ssh_session_config( // external tunnel endpoint (the cluster URL), not the server's internal // scheme/host/port which may be plaintext HTTP on 127.0.0.1. let gateway_url = if tls.is_bearer_auth() { - let base = server.trim_end_matches('/'); - format!("{base}{}", session.connect_path) + server.trim_end_matches('/').to_string() } else { // If the server returned a loopback gateway address, override it with the // cluster endpoint's host. This handles the case where the server defaults @@ -110,10 +109,7 @@ async fn ssh_session_config( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, server); - format!( - "{}://{}:{}{}", - session.gateway_scheme, gateway_host, gateway_port, session.connect_path - ) + format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port) }; let gateway_name = tls .gateway_name() @@ -609,6 +605,129 @@ fn split_sandbox_path(path: &str) -> (&str, &str) { } } +/// Writable root inside every sandbox. Used as the boundary for path-traversal +/// checks on sandbox-side source paths in download flows. +const SANDBOX_WORKSPACE_ROOT: &str = "/sandbox"; + +/// Lexically clean a POSIX-style absolute path by resolving `.` and `..` +/// components, collapsing repeated separators, and stripping any trailing +/// slash. Returns `None` if the input is empty or relative — the caller is +/// expected to reject those before reaching this helper. +/// +/// This is *lexical* only: it does not consult the filesystem and so cannot +/// follow symlinks. That trade-off is intentional — the function is used +/// client-side to refuse obvious path-traversal attempts before issuing the +/// SSH command. Symlink-based escapes inside the sandbox must be addressed +/// server-side. +fn lexical_clean_absolute_path(path: &str) -> Option { + if !path.starts_with('/') { + return None; + } + let mut stack: Vec<&str> = Vec::new(); + for component in path.split('/') { + match component { + "" | "." => {} + ".." => { + stack.pop(); + } + other => stack.push(other), + } + } + if stack.is_empty() { + return Some("/".to_string()); + } + let mut out = String::with_capacity(path.len()); + for component in stack { + out.push('/'); + out.push_str(component); + } + Some(out) +} + +/// Validate that a sandbox-side source path passed to `sandbox download` +/// resolves under the sandbox writable root. +/// +/// Returns the cleaned, traversal-resolved path on success. Refuses any +/// path that lexically escapes `/sandbox` (e.g. `/etc/passwd`, +/// `/sandbox/../etc/passwd`) with a user-facing error. +/// +/// This is a lexical guard only — it does not follow symlinks. Call +/// `resolve_sandbox_source_path` after this on any path that will be passed +/// to a subsequent SSH I/O operation, so a symlink such as +/// `/sandbox/etc-link -> /etc` cannot leak files outside the workspace. +fn validate_sandbox_source_path(path: &str) -> Result { + if path.is_empty() { + return Err(miette::miette!("sandbox source path is empty")); + } + let cleaned = lexical_clean_absolute_path(path) + .ok_or_else(|| miette::miette!("sandbox source path must be absolute (got '{path}')"))?; + if !is_under_sandbox_workspace(&cleaned) { + return Err(miette::miette!( + "sandbox source path '{path}' is outside the sandbox workspace ({SANDBOX_WORKSPACE_ROOT})" + )); + } + Ok(cleaned) +} + +/// Pure helper: is `path` equal to `/sandbox` or a descendant of it? +fn is_under_sandbox_workspace(path: &str) -> bool { + path == SANDBOX_WORKSPACE_ROOT || path.starts_with(&format!("{SANDBOX_WORKSPACE_ROOT}/")) +} + +/// Resolve every symlink in `sandbox_path` on the sandbox side and refuse the +/// result if it lands outside `/sandbox`. +/// +/// The lexical guard in `validate_sandbox_source_path` cannot see symlinks; a +/// path such as `/sandbox/etc-link/passwd` (where `etc-link -> /etc`) clears +/// the lexical check but would still leak `/etc/passwd` once `tar -C` follows +/// the link. Resolving symlinks on the remote side and re-validating closes +/// that gap. The returned fully-resolved path is what the caller should hand +/// to probe and tar invocations. +async fn resolve_sandbox_source_path( + session: &SshSessionConfig, + sandbox_path: &str, +) -> Result { + let resolve_cmd = format!("realpath -e -- {path}", path = shell_escape(sandbox_path)); + let resolved = ssh_run_capture_stdout(session, &resolve_cmd) + .await + .wrap_err_with(|| format!("failed to resolve sandbox source path '{sandbox_path}'"))?; + if resolved.is_empty() { + return Err(miette::miette!( + "sandbox source path '{sandbox_path}' does not exist" + )); + } + if !is_under_sandbox_workspace(&resolved) { + return Err(miette::miette!( + "sandbox source path '{sandbox_path}' resolves to '{resolved}', outside the sandbox workspace ({SANDBOX_WORKSPACE_ROOT})" + )); + } + Ok(resolved) +} + +/// Resolve the host-side target path for a downloaded *file*, following +/// `cp`-style semantics. +/// +/// - If `dest_str` ends with `/` or already exists as a directory, the file is +/// placed inside it as `/`. +/// - Otherwise `dest_str` is treated as the exact file path to write. +/// +/// `dest_exists_as_dir` is taken as a parameter (rather than queried inside) +/// so this function stays pure and unit-testable; the caller performs the +/// filesystem check. +fn resolve_file_download_target( + dest_str: &str, + source_basename: &str, + dest_exists_as_dir: bool, +) -> PathBuf { + let trailing_slash = dest_str.ends_with('/'); + let dest_path = Path::new(dest_str); + if trailing_slash || dest_exists_as_dir { + dest_path.join(source_basename) + } else { + dest_path.to_path_buf() + } +} + /// Push a list of files from a local directory into a sandbox using tar-over-SSH. /// /// Files are streamed as a tar archive to `ssh ... tar xf - -C ` on @@ -734,41 +853,106 @@ fn file_list_archive_prefix(local_path: &Path) -> Option { } } +/// Run a small command on the sandbox over SSH and capture its stdout. +/// +/// Used by the download flow to probe whether the source path is a regular +/// file or a directory before streaming the tar archive. Stderr is inherited +/// so the user still sees any diagnostic output from ssh itself. +async fn ssh_run_capture_stdout(session: &SshSessionConfig, command: &str) -> Result { + let mut ssh = ssh_base_command(&session.proxy_command); + ssh.arg("-T") + .arg("-o") + .arg("RequestTTY=no") + .arg("sandbox") + .arg(command) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()); + let output = tokio::task::spawn_blocking(move || ssh.output()) + .await + .into_diagnostic()? + .into_diagnostic()?; + if !output.status.success() { + return Err(miette::miette!( + "ssh probe exited with status {}", + output.status + )); + } + Ok(String::from_utf8_lossy(&output.stdout).trim().to_string()) +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum SandboxSourceKind { + File, + Directory, +} + +/// Probe the sandbox-side source path. The path is assumed to have already +/// been validated by `validate_sandbox_source_path`. +async fn probe_sandbox_source_kind( + session: &SshSessionConfig, + sandbox_path: &str, +) -> Result { + let probe_cmd = format!( + "if [ -d {path} ]; then printf dir; elif [ -e {path} ]; then printf file; else printf missing; fi", + path = shell_escape(sandbox_path), + ); + let kind = ssh_run_capture_stdout(session, &probe_cmd).await?; + match kind.as_str() { + "dir" => Ok(SandboxSourceKind::Directory), + "file" => Ok(SandboxSourceKind::File), + "missing" => Err(miette::miette!( + "sandbox source path '{sandbox_path}' does not exist" + )), + other => Err(miette::miette!( + "unexpected probe output for sandbox source path '{sandbox_path}': '{other}'" + )), + } +} + /// Pull a path from a sandbox to a local destination using tar-over-SSH. +/// +/// Follows `cp`-style semantics for the destination: +/// +/// - If the source is a single file: +/// - When `dest` ends with `/` or already exists as a directory on the host, +/// the file lands at `/`. +/// - Otherwise `dest` is taken to be the exact file path to write. +/// - If the source is a directory, its contents are extracted into `dest` +/// (creating `dest` if it does not yet exist). This preserves prior +/// behaviour for the directory-source case. +/// +/// The sandbox source path is also subjected to a workspace-boundary check +/// before any SSH command is issued; paths that lexically resolve outside +/// `/sandbox` are refused. pub async fn sandbox_sync_down( server: &str, name: &str, sandbox_path: &str, - local_path: &Path, + dest: &str, tls: &TlsOptions, ) -> Result<()> { + let sandbox_path = validate_sandbox_source_path(sandbox_path)?; let session = ssh_session_config(server, name, tls).await?; + let sandbox_path = resolve_sandbox_source_path(&session, &sandbox_path).await?; + let kind = probe_sandbox_source_kind(&session, &sandbox_path).await?; - // Build tar command. When the sandbox path is a directory we tar its - // *contents* (using `-C .`) so the caller gets the files directly - // without an extra wrapper directory. For a single file we split into - // the parent directory and the filename. - let sandbox_path_clean = sandbox_path.trim_end_matches('/'); - - let tar_cmd = format!( - "if [ -d {path} ]; then tar cf - -C {path} .; else tar cf - -C {parent} {name}; fi", - path = shell_escape(sandbox_path_clean), - parent = shell_escape( - sandbox_path_clean - .rfind('/') - .map_or(".", |pos| if pos == 0 { - "/" - } else { - &sandbox_path_clean[..pos] - }) - ), - name = shell_escape( - sandbox_path_clean - .rfind('/') - .map_or(sandbox_path_clean, |pos| &sandbox_path_clean[pos + 1..]) - ), - ); + match kind { + SandboxSourceKind::File => sandbox_sync_down_file(&session, &sandbox_path, dest).await, + SandboxSourceKind::Directory => { + sandbox_sync_down_directory(&session, &sandbox_path, dest).await + } + } +} +/// Stream a tar archive from the sandbox and extract it into a fresh +/// destination directory. The source is always wrapped on the sandbox side so +/// the host can pick a basename when needed. +async fn stream_sandbox_tar( + session: &SshSessionConfig, + tar_cmd: String, + extract_into: &Path, +) -> Result<()> { let mut ssh = ssh_base_command(&session.proxy_command); ssh.arg("-T") .arg("-o") @@ -785,14 +969,11 @@ pub async fn sandbox_sync_down( .take() .ok_or_else(|| miette::miette!("failed to open stdout for ssh process"))?; - let local_path = local_path.to_path_buf(); + let extract_into = extract_into.to_path_buf(); tokio::task::spawn_blocking(move || -> Result<()> { - fs::create_dir_all(&local_path) - .into_diagnostic() - .wrap_err("failed to create local destination directory")?; let mut archive = tar::Archive::new(stdout); archive - .unpack(&local_path) + .unpack(&extract_into) .into_diagnostic() .wrap_err("failed to extract tar archive from sandbox")?; Ok(()) @@ -810,10 +991,125 @@ pub async fn sandbox_sync_down( "ssh tar create exited with status {status}" )); } + Ok(()) +} + +/// Build the `tar cf - -C -- ` command used to wrap a +/// single sandbox-side file for download. +/// +/// The trailing `--` is required: a sandbox-side file whose basename starts +/// with `-` (e.g. `--checkpoint-action=...`) would otherwise be parsed by GNU +/// tar as an option rather than a member to archive. +fn build_single_file_tar_cmd(parent: &str, basename: &str) -> String { + format!( + "tar cf - -C {parent} -- {name}", + parent = shell_escape(parent), + name = shell_escape(basename), + ) +} + +async fn sandbox_sync_down_file( + session: &SshSessionConfig, + sandbox_path: &str, + dest: &str, +) -> Result<()> { + let (parent, basename) = split_sandbox_path(sandbox_path); + let dest_exists_as_dir = fs::symlink_metadata(Path::new(dest)).is_ok_and(|m| m.is_dir()); + let final_path = resolve_file_download_target(dest, basename, dest_exists_as_dir); + + let staging_parent = final_path + .parent() + .ok_or_else(|| miette::miette!("destination '{}' has no parent directory", dest))?; + fs::create_dir_all(staging_parent) + .into_diagnostic() + .wrap_err_with(|| { + format!( + "failed to create local destination directory '{}'", + staging_parent.display() + ) + })?; + + let staging = tempfile::TempDir::new_in(staging_parent) + .into_diagnostic() + .wrap_err("failed to create download staging directory")?; + + let tar_cmd = build_single_file_tar_cmd(parent, basename); + stream_sandbox_tar(session, tar_cmd, staging.path()).await?; + place_downloaded_file(staging.path(), basename, &final_path).wrap_err_with(|| { + format!( + "failed to place downloaded file at '{}'", + final_path.display() + ) + })?; Ok(()) } +/// Move a single file extracted by `stream_sandbox_tar` into its final +/// position on the host. +/// +/// `staging_dir` must contain a single regular-file entry named +/// `source_basename` (the wrapper produced by `tar cf - -C `). +/// The entry is renamed onto `final_path`, atomically when `staging_dir` is +/// on the same filesystem. Refuses to overwrite an existing directory at +/// `final_path` to match `cp` behaviour. +fn place_downloaded_file( + staging_dir: &Path, + source_basename: &str, + final_path: &Path, +) -> Result<()> { + let staged_file = staging_dir.join(source_basename); + let staged_meta = fs::symlink_metadata(&staged_file) + .into_diagnostic() + .wrap_err("downloaded archive did not contain the expected entry")?; + if !staged_meta.is_file() { + return Err(miette::miette!( + "downloaded entry '{source_basename}' is not a regular file" + )); + } + + if let Ok(existing) = fs::symlink_metadata(final_path) + && existing.is_dir() + { + return Err(miette::miette!( + "cannot overwrite directory '{}' with downloaded file", + final_path.display() + )); + } + + fs::rename(&staged_file, final_path) + .into_diagnostic() + .wrap_err_with(|| format!("failed to rename into '{}'", final_path.display()))?; + Ok(()) +} + +async fn sandbox_sync_down_directory( + session: &SshSessionConfig, + sandbox_path: &str, + dest: &str, +) -> Result<()> { + let dest_path = Path::new(dest); + if let Ok(existing) = fs::symlink_metadata(dest_path) + && !existing.is_dir() + { + return Err(miette::miette!( + "cannot extract directory '{sandbox_path}' over non-directory destination '{}'", + dest_path.display() + )); + } + fs::create_dir_all(dest_path) + .into_diagnostic() + .wrap_err_with(|| { + format!( + "failed to create local destination directory '{}'", + dest_path.display() + ) + })?; + + let tar_cmd = format!("tar cf - -C {path} .", path = shell_escape(sandbox_path)); + stream_sandbox_tar(session, tar_cmd, dest_path).await +} + /// Run the SSH proxy, connecting stdin/stdout to the gateway. pub async fn sandbox_ssh_proxy( gateway_url: &str, @@ -821,18 +1117,82 @@ pub async fn sandbox_ssh_proxy( token: &str, tls: &TlsOptions, ) -> Result<()> { - // The gateway returns 412 (Precondition Failed) when the sandbox pod - // exists but hasn't reached Ready phase yet. This is a transient state - // after sandbox allocation — retry with backoff instead of failing - // immediately. - const MAX_CONNECT_WAIT: Duration = Duration::from_secs(60); - const INITIAL_BACKOFF: Duration = Duration::from_secs(1); + let server = grpc_server_from_ssh_gateway_url(gateway_url)?; + let mut client = grpc_client(&server, tls).await?; + + let (tx, rx) = tokio::sync::mpsc::channel::(16); + tx.send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Init( + TcpForwardInit { + sandbox_id: sandbox_id.to_string(), + service_id: format!("ssh-proxy:{sandbox_id}"), + target: Some(tcp_forward_init::Target::Ssh(SshRelayTarget {})), + authorization_token: token.to_string(), + }, + )), + }) + .await + .map_err(|_| miette::miette!("failed to initialize SSH forward stream"))?; + + let mut response = client + .forward_tcp(ReceiverStream::new(rx)) + .await + .into_diagnostic()? + .into_inner(); + + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + let to_remote = tokio::spawn(async move { + let mut stdin = stdin; + let mut buf = vec![0u8; 64 * 1024]; + while let Ok(n) = stdin.read(&mut buf).await { + if n == 0 { + break; + } + if tx + .send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Data( + buf[..n].to_vec(), + )), + }) + .await + .is_err() + { + break; + } + } + }); + let from_remote = tokio::spawn(async move { + let mut stdout = stdout; + loop { + let Ok(Some(frame)) = response.message().await else { + break; + }; + let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = frame.payload + else { + continue; + }; + if data.is_empty() { + continue; + } + if stdout.write_all(&data).await.is_err() { + break; + } + let _ = stdout.flush().await; + } + }); + let _ = from_remote.await; + to_remote.abort(); + + Ok(()) +} + +fn grpc_server_from_ssh_gateway_url(gateway_url: &str) -> Result { let url: url::Url = gateway_url .parse() .into_diagnostic() .wrap_err("invalid gateway URL")?; - let scheme = url.scheme(); let gateway_host = url .host_str() @@ -840,69 +1200,7 @@ pub async fn sandbox_ssh_proxy( let gateway_port = url .port_or_known_default() .ok_or_else(|| miette::miette!("gateway URL missing port"))?; - let connect_path = url.path(); - - let request = format!( - "CONNECT {connect_path} HTTP/1.1\r\nHost: {gateway_host}\r\nX-Sandbox-Id: {sandbox_id}\r\nX-Sandbox-Token: {token}\r\n\r\n" - ); - - let start = std::time::Instant::now(); - let mut backoff = INITIAL_BACKOFF; - let mut buf_stream; - - loop { - let mut stream: Box = - connect_gateway(scheme, gateway_host, gateway_port, tls).await?; - stream - .write_all(request.as_bytes()) - .await - .into_diagnostic()?; - - // Wrap in a BufReader **before** reading the HTTP response. The gateway - // may send the 200 OK response and the first SSH protocol bytes in the - // same TCP segment / WebSocket frame. A plain `read()` would consume - // those SSH bytes into our buffer and discard them, causing SSH to see a - // truncated protocol banner and exit with code 255. BufReader ensures - // any bytes read past the `\r\n\r\n` header boundary stay buffered and - // are returned by subsequent reads during the bidirectional copy phase. - buf_stream = BufReader::new(stream); - let status = read_connect_status(&mut buf_stream).await?; - if status == 200 { - break; - } - if status == 412 && start.elapsed() < MAX_CONNECT_WAIT { - tracing::debug!( - elapsed = ?start.elapsed(), - "sandbox not yet ready (HTTP 412), retrying in {backoff:?}" - ); - tokio::time::sleep(backoff).await; - backoff = (backoff * 2).min(Duration::from_secs(8)); - continue; - } - return Err(miette::miette!( - "gateway CONNECT failed with status {status}" - )); - } - - let (reader, writer) = tokio::io::split(buf_stream); - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); - - // Spawn both copy directions as independent tasks. Using separate spawned - // tasks (instead of try_join!/select!) ensures that when one direction - // completes or errors, the other continues independently until it also - // finishes. This is critical: when the remote side closes the connection, - // we must keep the stdin→gateway copy alive so SSH can finish sending its - // protocol-close packets, and vice-versa. - let to_remote = tokio::spawn(copy_ignoring_errors(stdin, writer)); - let from_remote = tokio::spawn(copy_ignoring_errors(reader, stdout)); - let _ = from_remote.await; - // Once the remote→stdout direction is done, SSH has received all the data - // it needs. Drop the stdin→gateway task – SSH will close its pipe when - // it's done regardless. - to_remote.abort(); - - Ok(()) + Ok(format_gateway_url(scheme, gateway_host, gateway_port)) } /// Run the SSH proxy in "name mode": create a session on the fly, then proxy. @@ -1122,93 +1420,6 @@ pub fn print_ssh_config(gateway: &str, name: &str) { print!("{}", render_ssh_config(gateway, name)); } -/// Copy all bytes from `reader` to `writer`, flushing on completion. -/// Errors are intentionally discarded – connection teardown errors are -/// expected during normal SSH session shutdown. -async fn copy_ignoring_errors(mut reader: R, mut writer: W) -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - let _ = tokio::io::copy(&mut reader, &mut writer).await; - let _ = AsyncWriteExt::flush(&mut writer).await; - let _ = AsyncWriteExt::shutdown(&mut writer).await; -} - -async fn connect_gateway( - scheme: &str, - host: &str, - port: u16, - tls: &TlsOptions, -) -> Result> { - // When using Cloudflare edge bearer auth, route through the WebSocket - // tunnel proxy regardless of the origin scheme. The proxy handles edge - // auth headers and TLS termination at the edge; the origin may be - // plaintext HTTP behind the tunnel. OIDC tokens bypass the tunnel. - if let Some(token) = tls.edge_token.as_deref() { - let gateway_url = format!("https://{host}:{port}"); - let proxy = crate::edge_tunnel::start_tunnel_proxy(&gateway_url, token).await?; - let tcp = TcpStream::connect(proxy.local_addr) - .await - .into_diagnostic()?; - tcp.set_nodelay(true).into_diagnostic()?; - return Ok(Box::new(tcp)); - } - - let tcp = TcpStream::connect((host, port)).await.into_diagnostic()?; - tcp.set_nodelay(true).into_diagnostic()?; - if scheme.eq_ignore_ascii_case("https") { - let materials = require_tls_materials(&format!("https://{host}:{port}"), tls)?; - let config = build_rustls_config(&materials)?; - let connector = TlsConnector::from(Arc::new(config)); - let server_name = ServerName::try_from(host.to_string()) - .map_err(|_| miette::miette!("invalid server name: {host}"))?; - let tls = connector - .connect(server_name, tcp) - .await - .into_diagnostic()?; - Ok(Box::new(tls)) - } else { - Ok(Box::new(tcp)) - } -} - -/// Read exactly the HTTP response status line and headers up to `\r\n\r\n`. -/// -/// Uses byte-at-a-time reads so that the caller's `BufReader` retains any -/// bytes that arrived after the header boundary (e.g. the SSH protocol -/// banner that the gateway may send in the same TCP segment). -async fn read_connect_status(stream: &mut R) -> Result { - let mut buf = Vec::new(); - let mut byte = [0u8; 1]; - loop { - let n = stream.read(&mut byte).await.into_diagnostic()?; - if n == 0 { - break; - } - buf.push(byte[0]); - if buf.len() >= 4 && &buf[buf.len() - 4..] == b"\r\n\r\n" { - break; - } - if buf.len() > 8192 { - break; - } - } - let text = String::from_utf8_lossy(&buf); - let line = text.lines().next().unwrap_or(""); - let status = line - .split_whitespace() - .nth(1) - .unwrap_or("0") - .parse::() - .unwrap_or(0); - Ok(status) -} - -trait ProxyStream: AsyncRead + AsyncWrite + Unpin + Send {} - -impl ProxyStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {} - #[cfg(test)] mod tests { use super::*; @@ -1328,6 +1539,254 @@ mod tests { assert_eq!(split_sandbox_path("/a/b/c/d.txt"), ("/a/b/c", "d.txt")); } + #[test] + fn lexical_clean_resolves_dot_and_dotdot_segments() { + assert_eq!( + lexical_clean_absolute_path("/sandbox/./a"), + Some("/sandbox/a".to_string()) + ); + assert_eq!( + lexical_clean_absolute_path("/sandbox/sub/../a"), + Some("/sandbox/a".to_string()) + ); + assert_eq!( + lexical_clean_absolute_path("/sandbox/../etc/passwd"), + Some("/etc/passwd".to_string()) + ); + assert_eq!( + lexical_clean_absolute_path("//sandbox///foo//"), + Some("/sandbox/foo".to_string()) + ); + assert_eq!(lexical_clean_absolute_path("/"), Some("/".to_string())); + } + + #[test] + fn lexical_clean_refuses_relative_paths() { + assert_eq!(lexical_clean_absolute_path(""), None); + assert_eq!(lexical_clean_absolute_path("sandbox/a"), None); + assert_eq!(lexical_clean_absolute_path("./a"), None); + } + + #[test] + fn validate_sandbox_source_path_accepts_workspace_paths() { + assert_eq!( + validate_sandbox_source_path("/sandbox/file.txt").unwrap(), + "/sandbox/file.txt" + ); + assert_eq!( + validate_sandbox_source_path("/sandbox/.agent/workspace/hello.txt").unwrap(), + "/sandbox/.agent/workspace/hello.txt" + ); + assert_eq!( + validate_sandbox_source_path("/sandbox").unwrap(), + "/sandbox" + ); + assert_eq!( + validate_sandbox_source_path("/sandbox/").unwrap(), + "/sandbox" + ); + assert_eq!( + validate_sandbox_source_path("/sandbox/sub/../file").unwrap(), + "/sandbox/file" + ); + } + + #[test] + fn validate_sandbox_source_path_rejects_traversal_and_escapes() { + let traversal = validate_sandbox_source_path("/etc/passwd").unwrap_err(); + assert!( + format!("{traversal}").contains("outside the sandbox workspace"), + "unexpected error: {traversal}" + ); + + let parent_escape = validate_sandbox_source_path("/sandbox/../etc/passwd").unwrap_err(); + assert!( + format!("{parent_escape}").contains("outside the sandbox workspace"), + "unexpected error: {parent_escape}" + ); + + let prefix_only = validate_sandbox_source_path("/sandboxed/secrets").unwrap_err(); + assert!( + format!("{prefix_only}").contains("outside the sandbox workspace"), + "unexpected error: {prefix_only}" + ); + + let empty = validate_sandbox_source_path("").unwrap_err(); + assert!(format!("{empty}").contains("empty")); + + let relative = validate_sandbox_source_path("sandbox/file").unwrap_err(); + assert!(format!("{relative}").contains("must be absolute")); + } + + #[test] + fn is_under_sandbox_workspace_accepts_root_and_descendants() { + assert!(is_under_sandbox_workspace("/sandbox")); + assert!(is_under_sandbox_workspace("/sandbox/file")); + assert!(is_under_sandbox_workspace("/sandbox/sub/nested")); + } + + #[test] + fn is_under_sandbox_workspace_rejects_outside_paths_and_prefix_collisions() { + assert!(!is_under_sandbox_workspace("/etc/passwd")); + assert!(!is_under_sandbox_workspace("/sandboxed/secrets")); + assert!(!is_under_sandbox_workspace("/")); + assert!(!is_under_sandbox_workspace("")); + } + + #[test] + fn build_single_file_tar_cmd_inserts_double_dash_before_basename() { + // Without `--`, a basename such as `--checkpoint-action=...` would be + // parsed by GNU tar as an option. Guard the wire format against this + // regression. + let cmd = build_single_file_tar_cmd("/sandbox", "--checkpoint-action=exec=id"); + assert!( + cmd.contains(" -- "), + "expected `--` separator in tar command, got: {cmd}" + ); + assert!( + cmd.ends_with(&shell_escape("--checkpoint-action=exec=id")), + "expected basename at end of tar command, got: {cmd}" + ); + } + + #[test] + fn build_single_file_tar_cmd_escapes_parent_and_basename() { + let cmd = build_single_file_tar_cmd("/sandbox/with space", "name with space"); + assert!(cmd.contains(" -- "), "missing `--` separator: {cmd}"); + assert!( + cmd.contains(&shell_escape("/sandbox/with space")), + "parent not shell-escaped: {cmd}" + ); + assert!( + cmd.contains(&shell_escape("name with space")), + "basename not shell-escaped: {cmd}" + ); + } + + #[test] + fn resolve_file_download_target_writes_to_dest_when_not_a_directory() { + assert_eq!( + resolve_file_download_target("/tmp/out.txt", "hello.txt", false), + PathBuf::from("/tmp/out.txt") + ); + } + + #[test] + fn resolve_file_download_target_places_inside_existing_directory() { + assert_eq!( + resolve_file_download_target("/tmp", "hello.txt", true), + PathBuf::from("/tmp/hello.txt") + ); + } + + #[test] + fn resolve_file_download_target_honors_trailing_slash() { + assert_eq!( + resolve_file_download_target("/tmp/newdir/", "hello.txt", false), + PathBuf::from("/tmp/newdir/hello.txt") + ); + } + + fn build_single_file_archive(entry_path: &str, bytes: &[u8]) -> Vec { + let mut buf = Vec::new(); + { + let mut builder = tar::Builder::new(&mut buf); + let mut header = tar::Header::new_gnu(); + header.set_path(entry_path).expect("set tar entry path"); + header.set_size(bytes.len() as u64); + header.set_mode(0o644); + header.set_entry_type(tar::EntryType::Regular); + header.set_cksum(); + builder.append(&header, bytes).expect("append tar entry"); + builder.finish().expect("finish tar archive"); + } + buf + } + + fn unpack_into(archive_bytes: &[u8], staging: &Path) { + let mut archive = tar::Archive::new(std::io::Cursor::new(archive_bytes)); + archive.unpack(staging).expect("unpack archive"); + } + + #[test] + fn place_downloaded_file_writes_regular_file_at_dest() { + let workdir = tempfile::tempdir().expect("create workdir"); + let staging = workdir.path().join("staging"); + fs::create_dir_all(&staging).expect("create staging"); + let archive = build_single_file_archive("hello.txt", b"trust me"); + unpack_into(&archive, &staging); + + let dest = workdir.path().join("out.txt"); + place_downloaded_file(&staging, "hello.txt", &dest).expect("place file"); + + let meta = fs::symlink_metadata(&dest).expect("stat dest"); + assert!(meta.is_file(), "dest must be a regular file, got {meta:?}"); + assert_eq!(fs::read(&dest).expect("read dest"), b"trust me"); + } + + #[test] + fn place_downloaded_file_refuses_to_clobber_existing_directory() { + let workdir = tempfile::tempdir().expect("create workdir"); + let staging = workdir.path().join("staging"); + fs::create_dir_all(&staging).expect("create staging"); + let archive = build_single_file_archive("hello.txt", b"trust me"); + unpack_into(&archive, &staging); + + let dest = workdir.path().join("conflict-dir"); + fs::create_dir(&dest).expect("create conflict dir"); + + let err = place_downloaded_file(&staging, "hello.txt", &dest) + .expect_err("expected directory-clobber refusal"); + assert!( + format!("{err}").contains("cannot overwrite directory"), + "unexpected error: {err}" + ); + assert!( + fs::symlink_metadata(&dest).expect("stat dest").is_dir(), + "dest should remain a directory after refusal" + ); + } + + #[test] + fn download_full_pipeline_lands_file_at_exact_dest_path() { + let workdir = tempfile::tempdir().expect("create workdir"); + let staging_parent = workdir.path(); + let archive = build_single_file_archive("hello.txt", b"trust me"); + + let dest_str = staging_parent.join("out.txt"); + let dest_str = dest_str.to_str().unwrap(); + let final_path = resolve_file_download_target(dest_str, "hello.txt", false); + assert_eq!(final_path, Path::new(dest_str)); + + let staging = tempfile::TempDir::new_in(staging_parent).expect("staging dir"); + unpack_into(&archive, staging.path()); + place_downloaded_file(staging.path(), "hello.txt", &final_path).expect("place"); + + let meta = fs::symlink_metadata(&final_path).expect("stat final"); + assert!(meta.is_file(), "expected regular file, got {meta:?}"); + assert_eq!(fs::read(&final_path).expect("read final"), b"trust me"); + } + + #[test] + fn download_full_pipeline_places_inside_existing_directory_destination() { + let workdir = tempfile::tempdir().expect("create workdir"); + let archive = build_single_file_archive("hello.txt", b"trust me"); + + let dest_dir = workdir.path().join("out-dir"); + fs::create_dir(&dest_dir).expect("create dest dir"); + let dest_str = dest_dir.to_str().unwrap(); + let final_path = resolve_file_download_target(dest_str, "hello.txt", true); + assert_eq!(final_path, dest_dir.join("hello.txt")); + + let staging = tempfile::TempDir::new_in(workdir.path()).expect("staging dir"); + unpack_into(&archive, staging.path()); + place_downloaded_file(staging.path(), "hello.txt", &final_path).expect("place"); + + let meta = fs::symlink_metadata(&final_path).expect("stat final"); + assert!(meta.is_file()); + assert_eq!(fs::read(&final_path).expect("read final"), b"trust me"); + } + #[test] fn directory_upload_prefix_uses_basename_for_named_directories() { assert_eq!( diff --git a/crates/openshell-cli/src/tls.rs b/crates/openshell-cli/src/tls.rs index c733d3db3..10df401a5 100644 --- a/crates/openshell-cli/src/tls.rs +++ b/crates/openshell-cli/src/tls.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use miette::{IntoDiagnostic, Result, WrapErr}; +use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::proto::inference_client::InferenceClient; use openshell_core::proto::open_shell_client::OpenShellClient; use rustls::{ @@ -99,7 +100,7 @@ impl TlsOptions { } } - /// Returns `true` when using bearer token auth (edge or OIDC). + /// Returns `true` when using bearer token auth. pub fn is_bearer_auth(&self) -> bool { self.edge_token.is_some() || self.oidc_token.is_some() } @@ -342,6 +343,7 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { let endpoint = Endpoint::from_shared(server.to_string()) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) .keep_alive_while_idle(true); return endpoint.connect().await.into_diagnostic(); @@ -362,6 +364,7 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { let endpoint = Endpoint::from_shared(local_url) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) .keep_alive_while_idle(true); return endpoint.connect().await.into_diagnostic(); @@ -389,13 +392,14 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { let mut endpoint = Endpoint::from_shared(server.to_string()) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) .keep_alive_while_idle(true); let tls_config = if tls.oidc_token.is_some() { - // OIDC bearer auth over HTTPS: use mTLS certs for the transport layer - // when available (server may still require client certs), and layer - // the Bearer token on top via the interceptor. + // Bearer auth over HTTPS: use mTLS certs for the transport layer when + // available (server may still require client certs), and layer the + // Bearer token on top via the interceptor. require_tls_materials(server, tls).map_or_else( |_| { let resolved = tls.with_default_paths(server); @@ -403,9 +407,12 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { .ca .as_ref() .and_then(|ca_path| std::fs::read(ca_path).ok()) - .map_or_else(ClientTlsConfig::new, |ca_pem| { - ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca_pem)) - }) + .map_or_else( + || ClientTlsConfig::new().with_enabled_roots(), + |ca_pem| { + ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca_pem)) + }, + ) }, |materials| build_tonic_tls_config(&materials), ) @@ -429,85 +436,16 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { /// Otherwise, standard mTLS is used (interceptor is a no-op). pub async fn grpc_client(server: &str, tls: &TlsOptions) -> Result { let channel = build_channel(server, tls).await?; - let interceptor = EdgeAuthInterceptor::maybe_from(tls)?; + let interceptor = interceptor_from_tls(tls)?; Ok(OpenShellClient::with_interceptor(channel, interceptor)) } -/// Interceptor that injects authentication headers into every outgoing gRPC request. -/// -/// Supports OIDC Bearer tokens (standard `authorization` header) and -/// Cloudflare Access tokens (custom headers). When no token is set, acts -/// as a no-op. OIDC takes precedence over edge tokens. -#[derive(Clone)] -#[allow(clippy::struct_field_names)] -pub struct EdgeAuthInterceptor { - /// Standard `authorization: Bearer ` for OIDC. - bearer_value: Option>, - /// CF-specific `Cf-Access-Jwt-Assertion` header. - header_value: Option>, - /// CF-specific `Cookie: CF_Authorization=` header. - cookie_value: Option>, -} - -impl EdgeAuthInterceptor { - /// Create an interceptor from [`TlsOptions`]. Returns a no-op interceptor - /// when no auth token is configured. - pub fn maybe_from(tls: &TlsOptions) -> Result { - // OIDC bearer token takes precedence. - if let Some(ref token) = tls.oidc_token { - let bearer: tonic::metadata::MetadataValue = - format!("Bearer {token}") - .parse() - .map_err(|_| miette::miette!("invalid OIDC token value"))?; - return Ok(Self { - bearer_value: Some(bearer), - header_value: None, - cookie_value: None, - }); - } - - let (header_value, cookie_value) = match tls.edge_token.as_deref() { - Some(t) => { - let hv: tonic::metadata::MetadataValue = t - .parse() - .map_err(|_| miette::miette!("invalid edge token value"))?; - let cv: tonic::metadata::MetadataValue = - format!("CF_Authorization={t}") - .parse() - .map_err(|_| miette::miette!("invalid edge token value for cookie"))?; - (Some(hv), Some(cv)) - } - None => (None, None), - }; - Ok(Self { - bearer_value: None, - header_value, - cookie_value, - }) - } -} - -impl tonic::service::Interceptor for EdgeAuthInterceptor { - fn call( - &mut self, - mut req: tonic::Request<()>, - ) -> std::result::Result, tonic::Status> { - if let Some(ref val) = self.bearer_value { - req.metadata_mut().insert("authorization", val.clone()); - } - if let Some(ref val) = self.header_value { - req.metadata_mut() - .insert("cf-access-jwt-assertion", val.clone()); - } - if let Some(ref val) = self.cookie_value { - req.metadata_mut().insert("cookie", val.clone()); - } - Ok(req) - } +fn interceptor_from_tls(tls: &TlsOptions) -> Result { + EdgeAuthInterceptor::new(tls.oidc_token.as_deref(), tls.edge_token.as_deref()) } pub async fn grpc_inference_client(server: &str, tls: &TlsOptions) -> Result { let channel = build_channel(server, tls).await?; - let interceptor = EdgeAuthInterceptor::maybe_from(tls)?; + let interceptor = interceptor_from_tls(tls)?; Ok(InferenceClient::with_interceptor(channel, interceptor)) } diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index 15f620e8e..ea2d5a465 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -5,25 +5,30 @@ //! `--provider` names are auto-created when they match a known provider type, //! pass through when they already exist, and error for unrecognised names. +mod helpers; + +use helpers::{ + EnvVarGuard, build_ca, build_client_cert, build_server_cert, install_rustls_provider, +}; use openshell_cli::run; use openshell_cli::tls::TlsOptions; use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, + CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, + DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - Provider, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxProvidersRequest, + ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, Provider, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, }; use openshell_core::{ObjectId, ObjectName}; -use rcgen::{ - BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, -}; use std::collections::HashMap; use std::sync::Arc; use tempfile::TempDir; @@ -33,60 +38,6 @@ use tokio_stream::wrappers::TcpListenerStream; use tonic::transport::{Certificate as TlsCertificate, Identity, Server, ServerTlsConfig}; use tonic::{Response, Status}; -// ── EnvVarGuard ────────────────────────────────────────────────────── - -// Serialise tests that mutate environment variables so concurrent -// threads don't clobber each other. -static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); - -struct SavedVar { - key: &'static str, - original: Option, -} - -/// Holds the global env lock and restores all modified variables on drop. -struct EnvVarGuard { - vars: Vec, - _lock: std::sync::MutexGuard<'static, ()>, -} - -#[allow(unsafe_code)] -impl EnvVarGuard { - /// Acquire the lock and set one or more environment variables. - fn set(pairs: &[(&'static str, &str)]) -> Self { - let lock = ENV_LOCK - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - let mut vars = Vec::with_capacity(pairs.len()); - for &(key, value) in pairs { - let original = std::env::var(key).ok(); - unsafe { - std::env::set_var(key, value); - } - vars.push(SavedVar { key, original }); - } - Self { vars, _lock: lock } - } -} - -#[allow(unsafe_code)] -impl Drop for EnvVarGuard { - fn drop(&mut self) { - for var in &self.vars { - if let Some(value) = &var.original { - unsafe { - std::env::set_var(var.key, value); - } - } else { - unsafe { - std::env::remove_var(var.key); - } - } - } - // _lock drops here, releasing the mutex - } -} - // ── mock OpenShell server ───────────────────────────────────────────── #[derive(Clone, Default)] @@ -111,10 +62,12 @@ impl TestOpenShell { name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: HashMap::new(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ); } @@ -153,6 +106,27 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(ListSandboxProvidersResponse::default())) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(AttachSandboxProviderResponse::default())) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(DetachSandboxProviderResponse::default())) + } + async fn delete_sandbox( &self, _request: tonic::Request, @@ -190,6 +164,36 @@ impl OpenShell for TestOpenShell { Ok(Response::new(CreateSshSessionResponse::default())) } + async fn expose_service( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ServiceEndpointResponse::default(), + )) + } + + async fn get_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn list_services( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn revoke_ssh_session( &self, _request: tonic::Request, @@ -316,6 +320,19 @@ impl OpenShell for TestOpenShell { } base }; + let merge_expiry = |mut base: HashMap, incoming: HashMap| { + if incoming.is_empty() { + return base; + } + for (k, v) in incoming { + if v <= 0 { + base.remove(&k); + } else { + base.insert(k, v); + } + } + base + }; let existing_metadata = existing.metadata.clone().unwrap_or_default(); let provider_metadata = provider.metadata.clone().unwrap_or_default(); let updated = Provider { @@ -324,10 +341,15 @@ impl OpenShell for TestOpenShell { name: provider_metadata.name, created_at_ms: existing_metadata.created_at_ms, labels: existing_metadata.labels, + resource_version: 0, }), r#type: existing.r#type, credentials: merge(existing.credentials, provider.credentials), config: merge(existing.config, provider.config), + credential_expires_at_ms: merge_expiry( + existing.credential_expires_at_ms, + provider.credential_expires_at_ms, + ), }; let updated_name = updated.object_name().to_string(); providers.insert(updated_name, updated.clone()); @@ -335,6 +357,33 @@ impl OpenShell for TestOpenShell { provider: Some(updated), })) } + async fn get_provider_refresh_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn configure_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn rotate_provider_credential( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } async fn delete_provider( &self, @@ -372,6 +421,15 @@ impl OpenShell for TestOpenShell { ))) } + type ExecSandboxInteractiveStream = + tokio_stream::wrappers::ReceiverStream>; + async fn exec_sandbox_interactive( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_config( &self, _request: tonic::Request, @@ -477,6 +535,20 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn issue_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn refresh_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn connect_supervisor( &self, _request: tonic::Request>, @@ -493,36 +565,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } -} - -// ── TLS helpers ────────────────────────────────────────────────────── -fn install_rustls_provider() { - let _ = rustls::crypto::ring::default_provider().install_default(); -} - -fn build_ca() -> (Certificate, KeyPair) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); - let cert = params.self_signed(&key_pair).unwrap(); - (cert, key_pair) -} + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; -fn build_server_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(vec!["localhost".to_string()]).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - (cert.pem(), key_pair.serialize_pem()) -} - -fn build_client_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - (cert.pem(), key_pair.serialize_pem()) + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // ── test server fixture ────────────────────────────────────────────── @@ -693,19 +746,19 @@ async fn inferred_type_auto_creates_provider() { let result = run::ensure_required_providers( &mut client, &[], - &["claude".to_string()], + &["claude-code".to_string()], Some(true), // --auto-providers ) .await .expect("should auto-create the inferred provider"); - assert_eq!(result, vec!["claude".to_string()]); + assert_eq!(result, vec!["claude-code".to_string()]); let providers = ts.openshell.state.providers.lock().await; let provider = providers - .get("claude") - .expect("claude provider should exist"); - assert_eq!(provider.r#type, "claude"); + .get("claude-code") + .expect("claude-code provider should exist"); + assert_eq!(provider.r#type, "claude-code"); } /// When `--no-auto-providers` is set, missing explicit providers that would @@ -757,7 +810,7 @@ async fn explicit_and_inferred_providers_combined() { let result = run::ensure_required_providers( &mut client, &["nvidia".to_string()], - &["claude".to_string()], + &["claude-code".to_string()], Some(true), ) .await @@ -765,12 +818,12 @@ async fn explicit_and_inferred_providers_combined() { assert_eq!(result.len(), 2); assert!(result.contains(&"nvidia".to_string())); - assert!(result.contains(&"claude".to_string())); + assert!(result.contains(&"claude-code".to_string())); let providers = ts.openshell.state.providers.lock().await; assert_eq!(providers.len(), 2); assert!(providers.contains_key("nvidia")); - assert!(providers.contains_key("claude")); + assert!(providers.contains_key("claude-code")); } /// When an explicit provider name matches an inferred type, the provider diff --git a/crates/openshell-cli/tests/helpers/mod.rs b/crates/openshell-cli/tests/helpers/mod.rs new file mode 100644 index 000000000..a58e750b9 --- /dev/null +++ b/crates/openshell-cli/tests/helpers/mod.rs @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared helpers for CLI integration tests. +//! +//! Include this module from a test file with: +//! ```ignore +//! mod helpers; +//! ``` + +use rcgen::{ + BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, +}; + +// ── EnvVarGuard ────────────────────────────────────────────────────────────── + +/// Global mutex that serialises tests which mutate environment variables so +/// concurrent threads don't clobber each other's state. +static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + +struct SavedVar { + key: &'static str, + original: Option, +} + +/// RAII guard that acquires `ENV_LOCK` and restores all modified environment +/// variables on drop. +pub struct EnvVarGuard { + vars: Vec, + _lock: std::sync::MutexGuard<'static, ()>, +} + +#[allow(dead_code, unsafe_code)] +impl EnvVarGuard { + /// Acquire the global env-var lock and atomically set one or more + /// environment variables. All variables are restored to their prior + /// state (or removed) when the guard is dropped. + pub fn set(pairs: &[(&'static str, &str)]) -> Self { + let lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut vars = Vec::with_capacity(pairs.len()); + for &(key, value) in pairs { + let original = std::env::var(key).ok(); + unsafe { + std::env::set_var(key, value); + } + vars.push(SavedVar { key, original }); + } + Self { vars, _lock: lock } + } +} + +#[allow(unsafe_code)] +impl Drop for EnvVarGuard { + fn drop(&mut self) { + for var in &self.vars { + if let Some(value) = &var.original { + unsafe { + std::env::set_var(var.key, value); + } + } else { + unsafe { + std::env::remove_var(var.key); + } + } + } + // _lock drops here, releasing the mutex + } +} + +// ── TLS helpers ────────────────────────────────────────────────────────────── + +/// Install the `rustls` ring crypto provider as the process default. +/// +/// Safe to call multiple times — subsequent calls are no-ops. +#[allow(dead_code)] +pub fn install_rustls_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); +} + +/// Generate a self-signed CA certificate and its key pair. +#[allow(dead_code)] +pub fn build_ca() -> (Certificate, KeyPair) { + let key_pair = KeyPair::generate().unwrap(); + let mut params = CertificateParams::new(Vec::::new()).unwrap(); + params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + let cert = params.self_signed(&key_pair).unwrap(); + (cert, key_pair) +} + +/// Generate a server certificate signed by `ca`, valid for `localhost`. +/// +/// Returns `(cert_pem, key_pem)`. +#[allow(dead_code)] +pub fn build_server_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { + let key_pair = KeyPair::generate().unwrap(); + let mut params = CertificateParams::new(vec!["localhost".to_string()]).unwrap(); + params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; + let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); + (cert.pem(), key_pair.serialize_pem()) +} + +/// Generate a client authentication certificate signed by `ca`. +/// +/// Returns `(cert_pem, key_pem)`. +#[allow(dead_code)] +pub fn build_client_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { + let key_pair = KeyPair::generate().unwrap(); + let mut params = CertificateParams::new(Vec::::new()).unwrap(); + params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth]; + let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); + (cert.pem(), key_pair.serialize_pem()) +} diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index 866048a81..8f83599b1 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -1,18 +1,20 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +mod helpers; + +use helpers::{ + EnvVarGuard, build_ca, build_client_cert, build_server_cert, install_rustls_provider, +}; use openshell_cli::tls::{TlsOptions, grpc_client}; use openshell_core::proto::{ CreateProviderRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, ExecSandboxEvent, ExecSandboxRequest, - GetProviderRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, ServiceStatus, - UpdateProviderRequest, + DeleteProviderRequest, DeleteProviderResponse, ExecSandboxEvent, ExecSandboxInput, + ExecSandboxRequest, GetProviderRequest, HealthRequest, HealthResponse, ListProvidersRequest, + ListProvidersResponse, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, + ServiceStatus, UpdateProviderRequest, open_shell_server::{OpenShell, OpenShellServer}, }; -use rcgen::{ - BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, -}; use tempfile::tempdir; use tokio::net::TcpListener; use tokio::sync::mpsc; @@ -22,41 +24,6 @@ use tonic::{ transport::{Certificate as TlsCertificate, Identity, Server, ServerTlsConfig}, }; -struct EnvVarGuard { - key: &'static str, - original: Option, -} - -#[allow(unsafe_code)] -impl EnvVarGuard { - fn set(key: &'static str, value: &str) -> Self { - let original = std::env::var(key).ok(); - unsafe { - std::env::set_var(key, value); - } - Self { key, original } - } -} - -#[allow(unsafe_code)] -impl Drop for EnvVarGuard { - fn drop(&mut self) { - if let Some(value) = &self.original { - unsafe { - std::env::set_var(self.key, value); - } - } else { - unsafe { - std::env::remove_var(self.key); - } - } - } -} - -fn install_rustls_provider() { - let _ = rustls::crypto::ring::default_provider().install_default(); -} - #[derive(Clone, Default)] struct TestOpenShell; @@ -99,6 +66,33 @@ impl OpenShell for TestOpenShell { )) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ListSandboxProvidersResponse::default(), + )) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::DetachSandboxProviderResponse::default(), + )) + } + async fn delete_sandbox( &self, _request: tonic::Request, @@ -143,6 +137,36 @@ impl OpenShell for TestOpenShell { Ok(Response::new(CreateSshSessionResponse::default())) } + async fn expose_service( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ServiceEndpointResponse::default(), + )) + } + + async fn get_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn list_services( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn revoke_ssh_session( &self, _request: tonic::Request, @@ -220,6 +244,33 @@ impl OpenShell for TestOpenShell { "update_provider not implemented in test", )) } + async fn get_provider_refresh_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn configure_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn rotate_provider_credential( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } async fn delete_provider( &self, @@ -259,6 +310,15 @@ impl OpenShell for TestOpenShell { ))) } + type ExecSandboxInteractiveStream = + tokio_stream::wrappers::ReceiverStream>; + async fn exec_sandbox_interactive( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_config( &self, _request: tonic::Request, @@ -364,6 +424,20 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn issue_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn refresh_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn connect_supervisor( &self, _request: tonic::Request>, @@ -380,34 +454,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } -} -fn build_ca() -> (Certificate, KeyPair) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); - let cert = params.self_signed(&key_pair).unwrap(); - (cert, key_pair) -} - -fn build_server_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(vec!["localhost".to_string()]).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - let cert_pem = cert.pem(); - let key_pem = key_pair.serialize_pem(); - (cert_pem, key_pem) -} + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; -fn build_client_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - let cert_pem = cert.pem(); - let key_pem = key_pair.serialize_pem(); - (cert_pem, key_pem) + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } async fn run_server( @@ -476,7 +533,8 @@ async fn cli_requires_client_cert_for_https() { let dir = tempdir().unwrap(); // Point XDG_CONFIG_HOME at the isolated temp dir so that default_tls_dir // cannot discover real client certs from the developer's machine. - let _xdg_env = EnvVarGuard::set("XDG_CONFIG_HOME", &dir.path().to_string_lossy()); + let xdg_path = dir.path().to_string_lossy(); + let _xdg_env = EnvVarGuard::set(&[("XDG_CONFIG_HOME", &xdg_path)]); let ca_path = dir.path().join("ca.crt"); std::fs::write(&ca_path, ca_cert).unwrap(); diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 55ed69500..090097a20 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -1,25 +1,34 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +mod helpers; + +use helpers::{ + EnvVarGuard, build_ca, build_client_cert, build_server_cert, install_rustls_provider, +}; use openshell_cli::run; use openshell_cli::tls::TlsOptions; use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, - GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - Provider, ProviderProfile, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, - SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, + CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, + DeleteProviderRefreshRequest, DeleteProviderRefreshResponse, DeleteProviderRequest, + DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRefreshStatusRequest, GetProviderRefreshStatusResponse, + GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, + GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, + HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, Provider, ProviderCredentialRefresh, ProviderCredentialRefreshStatus, + ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileCredential, + ProviderProfileDiscovery, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, + RotateProviderCredentialRequest, RotateProviderCredentialResponse, Sandbox, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SettingValue, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, setting_value, }; use openshell_core::{ObjectId, ObjectName}; -use rcgen::{ - BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, -}; use std::collections::HashMap; use std::sync::Arc; use tempfile::TempDir; @@ -29,41 +38,51 @@ use tokio_stream::wrappers::TcpListenerStream; use tonic::transport::{Certificate as TlsCertificate, Identity, Server, ServerTlsConfig}; use tonic::{Response, Status}; -struct EnvVarGuard { - key: &'static str, - original: Option, -} - -#[allow(unsafe_code)] -impl EnvVarGuard { - fn set(key: &'static str, value: &str) -> Self { - let original = std::env::var(key).ok(); - unsafe { - std::env::set_var(key, value); - } - Self { key, original } - } -} - -#[allow(unsafe_code)] -impl Drop for EnvVarGuard { - fn drop(&mut self) { - if let Some(value) = &self.original { - unsafe { - std::env::set_var(self.key, value); - } - } else { - unsafe { - std::env::remove_var(self.key); - } - } - } -} - #[derive(Clone, Default)] struct ProviderState { providers: Arc>>, profiles: Arc>>, + refresh_statuses: Arc>>, + refresh_requests: Arc>>, + sandbox_providers: Arc>>>, + sandbox_provider_requests: Arc>>, + global_settings: Arc>>, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +enum ProviderRefreshRequestLog { + Status { + provider_name: String, + credential_key: String, + }, + Configure { + provider_name: String, + credential_key: String, + expires_at_ms: Option, + }, + Rotate { + provider_name: String, + credential_key: String, + }, + Delete { + provider_name: String, + credential_key: String, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +enum SandboxProviderRequestLog { + List { + sandbox_name: String, + }, + Attach { + sandbox_name: String, + provider_name: String, + }, + Detach { + sandbox_name: String, + provider_name: String, + }, } #[derive(Clone, Default)] @@ -92,9 +111,25 @@ impl OpenShell for TestOpenShell { async fn get_sandbox( &self, - _request: tonic::Request, + request: tonic::Request, ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) + let name = request.into_inner().name; + // Return a minimal sandbox with metadata for CAS operations + Ok(Response::new(SandboxResponse { + sandbox: Some(Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: format!("sb-{name}"), + name, + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 1, + }), + spec: None, + status: None, + phase: 0, + current_policy_version: 0, + }), + })) } async fn list_sandboxes( @@ -104,6 +139,120 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + request: tonic::Request, + ) -> Result, Status> { + let sandbox_name = request.into_inner().sandbox_name; + self.state + .sandbox_provider_requests + .lock() + .await + .push(SandboxProviderRequestLog::List { + sandbox_name: sandbox_name.clone(), + }); + let provider_names = self + .state + .sandbox_providers + .lock() + .await + .get(&sandbox_name) + .cloned() + .unwrap_or_default(); + let providers_by_name = self.state.providers.lock().await; + let providers = provider_names + .iter() + .filter_map(|name| providers_by_name.get(name).cloned()) + .collect(); + Ok(Response::new(ListSandboxProvidersResponse { providers })) + } + + async fn attach_sandbox_provider( + &self, + request: tonic::Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .sandbox_provider_requests + .lock() + .await + .push(SandboxProviderRequestLog::Attach { + sandbox_name: request.sandbox_name.clone(), + provider_name: request.provider_name.clone(), + }); + if !self + .state + .providers + .lock() + .await + .contains_key(&request.provider_name) + { + return Err(Status::failed_precondition("provider not found")); + } + let mut sandbox_providers = self.state.sandbox_providers.lock().await; + let providers = sandbox_providers + .entry(request.sandbox_name.clone()) + .or_default(); + let attached = if providers.contains(&request.provider_name) { + false + } else { + providers.push(request.provider_name.clone()); + true + }; + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + name: request.sandbox_name, + ..Default::default() + }), + spec: Some(openshell_core::proto::SandboxSpec { + providers: providers.clone(), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(AttachSandboxProviderResponse { + sandbox: Some(sandbox), + attached, + })) + } + + async fn detach_sandbox_provider( + &self, + request: tonic::Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .sandbox_provider_requests + .lock() + .await + .push(SandboxProviderRequestLog::Detach { + sandbox_name: request.sandbox_name.clone(), + provider_name: request.provider_name.clone(), + }); + let mut sandbox_providers = self.state.sandbox_providers.lock().await; + let providers = sandbox_providers + .entry(request.sandbox_name.clone()) + .or_default(); + let before_len = providers.len(); + providers.retain(|name| name != &request.provider_name); + let detached = providers.len() != before_len; + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + name: request.sandbox_name, + ..Default::default() + }), + spec: Some(openshell_core::proto::SandboxSpec { + providers: providers.clone(), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(DetachSandboxProviderResponse { + sandbox: Some(sandbox), + detached, + })) + } + async fn delete_sandbox( &self, _request: tonic::Request, @@ -122,7 +271,10 @@ impl OpenShell for TestOpenShell { &self, _request: tonic::Request, ) -> Result, Status> { - Ok(Response::new(GetGatewayConfigResponse::default())) + Ok(Response::new(GetGatewayConfigResponse { + settings: self.state.global_settings.lock().await.clone(), + settings_revision: 1, + })) } async fn get_sandbox_provider_environment( @@ -141,6 +293,36 @@ impl OpenShell for TestOpenShell { Ok(Response::new(CreateSshSessionResponse::default())) } + async fn expose_service( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ServiceEndpointResponse::default(), + )) + } + + async fn get_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn list_services( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn revoke_ssh_session( &self, _request: tonic::Request, @@ -304,6 +486,19 @@ impl OpenShell for TestOpenShell { } base }; + let merge_expiry = |mut base: HashMap, incoming: HashMap| { + if incoming.is_empty() { + return base; + } + for (k, v) in incoming { + if v <= 0 { + base.remove(&k); + } else { + base.insert(k, v); + } + } + base + }; let existing_metadata = existing.metadata.clone().unwrap_or_default(); let provider_metadata = provider.metadata.clone().unwrap_or_default(); let updated = Provider { @@ -312,10 +507,15 @@ impl OpenShell for TestOpenShell { name: provider_metadata.name, created_at_ms: existing_metadata.created_at_ms, labels: existing_metadata.labels, + resource_version: 0, }), r#type: existing.r#type, credentials: merge(existing.credentials, provider.credentials), config: merge(existing.config, provider.config), + credential_expires_at_ms: merge_expiry( + existing.credential_expires_at_ms, + provider.credential_expires_at_ms, + ), }; let updated_name = updated.object_name().to_string(); providers.insert(updated_name, updated.clone()); @@ -323,6 +523,125 @@ impl OpenShell for TestOpenShell { provider: Some(updated), })) } + async fn get_provider_refresh_status( + &self, + request: tonic::Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .refresh_requests + .lock() + .await + .push(ProviderRefreshRequestLog::Status { + provider_name: request.provider.clone(), + credential_key: request.credential_key.clone(), + }); + let refresh_statuses = self.state.refresh_statuses.lock().await; + let credentials = if request.credential_key.is_empty() { + refresh_statuses + .values() + .filter(|status| status.provider_name == request.provider) + .cloned() + .collect() + } else { + refresh_statuses + .get(&(request.provider, request.credential_key)) + .cloned() + .into_iter() + .collect() + }; + Ok(Response::new(GetProviderRefreshStatusResponse { + credentials, + })) + } + + async fn configure_provider_refresh( + &self, + request: tonic::Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .refresh_requests + .lock() + .await + .push(ProviderRefreshRequestLog::Configure { + provider_name: request.provider.clone(), + credential_key: request.credential_key.clone(), + expires_at_ms: request.expires_at_ms, + }); + let providers = self.state.providers.lock().await; + let provider = providers + .get(&request.provider) + .ok_or_else(|| Status::not_found("provider not found"))?; + let status = ProviderCredentialRefreshStatus { + provider_name: request.provider.clone(), + provider_id: provider.object_id().to_string(), + credential_key: request.credential_key.clone(), + strategy: request.strategy, + status: "configured".to_string(), + expires_at_ms: request.expires_at_ms.unwrap_or_default(), + next_refresh_at_ms: 0, + last_refresh_at_ms: 0, + last_error: String::new(), + }; + drop(providers); + self.state + .refresh_statuses + .lock() + .await + .insert((request.provider, request.credential_key), status.clone()); + Ok(Response::new( + openshell_core::proto::ConfigureProviderRefreshResponse { + status: Some(status), + }, + )) + } + + async fn rotate_provider_credential( + &self, + request: tonic::Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .refresh_requests + .lock() + .await + .push(ProviderRefreshRequestLog::Rotate { + provider_name: request.provider.clone(), + credential_key: request.credential_key.clone(), + }); + let mut refresh_statuses = self.state.refresh_statuses.lock().await; + let status = refresh_statuses + .get_mut(&(request.provider, request.credential_key)) + .ok_or_else(|| Status::not_found("provider refresh state not found"))?; + status.status = "rotation_requested".to_string(); + Ok(Response::new(RotateProviderCredentialResponse { + status: Some(status.clone()), + })) + } + + async fn delete_provider_refresh( + &self, + request: tonic::Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .refresh_requests + .lock() + .await + .push(ProviderRefreshRequestLog::Delete { + provider_name: request.provider.clone(), + credential_key: request.credential_key.clone(), + }); + let deleted = self + .state + .refresh_statuses + .lock() + .await + .remove(&(request.provider, request.credential_key)) + .is_some(); + Ok(Response::new(DeleteProviderRefreshResponse { deleted })) + } async fn delete_provider( &self, @@ -371,6 +690,15 @@ impl OpenShell for TestOpenShell { ))) } + type ExecSandboxInteractiveStream = + tokio_stream::wrappers::ReceiverStream>; + async fn exec_sandbox_interactive( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_config( &self, _request: tonic::Request, @@ -476,6 +804,20 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn issue_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn refresh_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn connect_supervisor( &self, _request: tonic::Request>, @@ -492,34 +834,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } -} -fn install_rustls_provider() { - let _ = rustls::crypto::ring::default_provider().install_default(); -} + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; -fn build_ca() -> (Certificate, KeyPair) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); - let cert = params.self_signed(&key_pair).unwrap(); - (cert, key_pair) -} - -fn build_server_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(vec!["localhost".to_string()]).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - (cert.pem(), key_pair.serialize_pem()) -} - -fn build_client_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - (cert.pem(), key_pair.serialize_pem()) + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } /// Test fixture: TLS-enabled server with matching client certs. @@ -580,6 +905,15 @@ async fn run_server() -> TestServer { } } +async fn enable_providers_v2(ts: &TestServer) { + ts.state.global_settings.lock().await.insert( + openshell_core::settings::PROVIDERS_V2_ENABLED_KEY.to_string(), + SettingValue { + value: Some(setting_value::Value::BoolValue(true)), + }, + ); +} + #[tokio::test] async fn provider_cli_run_functions_support_full_crud_flow() { let ts = run_server().await; @@ -609,6 +943,7 @@ async fn provider_cli_run_functions_support_full_crud_flow() { false, &["API_KEY=rotated".to_string()], &["profile=prod".to_string()], + &[], &ts.tls, ) .await @@ -628,6 +963,199 @@ async fn provider_list_profiles_cli_uses_profile_browsing_rpc() { .expect("provider list-profiles"); } +#[tokio::test] +async fn provider_refresh_cli_run_functions_wire_requests() { + let ts = run_server().await; + + run::provider_create( + &ts.endpoint, + "my-graph", + "outlook", + false, + &["MS_GRAPH_ACCESS_TOKEN=token".to_string()], + &[], + &ts.tls, + ) + .await + .expect("provider create"); + + run::provider_refresh_config( + &ts.endpoint, + run::ProviderRefreshConfigInput { + name: "my-graph", + credential_key: "MS_GRAPH_ACCESS_TOKEN", + strategy: "oauth2_client_credentials", + material: &["tenant_id=tenant".to_string()], + secret_material_keys: &["client_secret".to_string()], + credential_expires_at_ms: Some(1_767_225_600_000), + }, + &ts.tls, + ) + .await + .expect("provider refresh configure"); + run::provider_refresh_status( + &ts.endpoint, + "my-graph", + Some("MS_GRAPH_ACCESS_TOKEN"), + &ts.tls, + ) + .await + .expect("provider refresh status"); + run::provider_rotate(&ts.endpoint, "my-graph", "MS_GRAPH_ACCESS_TOKEN", &ts.tls) + .await + .expect("provider refresh rotate"); + run::provider_refresh_delete(&ts.endpoint, "my-graph", "MS_GRAPH_ACCESS_TOKEN", &ts.tls) + .await + .expect("provider refresh delete"); + + let requests = ts.state.refresh_requests.lock().await.clone(); + assert_eq!( + requests, + vec![ + ProviderRefreshRequestLog::Configure { + provider_name: "my-graph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + expires_at_ms: Some(1_767_225_600_000), + }, + ProviderRefreshRequestLog::Status { + provider_name: "my-graph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + }, + ProviderRefreshRequestLog::Rotate { + provider_name: "my-graph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + }, + ProviderRefreshRequestLog::Delete { + provider_name: "my-graph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + }, + ] + ); +} + +#[tokio::test] +async fn provider_create_allows_empty_credentials_for_gateway_refresh_profiles() { + let ts = run_server().await; + ts.state.profiles.lock().await.insert( + "custom-refresh".to_string(), + ProviderProfile { + id: "custom-refresh".to_string(), + display_name: "Custom Refresh".to_string(), + credentials: vec![ProviderProfileCredential { + name: "ACCESS_TOKEN".to_string(), + required: true, + refresh: Some(ProviderCredentialRefresh { + strategy: ProviderCredentialRefreshStrategy::Oauth2RefreshToken as i32, + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ); + + run::provider_create( + &ts.endpoint, + "custom-refresh-provider", + "custom-refresh", + false, + &[], + &[], + &ts.tls, + ) + .await + .expect("provider create"); + + let stored = ts.state.providers.lock().await; + let provider = stored.get("custom-refresh-provider").expect("provider"); + assert_eq!(provider.r#type, "custom-refresh"); + assert!(provider.credentials.is_empty()); +} + +#[tokio::test] +async fn sandbox_provider_cli_run_functions_wire_requests_and_idempotent_results() { + let ts = run_server().await; + + run::provider_create( + &ts.endpoint, + "work-github", + "github", + false, + &["GITHUB_TOKEN=ghp-test".to_string()], + &[], + &ts.tls, + ) + .await + .expect("provider create"); + + run::sandbox_provider_attach(&ts.endpoint, "dev-sandbox", "work-github", &ts.tls) + .await + .expect("sandbox provider attach"); + run::sandbox_provider_attach(&ts.endpoint, "dev-sandbox", "work-github", &ts.tls) + .await + .expect("sandbox provider attach is idempotent"); + run::sandbox_provider_list(&ts.endpoint, "dev-sandbox", &ts.tls) + .await + .expect("sandbox provider list"); + run::sandbox_provider_detach(&ts.endpoint, "dev-sandbox", "work-github", &ts.tls) + .await + .expect("sandbox provider detach"); + run::sandbox_provider_detach(&ts.endpoint, "dev-sandbox", "work-github", &ts.tls) + .await + .expect("sandbox provider detach is idempotent"); + + let requests = ts.state.sandbox_provider_requests.lock().await.clone(); + assert_eq!( + requests, + vec![ + SandboxProviderRequestLog::Attach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "work-github".to_string(), + }, + SandboxProviderRequestLog::Attach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "work-github".to_string(), + }, + SandboxProviderRequestLog::List { + sandbox_name: "dev-sandbox".to_string(), + }, + SandboxProviderRequestLog::Detach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "work-github".to_string(), + }, + SandboxProviderRequestLog::Detach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "work-github".to_string(), + }, + ] + ); + + let providers = ts.state.sandbox_providers.lock().await; + assert!(providers.get("dev-sandbox").is_none_or(Vec::is_empty)); +} + +#[tokio::test] +async fn sandbox_provider_attach_cli_surfaces_server_errors() { + let ts = run_server().await; + + let err = + run::sandbox_provider_attach(&ts.endpoint, "dev-sandbox", "missing-provider", &ts.tls) + .await + .expect_err("missing provider should fail"); + + assert!( + err.to_string().contains("provider not found"), + "unexpected error: {err}" + ); + assert_eq!( + ts.state.sandbox_provider_requests.lock().await.as_slice(), + [SandboxProviderRequestLog::Attach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "missing-provider".to_string(), + }] + ); +} + #[tokio::test] async fn provider_profile_cli_run_functions_support_custom_profiles() { let ts = run_server().await; @@ -644,6 +1172,8 @@ credentials: env_vars: [CUSTOM_API_KEY] auth_style: bearer header_name: authorization +discovery: + credentials: [api_key] endpoints: - host: api.custom.example port: 443 @@ -694,6 +1224,209 @@ binaries: [/usr/bin/custom] .expect("profile delete"); } +#[tokio::test] +async fn provider_create_from_existing_uses_profile_discovery_when_v2_enabled() { + let ts = run_server().await; + enable_providers_v2(&ts).await; + ts.state.profiles.lock().await.insert( + "custom-discovery".to_string(), + ProviderProfile { + id: "custom-discovery".to_string(), + display_name: "Custom Discovery".to_string(), + credentials: vec![ProviderProfileCredential { + name: "api_key".to_string(), + env_vars: vec!["CUSTOM_DISCOVERY_API_KEY".to_string()], + required: true, + ..Default::default() + }], + discovery: Some(ProviderProfileDiscovery { + credentials: vec!["api_key".to_string()], + }), + ..Default::default() + }, + ); + let _env = EnvVarGuard::set(&[("CUSTOM_DISCOVERY_API_KEY", "profile-secret")]); + + run::provider_create( + &ts.endpoint, + "custom-discovered", + "custom-discovery", + true, + &[], + &[], + &ts.tls, + ) + .await + .expect("profile-backed provider create --from-existing"); + + let provider = ts + .state + .providers + .lock() + .await + .get("custom-discovered") + .cloned() + .expect("custom provider should be stored"); + assert_eq!(provider.r#type, "custom-discovery"); + assert_eq!( + provider.credentials.get("CUSTOM_DISCOVERY_API_KEY"), + Some(&"profile-secret".to_string()) + ); +} + +#[tokio::test] +async fn provider_create_from_existing_uses_registry_discovery_when_v2_disabled() { + let ts = run_server().await; + let _env = EnvVarGuard::set(&[("OPENAI_API_KEY", "legacy-openai-secret")]); + + run::provider_create( + &ts.endpoint, + "legacy-openai", + "openai", + true, + &[], + &[], + &ts.tls, + ) + .await + .expect("legacy provider create --from-existing"); + + let provider = ts + .state + .providers + .lock() + .await + .get("legacy-openai") + .cloned() + .expect("legacy provider should be stored"); + assert_eq!(provider.r#type, "openai"); + assert_eq!( + provider.credentials.get("OPENAI_API_KEY"), + Some(&"legacy-openai-secret".to_string()) + ); +} + +#[tokio::test] +async fn provider_create_from_existing_requires_profile_when_v2_enabled() { + let ts = run_server().await; + enable_providers_v2(&ts).await; + let _env = EnvVarGuard::set(&[("OPENAI_API_KEY", "legacy-openai-secret")]); + + let err = run::provider_create(&ts.endpoint, "v2-openai", "openai", true, &[], &[], &ts.tls) + .await + .expect_err("v2 discovery without a profile should fail"); + + assert!( + err.to_string() + .contains("providers v2 discovery requires a provider profile"), + "unexpected error: {err}" + ); + assert!(!ts.state.providers.lock().await.contains_key("v2-openai")); +} + +#[tokio::test] +async fn provider_create_from_existing_fails_when_profile_discovery_finds_nothing() { + let ts = run_server().await; + enable_providers_v2(&ts).await; + ts.state.profiles.lock().await.insert( + "empty-discovery".to_string(), + ProviderProfile { + id: "empty-discovery".to_string(), + display_name: "Empty Discovery".to_string(), + credentials: vec![ProviderProfileCredential { + name: "api_key".to_string(), + env_vars: vec!["CUSTOM_DISCOVERY_TOKEN_NOT_SET_1460".to_string()], + required: false, + ..Default::default() + }], + discovery: Some(ProviderProfileDiscovery { + credentials: vec!["api_key".to_string()], + }), + ..Default::default() + }, + ); + + let err = run::provider_create( + &ts.endpoint, + "empty-discovered", + "empty-discovery", + true, + &[], + &[], + &ts.tls, + ) + .await + .expect_err("empty profile-backed discovery should fail"); + + assert!( + err.to_string() + .contains("no existing local credentials/config found"), + "unexpected error: {err}" + ); + assert!( + !ts.state + .providers + .lock() + .await + .contains_key("empty-discovered") + ); +} + +#[tokio::test] +async fn provider_update_from_existing_uses_profile_discovery_when_v2_enabled() { + let ts = run_server().await; + enable_providers_v2(&ts).await; + ts.state.profiles.lock().await.insert( + "custom-update-discovery".to_string(), + ProviderProfile { + id: "custom-update-discovery".to_string(), + display_name: "Custom Update Discovery".to_string(), + credentials: vec![ProviderProfileCredential { + name: "api_key".to_string(), + env_vars: vec!["CUSTOM_UPDATE_DISCOVERY_API_KEY".to_string()], + required: true, + ..Default::default() + }], + discovery: Some(ProviderProfileDiscovery { + credentials: vec!["api_key".to_string()], + }), + ..Default::default() + }, + ); + ts.state.providers.lock().await.insert( + "custom-update".to_string(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "id-custom-update".to_string(), + name: "custom-update".to_string(), + ..Default::default() + }), + r#type: "custom-update-discovery".to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ); + let _env = EnvVarGuard::set(&[("CUSTOM_UPDATE_DISCOVERY_API_KEY", "updated-profile-secret")]); + + run::provider_update(&ts.endpoint, "custom-update", true, &[], &[], &[], &ts.tls) + .await + .expect("profile-backed provider update --from-existing"); + + let provider = ts + .state + .providers + .lock() + .await + .get("custom-update") + .cloned() + .expect("custom provider should still be stored"); + assert_eq!( + provider.credentials.get("CUSTOM_UPDATE_DISCOVERY_API_KEY"), + Some(&"updated-profile-secret".to_string()) + ); +} + #[tokio::test] async fn provider_profile_import_from_directory_imports_supported_profile_files() { let ts = run_server().await; @@ -888,7 +1621,7 @@ async fn provider_create_rejects_key_only_credentials_without_local_env_value() #[tokio::test] async fn provider_create_supports_generic_type_and_env_lookup_credentials() { let ts = run_server().await; - let _guard = EnvVarGuard::set("NAV_GENERIC_TEST_KEY", "generic-value"); + let _guard = EnvVarGuard::set(&[("NAV_GENERIC_TEST_KEY", "generic-value")]); run::provider_create( &ts.endpoint, @@ -946,7 +1679,7 @@ async fn provider_create_rejects_combined_from_existing_and_credentials() { #[tokio::test] async fn provider_create_rejects_empty_env_var_for_key_only_credential() { let ts = run_server().await; - let _guard = EnvVarGuard::set("NAV_EMPTY_ENV_KEY", ""); + let _guard = EnvVarGuard::set(&[("NAV_EMPTY_ENV_KEY", "")]); let err = run::provider_create( &ts.endpoint, @@ -970,7 +1703,7 @@ async fn provider_create_rejects_empty_env_var_for_key_only_credential() { #[tokio::test] async fn provider_create_supports_nvidia_type_with_nvidia_api_key() { let ts = run_server().await; - let _guard = EnvVarGuard::set("NVIDIA_API_KEY", "nvapi-live-test"); + let _guard = EnvVarGuard::set(&[("NVIDIA_API_KEY", "nvapi-live-test")]); run::provider_create( &ts.endpoint, diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index a8e359d54..3ed43b2fc 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -1,29 +1,37 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +mod helpers; + +use helpers::{ + EnvVarGuard, build_ca, build_client_cert, build_server_cert, install_rustls_provider, +}; use openshell_bootstrap::load_last_sandbox; use openshell_cli::run; use openshell_cli::tls::TlsOptions; use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, + CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, + DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - PlatformEvent, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, Sandbox, - SandboxPhase, SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, - UpdateProviderRequest, WatchSandboxRequest, sandbox_stream_event, -}; -use rcgen::{ - BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, + ListProvidersRequest, ListProvidersResponse, ListSandboxProvidersRequest, + ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, PlatformEvent, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, Sandbox, SandboxCondition, + SandboxLogLine, SandboxPhase, SandboxResponse, SandboxStatus, SandboxStreamEvent, + ServiceStatus, SettingValue, SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, + sandbox_stream_event, setting_value, }; use std::collections::HashMap; use std::fs; use std::os::unix::fs::PermissionsExt; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::{Duration, Instant}; use tempfile::TempDir; use tokio::net::TcpListener; use tokio::sync::{Mutex, mpsc}; @@ -31,56 +39,14 @@ use tokio_stream::wrappers::TcpListenerStream; use tonic::transport::{Certificate as TlsCertificate, Identity, Server, ServerTlsConfig}; use tonic::{Response, Status}; -static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); - -struct SavedVar { - key: &'static str, - original: Option, -} - -struct EnvVarGuard { - vars: Vec, - _lock: std::sync::MutexGuard<'static, ()>, -} - -#[allow(unsafe_code)] -impl EnvVarGuard { - fn set(pairs: &[(&'static str, String)]) -> Self { - let lock = ENV_LOCK - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - let mut vars = Vec::with_capacity(pairs.len()); - for (key, value) in pairs { - let original = std::env::var(key).ok(); - unsafe { - std::env::set_var(key, value); - } - vars.push(SavedVar { key, original }); - } - Self { vars, _lock: lock } - } -} - -#[allow(unsafe_code)] -impl Drop for EnvVarGuard { - fn drop(&mut self) { - for var in &self.vars { - if let Some(value) = &var.original { - unsafe { - std::env::set_var(var.key, value); - } - } else { - unsafe { - std::env::remove_var(var.key); - } - } - } - } -} - #[derive(Clone, Default)] struct SandboxState { deleted_names: Arc>>>, + create_requests: Arc>>, + vm_error_after_started: Arc, + vm_slow_progress_before_ready: Arc, + vm_log_churn_before_ready: Arc, + global_settings: Arc>>, } #[derive(Clone, Default)] @@ -104,7 +70,9 @@ impl OpenShell for TestOpenShell { &self, request: tonic::Request, ) -> Result, Status> { - let name = request.into_inner().name; + let request = request.into_inner(); + let name = request.name.clone(); + self.state.create_requests.lock().await.push(request); let sandbox_name = if name.is_empty() { "test-sandbox".to_string() } else { @@ -118,6 +86,7 @@ impl OpenShell for TestOpenShell { name: sandbox_name, created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), phase: SandboxPhase::Provisioning as i32, ..Sandbox::default() @@ -137,6 +106,7 @@ impl OpenShell for TestOpenShell { name, created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), phase: SandboxPhase::Ready as i32, ..Sandbox::default() @@ -151,6 +121,27 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(ListSandboxProvidersResponse::default())) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(AttachSandboxProviderResponse::default())) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(DetachSandboxProviderResponse::default())) + } + async fn delete_sandbox( &self, request: tonic::Request, @@ -175,7 +166,10 @@ impl OpenShell for TestOpenShell { &self, _request: tonic::Request, ) -> Result, Status> { - Ok(Response::new(GetGatewayConfigResponse::default())) + Ok(Response::new(GetGatewayConfigResponse { + settings: self.state.global_settings.lock().await.clone(), + settings_revision: 1, + })) } async fn get_sandbox_provider_environment( @@ -198,11 +192,40 @@ impl OpenShell for TestOpenShell { gateway_scheme: "https".to_string(), gateway_host: "localhost".to_string(), gateway_port: 443, - connect_path: "/connect/ssh".to_string(), ..CreateSshSessionResponse::default() })) } + async fn expose_service( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ServiceEndpointResponse::default(), + )) + } + + async fn get_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn list_services( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn revoke_ssh_session( &self, _request: tonic::Request, @@ -272,6 +295,33 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Ok(Response::new(ProviderResponse::default())) } + async fn get_provider_refresh_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn configure_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn rotate_provider_credential( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } async fn delete_provider( &self, @@ -293,6 +343,12 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { let sandbox_id = request.into_inner().id; let (tx, rx) = mpsc::channel(4); + let vm_error_after_started = self.state.vm_error_after_started.load(Ordering::SeqCst); + let vm_slow_progress_before_ready = self + .state + .vm_slow_progress_before_ready + .load(Ordering::SeqCst); + let vm_log_churn_before_ready = self.state.vm_log_churn_before_ready.load(Ordering::SeqCst); tokio::spawn(async move { let provisioning = Sandbox { @@ -301,10 +357,28 @@ impl OpenShell for TestOpenShell { name: sandbox_id.trim_start_matches("id-").to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), phase: SandboxPhase::Provisioning as i32, ..Sandbox::default() }; + let error = Sandbox { + phase: SandboxPhase::Error as i32, + status: Some(SandboxStatus { + sandbox_name: sandbox_id.trim_start_matches("id-").to_string(), + agent_pod: String::new(), + agent_fd: String::new(), + sandbox_fd: String::new(), + conditions: vec![SandboxCondition { + r#type: "Ready".to_string(), + status: "False".to_string(), + reason: "ProcessExited".to_string(), + message: "VM process exited with status 0".to_string(), + last_transition_time: String::new(), + }], + }), + ..provisioning.clone() + }; let ready = Sandbox { phase: SandboxPhase::Ready as i32, ..provisioning.clone() @@ -315,6 +389,80 @@ impl OpenShell for TestOpenShell { payload: Some(sandbox_stream_event::Payload::Sandbox(provisioning)), })) .await; + if vm_error_after_started { + let _ = tx + .send(Ok(SandboxStreamEvent { + payload: Some(sandbox_stream_event::Payload::Event(PlatformEvent { + source: "vm".to_string(), + reason: "Started".to_string(), + message: "Started VM launcher".to_string(), + ..PlatformEvent::default() + })), + })) + .await; + let _ = tx + .send(Ok(SandboxStreamEvent { + payload: Some(sandbox_stream_event::Payload::Sandbox(error)), + })) + .await; + tokio::time::sleep(Duration::from_secs(5)).await; + return; + } + if vm_log_churn_before_ready { + for message in ["still booting", "still booting again"] { + tokio::time::sleep(Duration::from_millis(600)).await; + let _ = tx + .send(Ok(SandboxStreamEvent { + payload: Some(sandbox_stream_event::Payload::Log(SandboxLogLine { + sandbox_id: sandbox_id.clone(), + timestamp_ms: 0, + level: "INFO".to_string(), + target: "test".to_string(), + message: message.to_string(), + source: "gateway".to_string(), + fields: HashMap::new(), + })), + })) + .await; + } + let _ = tx + .send(Ok(SandboxStreamEvent { + payload: Some(sandbox_stream_event::Payload::Sandbox(ready)), + })) + .await; + return; + } + if vm_slow_progress_before_ready { + tokio::time::sleep(Duration::from_millis(600)).await; + let _ = tx + .send(Ok(SandboxStreamEvent { + payload: Some(sandbox_stream_event::Payload::Event(PlatformEvent { + source: "vm".to_string(), + reason: "PreparingRootfs".to_string(), + message: "Preparing rootfs".to_string(), + ..PlatformEvent::default() + })), + })) + .await; + tokio::time::sleep(Duration::from_millis(600)).await; + let _ = tx + .send(Ok(SandboxStreamEvent { + payload: Some(sandbox_stream_event::Payload::Event(PlatformEvent { + source: "vm".to_string(), + reason: "CreatingRootDisk".to_string(), + message: "Formatting root disk".to_string(), + ..PlatformEvent::default() + })), + })) + .await; + tokio::time::sleep(Duration::from_millis(600)).await; + let _ = tx + .send(Ok(SandboxStreamEvent { + payload: Some(sandbox_stream_event::Payload::Sandbox(ready)), + })) + .await; + return; + } let _ = tx .send(Ok(SandboxStreamEvent { payload: Some(sandbox_stream_event::Payload::Event(PlatformEvent { @@ -346,6 +494,15 @@ impl OpenShell for TestOpenShell { ))) } + type ExecSandboxInteractiveStream = + tokio_stream::wrappers::ReceiverStream>; + async fn exec_sandbox_interactive( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_config( &self, _request: tonic::Request, @@ -451,6 +608,20 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn issue_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn refresh_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn connect_supervisor( &self, _request: tonic::Request>, @@ -467,34 +638,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } -} - -fn install_rustls_provider() { - let _ = rustls::crypto::ring::default_provider().install_default(); -} - -fn build_ca() -> (Certificate, KeyPair) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); - let cert = params.self_signed(&key_pair).unwrap(); - (cert, key_pair) -} -fn build_server_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(vec!["localhost".to_string()]).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - (cert.pem(), key_pair.serialize_pem()) -} + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; -fn build_client_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - (cert.pem(), key_pair.serialize_pem()) + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } struct TestServer { @@ -564,26 +718,52 @@ fn install_fake_ssh(dir: &TempDir) -> std::path::PathBuf { } fn test_env(fake_ssh_dir: &TempDir, xdg_dir: &TempDir) -> EnvVarGuard { + test_env_with(fake_ssh_dir, xdg_dir, &[]) +} + +fn test_env_with( + fake_ssh_dir: &TempDir, + xdg_dir: &TempDir, + extra: &[(&'static str, String)], +) -> EnvVarGuard { let path = format!( "{}:{}", fake_ssh_dir.path().display(), std::env::var("PATH").unwrap_or_default() ); + let xdg = xdg_dir.path().to_str().unwrap().to_string(); - EnvVarGuard::set(&[ + let mut owned_pairs = vec![ ("PATH", path), - ( - "XDG_CONFIG_HOME", - xdg_dir.path().to_str().unwrap().to_string(), - ), - ("HOME", xdg_dir.path().to_str().unwrap().to_string()), - ]) + ("XDG_CONFIG_HOME", xdg.clone()), + ("HOME", xdg), + ]; + owned_pairs.extend(extra.iter().cloned()); + let pairs = owned_pairs + .iter() + .map(|(key, value)| (*key, value.as_str())) + .collect::>(); + + EnvVarGuard::set(&pairs) } async fn deleted_names(server: &TestServer) -> Vec> { server.openshell.state.deleted_names.lock().await.clone() } +async fn create_requests(server: &TestServer) -> Vec { + server.openshell.state.create_requests.lock().await.clone() +} + +async fn enable_providers_v2(server: &TestServer) { + server.openshell.state.global_settings.lock().await.insert( + openshell_core::settings::PROVIDERS_V2_ENABLED_KEY.to_string(), + SettingValue { + value: Some(setting_value::Value::BoolValue(true)), + }, + ); +} + fn test_tls(server: &TestServer) -> TlsOptions { server.tls.with_gateway_name("openshell") } @@ -607,6 +787,8 @@ async fn sandbox_create_keeps_command_sessions_by_default() { false, None, None, + None, + None, &[], None, None, @@ -627,6 +809,274 @@ async fn sandbox_create_keeps_command_sessions_by_default() { ); } +#[tokio::test] +async fn sandbox_create_sends_cpu_and_memory_limits_only() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("resources"), + None, + "openshell", + None, + true, + false, + None, + Some("500m"), + Some("2Gi"), + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let resources = requests[0] + .spec + .as_ref() + .and_then(|spec| spec.template.as_ref()) + .and_then(|template| template.resources.as_ref()) + .expect("resource limits should be sent"); + let limits = resources + .fields + .get("limits") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StructValue(inner) => Some(inner), + _ => None, + }) + .expect("limits should be a struct"); + + assert_eq!( + limits + .fields + .get("cpu") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StringValue(value) => Some(value.as_str()), + _ => None, + }), + Some("500m") + ); + assert_eq!( + limits + .fields + .get("memory") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StringValue(value) => Some(value.as_str()), + _ => None, + }), + Some("2Gi") + ); + assert!(!resources.fields.contains_key("requests")); +} + +#[tokio::test] +async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { + let server = run_server().await; + enable_providers_v2(&server).await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("v2-no-inferred-provider"), + None, + "openshell", + None, + true, + false, + None, + None, + None, + None, + &[], + None, + None, + &["claude".to_string(), "--version".to_string()], + Some(true), + Some(false), + &HashMap::new(), + &tls, + ) + .await + .expect("sandbox create should succeed without inferred provider"); + + let requests = create_requests(&server).await; + let providers = requests[0] + .spec + .as_ref() + .expect("sandbox spec should be sent") + .providers + .clone(); + assert!( + providers.is_empty(), + "providers v2 should not infer command providers, got {providers:?}" + ); +} + +#[tokio::test] +async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { + let server = run_server().await; + server + .openshell + .state + .vm_error_after_started + .store(true, Ordering::SeqCst); + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env_with( + &fake_ssh_dir, + &xdg_dir, + &[("OPENSHELL_PROVISION_TIMEOUT", "1".to_string())], + ); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + let started_at = Instant::now(); + let err = run::sandbox_create( + &server.endpoint, + Some("vm-error"), + None, + "openshell", + None, + true, + false, + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + &tls, + ) + .await + .expect_err("sandbox create should fail on terminal VM error"); + + assert!( + started_at.elapsed() < Duration::from_secs(2), + "terminal VM errors should not wait for the provisioning timeout" + ); + let rendered = err.to_string(); + assert!(rendered.contains("sandbox entered error phase while provisioning")); + assert!(rendered.contains("ProcessExited: VM process exited with status 0")); + assert!(!rendered.contains("timed out")); +} + +#[tokio::test] +async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { + let server = run_server().await; + server + .openshell + .state + .vm_slow_progress_before_ready + .store(true, Ordering::SeqCst); + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env_with( + &fake_ssh_dir, + &xdg_dir, + &[("OPENSHELL_PROVISION_TIMEOUT", "1".to_string())], + ); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("vm-slow-progress"), + None, + "openshell", + None, + true, + false, + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + &tls, + ) + .await + .expect("sandbox create should not time out while VM progress is active"); +} + +#[tokio::test] +async fn sandbox_create_times_out_when_only_logs_arrive() { + let server = run_server().await; + server + .openshell + .state + .vm_log_churn_before_ready + .store(true, Ordering::SeqCst); + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env_with( + &fake_ssh_dir, + &xdg_dir, + &[("OPENSHELL_PROVISION_TIMEOUT", "1".to_string())], + ); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + let started_at = Instant::now(); + let err = run::sandbox_create( + &server.endpoint, + Some("vm-log-churn"), + None, + "openshell", + None, + true, + false, + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + &tls, + ) + .await + .expect_err("sandbox create should time out when only logs arrive"); + + assert!( + started_at.elapsed() < Duration::from_secs(2), + "logs should not extend the provisioning timeout" + ); + assert!(err.to_string().contains("sandbox provisioning timed out")); +} + #[tokio::test] async fn sandbox_create_deletes_command_sessions_with_no_keep() { let server = run_server().await; @@ -646,6 +1096,8 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { false, None, None, + None, + None, &[], None, None, @@ -688,6 +1140,8 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { false, None, None, + None, + None, &[], None, None, @@ -730,6 +1184,8 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { false, None, None, + None, + None, &[], None, None, @@ -758,10 +1214,9 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { let _env = test_env(&fake_ssh_dir, &xdg_dir); let tls = test_tls(&server); install_fake_ssh(&fake_ssh_dir); - let forward_port = { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - listener.local_addr().unwrap().port() - }; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let forward_port = listener.local_addr().unwrap().port(); + drop(listener); run::sandbox_create( &server.endpoint, @@ -773,6 +1228,8 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { false, None, None, + None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index ac1ff37c6..f49ca71db 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -1,23 +1,29 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +mod helpers; + +use helpers::{ + EnvVarGuard, build_ca, build_client_cert, build_server_cert, install_rustls_provider, +}; use openshell_bootstrap::{load_last_sandbox, save_last_sandbox}; use openshell_cli::run; use openshell_cli::tls::TlsOptions; use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, + CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, + DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, Sandbox, SandboxPolicy, SandboxResponse, SandboxStreamEvent, ServiceStatus, - SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, -}; -use rcgen::{ - BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, + GetSandboxConfigResponse, GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, + GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, + HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, NetworkEndpoint, NetworkPolicyRule, PolicyStatus, ProviderResponse, + Sandbox, SandboxPolicy, SandboxPolicyRevision, SandboxResponse, SandboxStreamEvent, + ServiceStatus, SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, }; use std::sync::Arc; use tempfile::TempDir; @@ -27,50 +33,6 @@ use tokio_stream::wrappers::TcpListenerStream; use tonic::transport::{Certificate as TlsCertificate, Identity, Server, ServerTlsConfig}; use tonic::{Response, Status}; -// Serialise tests that mutate XDG_CONFIG_HOME so concurrent threads -// don't clobber each other's environment. -static XDG_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); - -struct EnvVarGuard { - key: &'static str, - original: Option, - _xdg_lock: std::sync::MutexGuard<'static, ()>, -} - -#[allow(unsafe_code)] -impl EnvVarGuard { - fn set(key: &'static str, value: &str) -> Self { - let lock = XDG_LOCK - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - let original = std::env::var(key).ok(); - unsafe { - std::env::set_var(key, value); - } - Self { - key, - original, - _xdg_lock: lock, - } - } -} - -#[allow(unsafe_code)] -impl Drop for EnvVarGuard { - fn drop(&mut self) { - if let Some(value) = &self.original { - unsafe { - std::env::set_var(self.key, value); - } - } else { - unsafe { - std::env::remove_var(self.key); - } - } - // _xdg_lock drops here, releasing the mutex - } -} - // ── mock OpenShell server ───────────────────────────────────────────── /// Records which sandbox name was requested via `get_sandbox`. @@ -116,6 +78,7 @@ impl OpenShell for TestOpenShell { name, created_at_ms: 0, labels: std::collections::HashMap::new(), + resource_version: 0, }), ..Default::default() }), @@ -129,6 +92,27 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(ListSandboxProvidersResponse::default())) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(AttachSandboxProviderResponse::default())) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(DetachSandboxProviderResponse::default())) + } + async fn delete_sandbox( &self, _request: tonic::Request, @@ -177,6 +161,36 @@ impl OpenShell for TestOpenShell { Ok(Response::new(CreateSshSessionResponse::default())) } + async fn expose_service( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ServiceEndpointResponse::default(), + )) + } + + async fn get_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn list_services( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn revoke_ssh_session( &self, _request: tonic::Request, @@ -248,6 +262,33 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Ok(Response::new(ProviderResponse::default())) } + async fn get_provider_refresh_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn configure_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn rotate_provider_credential( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } async fn delete_provider( &self, @@ -283,6 +324,15 @@ impl OpenShell for TestOpenShell { ))) } + type ExecSandboxInteractiveStream = + tokio_stream::wrappers::ReceiverStream>; + async fn exec_sandbox_interactive( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_config( &self, _request: tonic::Request, @@ -292,9 +342,46 @@ impl OpenShell for TestOpenShell { async fn get_sandbox_policy_status( &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) + request: tonic::Request, + ) -> Result, Status> { + let req = request.into_inner(); + assert_eq!(req.name, "my-sandbox"); + assert_eq!(req.version, 0); + assert!(!req.global); + + let policy = SandboxPolicy { + version: 7, + network_policies: std::iter::once(( + "api".to_string(), + NetworkPolicyRule { + name: "api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "read-only".to_string(), + ..Default::default() + }], + ..Default::default() + }, + )) + .collect(), + ..Default::default() + }; + + Ok(Response::new(GetSandboxPolicyStatusResponse { + revision: Some(SandboxPolicyRevision { + version: 7, + policy_hash: "sha256:test-policy".to_string(), + status: PolicyStatus::Loaded.into(), + created_at_ms: 1_700_000_000_000, + loaded_at_ms: 1_700_000_000_500, + policy: Some(policy), + ..Default::default() + }), + active_version: 7, + })) } async fn list_sandbox_policies( @@ -388,6 +475,20 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn issue_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn refresh_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn connect_supervisor( &self, _request: tonic::Request>, @@ -404,36 +505,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } -} - -// ── helpers ─────────────────────────────────────────────────────────── - -fn install_rustls_provider() { - let _ = rustls::crypto::ring::default_provider().install_default(); -} -fn build_ca() -> (Certificate, KeyPair) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); - let cert = params.self_signed(&key_pair).unwrap(); - (cert, key_pair) -} + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; -fn build_server_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(vec!["localhost".to_string()]).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - (cert.pem(), key_pair.serialize_pem()) -} - -fn build_client_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { - let key_pair = KeyPair::generate().unwrap(); - let mut params = CertificateParams::new(Vec::::new()).unwrap(); - params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth]; - let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); - (cert.pem(), key_pair.serialize_pem()) + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } struct TestServer { @@ -532,7 +614,7 @@ async fn sandbox_get_policy_only_round_trip() { async fn sandbox_get_with_persisted_last_sandbox() { let ts = run_server().await; let xdg_dir = tempfile::tempdir().unwrap(); - let _guard = EnvVarGuard::set("XDG_CONFIG_HOME", xdg_dir.path().to_str().unwrap()); + let _guard = EnvVarGuard::set(&[("XDG_CONFIG_HOME", xdg_dir.path().to_str().unwrap())]); // Persist a last-used sandbox for "integration-cluster". save_last_sandbox("integration-cluster", "persisted-sb") @@ -556,12 +638,51 @@ async fn sandbox_get_with_persisted_last_sandbox() { ); } +#[tokio::test] +async fn policy_get_full_json_cli_prints_policy_payload() { + let ts = run_server().await; + let mut stdout = Vec::new(); + let mut stderr = Vec::new(); + + run::sandbox_policy_get_to_writer( + &ts.endpoint, + "my-sandbox", + 0, + true, + "json", + &ts.tls, + (&mut stdout, &mut stderr), + ) + .await + .expect("policy get should succeed"); + + assert!( + stderr.is_empty(), + "policy get should not print stderr: {}", + String::from_utf8_lossy(&stderr) + ); + + let json: serde_json::Value = + serde_json::from_slice(&stdout).expect("stdout should be valid JSON"); + assert_eq!(json["scope"], "sandbox"); + assert_eq!(json["sandbox"], "my-sandbox"); + assert_eq!(json["version"], 7); + assert_eq!(json["active_version"], 7); + assert_eq!(json["hash"], "sha256:test-policy"); + assert_eq!(json["status"], "loaded"); + assert_eq!(json["policy"]["network_policies"]["api"]["name"], "api"); + assert_eq!( + json["policy"]["network_policies"]["api"]["endpoints"][0]["host"], + "api.example.com" + ); +} + /// Verify that an explicit name takes precedence over the persisted one. #[tokio::test] async fn explicit_name_takes_precedence_over_persisted() { let ts = run_server().await; let xdg_dir = tempfile::tempdir().unwrap(); - let _guard = EnvVarGuard::set("XDG_CONFIG_HOME", xdg_dir.path().to_str().unwrap()); + let _guard = EnvVarGuard::set(&[("XDG_CONFIG_HOME", xdg_dir.path().to_str().unwrap())]); // Persist one name, but supply a different one explicitly. save_last_sandbox("my-cluster", "old-sandbox").expect("save should succeed"); diff --git a/crates/openshell-core/src/auth.rs b/crates/openshell-core/src/auth.rs new file mode 100644 index 000000000..16d513346 --- /dev/null +++ b/crates/openshell-core/src/auth.rs @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! gRPC authentication interceptor shared by CLI and TUI. + +use miette::Result; + +/// Interceptor that injects authentication headers into every outgoing gRPC request. +/// +/// Supports application-layer Bearer tokens (standard `authorization` +/// header) and Cloudflare Access tokens (custom headers). When no token is +/// set, acts as a no-op. OIDC takes precedence over edge tokens. +#[derive(Clone)] +#[allow(clippy::struct_field_names)] +pub struct EdgeAuthInterceptor { + bearer_value: Option>, + header_value: Option>, + cookie_value: Option>, +} + +impl EdgeAuthInterceptor { + /// Create an interceptor from optional token strings. + /// + /// OIDC bearer tokens take precedence over edge tokens. Returns a no-op + /// interceptor when no token is provided. + pub fn new(oidc_token: Option<&str>, edge_token: Option<&str>) -> Result { + if let Some(token) = oidc_token { + let bearer: tonic::metadata::MetadataValue = + format!("Bearer {token}") + .parse() + .map_err(|_| miette::miette!("invalid bearer token value"))?; + return Ok(Self { + bearer_value: Some(bearer), + header_value: None, + cookie_value: None, + }); + } + + let (header_value, cookie_value) = match edge_token { + Some(t) => { + let hv: tonic::metadata::MetadataValue = t + .parse() + .map_err(|_| miette::miette!("invalid edge token value"))?; + let cv: tonic::metadata::MetadataValue = + format!("CF_Authorization={t}") + .parse() + .map_err(|_| miette::miette!("invalid edge token value for cookie"))?; + (Some(hv), Some(cv)) + } + None => (None, None), + }; + Ok(Self { + bearer_value: None, + header_value, + cookie_value, + }) + } + + /// No-op interceptor that passes requests through without modification. + pub fn noop() -> Self { + Self { + bearer_value: None, + header_value: None, + cookie_value: None, + } + } +} + +impl tonic::service::Interceptor for EdgeAuthInterceptor { + fn call( + &mut self, + mut req: tonic::Request<()>, + ) -> std::result::Result, tonic::Status> { + if let Some(ref val) = self.bearer_value { + req.metadata_mut().insert("authorization", val.clone()); + } + if let Some(ref val) = self.header_value { + req.metadata_mut() + .insert("cf-access-jwt-assertion", val.clone()); + } + if let Some(ref val) = self.cookie_value { + req.metadata_mut().insert("cookie", val.clone()); + } + Ok(req) + } +} diff --git a/crates/openshell-core/src/config.rs b/crates/openshell-core/src/config.rs index 1ec06677b..98562c8a6 100644 --- a/crates/openshell-core/src/config.rs +++ b/crates/openshell-core/src/config.rs @@ -6,7 +6,9 @@ use serde::{Deserialize, Serialize}; use std::fmt; use std::net::SocketAddr; -use std::path::PathBuf; +#[cfg(unix)] +use std::os::unix::fs::FileTypeExt; +use std::path::{Path, PathBuf}; use std::process::Command; use std::str::FromStr; @@ -19,29 +21,20 @@ use std::str::FromStr; /// Default SSH port inside sandbox containers. pub const DEFAULT_SSH_PORT: u16 = 2222; -/// Default server / SSH gateway port. -pub const DEFAULT_SERVER_PORT: u16 = 8080; +/// Default gateway server port. +pub const DEFAULT_SERVER_PORT: u16 = 17670; /// Default container stop timeout in seconds (SIGTERM → SIGKILL). pub const DEFAULT_STOP_TIMEOUT_SECS: u32 = 10; -/// Default allowed clock skew for SSH handshake validation, in seconds. -pub const DEFAULT_SSH_HANDSHAKE_SKEW_SECS: u64 = 300; - -/// Default Podman bridge network name. -pub const DEFAULT_NETWORK_NAME: &str = "openshell"; - /// Default Docker bridge network name for local sandboxes. pub const DEFAULT_DOCKER_NETWORK_NAME: &str = "openshell-docker"; -/// Default OCI image for the openshell-sandbox supervisor binary. -pub const DEFAULT_SUPERVISOR_IMAGE: &str = "openshell/supervisor:latest"; - -/// Default image pull policy for sandbox images. -pub const DEFAULT_IMAGE_PULL_POLICY: &str = "missing"; +/// Default domain used for browser-facing sandbox service URLs. +pub const DEFAULT_SERVICE_ROUTING_DOMAIN: &str = "openshell.localhost"; -/// Default Kubernetes namespace for sandbox resources. -pub const DEFAULT_K8S_NAMESPACE: &str = "openshell"; +/// Default OCI image for the openshell-sandbox supervisor binary. +pub const DEFAULT_SUPERVISOR_IMAGE: &str = "ghcr.io/nvidia/openshell/supervisor:latest"; /// CDI device identifier for requesting all NVIDIA GPUs. pub const CDI_GPU_DEVICE_ALL: &str = "nvidia.com/gpu=all"; @@ -108,8 +101,8 @@ pub fn detect_driver() -> Option { return Some(ComputeDriverKind::Podman); } - // Docker: check if docker binary is available - if is_binary_available("docker") { + // Docker: check if the CLI is available or a local Docker socket exists. + if is_docker_available() { return Some(ComputeDriverKind::Docker); } @@ -124,6 +117,55 @@ fn is_binary_available(name: &str) -> bool { .is_ok_and(|output| output.status.success()) } +fn is_docker_available() -> bool { + is_binary_available("docker") || docker_socket_available() +} + +fn docker_socket_available() -> bool { + docker_socket_candidates() + .iter() + .any(|path| is_unix_socket(path)) +} + +fn docker_socket_candidates() -> Vec { + let mut candidates = Vec::new(); + + if let Ok(host) = std::env::var("DOCKER_HOST") + && let Some(path) = docker_host_unix_socket_path(&host) + { + candidates.push(path); + } + + candidates.push(PathBuf::from("/var/run/docker.sock")); + + if let Some(home) = std::env::var_os("HOME") { + candidates.push(PathBuf::from(home).join(".docker/run/docker.sock")); + } + + if let Some(runtime_dir) = std::env::var_os("XDG_RUNTIME_DIR") { + candidates.push(PathBuf::from(runtime_dir).join("docker.sock")); + } + + candidates +} + +fn docker_host_unix_socket_path(host: &str) -> Option { + let path = host.trim().strip_prefix("unix://")?; + (!path.is_empty()).then(|| PathBuf::from(path)) +} + +#[cfg(unix)] +fn is_unix_socket(path: &Path) -> bool { + path.metadata() + .is_ok_and(|metadata| metadata.file_type().is_socket()) +} + +#[cfg(not(unix))] +fn is_unix_socket(path: &Path) -> bool { + let _ = path; + false +} + /// Server configuration. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { @@ -163,6 +205,24 @@ pub struct Config { #[serde(default)] pub oidc: Option, + /// Gateway user authentication behavior. + #[serde(default)] + pub auth: GatewayAuthConfig, + + /// mTLS user authentication configuration. When enabled, a verified TLS + /// client certificate can authenticate CLI/SDK callers as a + /// `Principal::User`. This is for local single-user gateways only; + /// sandbox identity is always carried by gateway-minted sandbox JWTs. + #[serde(default)] + pub mtls_auth: MtlsAuthConfig, + + /// Gateway-minted sandbox JWT configuration. When `Some`, the gateway + /// loads the signing key from disk and accepts gateway-issued sandbox + /// JWTs as `Principal::Sandbox`. Required for the per-sandbox identity + /// flow (issue #1354). + #[serde(default)] + pub gateway_jwt: Option, + /// Database URL for persistence. pub database_url: String, @@ -174,86 +234,40 @@ pub struct Config { #[serde(default)] pub compute_drivers: Vec, - /// Kubernetes namespace for sandboxes. - #[serde(default = "default_sandbox_namespace")] - pub sandbox_namespace: String, - - /// Default container image for sandboxes. - #[serde(default = "default_sandbox_image")] - pub sandbox_image: String, - - /// Kubernetes `imagePullPolicy` for sandbox pods (e.g. `Always`, - /// `IfNotPresent`, `Never`). Defaults to empty, which lets Kubernetes - /// apply its own default (`:latest` → `Always`, anything else → - /// `IfNotPresent`). - #[serde(default)] - pub sandbox_image_pull_policy: String, - - /// gRPC endpoint for sandboxes to connect back to `OpenShell`. - /// Used by sandbox pods to fetch their policy at startup. - #[serde(default)] - pub grpc_endpoint: String, - - /// Public gateway host for SSH proxy connections. - #[serde(default = "default_ssh_gateway_host")] - pub ssh_gateway_host: String, - - /// Public gateway port for SSH proxy connections. - #[serde(default = "default_ssh_gateway_port")] - pub ssh_gateway_port: u16, - - /// Path for SSH CONNECT/upgrade requests. - #[serde(default = "default_ssh_connect_path")] - pub ssh_connect_path: String, - - /// SSH listen port inside sandbox containers that expose a TCP endpoint. - #[serde(default = "default_sandbox_ssh_port")] - pub sandbox_ssh_port: u16, - - /// Filesystem path where the sandbox supervisor binds its SSH Unix - /// socket. The supervisor is passed this path via - /// `OPENSHELL_SSH_SOCKET_PATH` / `--ssh-socket-path` and connects its - /// relay bridge to the same path. - /// - /// When the gateway orchestrates sandboxes that each live in their own - /// filesystem (K8s pod, libkrun VM, etc.), the default is safe. For - /// local dev where multiple supervisors share `/run`, override this to - /// something unique per sandbox. - #[serde(default = "default_sandbox_ssh_socket_path")] - pub sandbox_ssh_socket_path: String, - - /// Shared secret for gateway-to-sandbox SSH handshake. - #[serde(default)] - pub ssh_handshake_secret: String, - - /// Allowed clock skew for SSH handshake validation, in seconds. - #[serde(default = "default_ssh_handshake_skew_secs")] - pub ssh_handshake_skew_secs: u64, - /// TTL for SSH session tokens, in seconds. 0 disables expiry. #[serde(default = "default_ssh_session_ttl_secs")] pub ssh_session_ttl_secs: u64, - /// Kubernetes secret name containing client TLS materials for sandbox pods. - /// When set, sandbox pods get this secret mounted so they can connect to - /// the server over mTLS. + /// Browser-facing sandbox service routing configuration. #[serde(default)] - pub client_tls_secret_name: String, + pub service_routing: ServiceRoutingConfig, +} - /// Host gateway IP for sandbox pod hostAliases. - /// When set, sandbox pods get hostAliases entries mapping - /// `host.docker.internal` and `host.openshell.internal` to this IP, - /// allowing them to reach services running on the Docker host. - #[serde(default)] - pub host_gateway_ip: String, +/// Browser-facing sandbox service routing configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServiceRoutingConfig { + /// Base domains accepted for `sandbox--service.` routes. + /// The first domain is used when the gateway prints endpoint URLs. + #[serde(default = "default_service_routing_domains")] + pub base_domains: Vec, + + /// Enable TLS-enabled loopback gateway listeners to also accept plaintext + /// HTTP for sandbox service hostnames. + #[serde(default = "default_enable_loopback_service_http")] + pub enable_loopback_service_http: bool, } /// TLS configuration. /// -/// By default mTLS is enforced — all clients must present a certificate -/// signed by the given CA. When `allow_unauthenticated` is `true`, the -/// TLS handshake also accepts connections without a client certificate -/// (needed for reverse-proxy deployments like Cloudflare Tunnel). +/// Two modes are supported: +/// - **HTTPS with optional mTLS** (`client_ca_path = Some`): +/// Client certificates are validated against the given CA when presented, +/// but never required. Clients may connect with or without a certificate. +/// - **HTTPS-only** (`client_ca_path = None`): +/// Server-side TLS only; no client certificates are requested. +/// +/// In both modes, authentication is handled at the application layer +/// (e.g. OIDC bearer tokens). mTLS is an additional mechanism. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TlsConfig { /// Path to the TLS certificate file. @@ -262,16 +276,17 @@ pub struct TlsConfig { /// Path to the TLS private key file. pub key_path: PathBuf, - /// Path to the CA certificate file for client certificate verification (mTLS). - /// The server requires all clients to present a valid certificate signed by - /// this CA. - pub client_ca_path: PathBuf, + /// Path to the CA certificate file for client certificate verification. + /// When `Some`, client certs signed by this CA are validated. + /// When `None`, the server does not request client certs. + #[serde(default)] + pub client_ca_path: Option, - /// When `true`, the TLS handshake succeeds even without a client - /// certificate. Application-layer middleware must then enforce auth - /// (e.g. via a CF JWT header). + /// When `true` and `client_ca_path` is `Some`, the TLS handshake rejects + /// connections that do not present a valid client certificate. + /// When `false`, client certificates are accepted but not required. #[serde(default)] - pub allow_unauthenticated: bool, + pub require_client_auth: bool, } /// OIDC (`OpenID` Connect) configuration for JWT-based authentication. @@ -316,10 +331,62 @@ pub struct OidcConfig { pub scopes_claim: String, } +/// mTLS user authentication for local, single-user gateways. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct MtlsAuthConfig { + /// When true, the gateway maps a verified TLS client certificate into a + /// user principal. Keep disabled for Kubernetes deployments because + /// Kubernetes sandbox pods and external users must not share user auth. + #[serde(default)] + pub enabled: bool, +} + +/// Gateway user authentication settings. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct GatewayAuthConfig { + /// When true, unauthenticated user/CLI calls are accepted as a local + /// developer principal. This is an unsafe local-development escape hatch + /// for trusted, non-shared gateways. Sandbox supervisor calls still use + /// gateway-minted sandbox JWTs. + #[serde(default)] + pub allow_unauthenticated_users: bool, +} + const fn default_jwks_ttl_secs() -> u64 { 3600 } +/// Gateway-minted sandbox JWT configuration. +/// +/// Points the gateway at the Ed25519 signing key (produced by `certgen`) +/// and identifies the issuer string embedded in every minted token. The +/// signing key never leaves the gateway process; the public key is loaded +/// by the same gateway so it can validate its own tokens. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GatewayJwtConfig { + /// Path to the Ed25519 signing key (PKCS#8 PEM). + pub signing_key_path: PathBuf, + /// Path to the matching public key (SPKI PEM). + pub public_key_path: PathBuf, + /// Path to the `kid` value (plain text, one line). + pub kid_path: PathBuf, + /// Stable gateway identity embedded in `iss`/`aud`. Defaults to the + /// hostname-or-`openshell` placeholder if unset. + #[serde(default = "default_gateway_id")] + pub gateway_id: String, + /// Token lifetime in seconds. Defaults to 1 hour. + #[serde(default = "default_sandbox_token_ttl_secs")] + pub ttl_secs: u64, +} + +fn default_gateway_id() -> String { + "openshell".to_string() +} + +const fn default_sandbox_token_ttl_secs() -> u64 { + 3_600 +} + fn default_roles_claim() -> String { "realm_access.roles".to_string() } @@ -343,22 +410,13 @@ impl Config { log_level: default_log_level(), tls, oidc: None, + auth: GatewayAuthConfig::default(), + mtls_auth: MtlsAuthConfig::default(), + gateway_jwt: None, database_url: String::new(), compute_drivers: vec![], - sandbox_namespace: default_sandbox_namespace(), - sandbox_image: default_sandbox_image(), - sandbox_image_pull_policy: String::new(), - grpc_endpoint: String::new(), - ssh_gateway_host: default_ssh_gateway_host(), - ssh_gateway_port: default_ssh_gateway_port(), - ssh_connect_path: default_ssh_connect_path(), - sandbox_ssh_port: default_sandbox_ssh_port(), - sandbox_ssh_socket_path: default_sandbox_ssh_socket_path(), - ssh_handshake_secret: String::new(), - ssh_handshake_skew_secs: default_ssh_handshake_skew_secs(), ssh_session_ttl_secs: default_ssh_session_ttl_secs(), - client_tls_secret_name: String::new(), - host_gateway_ip: String::new(), + service_routing: ServiceRoutingConfig::default(), } } @@ -418,76 +476,6 @@ impl Config { self } - /// Create a new configuration with a sandbox namespace. - #[must_use] - pub fn with_sandbox_namespace(mut self, namespace: impl Into) -> Self { - self.sandbox_namespace = namespace.into(); - self - } - - /// Create a new configuration with a default sandbox image. - #[must_use] - pub fn with_sandbox_image(mut self, image: impl Into) -> Self { - self.sandbox_image = image.into(); - self - } - - /// Create a new configuration with a sandbox image pull policy. - #[must_use] - pub fn with_sandbox_image_pull_policy(mut self, policy: impl Into) -> Self { - self.sandbox_image_pull_policy = policy.into(); - self - } - - /// Create a new configuration with a gRPC endpoint for sandbox callback. - #[must_use] - pub fn with_grpc_endpoint(mut self, endpoint: impl Into) -> Self { - self.grpc_endpoint = endpoint.into(); - self - } - - /// Create a new configuration with the SSH gateway host. - #[must_use] - pub fn with_ssh_gateway_host(mut self, host: impl Into) -> Self { - self.ssh_gateway_host = host.into(); - self - } - - /// Create a new configuration with the SSH gateway port. - #[must_use] - pub const fn with_ssh_gateway_port(mut self, port: u16) -> Self { - self.ssh_gateway_port = port; - self - } - - /// Create a new configuration with the SSH connect path. - #[must_use] - pub fn with_ssh_connect_path(mut self, path: impl Into) -> Self { - self.ssh_connect_path = path.into(); - self - } - - /// Create a new configuration with the sandbox SSH port. - #[must_use] - pub const fn with_sandbox_ssh_port(mut self, port: u16) -> Self { - self.sandbox_ssh_port = port; - self - } - - /// Create a new configuration with the SSH handshake secret. - #[must_use] - pub fn with_ssh_handshake_secret(mut self, secret: impl Into) -> Self { - self.ssh_handshake_secret = secret.into(); - self - } - - /// Create a new configuration with SSH handshake skew allowance. - #[must_use] - pub const fn with_ssh_handshake_skew_secs(mut self, secs: u64) -> Self { - self.ssh_handshake_skew_secs = secs; - self - } - /// Create a new configuration with the SSH session TTL. #[must_use] pub const fn with_ssh_session_ttl_secs(mut self, secs: u64) -> Self { @@ -495,66 +483,108 @@ impl Config { self } - /// Set the Kubernetes secret name for sandbox client TLS materials. + /// Set the OIDC configuration for JWT-based authentication. #[must_use] - pub fn with_client_tls_secret_name(mut self, name: impl Into) -> Self { - self.client_tls_secret_name = name.into(); + pub fn with_oidc(mut self, oidc: OidcConfig) -> Self { + self.oidc = Some(oidc); self } - /// Set the host gateway IP for sandbox pod hostAliases. + /// Derive browser-facing sandbox service domains from gateway server SANs. + /// + /// Wildcard DNS SANs such as `*.apps.example.com` enable service URLs + /// under `apps.example.com`. Non-wildcard DNS names and IP SANs do not + /// enable service subdomains. #[must_use] - pub fn with_host_gateway_ip(mut self, ip: impl Into) -> Self { - self.host_gateway_ip = ip.into(); + pub fn with_server_sans(mut self, sans: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.service_routing.base_domains = service_routing_domains_from_server_sans(sans); self } - /// Set the OIDC configuration for JWT-based authentication. + /// Enable or disable plaintext HTTP routing for loopback sandbox service + /// hostnames on TLS-enabled gateway listeners. #[must_use] - pub fn with_oidc(mut self, oidc: OidcConfig) -> Self { - self.oidc = Some(oidc); + pub const fn with_loopback_service_http(mut self, enabled: bool) -> Self { + self.service_routing.enable_loopback_service_http = enabled; self } } -fn default_bind_address() -> SocketAddr { - "127.0.0.1:8080".parse().expect("valid default address") -} - -fn default_log_level() -> String { - "info".to_string() +impl Default for ServiceRoutingConfig { + fn default() -> Self { + Self { + base_domains: default_service_routing_domains(), + enable_loopback_service_http: default_enable_loopback_service_http(), + } + } } -fn default_sandbox_namespace() -> String { - "default".to_string() +fn default_bind_address() -> SocketAddr { + "127.0.0.1:17670".parse().expect("valid default address") } -fn default_sandbox_image() -> String { - format!("{}/base:latest", crate::image::DEFAULT_COMMUNITY_REGISTRY) +fn default_service_routing_domains() -> Vec { + vec![DEFAULT_SERVICE_ROUTING_DOMAIN.to_string()] } -fn default_ssh_gateway_host() -> String { - "127.0.0.1".to_string() +const fn default_enable_loopback_service_http() -> bool { + true } -const fn default_ssh_gateway_port() -> u16 { - DEFAULT_SERVER_PORT +fn service_routing_domains_from_server_sans(sans: I) -> Vec +where + I: IntoIterator, + S: Into, +{ + let mut domains = Vec::new(); + for san in sans { + if let Some(domain) = service_routing_domain_from_server_san(&san.into()) + && !domains.contains(&domain) + { + domains.push(domain); + } + } + for domain in default_service_routing_domains() { + if !domains.contains(&domain) { + domains.push(domain); + } + } + domains } -fn default_ssh_connect_path() -> String { - "/connect/ssh".to_string() +fn service_routing_domain_from_server_san(san: &str) -> Option { + let san = san.trim().trim_matches('.').to_ascii_lowercase(); + let domain = san.strip_prefix("*.")?; + normalize_service_routing_domain(domain) } -fn default_sandbox_ssh_socket_path() -> String { - "/run/openshell/ssh.sock".to_string() +fn normalize_service_routing_domain(domain: &str) -> Option { + let domain = domain.trim().trim_matches('.'); + if domain.is_empty() || domain.len() > 253 { + return None; + } + let labels = domain.split('.'); + if labels.clone().any(|label| !is_dns_label(label)) { + return None; + } + Some(domain.to_string()) } -const fn default_sandbox_ssh_port() -> u16 { - DEFAULT_SSH_PORT +fn is_dns_label(label: &str) -> bool { + if label.is_empty() || label.len() > 63 || label.starts_with('-') || label.ends_with('-') { + return false; + } + label + .bytes() + .all(|byte| byte.is_ascii_lowercase() || byte.is_ascii_digit() || byte == b'-') } -const fn default_ssh_handshake_skew_secs() -> u64 { - DEFAULT_SSH_HANDSHAKE_SKEW_SECS +fn default_log_level() -> String { + "info".to_string() } const fn default_ssh_session_ttl_secs() -> u64 { @@ -563,8 +593,14 @@ const fn default_ssh_session_ttl_secs() -> u64 { #[cfg(test)] mod tests { - use super::{ComputeDriverKind, Config, detect_driver}; + use super::{ + ComputeDriverKind, Config, DEFAULT_SERVICE_ROUTING_DOMAIN, detect_driver, + docker_host_unix_socket_path, is_unix_socket, + }; use std::net::SocketAddr; + #[cfg(unix)] + use std::os::unix::net::UnixListener; + use std::path::PathBuf; #[test] fn compute_driver_kind_parses_supported_values() { @@ -594,7 +630,7 @@ mod tests { #[test] fn config_defaults_to_loopback_bind_address() { - let expected: SocketAddr = "127.0.0.1:8080".parse().expect("valid address"); + let expected: SocketAddr = "127.0.0.1:17670".parse().expect("valid address"); assert_eq!(Config::new(None).bind_address, expected); } @@ -604,6 +640,58 @@ mod tests { assert!(cfg.health_bind_address.is_none()); } + #[test] + fn config_disables_unauthenticated_users_by_default() { + let cfg = Config::new(None); + assert!(!cfg.auth.allow_unauthenticated_users); + } + + #[test] + fn service_routing_allows_loopback_plaintext_http_by_default() { + let cfg = Config::new(None); + assert_eq!( + cfg.service_routing.base_domains, + vec![DEFAULT_SERVICE_ROUTING_DOMAIN.to_string()] + ); + assert!(cfg.service_routing.enable_loopback_service_http); + } + + #[test] + fn server_sans_update_preserves_loopback_plaintext_http_flag() { + let cfg = Config::new(None) + .with_loopback_service_http(false) + .with_server_sans(["*.dev.openshell.localhost"]); + + assert_eq!( + cfg.service_routing.base_domains, + vec![ + "dev.openshell.localhost".to_string(), + DEFAULT_SERVICE_ROUTING_DOMAIN.to_string() + ] + ); + assert!(!cfg.service_routing.enable_loopback_service_http); + } + + #[test] + fn service_routing_domains_are_derived_from_wildcard_server_sans() { + let cfg = Config::new(None).with_server_sans([ + "gateway.example.com", + "*.apps.example.com", + "127.0.0.1", + "*.apps.example.com", + "*.dev.example.com.", + ]); + + assert_eq!( + cfg.service_routing.base_domains, + vec![ + "apps.example.com".to_string(), + "dev.example.com".to_string(), + DEFAULT_SERVICE_ROUTING_DOMAIN.to_string(), + ] + ); + } + #[test] fn config_with_health_bind_address_sets_address() { let addr: SocketAddr = "0.0.0.0:9090".parse().expect("valid address"); @@ -614,12 +702,36 @@ mod tests { #[test] fn detect_driver_returns_none_without_k8s_env_or_binaries() { // When KUBERNETES_SERVICE_HOST is not set and no docker/podman binaries - // are available, detect_driver should return None. + // or Docker socket are available, detect_driver should return None. // This test may pass or fail depending on the test environment, // but it documents the expected behavior. let _ = detect_driver(); // Returns Some or None based on environment } + #[test] + fn docker_host_unix_socket_path_parses_unix_hosts() { + assert_eq!( + docker_host_unix_socket_path("unix:///var/run/docker.sock"), + Some(PathBuf::from("/var/run/docker.sock")) + ); + assert_eq!(docker_host_unix_socket_path("tcp://127.0.0.1:2375"), None); + assert_eq!(docker_host_unix_socket_path("unix://"), None); + } + + #[cfg(unix)] + #[test] + fn is_unix_socket_detects_socket_files() { + let temp_dir = tempfile::tempdir().expect("create temp dir"); + let socket_path = temp_dir.path().join("docker.sock"); + let _listener = UnixListener::bind(&socket_path).expect("bind unix socket"); + + assert!(is_unix_socket(&socket_path)); + + let regular_file = temp_dir.path().join("not-a-socket"); + std::fs::write(®ular_file, b"not a socket").expect("write regular file"); + assert!(!is_unix_socket(®ular_file)); + } + #[test] #[allow(unsafe_code)] // std::env::set_var/remove_var require unsafe in Rust 2024 fn detect_driver_prefers_kubernetes_when_k8s_env_is_set() { diff --git a/crates/openshell-core/src/driver_utils.rs b/crates/openshell-core/src/driver_utils.rs new file mode 100644 index 000000000..4438fbbe4 --- /dev/null +++ b/crates/openshell-core/src/driver_utils.rs @@ -0,0 +1,95 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Utility helpers shared across compute-driver crates. + +use std::path::PathBuf; + +use crate::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; + +// --------------------------------------------------------------------------- +// Sandbox container/pod label keys (openshell.ai/ namespace) +// --------------------------------------------------------------------------- + +/// Container/pod label that identifies this resource as managed by `OpenShell`. +/// Value should be `"openshell"`. +pub const LABEL_MANAGED_BY: &str = "openshell.ai/managed-by"; + +/// Expected value for [`LABEL_MANAGED_BY`]. +pub const LABEL_MANAGED_BY_VALUE: &str = "openshell"; + +/// Container/pod label carrying the sandbox ID. +pub const LABEL_SANDBOX_ID: &str = "openshell.ai/sandbox-id"; + +/// Container/pod label carrying the sandbox name. +pub const LABEL_SANDBOX_NAME: &str = "openshell.ai/sandbox-name"; + +/// Container/pod label carrying the sandbox namespace. +pub const LABEL_SANDBOX_NAMESPACE: &str = "openshell.ai/sandbox-namespace"; + +// --------------------------------------------------------------------------- + +/// Path to the sandbox supervisor binary inside the container image. +/// +/// All compute drivers must launch this binary as the container entrypoint to +/// start the sandboxed environment. The value must be kept in sync with the +/// path used when building the `openshell-sandbox` image layer. +pub const SUPERVISOR_IMAGE_BINARY_PATH: &str = "/openshell-sandbox"; + +/// Return the XDG state path for a driver's sandbox JWT token file. +/// +/// The resulting path is `$XDG_STATE_HOME/openshell/[/]//sandbox.jwt`. +/// +/// `driver_subdir` is driver-specific, e.g. `"docker-sandbox-tokens"` or +/// `"podman-sandbox-tokens"`. When `namespace` is `Some`, it is appended as +/// an additional path component (with `/` and `\` replaced by `-`). +/// +/// # Errors +/// Returns an error if the XDG state directory cannot be resolved. +pub fn sandbox_token_path( + driver_subdir: &str, + namespace: Option<&str>, + sandbox_id: &str, +) -> miette::Result { + let mut path = crate::paths::xdg_state_dir()? + .join("openshell") + .join(driver_subdir); + if let Some(ns) = namespace { + path = path.join(ns.replace(['/', '\\'], "-")); + } + Ok(path.join(sandbox_id).join("sandbox.jwt")) +} + +/// Build a [`GetCapabilitiesResponse`] from the common driver capability fields. +/// +/// Every compute driver constructs this response with the same fields. Shared +/// here to avoid repeating the struct literal (and the always-zero `gpu_count` +/// default) in each driver crate. +pub fn build_capabilities_response( + driver_name: &str, + driver_version: impl Into, + default_image: impl Into, + supports_gpu: bool, +) -> GetCapabilitiesResponse { + GetCapabilitiesResponse { + driver_name: driver_name.to_string(), + driver_version: driver_version.into(), + default_image: default_image.into(), + supports_gpu, + gpu_count: 0, + } +} + +/// Return the effective log level for a sandbox. +/// +/// Uses the level from the sandbox spec when non-empty, falling back to +/// `default_level` otherwise. +pub fn sandbox_log_level(sandbox: &DriverSandbox, default_level: &str) -> String { + sandbox + .spec + .as_ref() + .map(|spec| spec.log_level.as_str()) + .filter(|level| !level.is_empty()) + .unwrap_or(default_level) + .to_string() +} diff --git a/crates/openshell-core/src/error.rs b/crates/openshell-core/src/error.rs index 7c33c9eaf..6f04ebece 100644 --- a/crates/openshell-core/src/error.rs +++ b/crates/openshell-core/src/error.rs @@ -120,3 +120,13 @@ pub enum ComputeDriverError { #[error("{0}")] Message(String), } + +impl From for tonic::Status { + fn from(err: ComputeDriverError) -> Self { + match err { + ComputeDriverError::AlreadyExists => Self::already_exists("sandbox already exists"), + ComputeDriverError::Precondition(m) => Self::failed_precondition(m), + ComputeDriverError::Message(m) => Self::internal(m), + } + } +} diff --git a/crates/openshell-core/src/forward.rs b/crates/openshell-core/src/forward.rs index b48e5594a..a5e373e61 100644 --- a/crates/openshell-core/src/forward.rs +++ b/crates/openshell-core/src/forward.rs @@ -462,13 +462,33 @@ pub fn resolve_ssh_gateway( // Remote cluster: use the remote host but keep the cluster URL port. return (host.to_string(), cluster_port); } - // Local cluster: both loopback — use cluster URL's port (Docker-mapped). + // Both endpoints loopback. The unspecified addresses (0.0.0.0 / ::) + // are bind-only — they aren't valid connect targets and aren't in TLS + // cert SANs, so fall back to the cluster URL's host (which the CLI + // is already using to reach the gateway). + if gateway_host == "0.0.0.0" || gateway_host == "::" { + return (host.to_string(), cluster_port); + } return (gateway_host.to_string(), cluster_port); } (gateway_host.to_string(), gateway_port) } +/// Format a gateway URL, bracketing IPv6 literals when needed. +pub fn format_gateway_url(scheme: &str, host: &str, port: u16) -> String { + let host = if host + .parse::() + .is_ok_and(|ip| ip.is_ipv6()) + && !host.starts_with('[') + { + format!("[{host}]") + } else { + host.to_string() + }; + format!("{scheme}://{host}:{port}") +} + /// Shell-escape a value for use inside a `ProxyCommand` string. pub fn shell_escape(value: &str) -> String { if value.is_empty() { @@ -525,14 +545,11 @@ pub enum SshSessionResponseError { InvalidScheme, #[error("gateway_port must be in range 1..=65535")] InvalidPort, - #[error("connect_path must start with '/'")] - ConnectPathNotAbsolute, } const MAX_SANDBOX_ID_LEN: usize = 128; const MAX_TOKEN_LEN: usize = 4096; const MAX_GATEWAY_HOST_LEN: usize = 253; -const MAX_CONNECT_PATH_LEN: usize = 2048; const MAX_FINGERPRINT_LEN: usize = 256; fn is_sandbox_id_byte(b: u8) -> bool { @@ -551,33 +568,6 @@ fn is_gateway_host_byte(b: u8) -> bool { b.is_ascii_alphanumeric() || matches!(b, b'.' | b'-' | b':' | b'[' | b']') } -fn is_connect_path_byte(b: u8) -> bool { - // RFC 3986 path charset (pchar) without `?`, `#`, space, backtick, or - // backslash. `%` is permitted so percent-encoded segments round-trip. - b.is_ascii_alphanumeric() - || matches!( - b, - b'-' | b'.' - | b'_' - | b'~' - | b'!' - | b'$' - | b'&' - | b'\'' - | b'(' - | b')' - | b'*' - | b'+' - | b',' - | b';' - | b'=' - | b':' - | b'@' - | b'/' - | b'%' - ) -} - fn is_fingerprint_byte(b: u8) -> bool { b.is_ascii_alphanumeric() || matches!(b, b':' | b'+' | b'/' | b'=' | b'-') } @@ -612,25 +602,6 @@ pub fn validate_ssh_session_response( if resp.gateway_port == 0 || resp.gateway_port > u32::from(u16::MAX) { return Err(SshSessionResponseError::InvalidPort); } - if resp.connect_path.is_empty() { - return Err(SshSessionResponseError::Empty { - field: "connect_path", - }); - } - if !resp.connect_path.starts_with('/') { - return Err(SshSessionResponseError::ConnectPathNotAbsolute); - } - if resp.connect_path.len() > MAX_CONNECT_PATH_LEN { - return Err(SshSessionResponseError::TooLong { - field: "connect_path", - max: MAX_CONNECT_PATH_LEN, - }); - } - if !resp.connect_path.bytes().all(is_connect_path_byte) { - return Err(SshSessionResponseError::InvalidChars { - field: "connect_path", - }); - } if !resp.host_key_fingerprint.is_empty() { if resp.host_key_fingerprint.len() > MAX_FINGERPRINT_LEN { return Err(SshSessionResponseError::TooLong { @@ -728,6 +699,16 @@ mod tests { assert_eq!(port, 8080); } + #[test] + fn resolve_ssh_gateway_swaps_zeros_for_loopback_cluster_host() { + // The gateway binds 0.0.0.0 but advertises that bind address via the + // SSH session response. 0.0.0.0 is not a valid connect target and is + // not in any TLS cert SAN; fall through to the cluster URL's host. + let (host, port) = resolve_ssh_gateway("0.0.0.0", 8080, "https://127.0.0.1:9000"); + assert_eq!(host, "127.0.0.1"); + assert_eq!(port, 9000); + } + #[test] fn resolve_ssh_gateway_handles_invalid_cluster_url() { let (host, port) = resolve_ssh_gateway("127.0.0.1", 8080, "not-a-url"); @@ -735,6 +716,26 @@ mod tests { assert_eq!(port, 8080); } + #[test] + fn format_gateway_url_brackets_ipv6_literals() { + assert_eq!( + format_gateway_url("https", "::1", 8080), + "https://[::1]:8080" + ); + } + + #[test] + fn format_gateway_url_leaves_dns_and_bracketed_ipv6_unchanged() { + assert_eq!( + format_gateway_url("https", "gateway.example.com", 443), + "https://gateway.example.com:443" + ); + assert_eq!( + format_gateway_url("https", "[::1]", 8080), + "https://[::1]:8080" + ); + } + #[test] fn shell_escape_empty() { assert_eq!(shell_escape(""), "''"); @@ -757,7 +758,6 @@ mod tests { gateway_scheme: "https".to_string(), gateway_host: "gateway.example.com".to_string(), gateway_port: 443, - connect_path: "/connect/ssh".to_string(), host_key_fingerprint: String::new(), expires_at_ms: 0, } @@ -857,33 +857,6 @@ mod tests { } } - #[test] - fn validate_ssh_session_response_rejects_connect_path_without_leading_slash() { - let mut r = valid_session_response(); - r.connect_path = "connect/ssh".to_string(); - assert!(matches!( - validate_ssh_session_response(&r), - Err(SshSessionResponseError::ConnectPathNotAbsolute) - )); - } - - #[test] - fn validate_ssh_session_response_rejects_injected_connect_path() { - // `$`, `(`, `)` are valid RFC 3986 sub-delims (pchar) so the validator - // permits them; shell_escape is the second defensive layer. The - // following characters are rejected at the validator boundary because - // they are either unambiguously hostile in a shell context or invalid - // per RFC 3986 in the path component. - for bad in ["/x`id`y", "/x y", "/x\nb", "/x\\b", "/x?q=1", "/x#frag"] { - let mut r = valid_session_response(); - r.connect_path = bad.to_string(); - assert!( - validate_ssh_session_response(&r).is_err(), - "expected reject for connect_path={bad:?}" - ); - } - } - #[test] fn build_proxy_command_escapes_shell_metacharacters() { // Attacker-controlled values in every escapable position. diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs new file mode 100644 index 000000000..5df8702ed --- /dev/null +++ b/crates/openshell-core/src/gpu.rs @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared GPU request helpers. + +use crate::config::CDI_GPU_DEVICE_ALL; + +/// Resolve the existing GPU request fields into CDI device identifiers. +/// +/// `None` means no GPU was requested. A GPU request with no explicit device +/// ID uses the CDI all-GPU request; otherwise the driver-native ID passes +/// through unchanged. +#[must_use] +pub fn cdi_gpu_device_ids(gpu: bool, gpu_device: &str) -> Option> { + gpu.then(|| { + if gpu_device.is_empty() { + vec![CDI_GPU_DEVICE_ALL.to_string()] + } else { + vec![gpu_device.to_string()] + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cdi_gpu_device_ids_returns_none_when_absent() { + assert_eq!(cdi_gpu_device_ids(false, ""), None); + } + + #[test] + fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { + assert_eq!( + cdi_gpu_device_ids(true, ""), + Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) + ); + } + + #[test] + fn cdi_gpu_device_ids_passes_explicit_device_id_through() { + assert_eq!( + cdi_gpu_device_ids(true, "nvidia.com/gpu=0"), + Some(vec!["nvidia.com/gpu=0".to_string()]) + ); + } +} diff --git a/crates/openshell-core/src/image.rs b/crates/openshell-core/src/image.rs index 6a628e2a9..e804afd60 100644 --- a/crates/openshell-core/src/image.rs +++ b/crates/openshell-core/src/image.rs @@ -13,6 +13,15 @@ /// Override at runtime with the `OPENSHELL_COMMUNITY_REGISTRY` env var. pub const DEFAULT_COMMUNITY_REGISTRY: &str = "ghcr.io/nvidia/openshell-community/sandboxes"; +/// Return the default sandbox image reference (`{registry}/base:latest`). +/// +/// Used by all compute drivers as the fallback image when none is specified in +/// the sandbox spec. +#[must_use] +pub fn default_sandbox_image() -> String { + format!("{DEFAULT_COMMUNITY_REGISTRY}/base:latest") +} + /// Resolve a user-supplied image string into a fully-qualified reference. /// /// Resolution rules (applied in order): diff --git a/crates/openshell-core/src/lib.rs b/crates/openshell-core/src/lib.rs index a4a1ea822..2c003f38c 100644 --- a/crates/openshell-core/src/lib.rs +++ b/crates/openshell-core/src/lib.rs @@ -9,20 +9,29 @@ //! - Common error types //! - Build version metadata +pub mod auth; pub mod config; +pub mod driver_utils; pub mod error; pub mod forward; +pub mod gpu; pub mod image; pub mod inference; pub mod metadata; pub mod net; pub mod paths; +pub mod progress; pub mod proto; +pub mod sandbox_env; pub mod settings; +pub mod time; -pub use config::{ComputeDriverKind, Config, OidcConfig, TlsConfig}; +pub use config::{ + ComputeDriverKind, Config, GatewayAuthConfig, GatewayJwtConfig, MtlsAuthConfig, OidcConfig, + TlsConfig, +}; pub use error::{ComputeDriverError, Error, Result}; -pub use metadata::{ObjectId, ObjectLabels, ObjectName}; +pub use metadata::{GetResourceVersion, ObjectId, ObjectLabels, ObjectName, SetResourceVersion}; /// Build version string derived from git metadata. /// diff --git a/crates/openshell-core/src/metadata.rs b/crates/openshell-core/src/metadata.rs index e7ffea61a..78533e1e0 100644 --- a/crates/openshell-core/src/metadata.rs +++ b/crates/openshell-core/src/metadata.rs @@ -6,7 +6,8 @@ //! These traits provide uniform access to `ObjectMeta` fields across all resource types. use crate::proto::{ - InferenceRoute, ObjectForTest, Provider, Sandbox, SshSession, StoredProviderProfile, + InferenceRoute, ObjectForTest, Provider, Sandbox, ServiceEndpoint, SshSession, + StoredProviderCredentialRefreshState, StoredProviderProfile, }; use std::collections::HashMap; @@ -25,6 +26,16 @@ pub trait ObjectLabels { fn object_labels(&self) -> Option>; } +/// Provides mutable access to set the object's resource version from persistence. +pub trait SetResourceVersion { + fn set_resource_version(&mut self, version: u64); +} + +/// Provides read access to the object's current resource version. +pub trait GetResourceVersion { + fn get_resource_version(&self) -> u64; +} + // Implementations for Sandbox impl ObjectId for Sandbox { fn object_id(&self) -> &str { @@ -44,6 +55,20 @@ impl ObjectLabels for Sandbox { } } +impl SetResourceVersion for Sandbox { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for Sandbox { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for Provider impl ObjectId for Provider { fn object_id(&self) -> &str { @@ -63,6 +88,20 @@ impl ObjectLabels for Provider { } } +impl SetResourceVersion for Provider { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for Provider { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for StoredProviderProfile impl ObjectId for StoredProviderProfile { fn object_id(&self) -> &str { @@ -82,6 +121,53 @@ impl ObjectLabels for StoredProviderProfile { } } +impl SetResourceVersion for StoredProviderProfile { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for StoredProviderProfile { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + +// Implementations for StoredProviderCredentialRefreshState +impl ObjectId for StoredProviderCredentialRefreshState { + fn object_id(&self) -> &str { + self.metadata.as_ref().map_or("", |m| m.id.as_str()) + } +} + +impl ObjectName for StoredProviderCredentialRefreshState { + fn object_name(&self) -> &str { + self.metadata.as_ref().map_or("", |m| m.name.as_str()) + } +} + +impl ObjectLabels for StoredProviderCredentialRefreshState { + fn object_labels(&self) -> Option> { + self.metadata.as_ref().map(|m| m.labels.clone()) + } +} + +impl SetResourceVersion for StoredProviderCredentialRefreshState { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for StoredProviderCredentialRefreshState { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for SshSession impl ObjectId for SshSession { fn object_id(&self) -> &str { @@ -101,6 +187,53 @@ impl ObjectLabels for SshSession { } } +impl SetResourceVersion for SshSession { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for SshSession { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + +// Implementations for ServiceEndpoint +impl ObjectId for ServiceEndpoint { + fn object_id(&self) -> &str { + self.metadata.as_ref().map_or("", |m| m.id.as_str()) + } +} + +impl ObjectName for ServiceEndpoint { + fn object_name(&self) -> &str { + self.metadata.as_ref().map_or("", |m| m.name.as_str()) + } +} + +impl ObjectLabels for ServiceEndpoint { + fn object_labels(&self) -> Option> { + self.metadata.as_ref().map(|m| m.labels.clone()) + } +} + +impl SetResourceVersion for ServiceEndpoint { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for ServiceEndpoint { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for InferenceRoute impl ObjectId for InferenceRoute { fn object_id(&self) -> &str { @@ -120,6 +253,20 @@ impl ObjectLabels for InferenceRoute { } } +impl SetResourceVersion for InferenceRoute { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for InferenceRoute { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for ObjectForTest (test-only proto type) impl ObjectId for ObjectForTest { fn object_id(&self) -> &str { @@ -138,3 +285,16 @@ impl ObjectLabels for ObjectForTest { None } } + +impl SetResourceVersion for ObjectForTest { + fn set_resource_version(&mut self, _version: u64) { + // ObjectForTest doesn't have metadata, so this is a no-op + } +} + +impl GetResourceVersion for ObjectForTest { + fn get_resource_version(&self) -> u64 { + // ObjectForTest doesn't have metadata + 0 + } +} diff --git a/crates/openshell-core/src/net.rs b/crates/openshell-core/src/net.rs index 5dca4feb6..0e2654fc3 100644 --- a/crates/openshell-core/src/net.rs +++ b/crates/openshell-core/src/net.rs @@ -12,6 +12,31 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +/// Check if an IP address is link-local. +/// +/// Covers IPv4 `169.254.0.0/16`, IPv6 `fe80::/10`, and IPv4-mapped IPv6 +/// addresses whose embedded IPv4 address is link-local (`::ffff:169.254.x.x`). +/// +/// This is a point-check helper used to build the always-blocked and +/// trusted-gateway exemption predicates. For CIDR-range overlap checks see +/// [`is_always_blocked_net`]. +pub fn is_link_local_ip(ip: IpAddr) -> bool { + match ip { + IpAddr::V4(v4) => v4.is_link_local(), + IpAddr::V6(v6) => { + // fe80::/10 — IPv6 link-local + if (v6.segments()[0] & 0xffc0) == 0xfe80 { + return true; + } + // ::ffff:169.254.x.x — IPv4-mapped link-local + if let Some(v4) = v6.to_ipv4_mapped() { + return v4.is_link_local(); + } + false + } + } +} + /// Check if an IP address is always blocked regardless of policy. /// /// Loopback, link-local, and unspecified addresses are never allowed even when @@ -24,13 +49,12 @@ pub fn is_always_blocked_ip(ip: IpAddr) -> bool { if v6.is_loopback() || v6.is_unspecified() { return true; } - // fe80::/10 — IPv6 link-local - if (v6.segments()[0] & 0xffc0) == 0xfe80 { + if is_link_local_ip(IpAddr::V6(v6)) { return true; } // Check IPv4-mapped IPv6 (::ffff:x.x.x.x) if let Some(v4) = v6.to_ipv4_mapped() { - return v4.is_loopback() || v4.is_link_local() || v4.is_unspecified(); + return v4.is_loopback() || v4.is_unspecified(); } false } @@ -138,8 +162,7 @@ pub fn is_internal_ip(ip: IpAddr) -> bool { if v6.is_loopback() || v6.is_unspecified() { return true; } - // fe80::/10 — IPv6 link-local - if (v6.segments()[0] & 0xffc0) == 0xfe80 { + if is_link_local_ip(IpAddr::V6(v6)) { return true; } // fc00::/7 — IPv6 unique local addresses (ULA) @@ -190,6 +213,69 @@ fn is_internal_v4(v4: Ipv4Addr) -> bool { mod tests { use super::*; + // -- is_link_local_ip -- + + #[test] + fn test_link_local_ip_v4() { + assert!(is_link_local_ip(IpAddr::V4(Ipv4Addr::new(169, 254, 0, 1)))); + assert!(is_link_local_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 169, 254 + )))); + assert!(is_link_local_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 255, 255 + )))); + } + + #[test] + fn test_link_local_ip_v6_fe80() { + assert!(is_link_local_ip(IpAddr::V6(Ipv6Addr::new( + 0xfe80, 0, 0, 0, 0, 0, 0, 1 + )))); + // Upper boundary of fe80::/10 (febf:...) + assert!(is_link_local_ip(IpAddr::V6(Ipv6Addr::new( + 0xfebf, 0, 0, 0, 0, 0, 0, 1 + )))); + } + + #[test] + fn test_link_local_ip_v6_mapped_v4() { + let mapped = Ipv4Addr::new(169, 254, 1, 2).to_ipv6_mapped(); + assert!(is_link_local_ip(IpAddr::V6(mapped))); + } + + #[test] + fn test_link_local_ip_not_loopback() { + assert!(!is_link_local_ip(IpAddr::V4(Ipv4Addr::LOCALHOST))); + assert!(!is_link_local_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))); + } + + #[test] + fn test_link_local_ip_not_unspecified() { + assert!(!is_link_local_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED))); + assert!(!is_link_local_ip(IpAddr::V6(Ipv6Addr::UNSPECIFIED))); + } + + #[test] + fn test_link_local_ip_not_rfc1918() { + assert!(!is_link_local_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))); + assert!(!is_link_local_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)))); + assert!(!is_link_local_ip(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)))); + } + + #[test] + fn test_link_local_ip_not_public() { + assert!(!is_link_local_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); + assert!(!is_link_local_ip(IpAddr::V6(Ipv6Addr::new( + 0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888 + )))); + } + + #[test] + fn test_link_local_ip_not_v6_mapped_loopback() { + let mapped = Ipv4Addr::LOCALHOST.to_ipv6_mapped(); + assert!(!is_link_local_ip(IpAddr::V6(mapped))); + } + // -- is_always_blocked_ip -- #[test] diff --git a/crates/openshell-core/src/paths.rs b/crates/openshell-core/src/paths.rs index 00104f3c2..65000c6cf 100644 --- a/crates/openshell-core/src/paths.rs +++ b/crates/openshell-core/src/paths.rs @@ -29,6 +29,24 @@ pub fn openshell_config_dir() -> Result { Ok(xdg_config_dir()?.join("openshell")) } +/// Resolve the XDG state base directory. +/// +/// Returns `$XDG_STATE_HOME` if set, otherwise `$HOME/.local/state`. +pub fn xdg_state_dir() -> Result { + if let Ok(path) = std::env::var("XDG_STATE_HOME") { + return Ok(PathBuf::from(path)); + } + let home = std::env::var("HOME") + .into_diagnostic() + .wrap_err("HOME is not set")?; + Ok(PathBuf::from(home).join(".local").join("state")) +} + +/// The top-level `OpenShell` state directory: `$XDG_STATE_HOME/openshell/`. +pub fn openshell_state_dir() -> Result { + Ok(xdg_state_dir()?.join("openshell")) +} + /// Resolve the XDG data base directory. /// /// Returns `$XDG_DATA_HOME` if set, otherwise `$HOME/.local/share`. @@ -130,6 +148,15 @@ mod tests { ); } + #[test] + fn openshell_state_dir_appends_openshell() { + let dir = openshell_state_dir().unwrap(); + assert!( + dir.ends_with("openshell"), + "expected path ending with 'openshell', got: {dir:?}" + ); + } + #[cfg(unix)] #[test] fn create_dir_restricted_sets_0o700() { diff --git a/crates/openshell-core/src/progress.rs b/crates/openshell-core/src/progress.rs new file mode 100644 index 000000000..793f8e7a1 --- /dev/null +++ b/crates/openshell-core/src/progress.rs @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared metadata keys for driver-provided sandbox provisioning progress. + +use std::collections::HashMap; +use std::hash::BuildHasher; + +pub const PROGRESS_COMPLETE_STEP_KEY: &str = "openshell.progress.complete_step"; +pub const PROGRESS_COMPLETE_LABEL_KEY: &str = "openshell.progress.complete_label"; +pub const PROGRESS_ACTIVE_STEP_KEY: &str = "openshell.progress.active_step"; +pub const PROGRESS_ACTIVE_DETAIL_KEY: &str = "openshell.progress.active_detail"; + +pub const PROGRESS_STEP_REQUESTING_SANDBOX: &str = "requesting_sandbox"; +pub const PROGRESS_STEP_PULLING_IMAGE: &str = "pulling_image"; +pub const PROGRESS_STEP_STARTING_SANDBOX: &str = "starting_sandbox"; + +pub fn mark_progress_complete( + metadata: &mut HashMap, + step: &'static str, + label: impl Into, +) { + metadata.insert(PROGRESS_COMPLETE_STEP_KEY.to_string(), step.to_string()); + metadata.insert(PROGRESS_COMPLETE_LABEL_KEY.to_string(), label.into()); +} + +pub fn mark_progress_active( + metadata: &mut HashMap, + step: &'static str, +) { + metadata.insert(PROGRESS_ACTIVE_STEP_KEY.to_string(), step.to_string()); +} + +pub fn mark_progress_detail( + metadata: &mut HashMap, + detail: impl Into, +) { + metadata.insert(PROGRESS_ACTIVE_DETAIL_KEY.to_string(), detail.into()); +} diff --git a/crates/openshell-core/src/sandbox_env.rs b/crates/openshell-core/src/sandbox_env.rs new file mode 100644 index 000000000..b367e450c --- /dev/null +++ b/crates/openshell-core/src/sandbox_env.rs @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Environment-variable names used to configure the sandbox supervisor. +//! +//! These constants are the shared protocol between the compute drivers (which +//! set the variables when launching a sandbox container/VM) and the sandbox +//! supervisor process (which reads them on startup). Using constants here +//! prevents typos from producing silently broken sandboxes. + +/// Name of the sandbox (used for policy sync and identification). +pub const SANDBOX: &str = "OPENSHELL_SANDBOX"; + +/// gRPC endpoint of the `OpenShell` gateway that the sandbox reports to. +pub const ENDPOINT: &str = "OPENSHELL_ENDPOINT"; + +/// Unique identifier of the sandbox being supervised. +pub const SANDBOX_ID: &str = "OPENSHELL_SANDBOX_ID"; + +/// Filesystem path to the UNIX socket used for the in-sandbox SSH server. +pub const SSH_SOCKET_PATH: &str = "OPENSHELL_SSH_SOCKET_PATH"; + +/// Log level for the sandbox supervisor (e.g. `"debug"`, `"info"`, `"warn"`). +pub const LOG_LEVEL: &str = "OPENSHELL_LOG_LEVEL"; + +/// Shell command to run inside the sandbox. +pub const SANDBOX_COMMAND: &str = "OPENSHELL_SANDBOX_COMMAND"; + +/// Path to the CA certificate for mTLS communication with the gateway. +pub const TLS_CA: &str = "OPENSHELL_TLS_CA"; + +/// Path to the client certificate for mTLS communication with the gateway. +pub const TLS_CERT: &str = "OPENSHELL_TLS_CERT"; + +/// Path to the private key for mTLS communication with the gateway. +pub const TLS_KEY: &str = "OPENSHELL_TLS_KEY"; + +/// Raw gateway-minted JWT identifying this sandbox. Mutually exclusive with +/// [`SANDBOX_TOKEN_FILE`] / [`K8S_SA_TOKEN_FILE`]; used only by test harnesses +/// that bypass the file-mount path. +pub const SANDBOX_TOKEN: &str = "OPENSHELL_SANDBOX_TOKEN"; + +/// Path to the file holding a gateway-minted sandbox JWT. +/// +/// Set by the Docker, Podman, and VM drivers, which write the token to a +/// bundle file at sandbox-create time. Read once at supervisor startup; +/// the token is held in process memory thereafter. +pub const SANDBOX_TOKEN_FILE: &str = "OPENSHELL_SANDBOX_TOKEN_FILE"; + +/// Path to the projected `ServiceAccount` JWT (Kubernetes driver). +/// +/// Used to bootstrap a gateway-minted JWT via `IssueSandboxToken`. Kubelet +/// writes and rotates this file; the supervisor exchanges its contents +/// for a gateway JWT at startup and on refresh. +pub const K8S_SA_TOKEN_FILE: &str = "OPENSHELL_K8S_SA_TOKEN_FILE"; diff --git a/crates/openshell-core/src/settings.rs b/crates/openshell-core/src/settings.rs index 2765ebeda..897317a5a 100644 --- a/crates/openshell-core/src/settings.rs +++ b/crates/openshell-core/src/settings.rs @@ -50,6 +50,15 @@ pub struct RegisteredSetting { /// 5. Add a unit test in this module's `tests` section to cover the new key. pub const PROVIDERS_V2_ENABLED_KEY: &str = "providers_v2_enabled"; +/// Sandbox-level opt-in for the agent-driven policy proposal surface. +/// +/// When true, the supervisor installs the `policy_advisor` skill, serves +/// the `policy.local` API routes, and includes `next_steps` in L7 deny +/// bodies. See `crates/openshell-sandbox/src/policy_local.rs`. Defaults to +/// false. Independent of the per-proposal developer approval gate, which +/// still applies when this flag is on. +pub const AGENT_POLICY_PROPOSALS_ENABLED_KEY: &str = "agent_policy_proposals_enabled"; + pub const REGISTERED_SETTINGS: &[RegisteredSetting] = &[ // Gateway-level opt-in for provider profile policy composition. Defaults // to false when unset. @@ -64,6 +73,12 @@ pub const REGISTERED_SETTINGS: &[RegisteredSetting] = &[ key: "ocsf_json_enabled", kind: SettingValueKind::Bool, }, + // Sandbox-level opt-in for the agent-driven policy proposal surface. + // See AGENT_POLICY_PROPOSALS_ENABLED_KEY for details. Defaults to false. + RegisteredSetting { + key: AGENT_POLICY_PROPOSALS_ENABLED_KEY, + kind: SettingValueKind::Bool, + }, // Test-only keys live behind the `dev-settings` feature flag so they // don't appear in production builds. #[cfg(feature = "dev-settings")] diff --git a/crates/openshell-core/src/time.rs b/crates/openshell-core/src/time.rs new file mode 100644 index 000000000..15dc0c40d --- /dev/null +++ b/crates/openshell-core/src/time.rs @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Time utilities shared across `OpenShell` crates. + +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Return the current Unix timestamp in milliseconds, saturating to [`i64::MAX`] +/// on overflow. Returns `0` if the system clock is before the Unix epoch. +/// +/// Prefer this over local implementations of the same pattern. +pub fn now_ms() -> i64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |d| i64::try_from(d.as_millis()).unwrap_or(i64::MAX)) +} diff --git a/crates/openshell-driver-docker/Cargo.toml b/crates/openshell-driver-docker/Cargo.toml index 79d4fb37d..e2c97532a 100644 --- a/crates/openshell-driver-docker/Cargo.toml +++ b/crates/openshell-driver-docker/Cargo.toml @@ -19,6 +19,7 @@ futures = { workspace = true } tokio-stream = { workspace = true } tracing = { workspace = true } bytes = { workspace = true } +serde = { workspace = true } bollard = { version = "0.20" } tar = "0.4" tempfile = "3" diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index 7bc8048b2..96caa7b12 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -12,11 +12,12 @@ The gateway runs as a host process. The Docker driver creates one container per sandbox and starts the `openshell-sandbox` supervisor inside that container. The supervisor then creates the nested sandbox namespace for the agent process. -Docker containers currently use host networking. This lets a supervisor reach a -gateway bound to `127.0.0.1` without requiring a separate bridge listener, NAT -rule, or userland proxy. The container also receives -`host.openshell.internal -> 127.0.0.1` so local host services have a stable -OpenShell-owned name. +Docker containers join an OpenShell-managed bridge network. The driver injects +`host.openshell.internal` and `host.docker.internal` so supervisors have stable +names for reaching the gateway host. On Docker Desktop, Colima, Rancher +Desktop, OrbStack, and macOS-hosted gateways, those names use Docker's +`host-gateway` alias. On native Linux Docker, the gateway also binds the bridge +gateway IP so containers can call back to the host process. ## Container Contract @@ -26,19 +27,37 @@ contract: | Setting | Purpose | |---|---| | `user = "0"` | The supervisor needs root inside the container to prepare namespaces, mounts, Landlock, and seccomp. | -| `network_mode = "host"` | Lets the supervisor call back to loopback gateway endpoints. | +| `network_mode = openshell` | Places the supervisor on the managed Docker bridge network. | | `cap_add` | Grants supervisor-only capabilities required for namespace setup and process inspection. | | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | -| CDI GPU request | Requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | +| CDI GPU request | Uses the sandbox `gpu_device` value when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | The agent child process does not retain these supervisor privileges. +## Supervisor Binary Resolution + +The Docker driver bind-mounts a host-side Linux `openshell-sandbox` binary into +each sandbox container. Resolution order is: + +1. `supervisor_bin` in `[openshell.drivers.docker]`. +2. A sibling `openshell-sandbox` next to the running `openshell-gateway` binary. +3. A local Linux cargo target build for the Docker daemon architecture. +4. `supervisor_image` in `[openshell.drivers.docker]`, or the + release-matched default supervisor image, extracting `/openshell-sandbox`. + +Release and Docker-image gateway builds bake the matching supervisor image tag +into the binary at compile time. The default Docker supervisor image is not +`:latest` unless a custom build explicitly sets that tag. + ## Callback and TLS -`OPENSHELL_ENDPOINT` is injected from the gateway's configured gRPC endpoint -without rewriting. Because the container uses host networking, loopback -endpoints such as `http://127.0.0.1:8080` resolve to the host gateway. +`OPENSHELL_ENDPOINT` is injected from the gateway's configured gRPC endpoint. +When no endpoint is configured, the driver uses +`host.openshell.internal:` with the appropriate HTTP or HTTPS +scheme. Set `host_gateway_ip` only when the host has an explicit, locally +assigned address that containers should use for callbacks; package-managed +macOS gateways should leave it unset. For HTTPS endpoints, the server certificate must include the endpoint host as a subject alternative name. Docker sandboxes also need the client TLS bundle diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 0eaef3bce..1fdcea9cd 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -9,8 +9,8 @@ use bollard::Docker; use bollard::errors::Error as BollardError; use bollard::models::{ ContainerCreateBody, ContainerSummary, ContainerSummaryStateEnum, DeviceRequest, - EndpointSettings, HostConfig, Mount, MountTypeEnum, NetworkCreateRequest, NetworkingConfig, - RestartPolicy, RestartPolicyNameEnum, SystemInfo, + EndpointSettings, HostConfig, NetworkCreateRequest, NetworkingConfig, RestartPolicy, + RestartPolicyNameEnum, SystemInfo, }; use bollard::query_parameters::{ CreateContainerOptionsBuilder, CreateImageOptions, DownloadFromContainerOptionsBuilder, @@ -18,9 +18,12 @@ use bollard::query_parameters::{ }; use bytes::Bytes; use futures::{Stream, StreamExt}; -use openshell_core::config::{ - CDI_GPU_DEVICE_ALL, DEFAULT_DOCKER_NETWORK_NAME, DEFAULT_STOP_TIMEOUT_SECS, +use openshell_core::config::{DEFAULT_DOCKER_NETWORK_NAME, DEFAULT_STOP_TIMEOUT_SECS}; +use openshell_core::driver_utils::{ + LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME, + LABEL_SANDBOX_NAMESPACE, SUPERVISOR_IMAGE_BINARY_PATH, }; +use openshell_core::gpu::cdi_gpu_device_ids; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition, DriverSandbox, DriverSandboxStatus, DriverSandboxTemplate, @@ -48,16 +51,11 @@ const WATCH_BUFFER: usize = 128; const WATCH_POLL_INTERVAL: Duration = Duration::from_secs(2); const WATCH_POLL_MAX_BACKOFF: Duration = Duration::from_secs(30); -const MANAGED_BY_LABEL_KEY: &str = "openshell.ai/managed-by"; -const MANAGED_BY_LABEL_VALUE: &str = "openshell"; -const SANDBOX_ID_LABEL_KEY: &str = "openshell.ai/sandbox-id"; -const SANDBOX_NAME_LABEL_KEY: &str = "openshell.ai/sandbox-name"; -const SANDBOX_NAMESPACE_LABEL_KEY: &str = "openshell.ai/sandbox-namespace"; - const SUPERVISOR_MOUNT_PATH: &str = "/opt/openshell/bin/openshell-sandbox"; const TLS_CA_MOUNT_PATH: &str = "/etc/openshell/tls/client/ca.crt"; const TLS_CERT_MOUNT_PATH: &str = "/etc/openshell/tls/client/tls.crt"; const TLS_KEY_MOUNT_PATH: &str = "/etc/openshell/tls/client/tls.key"; +const SANDBOX_TOKEN_MOUNT_PATH: &str = "/etc/openshell/auth/sandbox.jwt"; const SANDBOX_COMMAND: &str = "sleep infinity"; const SUPERVISOR_PATH: &str = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"; const HOST_OPENSHELL_INTERNAL: &str = "host.openshell.internal"; @@ -66,13 +64,9 @@ const DOCKER_NETWORK_DRIVER: &str = "bridge"; /// Default image holding the Linux `openshell-sandbox` binary. The gateway /// pulls this image and extracts the binary to a host-side cache when no -/// explicit `--docker-supervisor-bin` override or local build is available. +/// explicit `supervisor_bin` override or local build is available. const DEFAULT_DOCKER_SUPERVISOR_IMAGE_REPO: &str = "ghcr.io/nvidia/openshell/supervisor"; -/// Path to the supervisor binary inside the `openshell/supervisor` image -/// (a `FROM scratch` image containing only the binary). -const SUPERVISOR_IMAGE_BINARY_PATH: &str = "/openshell-sandbox"; - /// Return the default `ghcr.io/nvidia/openshell/supervisor:` reference /// used when no supervisor binary override is provided. pub fn default_docker_supervisor_image() -> String { @@ -89,7 +83,7 @@ pub fn default_docker_supervisor_image() -> String { /// fallback covers image build wrappers that already tag the gateway and /// supervisor together. Standalone release binaries also patch the Cargo /// package version, so use it when it has been set to a real release value. -fn default_docker_supervisor_image_tag() -> &'static str { +fn default_docker_supervisor_image_tag() -> String { resolve_default_docker_supervisor_image_tag( option_env!("OPENSHELL_IMAGE_TAG"), option_env!("IMAGE_TAG"), @@ -101,8 +95,8 @@ fn resolve_default_docker_supervisor_image_tag( openshell_image_tag: Option<&'static str>, image_tag: Option<&'static str>, cargo_pkg_version: &'static str, -) -> &'static str { - openshell_image_tag +) -> String { + let tag = openshell_image_tag .filter(|tag| !tag.is_empty()) .or_else(|| image_tag.filter(|tag| !tag.is_empty())) .unwrap_or_else(|| { @@ -111,7 +105,9 @@ fn resolve_default_docker_supervisor_image_tag( } else { cargo_pkg_version } - }) + }); + + tag.replace('+', "-") } /// Queried by the Docker driver to decide when a sandbox's supervisor @@ -126,8 +122,21 @@ pub trait SupervisorReadiness: Send + Sync + 'static { } /// Gateway-local configuration for the Docker compute driver. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(default, deny_unknown_fields)] pub struct DockerComputeConfig { + /// Default OCI image for sandboxes. + pub default_image: String, + + /// Image pull policy for sandbox images. + pub image_pull_policy: String, + + /// Namespace label applied to Docker sandboxes. + pub sandbox_namespace: String, + + /// Gateway gRPC endpoint the sandbox connects back to. + pub grpc_endpoint: String, + /// Optional override for the Linux `openshell-sandbox` binary mounted into containers. pub supervisor_bin: Option, @@ -148,6 +157,31 @@ pub struct DockerComputeConfig { /// Docker bridge network that sandbox containers join. pub network_name: String, + + /// Host gateway IP used for sandbox host aliases. + pub host_gateway_ip: String, + + /// Unix socket path the in-container supervisor bridges relay traffic to. + pub ssh_socket_path: String, +} + +impl Default for DockerComputeConfig { + fn default() -> Self { + Self { + default_image: openshell_core::image::default_sandbox_image(), + image_pull_policy: String::new(), + sandbox_namespace: "default".to_string(), + grpc_endpoint: String::new(), + supervisor_bin: None, + supervisor_image: None, + guest_tls_ca: None, + guest_tls_cert: None, + guest_tls_key: None, + network_name: DEFAULT_DOCKER_NETWORK_NAME.to_string(), + host_gateway_ip: String::new(), + ssh_socket_path: "/run/openshell/ssh.sock".to_string(), + } + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -206,12 +240,6 @@ impl DockerComputeDriver { docker_config: &DockerComputeConfig, supervisor_readiness: Arc, ) -> CoreResult { - if config.grpc_endpoint.trim().is_empty() { - return Err(Error::config( - "grpc_endpoint is required when using the docker compute driver", - )); - } - let docker = Docker::connect_with_local_defaults() .map_err(|err| Error::execution(format!("failed to create Docker client: {err}")))?; let version = docker.version().await.map_err(|err| { @@ -232,28 +260,38 @@ impl DockerComputeDriver { } let network_name = docker_network_name(docker_config); let bridge_gateway_ip = ensure_bridge_network(&docker, &network_name).await?; - let host_gateway_ip = parse_optional_host_gateway_ip(&config.host_gateway_ip)?; + let host_gateway_ip = parse_optional_host_gateway_ip(&docker_config.host_gateway_ip)?; let gateway_route = docker_gateway_route(&info, bridge_gateway_ip, gateway_port, host_gateway_ip); + let mut docker_config = docker_config.clone(); + if docker_config.grpc_endpoint.trim().is_empty() { + let scheme = if docker_guest_tls_configured(&docker_config) { + "https" + } else { + "http" + }; + docker_config.grpc_endpoint = + format!("{scheme}://{HOST_OPENSHELL_INTERNAL}:{gateway_port}"); + } let grpc_endpoint = docker_container_openshell_endpoint( - &config.grpc_endpoint, + &docker_config.grpc_endpoint, HOST_OPENSHELL_INTERNAL, gateway_port, ); let daemon_arch = normalize_docker_arch(version.arch.as_deref().unwrap_or_default()); - let supervisor_bin = resolve_supervisor_bin(&docker, docker_config, &daemon_arch).await?; - let guest_tls = docker_guest_tls_paths(config, docker_config)?; + let supervisor_bin = resolve_supervisor_bin(&docker, &docker_config, &daemon_arch).await?; + let guest_tls = docker_guest_tls_paths(&docker_config)?; let driver = Self { docker: Arc::new(docker), config: DockerDriverRuntimeConfig { - default_image: config.sandbox_image.clone(), - image_pull_policy: config.sandbox_image_pull_policy.clone(), - sandbox_namespace: config.sandbox_namespace.clone(), + default_image: docker_config.default_image.clone(), + image_pull_policy: docker_config.image_pull_policy.clone(), + sandbox_namespace: docker_config.sandbox_namespace.clone(), grpc_endpoint, network_name, gateway_route, - ssh_socket_path: config.sandbox_ssh_socket_path.clone(), + ssh_socket_path: docker_config.ssh_socket_path.clone(), stop_timeout_secs: DEFAULT_STOP_TIMEOUT_SECS, log_level: config.log_level.clone(), supervisor_bin, @@ -282,13 +320,12 @@ impl DockerComputeDriver { } fn capabilities(&self) -> GetCapabilitiesResponse { - GetCapabilitiesResponse { - driver_name: "docker".to_string(), - driver_version: self.config.daemon_version.clone(), - default_image: self.config.default_image.clone(), - supports_gpu: self.config.supports_gpu, - gpu_count: 0, - } + openshell_core::driver_utils::build_capabilities_response( + "docker", + &self.config.daemon_version, + &self.config.default_image, + self.config.supports_gpu, + ) } fn validate_sandbox( @@ -309,11 +346,7 @@ impl DockerComputeDriver { "docker sandboxes require a template image", )); } - if spec.gpu && !config.supports_gpu { - return Err(Status::failed_precondition( - "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", - )); - } + Self::validate_gpu_request(spec.gpu, config.supports_gpu)?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -333,6 +366,15 @@ impl DockerComputeDriver { Ok(()) } + fn validate_gpu_request(gpu: bool, supports_gpu: bool) -> Result<(), Status> { + if gpu && !supports_gpu { + return Err(Status::failed_precondition( + "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", + )); + } + Ok(()) + } + async fn get_sandbox_snapshot( &self, sandbox_id: &str, @@ -375,6 +417,7 @@ impl DockerComputeDriver { .and_then(|spec| spec.template.as_ref()) .expect("validated sandbox has template"); self.ensure_image_available(&template.image).await?; + let token_file_created = write_sandbox_token_file(sandbox, &self.config).await?; let container_name = container_name_for_sandbox(sandbox); let create_body = build_container_create_body(sandbox, &self.config)?; @@ -389,6 +432,9 @@ impl DockerComputeDriver { ) .await .map_err(|err| { + if token_file_created { + cleanup_sandbox_token_file(sandbox, &self.config); + } create_status_from_docker_error("create docker sandbox container", err) })?; @@ -408,6 +454,9 @@ impl DockerComputeDriver { "Failed to clean up Docker container after start failure" ); } + if token_file_created { + cleanup_sandbox_token_file(sandbox, &self.config); + } return Err(create_status_from_docker_error( "start docker sandbox container", err, @@ -440,8 +489,14 @@ impl DockerComputeDriver { ) .await { - Ok(()) => Ok(true), - Err(err) if is_not_found_error(&err) => Ok(false), + Ok(()) => { + cleanup_sandbox_token_file_by_id(sandbox_id, &self.config); + Ok(true) + } + Err(err) if is_not_found_error(&err) => { + cleanup_sandbox_token_file_by_id(sandbox_id, &self.config); + Ok(false) + } Err(err) => Err(internal_status("delete docker sandbox container", err)), } } @@ -638,9 +693,9 @@ impl DockerComputeDriver { ) -> Result, Status> { let mut label_filter_values = Vec::new(); if !sandbox_id.is_empty() { - label_filter_values.push(format!("{SANDBOX_ID_LABEL_KEY}={sandbox_id}")); + label_filter_values.push(format!("{LABEL_SANDBOX_ID}={sandbox_id}")); } else if !sandbox_name.is_empty() { - label_filter_values.push(format!("{SANDBOX_NAME_LABEL_KEY}={sandbox_name}")); + label_filter_values.push(format!("{LABEL_SANDBOX_NAME}={sandbox_name}")); } let filters = @@ -661,15 +716,15 @@ impl DockerComputeDriver { return false; }; let namespace_matches = labels - .get(SANDBOX_NAMESPACE_LABEL_KEY) + .get(LABEL_SANDBOX_NAMESPACE) .is_some_and(|value| value == &self.config.sandbox_namespace); let id_matches = sandbox_id.is_empty() || labels - .get(SANDBOX_ID_LABEL_KEY) + .get(LABEL_SANDBOX_ID) .is_some_and(|value| value == sandbox_id); let name_matches = sandbox_name.is_empty() || labels - .get(SANDBOX_NAME_LABEL_KEY) + .get(LABEL_SANDBOX_NAME) .is_some_and(|value| value == sandbox_name); namespace_matches && id_matches && name_matches })) @@ -688,12 +743,12 @@ impl DockerComputeDriver { "never" => match self.docker.inspect_image(image).await { Ok(_) => Ok(()), Err(err) if is_not_found_error(&err) => Err(Status::failed_precondition(format!( - "docker image '{image}' is not present locally and sandbox_image_pull_policy=Never" + "docker image '{image}' is not present locally and image_pull_policy=Never" ))), Err(err) => Err(internal_status("inspect Docker image", err)), }, other => Err(Status::failed_precondition(format!( - "unsupported docker sandbox_image_pull_policy '{other}'; expected Always, IfNotPresent, or Never", + "unsupported docker image_pull_policy '{other}'; expected Always, IfNotPresent, or Never", ))), } } @@ -863,27 +918,117 @@ impl ComputeDriver for DockerComputeDriver { } } -fn build_mounts(config: &DockerDriverRuntimeConfig) -> Vec { - let mut mounts = vec![bind_mount( - &config.supervisor_bin, - SUPERVISOR_MOUNT_PATH, - true, +fn build_binds( + sandbox: &DriverSandbox, + config: &DockerDriverRuntimeConfig, +) -> Result, Status> { + let mut binds = vec![format!( + "{}:{}:ro,z", + config.supervisor_bin.display(), + SUPERVISOR_MOUNT_PATH )]; if let Some(tls) = &config.guest_tls { - mounts.push(bind_mount(&tls.ca, TLS_CA_MOUNT_PATH, true)); - mounts.push(bind_mount(&tls.cert, TLS_CERT_MOUNT_PATH, true)); - mounts.push(bind_mount(&tls.key, TLS_KEY_MOUNT_PATH, true)); + binds.push(format!("{}:{}:ro,z", tls.ca.display(), TLS_CA_MOUNT_PATH)); + binds.push(format!( + "{}:{}:ro,z", + tls.cert.display(), + TLS_CERT_MOUNT_PATH + )); + binds.push(format!("{}:{}:ro,z", tls.key.display(), TLS_KEY_MOUNT_PATH)); + } + if sandbox + .spec + .as_ref() + .is_some_and(|spec| !spec.sandbox_token.is_empty()) + { + binds.push(format!( + "{}:{}:ro,z", + sandbox_token_host_path(sandbox, config)?.display(), + SANDBOX_TOKEN_MOUNT_PATH + )); } - mounts + Ok(binds) } -fn bind_mount(source: &Path, target: &str, read_only: bool) -> Mount { - Mount { - target: Some(target.to_string()), - source: Some(source.display().to_string()), - typ: Some(MountTypeEnum::BIND), - read_only: Some(read_only), - ..Default::default() +fn sandbox_token_host_path( + sandbox: &DriverSandbox, + config: &DockerDriverRuntimeConfig, +) -> Result { + sandbox_token_host_path_by_id(&sandbox.id, config) +} + +fn sandbox_token_host_path_by_id( + sandbox_id: &str, + config: &DockerDriverRuntimeConfig, +) -> Result { + openshell_core::driver_utils::sandbox_token_path( + "docker-sandbox-tokens", + Some(&config.sandbox_namespace), + sandbox_id, + ) + .map_err(|err| { + Status::internal(format!( + "resolve sandbox token state directory failed: {err}" + )) + }) +} + +async fn write_sandbox_token_file( + sandbox: &DriverSandbox, + config: &DockerDriverRuntimeConfig, +) -> Result { + let Some(spec) = sandbox.spec.as_ref() else { + return Ok(false); + }; + if spec.sandbox_token.is_empty() { + return Ok(false); + } + let path = sandbox_token_host_path(sandbox, config)?; + if let Some(parent) = path.parent() { + openshell_core::paths::create_dir_restricted(parent).map_err(|err| { + Status::internal(format!( + "create sandbox token directory {} failed: {err}", + parent.display() + )) + })?; + } + tokio::fs::write(&path, format!("{}\n", spec.sandbox_token)) + .await + .map_err(|err| { + Status::internal(format!( + "write sandbox token file {} failed: {err}", + path.display() + )) + })?; + openshell_core::paths::set_file_owner_only(&path).map_err(|err| { + Status::internal(format!( + "restrict sandbox token file {} failed: {err}", + path.display() + )) + })?; + Ok(true) +} + +fn cleanup_sandbox_token_file(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig) { + cleanup_sandbox_token_file_by_id(&sandbox.id, config); +} + +fn cleanup_sandbox_token_file_by_id(sandbox_id: &str, config: &DockerDriverRuntimeConfig) { + let Ok(path) = sandbox_token_host_path_by_id(sandbox_id, config) else { + return; + }; + if let Err(err) = std::fs::remove_file(&path) + && err.kind() != std::io::ErrorKind::NotFound + { + warn!( + sandbox_id = %sandbox_id, + path = %path.display(), + error = %err, + "Failed to remove Docker sandbox token file" + ); + } + if let Some(dir) = path.parent() { + let _ = std::fs::remove_dir(dir); } } @@ -894,7 +1039,7 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig ("TERM".to_string(), "xterm".to_string()), ( "OPENSHELL_LOG_LEVEL".to_string(), - sandbox_log_level(sandbox, &config.log_level), + openshell_core::driver_utils::sandbox_log_level(sandbox, &config.log_level), ), ]); @@ -906,17 +1051,23 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig } environment.insert( - "OPENSHELL_ENDPOINT".to_string(), + openshell_core::sandbox_env::ENDPOINT.to_string(), config.grpc_endpoint.clone(), ); - environment.insert("OPENSHELL_SANDBOX_ID".to_string(), sandbox.id.clone()); - environment.insert("OPENSHELL_SANDBOX".to_string(), sandbox.name.clone()); environment.insert( - "OPENSHELL_SSH_SOCKET_PATH".to_string(), + openshell_core::sandbox_env::SANDBOX_ID.to_string(), + sandbox.id.clone(), + ); + environment.insert( + openshell_core::sandbox_env::SANDBOX.to_string(), + sandbox.name.clone(), + ); + environment.insert( + openshell_core::sandbox_env::SSH_SOCKET_PATH.to_string(), config.ssh_socket_path.clone(), ); environment.insert( - "OPENSHELL_SANDBOX_COMMAND".to_string(), + openshell_core::sandbox_env::SANDBOX_COMMAND.to_string(), SANDBOX_COMMAND.to_string(), ); // The root supervisor executes namespace helpers during bootstrap; keep @@ -924,19 +1075,33 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig environment.insert("PATH".to_string(), SUPERVISOR_PATH.to_string()); if config.guest_tls.is_some() { environment.insert( - "OPENSHELL_TLS_CA".to_string(), + openshell_core::sandbox_env::TLS_CA.to_string(), TLS_CA_MOUNT_PATH.to_string(), ); environment.insert( - "OPENSHELL_TLS_CERT".to_string(), + openshell_core::sandbox_env::TLS_CERT.to_string(), TLS_CERT_MOUNT_PATH.to_string(), ); environment.insert( - "OPENSHELL_TLS_KEY".to_string(), + openshell_core::sandbox_env::TLS_KEY.to_string(), TLS_KEY_MOUNT_PATH.to_string(), ); } + environment.remove(openshell_core::sandbox_env::SANDBOX_TOKEN); + environment.remove(openshell_core::sandbox_env::SANDBOX_TOKEN_FILE); + + // Gateway-minted sandbox JWT. Keep the raw bearer out of container + // metadata; the supervisor reads it from this driver-owned bind mount. + if let Some(spec) = sandbox.spec.as_ref() + && !spec.sandbox_token.is_empty() + { + environment.insert( + openshell_core::sandbox_env::SANDBOX_TOKEN_FILE.to_string(), + SANDBOX_TOKEN_MOUNT_PATH.to_string(), + ); + } + let mut pairs = environment.into_iter().collect::>(); pairs.sort_by(|left, right| left.0.cmp(&right.0)); pairs @@ -945,11 +1110,11 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig .collect() } -fn docker_gpu_device_requests(gpu: bool) -> Option> { - gpu.then(|| { +fn docker_gpu_device_requests(gpu: bool, gpu_device: &str) -> Option> { + cdi_gpu_device_ids(gpu, gpu_device).map(|device_ids| { vec![DeviceRequest { driver: Some("cdi".to_string()), - device_ids: Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), + device_ids: Some(device_ids), ..Default::default() }] }) @@ -970,17 +1135,17 @@ fn build_container_create_body( let resource_limits = docker_resource_limits(template)?; let mut labels = template.labels.clone(); labels.insert( - MANAGED_BY_LABEL_KEY.to_string(), - MANAGED_BY_LABEL_VALUE.to_string(), + LABEL_MANAGED_BY.to_string(), + LABEL_MANAGED_BY_VALUE.to_string(), ); - labels.insert(SANDBOX_ID_LABEL_KEY.to_string(), sandbox.id.clone()); - labels.insert(SANDBOX_NAME_LABEL_KEY.to_string(), sandbox.name.clone()); + labels.insert(LABEL_SANDBOX_ID.to_string(), sandbox.id.clone()); + labels.insert(LABEL_SANDBOX_NAME.to_string(), sandbox.name.clone()); // The list/get/find paths filter by `config.sandbox_namespace`, so use // the same value here. `DriverSandbox.namespace` is unset on the request // path (the gateway elides it), and using it would produce containers // that the driver itself cannot find afterwards. labels.insert( - SANDBOX_NAMESPACE_LABEL_KEY.to_string(), + LABEL_SANDBOX_NAMESPACE.to_string(), config.sandbox_namespace.clone(), ); @@ -996,8 +1161,8 @@ fn build_container_create_body( host_config: Some(HostConfig { nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, - device_requests: docker_gpu_device_requests(spec.gpu), - mounts: Some(build_mounts(config)), + device_requests: docker_gpu_device_requests(spec.gpu, &spec.gpu_device), + binds: Some(build_binds(sandbox, config)?), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), maximum_retry_count: None, @@ -1047,16 +1212,6 @@ fn require_sandbox_identifier(sandbox_id: &str, sandbox_name: &str) -> Result<() Ok(()) } -fn sandbox_log_level(sandbox: &DriverSandbox, default_level: &str) -> String { - sandbox - .spec - .as_ref() - .map(|spec| spec.log_level.as_str()) - .filter(|level| !level.is_empty()) - .unwrap_or(default_level) - .to_string() -} - fn docker_container_openshell_endpoint(endpoint: &str, host: &str, port: u16) -> String { let Ok(mut url) = Url::parse(endpoint) else { return endpoint.to_string(); @@ -1083,11 +1238,10 @@ fn parse_optional_host_gateway_ip(value: &str) -> CoreResult> { return Ok(None); } - trimmed.parse().map(Some).map_err(|err| { - Error::config(format!( - "invalid OPENSHELL_HOST_GATEWAY_IP value '{trimmed}': {err}" - )) - }) + trimmed + .parse() + .map(Some) + .map_err(|err| Error::config(format!("invalid host_gateway_ip value '{trimmed}': {err}"))) } fn docker_gateway_route( @@ -1095,6 +1249,22 @@ fn docker_gateway_route( bridge_gateway_ip: IpAddr, port: u16, host_gateway_ip: Option, +) -> DockerGatewayRoute { + docker_gateway_route_for_host( + info, + bridge_gateway_ip, + port, + host_gateway_ip, + host_runtime_requires_host_gateway_alias(), + ) +} + +fn docker_gateway_route_for_host( + info: &SystemInfo, + bridge_gateway_ip: IpAddr, + port: u16, + host_gateway_ip: Option, + host_requires_host_gateway_alias: bool, ) -> DockerGatewayRoute { if let Some(host_alias_ip) = host_gateway_ip { return DockerGatewayRoute::Bridge { @@ -1103,7 +1273,7 @@ fn docker_gateway_route( }; } - if is_docker_desktop(info) { + if host_requires_host_gateway_alias || uses_host_gateway_alias(info) { DockerGatewayRoute::HostGateway } else { DockerGatewayRoute::Bridge { @@ -1113,7 +1283,19 @@ fn docker_gateway_route( } } -fn is_docker_desktop(info: &SystemInfo) -> bool { +fn host_runtime_requires_host_gateway_alias() -> bool { + cfg!(target_os = "macos") +} + +/// Detect Docker Desktop and behaviourally compatible runtimes - Colima, +/// Lima, Rancher Desktop, and `OrbStack` - that share Docker Desktop's routing +/// constraint: the bridge gateway IP is reachable from inside containers but +/// not from the `OpenShell` server process running on the host, so callbacks +/// must traverse `host-gateway`. +/// +/// Each runtime is detected via the daemon's reported OS string or hostname, +/// supplemented by labels where the runtime publishes them. +fn uses_host_gateway_alias(info: &SystemInfo) -> bool { let operating_system = info .operating_system .as_deref() @@ -1123,10 +1305,25 @@ fn is_docker_desktop(info: &SystemInfo) -> bool { return true; } + let name = info + .name + .as_deref() + .unwrap_or_default() + .to_ascii_lowercase(); + if name.starts_with("colima") + || name.starts_with("lima-") + || name.starts_with("rancher-desktop") + || name.starts_with("orbstack") + { + return true; + } + info.labels.as_ref().is_some_and(|labels| { - labels - .iter() - .any(|label| label.starts_with("com.docker.desktop.")) + labels.iter().any(|label| { + label.starts_with("com.docker.desktop.") + || label.starts_with("dev.rancherdesktop.") + || label.starts_with("dev.orbstack.") + }) }) } @@ -1136,9 +1333,10 @@ fn docker_extra_hosts(route: &DockerGatewayRoute) -> Vec { format!("{HOST_DOCKER_INTERNAL}:{host_alias_ip}"), format!("{HOST_OPENSHELL_INTERNAL}:{host_alias_ip}"), ], - DockerGatewayRoute::HostGateway => { - vec![format!("{HOST_OPENSHELL_INTERNAL}:host-gateway")] - } + DockerGatewayRoute::HostGateway => vec![ + format!("{HOST_DOCKER_INTERNAL}:host-gateway"), + format!("{HOST_OPENSHELL_INTERNAL}:host-gateway"), + ], } } @@ -1159,8 +1357,8 @@ async fn ensure_bridge_network(docker: &Docker, network_name: &str) -> CoreResul driver: Some(DOCKER_NETWORK_DRIVER.to_string()), attachable: Some(true), labels: Some(HashMap::from([( - MANAGED_BY_LABEL_KEY.to_string(), - MANAGED_BY_LABEL_VALUE.to_string(), + LABEL_MANAGED_BY.to_string(), + LABEL_MANAGED_BY_VALUE.to_string(), )])), ..Default::default() }) @@ -1339,10 +1537,10 @@ fn sandbox_from_container_summary( readiness: &dyn SupervisorReadiness, ) -> Option { let labels = summary.labels.as_ref()?; - let id = labels.get(SANDBOX_ID_LABEL_KEY)?.clone(); - let name = labels.get(SANDBOX_NAME_LABEL_KEY)?.clone(); + let id = labels.get(LABEL_SANDBOX_ID)?.clone(); + let name = labels.get(LABEL_SANDBOX_NAME)?.clone(); let namespace = labels - .get(SANDBOX_NAMESPACE_LABEL_KEY) + .get(LABEL_SANDBOX_NAMESPACE) .cloned() .unwrap_or_default(); @@ -1515,8 +1713,8 @@ fn managed_container_label_filters( extra_values: impl IntoIterator, ) -> HashMap> { let mut values = vec![ - format!("{MANAGED_BY_LABEL_KEY}={MANAGED_BY_LABEL_VALUE}"), - format!("{SANDBOX_NAMESPACE_LABEL_KEY}={sandbox_namespace}"), + format!("{LABEL_MANAGED_BY}={LABEL_MANAGED_BY_VALUE}"), + format!("{LABEL_SANDBOX_NAMESPACE}={sandbox_namespace}"), ]; values.extend(extra_values); label_filters(values) @@ -1602,7 +1800,7 @@ pub(crate) async fn resolve_supervisor_bin( docker_config: &DockerComputeConfig, daemon_arch: &str, ) -> CoreResult { - // Tier 1: explicit --docker-supervisor-bin / OPENSHELL_DOCKER_SUPERVISOR_BIN. + // Tier 1: explicit supervisor_bin in [openshell.drivers.docker]. if let Some(path) = docker_config.supervisor_bin.clone() { let path = canonicalize_existing_file(&path, "docker supervisor binary")?; validate_linux_elf_binary(&path)?; @@ -1669,10 +1867,27 @@ fn linux_supervisor_candidates(daemon_arch: &str) -> Vec { /// inside the digest-keyed directory and renamed into place, so concurrent /// gateway starts don't observe a partial file. async fn extract_supervisor_bin_from_image(docker: &Docker, image: &str) -> CoreResult { + let refresh_attempted = if supervisor_image_should_refresh(image) { + info!(image = image, "Refreshing mutable docker supervisor image"); + match pull_supervisor_image(docker, image).await { + Ok(()) => true, + Err(err) => { + warn!( + image = image, + error = %err, + "failed to refresh mutable docker supervisor image; falling back to local image if present", + ); + true + } + } + } else { + false + }; + // Inspect first to see if the image is already present; only pull on miss. let inspect = match docker.inspect_image(image).await { Ok(inspect) => inspect, - Err(err) if is_not_found_error(&err) => { + Err(err) if is_not_found_error(&err) && !refresh_attempted => { info!(image = image, "Pulling docker supervisor image"); pull_supervisor_image(docker, image).await?; docker.inspect_image(image).await.map_err(|err| { @@ -1681,6 +1896,11 @@ async fn extract_supervisor_bin_from_image(docker: &Docker, image: &str) -> Core )) })? } + Err(err) if is_not_found_error(&err) => { + return Err(Error::config(format!( + "docker supervisor image '{image}' is not present locally after refresh attempt", + ))); + } Err(err) => { return Err(Error::config(format!( "failed to inspect docker supervisor image '{image}': {err}", @@ -1726,6 +1946,23 @@ async fn extract_supervisor_bin_from_image(docker: &Docker, image: &str) -> Core Ok(cache_path) } +fn supervisor_image_should_refresh(image: &str) -> bool { + matches!(supervisor_image_tag(image), Some("dev" | "latest")) +} + +fn supervisor_image_tag(image: &str) -> Option<&str> { + if image.contains('@') { + return None; + } + + let image_name = image.rsplit('/').next().unwrap_or(image); + image_name + .rsplit_once(':') + .map_or(Some("latest"), |(_, tag)| { + if tag.is_empty() { None } else { Some(tag) } + }) +} + async fn pull_supervisor_image(docker: &Docker, image: &str) -> CoreResult<()> { let mut stream = docker.create_image( Some(CreateImageOptions { @@ -1759,7 +1996,7 @@ async fn extract_supervisor_binary_bytes(docker: &Docker, image: &str) -> CoreRe ), ContainerCreateBody { image: Some(image.to_string()), - entrypoint: Some(vec!["/openshell-sandbox".to_string()]), + entrypoint: Some(vec![SUPERVISOR_IMAGE_BINARY_PATH.to_string()]), cmd: Some(Vec::new()), ..Default::default() }, @@ -1945,19 +2182,24 @@ pub(crate) fn validate_linux_elf_binary(path: &Path) -> CoreResult<()> { Ok(()) } +fn docker_guest_tls_configured(docker_config: &DockerComputeConfig) -> bool { + docker_config.guest_tls_ca.is_some() + && docker_config.guest_tls_cert.is_some() + && docker_config.guest_tls_key.is_some() +} + pub(crate) fn docker_guest_tls_paths( - config: &Config, docker_config: &DockerComputeConfig, ) -> CoreResult> { let tls_flags_provided = docker_config.guest_tls_ca.is_some() || docker_config.guest_tls_cert.is_some() || docker_config.guest_tls_key.is_some(); - if !config.grpc_endpoint.starts_with("https://") { + if !docker_config.grpc_endpoint.starts_with("https://") { if tls_flags_provided { return Err(Error::config(format!( - "--docker-tls-ca/--docker-tls-cert/--docker-tls-key were provided but OPENSHELL_GRPC_ENDPOINT is '{}'; TLS materials require an https:// endpoint", - config.grpc_endpoint, + "guest_tls_ca/guest_tls_cert/guest_tls_key were provided but grpc_endpoint is '{}'; TLS materials require an https:// endpoint", + docker_config.grpc_endpoint, ))); } return Ok(None); @@ -1970,23 +2212,23 @@ pub(crate) fn docker_guest_tls_paths( ]; if provided.iter().all(Option::is_none) { return Err(Error::config( - "docker compute driver requires --docker-tls-ca, --docker-tls-cert, and --docker-tls-key when OPENSHELL_GRPC_ENDPOINT uses https://", + "docker compute driver requires guest_tls_ca, guest_tls_cert, and guest_tls_key when grpc_endpoint uses https://", )); } let Some(ca) = docker_config.guest_tls_ca.clone() else { return Err(Error::config( - "--docker-tls-ca is required when Docker sandbox TLS materials are configured", + "guest_tls_ca is required when Docker sandbox TLS materials are configured", )); }; let Some(cert) = docker_config.guest_tls_cert.clone() else { return Err(Error::config( - "--docker-tls-cert is required when Docker sandbox TLS materials are configured", + "guest_tls_cert is required when Docker sandbox TLS materials are configured", )); }; let Some(key) = docker_config.guest_tls_key.clone() else { return Err(Error::config( - "--docker-tls-key is required when Docker sandbox TLS materials are configured", + "guest_tls_key is required when Docker sandbox TLS materials are configured", )); }; diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index e41f2688e..9afec4be4 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -2,7 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; -use openshell_core::config::DEFAULT_SERVER_PORT; +use openshell_core::config::{CDI_GPU_DEVICE_ALL, DEFAULT_SERVER_PORT}; +use openshell_core::driver_utils::{ + LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME, + LABEL_SANDBOX_NAMESPACE, +}; use openshell_core::proto::compute::v1::{ DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, }; @@ -33,6 +37,7 @@ fn test_sandbox() -> DriverSandbox { }), gpu: false, gpu_device: String::new(), + sandbox_token: String::new(), }), status: None, } @@ -74,7 +79,7 @@ fn container_visible_endpoint_rewrites_loopback_hosts() { HOST_OPENSHELL_INTERNAL, DEFAULT_SERVER_PORT, ), - "https://host.openshell.internal:8080/" + "https://host.openshell.internal:17670/" ); assert_eq!( docker_container_openshell_endpoint( @@ -82,7 +87,7 @@ fn container_visible_endpoint_rewrites_loopback_hosts() { HOST_OPENSHELL_INTERNAL, DEFAULT_SERVER_PORT, ), - "http://host.openshell.internal:8080/" + "http://host.openshell.internal:17670/" ); assert_eq!( docker_container_openshell_endpoint( @@ -90,7 +95,7 @@ fn container_visible_endpoint_rewrites_loopback_hosts() { HOST_OPENSHELL_INTERNAL, DEFAULT_SERVER_PORT, ), - "https://host.openshell.internal:8080/" + "https://host.openshell.internal:17670/" ); } @@ -160,7 +165,99 @@ fn docker_gateway_route_uses_host_gateway_for_docker_desktop() { ); assert_eq!( docker_extra_hosts(&DockerGatewayRoute::HostGateway), - vec!["host.openshell.internal:host-gateway".to_string()] + vec![ + "host.docker.internal:host-gateway".to_string(), + "host.openshell.internal:host-gateway".to_string() + ] + ); +} + +#[test] +fn docker_gateway_route_uses_host_gateway_for_colima() { + let info = SystemInfo { + name: Some("colima".to_string()), + operating_system: Some("Ubuntu 24.04.4 LTS".to_string()), + ..Default::default() + }; + + assert_eq!( + docker_gateway_route( + &info, + IpAddr::V4(Ipv4Addr::new(172, 20, 0, 1)), + DEFAULT_SERVER_PORT, + None, + ), + DockerGatewayRoute::HostGateway + ); + assert_eq!( + docker_extra_hosts(&DockerGatewayRoute::HostGateway), + vec![ + "host.docker.internal:host-gateway".to_string(), + "host.openshell.internal:host-gateway".to_string() + ] + ); +} + +#[test] +fn docker_gateway_route_uses_host_gateway_for_colima_named_profile() { + let info = SystemInfo { + operating_system: Some("Ubuntu 24.04 LTS".to_string()), + // `colima start --profile ` sets the daemon hostname to + // `colima-`; the prefix match still catches it. + name: Some("colima-default".to_string()), + ..Default::default() + }; + + assert_eq!( + docker_gateway_route( + &info, + IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), + DEFAULT_SERVER_PORT, + None, + ), + DockerGatewayRoute::HostGateway + ); +} + +#[test] +fn docker_gateway_route_uses_host_gateway_for_rancher_desktop() { + let info = SystemInfo { + operating_system: Some("Alpine Linux v3.20".to_string()), + name: Some("lima-rancher-desktop".to_string()), + labels: Some(vec![ + "dev.rancherdesktop.profile=Rancher Desktop".to_string(), + ]), + ..Default::default() + }; + + assert_eq!( + docker_gateway_route( + &info, + IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), + DEFAULT_SERVER_PORT, + None, + ), + DockerGatewayRoute::HostGateway + ); +} + +#[test] +fn docker_gateway_route_uses_host_gateway_for_orbstack() { + let info = SystemInfo { + operating_system: Some("OrbStack".to_string()), + name: Some("orbstack".to_string()), + labels: Some(vec!["dev.orbstack.machine_type=docker".to_string()]), + ..Default::default() + }; + + assert_eq!( + docker_gateway_route( + &info, + IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), + DEFAULT_SERVER_PORT, + None, + ), + DockerGatewayRoute::HostGateway ); } @@ -171,17 +268,18 @@ fn docker_gateway_route_uses_bridge_gateway_for_linux_docker() { ..Default::default() }; - let route = docker_gateway_route( + let route = docker_gateway_route_for_host( &info, IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), DEFAULT_SERVER_PORT, None, + false, ); assert_eq!( route, DockerGatewayRoute::Bridge { - bind_address: "172.18.0.1:8080".parse().unwrap(), + bind_address: "172.18.0.1:17670".parse().unwrap(), host_alias_ip: IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), } ); @@ -194,6 +292,25 @@ fn docker_gateway_route_uses_bridge_gateway_for_linux_docker() { ); } +#[test] +fn docker_gateway_route_uses_host_gateway_when_host_runtime_requires_it() { + let info = SystemInfo { + operating_system: Some("Ubuntu 24.04 LTS".to_string()), + ..Default::default() + }; + + assert_eq!( + docker_gateway_route_for_host( + &info, + IpAddr::V4(Ipv4Addr::new(10, 89, 10, 1)), + DEFAULT_SERVER_PORT, + None, + true, + ), + DockerGatewayRoute::HostGateway + ); +} + #[test] fn docker_gateway_route_prefers_configured_host_gateway_ip() { let info = SystemInfo { @@ -211,7 +328,7 @@ fn docker_gateway_route_prefers_configured_host_gateway_ip() { assert_eq!( route, DockerGatewayRoute::Bridge { - bind_address: "172.20.0.4:8080".parse().unwrap(), + bind_address: "172.20.0.4:17670".parse().unwrap(), host_alias_ip: IpAddr::V4(Ipv4Addr::new(172, 20, 0, 4)), } ); @@ -235,7 +352,7 @@ fn parse_optional_host_gateway_ip_rejects_invalid_values() { parse_optional_host_gateway_ip("not-an-ip") .unwrap_err() .to_string() - .contains("OPENSHELL_HOST_GATEWAY_IP") + .contains("host_gateway_ip") ); } @@ -274,6 +391,26 @@ fn docker_resource_limits_rejects_requests() { assert!(err.message().contains("resources.requests.cpu")); } +#[test] +fn docker_resource_limits_applies_cpu_and_memory_limits() { + let template = DriverSandboxTemplate { + image: "img".to_string(), + agent_socket_path: String::new(), + labels: HashMap::new(), + environment: HashMap::new(), + resources: Some(DriverResourceRequirements { + cpu_limit: "500m".to_string(), + memory_limit: "2Gi".to_string(), + ..Default::default() + }), + platform_config: None, + }; + + let limits = docker_resource_limits(&template).unwrap(); + assert_eq!(limits.nano_cpus, Some(500_000_000)); + assert_eq!(limits.memory_bytes, Some(2_147_483_648)); +} + #[test] fn build_environment_sets_docker_tls_paths() { let env = build_environment(&test_sandbox(), &runtime_config()); @@ -283,14 +420,6 @@ fn build_environment_sets_docker_tls_paths() { assert!(env.contains(&"TEMPLATE_ENV=template".to_string())); assert!(env.contains(&"SPEC_ENV=spec".to_string())); assert!(env.contains(&"OPENSHELL_SANDBOX_COMMAND=sleep infinity".to_string())); - assert!( - !env.iter() - .any(|entry| entry.starts_with("OPENSHELL_SSH_HANDSHAKE_SECRET=")) - ); - assert!( - !env.iter() - .any(|entry| entry.starts_with("OPENSHELL_SSH_HANDSHAKE_SKEW_SECS=")) - ); } #[test] @@ -317,11 +446,11 @@ fn build_environment_keeps_path_driver_controlled() { } #[test] -fn build_mounts_uses_docker_tls_directory() { - let mounts = build_mounts(&runtime_config()); - let targets = mounts +fn build_binds_uses_docker_tls_directory() { + let binds = build_binds(&test_sandbox(), &runtime_config()).unwrap(); + let targets = binds .iter() - .filter_map(|mount| mount.target.clone()) + .filter_map(|bind| bind.split(':').nth(1).map(String::from)) .collect::>(); assert!(targets.contains(&SUPERVISOR_MOUNT_PATH.to_string())); assert!(targets.contains(&TLS_CA_MOUNT_PATH.to_string())); @@ -334,15 +463,36 @@ fn build_mounts_uses_docker_tls_directory() { ); } +#[test] +fn build_environment_uses_token_file_without_raw_token_env() { + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.sandbox_token = "secret.jwt.value".to_string(); + spec.environment.insert( + openshell_core::sandbox_env::SANDBOX_TOKEN.to_string(), + "user-provided-token".to_string(), + ); + + let env = build_environment(&sandbox, &runtime_config()); + + assert!(!env.iter().any(|entry| { + entry.starts_with(&format!("{}=", openshell_core::sandbox_env::SANDBOX_TOKEN)) + })); + assert!(env.contains(&format!( + "{}={SANDBOX_TOKEN_MOUNT_PATH}", + openshell_core::sandbox_env::SANDBOX_TOKEN_FILE + ))); +} + #[test] fn managed_container_label_filters_include_gateway_namespace() { let filters = - managed_container_label_filters("tenant-a", [format!("{SANDBOX_ID_LABEL_KEY}=sbx-123")]); + managed_container_label_filters("tenant-a", [format!("{LABEL_SANDBOX_ID}=sbx-123")]); let labels = filters.get("label").unwrap(); - assert!(labels.contains(&format!("{MANAGED_BY_LABEL_KEY}={MANAGED_BY_LABEL_VALUE}"))); - assert!(labels.contains(&format!("{SANDBOX_NAMESPACE_LABEL_KEY}=tenant-a"))); - assert!(labels.contains(&format!("{SANDBOX_ID_LABEL_KEY}=sbx-123"))); + assert!(labels.contains(&format!("{LABEL_MANAGED_BY}={LABEL_MANAGED_BY_VALUE}"))); + assert!(labels.contains(&format!("{LABEL_SANDBOX_NAMESPACE}=tenant-a"))); + assert!(labels.contains(&format!("{LABEL_SANDBOX_ID}=sbx-123"))); } #[test] @@ -358,7 +508,7 @@ fn build_container_create_body_clears_inherited_cmd() { create_body .labels .as_ref() - .and_then(|labels| labels.get(SANDBOX_NAMESPACE_LABEL_KEY)), + .and_then(|labels| labels.get(LABEL_SANDBOX_NAMESPACE)), Some(&"default".to_string()) ); let host_config = create_body.host_config.as_ref().unwrap(); @@ -425,6 +575,30 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { ); } +#[test] +fn build_container_create_body_passes_explicit_cdi_device_id_through() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.gpu = true; + spec.gpu_device = "nvidia.com/gpu=0".to_string(); + + let create_body = build_container_create_body(&sandbox, &config).unwrap(); + let request = create_body + .host_config + .as_ref() + .and_then(|host_config| host_config.device_requests.as_ref()) + .and_then(|requests| requests.first()) + .expect("GPU request should add a Docker device request"); + + assert_eq!(request.driver.as_deref(), Some("cdi")); + assert_eq!( + request.device_ids.as_ref().unwrap(), + &vec!["nvidia.com/gpu=0".to_string()] + ); +} + #[test] fn require_sandbox_identifier_rejects_when_id_and_name_are_empty() { // Regression test: `delete_sandbox` (and the other identifier-keyed @@ -478,7 +652,7 @@ fn build_container_create_body_uses_runtime_namespace_label() { let labels = create_body.labels.expect("labels are populated"); assert_eq!( - labels.get(SANDBOX_NAMESPACE_LABEL_KEY), + labels.get(LABEL_SANDBOX_NAMESPACE), Some(&"tenant-a".to_string()), "namespace label must reflect the driver's runtime config" ); @@ -490,12 +664,9 @@ fn driver_status_keeps_running_sandboxes_provisioning_with_stable_message() { id: Some("cid".to_string()), names: Some(vec!["/openshell-demo".to_string()]), labels: Some(HashMap::from([ - (SANDBOX_ID_LABEL_KEY.to_string(), "sbx-1".to_string()), - (SANDBOX_NAME_LABEL_KEY.to_string(), "demo".to_string()), - ( - SANDBOX_NAMESPACE_LABEL_KEY.to_string(), - "default".to_string(), - ), + (LABEL_SANDBOX_ID.to_string(), "sbx-1".to_string()), + (LABEL_SANDBOX_NAME.to_string(), "demo".to_string()), + (LABEL_SANDBOX_NAMESPACE.to_string(), "default".to_string()), ])), state: Some(ContainerSummaryStateEnum::RUNNING), status: Some("Up 2 seconds".to_string()), @@ -547,12 +718,9 @@ fn driver_status_marks_restarting_sandboxes_as_error() { id: Some("cid".to_string()), names: Some(vec!["/openshell-demo".to_string()]), labels: Some(HashMap::from([ - (SANDBOX_ID_LABEL_KEY.to_string(), "sbx-1".to_string()), - (SANDBOX_NAME_LABEL_KEY.to_string(), "demo".to_string()), - ( - SANDBOX_NAMESPACE_LABEL_KEY.to_string(), - "default".to_string(), - ), + (LABEL_SANDBOX_ID.to_string(), "sbx-1".to_string()), + (LABEL_SANDBOX_NAME.to_string(), "demo".to_string()), + (LABEL_SANDBOX_NAMESPACE.to_string(), "default".to_string()), ])), state: Some(ContainerSummaryStateEnum::RESTARTING), status: Some("Restarting (1) 2 seconds ago".to_string()), @@ -580,20 +748,17 @@ fn validate_linux_elf_binary_rejects_non_elf_files() { #[test] fn docker_guest_tls_paths_require_all_files_for_https() { - let config = Config::new(None).with_grpc_endpoint("https://localhost:8443"); let tempdir = TempDir::new().unwrap(); let ca = tempdir.path().join("ca.crt"); fs::write(&ca, b"ca").unwrap(); - let err = docker_guest_tls_paths( - &config, - &DockerComputeConfig { - guest_tls_ca: Some(ca), - ..Default::default() - }, - ) + let err = docker_guest_tls_paths(&DockerComputeConfig { + grpc_endpoint: "https://localhost:8443".to_string(), + guest_tls_ca: Some(ca), + ..Default::default() + }) .unwrap_err(); - assert!(err.to_string().contains("--docker-tls-cert")); + assert!(err.to_string().contains("guest_tls_cert")); } #[test] @@ -670,26 +835,26 @@ fn trim_container_name_tail_strips_separators() { #[test] fn docker_guest_tls_paths_rejects_tls_flags_without_https() { - let config = Config::new(None).with_grpc_endpoint("http://localhost:8080"); let tempdir = TempDir::new().unwrap(); let ca = tempdir.path().join("ca.crt"); fs::write(&ca, b"ca").unwrap(); - let err = docker_guest_tls_paths( - &config, - &DockerComputeConfig { - guest_tls_ca: Some(ca), - ..Default::default() - }, - ) + let err = docker_guest_tls_paths(&DockerComputeConfig { + grpc_endpoint: "http://localhost:8080".to_string(), + guest_tls_ca: Some(ca), + ..Default::default() + }) .unwrap_err(); assert!(err.to_string().contains("https://")); } #[test] fn docker_guest_tls_paths_allows_plain_http_without_tls_flags() { - let config = Config::new(None).with_grpc_endpoint("http://localhost:8080"); - let result = docker_guest_tls_paths(&config, &DockerComputeConfig::default()).unwrap(); + let result = docker_guest_tls_paths(&DockerComputeConfig { + grpc_endpoint: "http://localhost:8080".to_string(), + ..Default::default() + }) + .unwrap(); assert!(result.is_none()); } @@ -722,6 +887,41 @@ fn docker_supervisor_image_tag_prefers_explicit_build_tags() { ); } +#[test] +fn docker_supervisor_image_tag_sanitizes_build_metadata_for_docker() { + assert_eq!( + resolve_default_docker_supervisor_image_tag(None, None, "0.0.37-dev.156+g1d3b741ee"), + "0.0.37-dev.156-g1d3b741ee", + ); + assert_eq!( + resolve_default_docker_supervisor_image_tag( + Some("0.0.37-dev.156+g1d3b741ee"), + None, + "0.0.0", + ), + "0.0.37-dev.156-g1d3b741ee", + ); +} + +#[test] +fn docker_supervisor_image_refreshes_mutable_tags_only() { + assert!(supervisor_image_should_refresh( + "ghcr.io/nvidia/openshell/supervisor:dev" + )); + assert!(supervisor_image_should_refresh( + "ghcr.io/nvidia/openshell/supervisor:latest" + )); + assert!(supervisor_image_should_refresh( + "ghcr.io/nvidia/openshell/supervisor" + )); + assert!(!supervisor_image_should_refresh( + "ghcr.io/nvidia/openshell/supervisor:0.0.47-dev.13-g57b71c68f" + )); + assert!(!supervisor_image_should_refresh( + "ghcr.io/nvidia/openshell/supervisor@sha256:abc123" + )); +} + #[test] fn supervisor_cache_path_namespaces_by_digest_under_openshell_data_dir() { let base = PathBuf::from("/var/cache/share"); diff --git a/crates/openshell-driver-kubernetes/Cargo.toml b/crates/openshell-driver-kubernetes/Cargo.toml index 5e247dc77..c222c9c31 100644 --- a/crates/openshell-driver-kubernetes/Cargo.toml +++ b/crates/openshell-driver-kubernetes/Cargo.toml @@ -26,6 +26,7 @@ tokio-stream = { workspace = true } kube = { workspace = true } kube-runtime = { workspace = true } k8s-openapi = { workspace = true } +serde = { workspace = true } serde_json = { workspace = true } clap = { workspace = true } tracing = { workspace = true } diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 4a8a8f76b..1d45a1d83 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -38,6 +38,12 @@ The driver injects gateway callback configuration, sandbox identity, TLS client material, and the supervisor SSH socket path into the workload. Driver-owned values must override image-provided environment variables. +Sandbox pods run as `service_account_name` and keep +`automountServiceAccountToken: false`. The only Kubernetes token exposed to the +supervisor is an explicit, audience-bound projected token mounted at +`/var/run/secrets/openshell/token` for the one-shot `IssueSandboxToken` +bootstrap exchange. + The gateway uses the supervisor relay for connect, exec, and file sync. Sandbox pods do not need direct external ingress for SSH. diff --git a/crates/openshell-driver-kubernetes/src/config.rs b/crates/openshell-driver-kubernetes/src/config.rs index 838262c77..d71133465 100644 --- a/crates/openshell-driver-kubernetes/src/config.rs +++ b/crates/openshell-driver-kubernetes/src/config.rs @@ -1,22 +1,175 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -#[derive(Debug, Clone)] +use openshell_core::config::DEFAULT_SUPERVISOR_IMAGE; +use serde::{Deserialize, Serialize}; + +/// Default Kubernetes namespace for sandbox resources. +pub const DEFAULT_K8S_NAMESPACE: &str = "openshell"; + +/// Default Kubernetes `ServiceAccount` assigned to sandbox pods. +pub const DEFAULT_SANDBOX_SERVICE_ACCOUNT_NAME: &str = "default"; + +/// Default storage size for the workspace PVC. +pub const DEFAULT_WORKSPACE_STORAGE_SIZE: &str = "2Gi"; + +/// How the supervisor binary is delivered into sandbox pods. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum SupervisorSideloadMethod { + /// Mount the supervisor OCI image directly as a read-only volume + /// (requires Kubernetes >= v1.33 with the `ImageVolume` feature gate, + /// or >= v1.36 where it is GA). + #[default] + ImageVolume, + /// Copy the binary via an init container and emptyDir volume. + /// Works on all Kubernetes versions. + InitContainer, +} + +impl std::fmt::Display for SupervisorSideloadMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ImageVolume => f.write_str("image-volume"), + Self::InitContainer => f.write_str("init-container"), + } + } +} + +impl std::str::FromStr for SupervisorSideloadMethod { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "image-volume" => Ok(Self::ImageVolume), + "init-container" => Ok(Self::InitContainer), + other => Err(format!( + "unknown supervisor sideload method '{other}'; expected 'image-volume' or 'init-container'" + )), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default, deny_unknown_fields)] pub struct KubernetesComputeConfig { pub namespace: String, + /// Kubernetes `ServiceAccount` assigned to sandbox pods and accepted by + /// the gateway's `TokenReview` bootstrap authenticator. + pub service_account_name: String, pub default_image: String, pub image_pull_policy: String, /// Image that provides the `openshell-sandbox` supervisor binary. - /// An init container copies the binary from this image into a shared - /// emptyDir volume before the sandbox container starts. + /// Mounted directly as an image volume, or copied via an init container, + /// depending on `supervisor_sideload_method`. pub supervisor_image: String, - /// Kubernetes `imagePullPolicy` for the supervisor init container. + /// Kubernetes `imagePullPolicy` for the supervisor image. /// Empty string delegates to the Kubernetes default. pub supervisor_image_pull_policy: String, + /// How the supervisor binary is delivered into sandbox pods. + pub supervisor_sideload_method: SupervisorSideloadMethod, pub grpc_endpoint: String, pub ssh_socket_path: String, - pub ssh_handshake_secret: String, - pub ssh_handshake_skew_secs: u64, pub client_tls_secret_name: String, pub host_gateway_ip: String, + pub enable_user_namespaces: bool, + pub workspace_default_storage_size: String, + /// Lifetime (seconds) of the projected `ServiceAccount` token kubelet + /// writes into each sandbox pod. Used only for the one-shot + /// `IssueSandboxToken` bootstrap exchange — the gateway-minted JWT + /// that follows has its own TTL set via `gateway_jwt.ttl_secs`. + /// + /// Kubelet enforces a minimum of 600 seconds; the supervisor uses + /// this token within a few seconds of pod start, so any value at + /// the floor is sufficient. Default 3600. + pub sa_token_ttl_secs: i64, +} + +/// Lower bound enforced by kubelet for projected SA tokens. +pub const MIN_SA_TOKEN_TTL_SECS: i64 = 600; + +/// Cap at 24h — operators who want longer-lived bootstrap tokens are +/// almost certainly misconfigured (the token is consumed seconds after +/// pod start). +pub const MAX_SA_TOKEN_TTL_SECS: i64 = 86_400; + +impl Default for KubernetesComputeConfig { + fn default() -> Self { + Self { + namespace: DEFAULT_K8S_NAMESPACE.to_string(), + service_account_name: DEFAULT_SANDBOX_SERVICE_ACCOUNT_NAME.to_string(), + default_image: openshell_core::image::default_sandbox_image(), + // Default empty so the gateway omits `imagePullPolicy` from pod + // specs and Kubernetes applies its own default (Always for `latest`, + // IfNotPresent otherwise). `DEFAULT_IMAGE_PULL_POLICY` ("missing") + // is Podman vocabulary and is not a valid Kubernetes value. + image_pull_policy: String::new(), + supervisor_image: DEFAULT_SUPERVISOR_IMAGE.to_string(), + supervisor_image_pull_policy: String::new(), + supervisor_sideload_method: SupervisorSideloadMethod::default(), + grpc_endpoint: String::new(), + ssh_socket_path: "/run/openshell/ssh.sock".to_string(), + client_tls_secret_name: String::new(), + host_gateway_ip: String::new(), + enable_user_namespaces: false, + workspace_default_storage_size: DEFAULT_WORKSPACE_STORAGE_SIZE.to_string(), + sa_token_ttl_secs: 3600, + } + } +} + +impl KubernetesComputeConfig { + /// Clamp `sa_token_ttl_secs` into the `[MIN_SA_TOKEN_TTL_SECS, + /// MAX_SA_TOKEN_TTL_SECS]` range used by the projected-volume spec. + /// Invalid (≤0) values fall back to the default 3600. + #[must_use] + pub fn effective_sa_token_ttl_secs(&self) -> i64 { + if self.sa_token_ttl_secs <= 0 { + 3600 + } else { + self.sa_token_ttl_secs + .clamp(MIN_SA_TOKEN_TTL_SECS, MAX_SA_TOKEN_TTL_SECS) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_workspace_storage_size_is_2gi() { + let cfg = KubernetesComputeConfig::default(); + assert_eq!( + cfg.workspace_default_storage_size, + DEFAULT_WORKSPACE_STORAGE_SIZE + ); + } + + #[test] + fn default_service_account_name_is_default() { + let cfg = KubernetesComputeConfig::default(); + assert_eq!( + cfg.service_account_name, + DEFAULT_SANDBOX_SERVICE_ACCOUNT_NAME + ); + } + + #[test] + fn serde_override_workspace_storage_size() { + let json = serde_json::json!({ + "workspace_default_storage_size": "10Gi" + }); + let cfg: KubernetesComputeConfig = serde_json::from_value(json).unwrap(); + assert_eq!(cfg.workspace_default_storage_size, "10Gi"); + } + + #[test] + fn serde_override_service_account_name() { + let json = serde_json::json!({ + "service_account_name": "openshell-sandbox" + }); + let cfg: KubernetesComputeConfig = serde_json::from_value(json).unwrap(); + assert_eq!(cfg.service_account_name, "openshell-sandbox"); + } } diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index e2d06044d..0a428f146 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -3,7 +3,10 @@ //! Kubernetes compute driver. -use crate::config::KubernetesComputeConfig; +use crate::config::{ + DEFAULT_SANDBOX_SERVICE_ACCOUNT_NAME, DEFAULT_WORKSPACE_STORAGE_SIZE, KubernetesComputeConfig, + SupervisorSideloadMethod, +}; use futures::{Stream, StreamExt, TryStreamExt}; use k8s_openapi::api::core::v1::{Event as KubeEventObj, Node}; use kube::api::{Api, ApiResource, DeleteParams, ListParams, PostParams}; @@ -11,6 +14,13 @@ use kube::core::gvk::GroupVersionKind; use kube::core::{DynamicObject, ObjectMeta}; use kube::runtime::watcher::{self, Event}; use kube::{Client, Error as KubeError}; +use openshell_core::driver_utils::{ + LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, SUPERVISOR_IMAGE_BINARY_PATH, +}; +use openshell_core::progress::{ + PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, + mark_progress_active, mark_progress_complete, mark_progress_detail, +}; use openshell_core::proto::compute::v1::{ DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, @@ -47,6 +57,16 @@ impl KubernetesDriverError { } } +impl From for openshell_core::ComputeDriverError { + fn from(err: KubernetesDriverError) -> Self { + match err { + KubernetesDriverError::AlreadyExists => Self::AlreadyExists, + KubernetesDriverError::Precondition(m) => Self::Precondition(m), + KubernetesDriverError::Message(m) => Self::Message(m), + } + } +} + /// Timeout for individual Kubernetes API calls (create, delete, get). /// This prevents gRPC handlers from blocking indefinitely when the k8s /// API server is unreachable or slow. @@ -55,9 +75,7 @@ const KUBE_API_TIMEOUT: Duration = Duration::from_secs(30); const SANDBOX_GROUP: &str = "agents.x-k8s.io"; const SANDBOX_VERSION: &str = "v1alpha1"; pub const SANDBOX_KIND: &str = "Sandbox"; -const SANDBOX_ID_LABEL: &str = "openshell.ai/sandbox-id"; -const SANDBOX_MANAGED_LABEL: &str = "openshell.ai/managed-by"; -const SANDBOX_MANAGED_VALUE: &str = "openshell"; + const GPU_RESOURCE_NAME: &str = "nvidia.com/gpu"; const GPU_RESOURCE_QUANTITY: &str = "1"; @@ -92,9 +110,6 @@ const WORKSPACE_INIT_MOUNT_PATH: &str = "/workspace-pvc"; /// Name of the init container that seeds the workspace PVC. const WORKSPACE_INIT_CONTAINER_NAME: &str = "workspace-init"; -/// Default storage request for the workspace PVC. -const WORKSPACE_DEFAULT_STORAGE: &str = "2Gi"; - /// Sentinel file written by the init container after copying the image's /// `/sandbox` contents. Subsequent pod starts skip the copy. const WORKSPACE_SENTINEL: &str = ".workspace-initialized"; @@ -145,13 +160,12 @@ impl KubernetesComputeDriver { } pub async fn capabilities(&self) -> Result { - Ok(GetCapabilitiesResponse { - driver_name: "kubernetes".to_string(), - driver_version: openshell_core::VERSION.to_string(), - default_image: self.config.default_image.clone(), - supports_gpu: self.has_gpu_capacity().await.unwrap_or(false), - gpu_count: 0, - }) + Ok(openshell_core::driver_utils::build_capabilities_response( + "kubernetes", + openshell_core::VERSION, + &self.config.default_image, + self.has_gpu_capacity().await.unwrap_or(false), + )) } pub fn default_image(&self) -> &str { @@ -166,10 +180,6 @@ impl KubernetesComputeDriver { &self.config.ssh_socket_path } - pub const fn ssh_handshake_skew_secs(&self) -> u64 { - self.config.ssh_handshake_skew_secs - } - fn watch_api(&self) -> Api { let gvk = GroupVersionKind::gvk(SANDBOX_GROUP, SANDBOX_VERSION, SANDBOX_KIND); let resource = ApiResource::from_gvk(&gvk); @@ -286,10 +296,6 @@ impl KubernetesComputeDriver { } } - fn ssh_handshake_secret(&self) -> &str { - &self.config.ssh_handshake_secret - } - pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { let name = sandbox.name.as_str(); info!( @@ -308,21 +314,24 @@ impl KubernetesComputeDriver { labels: Some(sandbox_labels(sandbox)), ..Default::default() }; - obj.data = sandbox_to_k8s_spec( - sandbox.spec.as_ref(), - &self.config.default_image, - &self.config.image_pull_policy, - &self.config.supervisor_image, - &self.config.supervisor_image_pull_policy, - &sandbox.id, - &sandbox.name, - &self.config.grpc_endpoint, - self.ssh_socket_path(), - self.ssh_handshake_secret(), - self.ssh_handshake_skew_secs(), - &self.config.client_tls_secret_name, - &self.config.host_gateway_ip, - ); + let params = SandboxPodParams { + default_image: &self.config.default_image, + image_pull_policy: &self.config.image_pull_policy, + supervisor_image: &self.config.supervisor_image, + supervisor_image_pull_policy: &self.config.supervisor_image_pull_policy, + supervisor_sideload_method: self.config.supervisor_sideload_method, + service_account_name: &self.config.service_account_name, + sandbox_id: &sandbox.id, + sandbox_name: &sandbox.name, + grpc_endpoint: &self.config.grpc_endpoint, + ssh_socket_path: self.ssh_socket_path(), + client_tls_secret_name: &self.config.client_tls_secret_name, + host_gateway_ip: &self.config.host_gateway_ip, + enable_user_namespaces: self.config.enable_user_namespaces, + workspace_default_storage_size: &self.config.workspace_default_storage_size, + sa_token_ttl_secs: self.config.effective_sa_token_ttl_secs(), + }; + obj.data = sandbox_to_k8s_spec(sandbox.spec.as_ref(), ¶ms); let api = self.api(); match tokio::time::timeout(KUBE_API_TIMEOUT, api.create(&PostParams::default(), &obj)).await @@ -545,17 +554,17 @@ impl KubernetesComputeDriver { fn sandbox_labels(sandbox: &Sandbox) -> BTreeMap { let mut labels = BTreeMap::new(); - labels.insert(SANDBOX_ID_LABEL.to_string(), sandbox.id.clone()); + labels.insert(LABEL_SANDBOX_ID.to_string(), sandbox.id.clone()); labels.insert( - SANDBOX_MANAGED_LABEL.to_string(), - SANDBOX_MANAGED_VALUE.to_string(), + LABEL_MANAGED_BY.to_string(), + LABEL_MANAGED_BY_VALUE.to_string(), ); labels } fn sandbox_id_from_object(obj: &DynamicObject) -> Result { if let Some(labels) = obj.metadata.labels.as_ref() - && let Some(id) = labels.get(SANDBOX_ID_LABEL) + && let Some(id) = labels.get(LABEL_SANDBOX_ID) { return Ok(id.clone()); } @@ -644,6 +653,11 @@ fn map_kube_event_to_platform( if let Some(count) = obj.count { metadata.insert("count".to_string(), count.to_string()); } + attach_kube_progress_metadata( + &mut metadata, + obj.reason.as_deref().unwrap_or_default(), + obj.message.as_deref().unwrap_or_default(), + ); Some(( sandbox_id, @@ -658,6 +672,76 @@ fn map_kube_event_to_platform( )) } +fn attach_kube_progress_metadata( + metadata: &mut std::collections::HashMap, + reason: &str, + message: &str, +) { + match reason { + "Scheduled" => { + mark_progress_complete( + metadata, + PROGRESS_STEP_REQUESTING_SANDBOX, + "Sandbox allocated", + ); + mark_progress_active(metadata, PROGRESS_STEP_PULLING_IMAGE); + } + "Pulling" => { + mark_progress_active(metadata, PROGRESS_STEP_PULLING_IMAGE); + if let Some(image) = pulling_image_from_kube_message(message) { + mark_progress_detail(metadata, image); + } + } + "Pulled" => { + let label = pulled_image_label(message); + mark_progress_complete(metadata, PROGRESS_STEP_PULLING_IMAGE, label); + mark_progress_active(metadata, PROGRESS_STEP_STARTING_SANDBOX); + } + _ => {} + } +} + +fn pulling_image_from_kube_message(message: &str) -> Option { + let image = message + .strip_prefix("Pulling image ") + .map(str::trim) + .map(|value| value.trim_matches('"'))?; + (!image.is_empty()).then(|| image.to_string()) +} + +fn pulled_image_label(message: &str) -> String { + extract_image_size(message).map_or_else( + || "Image pulled".to_string(), + |bytes| format!("Image pulled ({})", format_bytes(bytes)), + ) +} + +fn extract_image_size(message: &str) -> Option { + let size_prefix = "Image size: "; + let start = message.find(size_prefix)? + size_prefix.len(); + let rest = &message[start..]; + let end = rest.find(' ')?; + rest[..end].parse().ok() +} + +fn format_bytes(bytes: u64) -> String { + const KB: u64 = 1024; + const MB: u64 = 1024 * KB; + const GB: u64 = 1024 * MB; + + if bytes >= GB { + #[allow(clippy::cast_precision_loss)] + let gb = bytes as f64 / GB as f64; + format!("{gb:.1} GB") + } else if bytes >= MB { + format!("{} MB", bytes / MB) + } else if bytes >= KB { + format!("{} KB", bytes / KB) + } else { + format!("{bytes} B") + } +} + /// Path where the supervisor binary is mounted inside the agent container. const SUPERVISOR_MOUNT_PATH: &str = "/opt/openshell/bin"; @@ -686,21 +770,35 @@ fn supervisor_volume_mount() -> serde_json::Value { }) } -/// Path of the supervisor binary inside the supervisor image. +/// Build an image volume that mounts the supervisor OCI image directly. /// -/// The supervisor image places the binary at the filesystem root and ships -/// nothing else. We invoke it directly — there is no shell, `cp`, or PATH -/// resolution available inside the image. -const SUPERVISOR_IMAGE_BINARY_PATH: &str = "/openshell-sandbox"; +/// Requires Kubernetes >= v1.33 (`ImageVolume` beta) or >= v1.36 (GA). +/// The entire image filesystem is mounted read-only, making the binary +/// available at `{SUPERVISOR_MOUNT_PATH}/openshell-sandbox`. +fn supervisor_image_volume( + supervisor_image: &str, + supervisor_image_pull_policy: &str, +) -> serde_json::Value { + let mut image_spec = serde_json::json!({ + "reference": supervisor_image, + }); + if !supervisor_image_pull_policy.is_empty() { + image_spec["pullPolicy"] = serde_json::json!(supervisor_image_pull_policy); + } + serde_json::json!({ + "name": SUPERVISOR_VOLUME_NAME, + "image": image_spec + }) +} /// Build the init container that copies the supervisor binary into the emptyDir. /// -/// The supervisor image contains only the supervisor binary at -/// `/openshell-sandbox`. We invoke that binary with the `copy-self` -/// subcommand so it copies itself into the shared emptyDir volume, where the -/// agent container then executes it from a fixed, writable path. This pattern -/// (binary self-copy) avoids requiring `sh`/`cp` in the supervisor image and -/// mirrors the approach used by argoexec's emissary executor. +/// The supervisor image contains the supervisor binary at `/openshell-sandbox`. +/// We invoke that binary with the `copy-self` subcommand so it copies itself +/// into the shared emptyDir volume, where the agent container then executes it +/// from a fixed, writable path. This pattern (binary self-copy) avoids requiring +/// `sh`/`cp` in the supervisor image and mirrors the approach used by argoexec's +/// emissary executor. fn supervisor_init_container( supervisor_image: &str, supervisor_image_pull_policy: &str, @@ -729,43 +827,56 @@ fn supervisor_init_container( /// Apply supervisor side-load transforms to an already-built pod template JSON. /// -/// Injects an emptyDir volume, an init container that copies the supervisor -/// binary from the supervisor image into that volume, and a read-only volume -/// mount + command override on the agent container. +/// Depending on the sideload method: +/// - **`ImageVolume`**: mounts the supervisor OCI image directly as a read-only +/// volume (no init container needed, requires K8s >= v1.33). +/// - **`InitContainer`**: injects an emptyDir volume and an init container that +/// copies the supervisor binary from the supervisor image into that volume. /// -/// The `runAsUser: 0` override ensures the supervisor binary runs as root -/// regardless of the image's `USER` directive. The supervisor needs root for -/// network namespace creation, proxy setup, and Landlock/seccomp configuration. -/// It drops to the appropriate non-root user for child processes via the -/// policy's `run_as_user`/`run_as_group`. +/// In both cases, the agent container gets a command override to run the +/// side-loaded binary and `runAsUser: 0` so it can create network namespaces, +/// set up the proxy, and configure Landlock/seccomp. fn apply_supervisor_sideload( pod_template: &mut serde_json::Value, supervisor_image: &str, supervisor_image_pull_policy: &str, + method: SupervisorSideloadMethod, ) { let Some(spec) = pod_template.get_mut("spec").and_then(|v| v.as_object_mut()) else { return; }; - // 1. Add the emptyDir volume to spec.volumes + // 1. Add the volume (image source or emptyDir depending on method) let volumes = spec .entry("volumes") .or_insert_with(|| serde_json::json!([])) .as_array_mut(); if let Some(volumes) = volumes { - volumes.push(supervisor_volume()); + match method { + SupervisorSideloadMethod::ImageVolume => { + volumes.push(supervisor_image_volume( + supervisor_image, + supervisor_image_pull_policy, + )); + } + SupervisorSideloadMethod::InitContainer => { + volumes.push(supervisor_volume()); + } + } } - // 2. Add the init container that copies the binary into the emptyDir - let init_containers = spec - .entry("initContainers") - .or_insert_with(|| serde_json::json!([])) - .as_array_mut(); - if let Some(init_containers) = init_containers { - init_containers.push(supervisor_init_container( - supervisor_image, - supervisor_image_pull_policy, - )); + // 2. Add the init container only for the init-container method + if method == SupervisorSideloadMethod::InitContainer { + let init_containers = spec + .entry("initContainers") + .or_insert_with(|| serde_json::json!([])) + .as_array_mut(); + if let Some(init_containers) = init_containers { + init_containers.push(supervisor_init_container( + supervisor_image, + supervisor_image_pull_policy, + )); + } } // 3. Find the agent container and add volume mount + command override @@ -910,7 +1021,12 @@ fn apply_workspace_persistence( /// /// Provides a single PVC named "workspace" that backs the `/sandbox` /// directory. The init container seeds it from the image on first use. -fn default_workspace_volume_claim_templates() -> serde_json::Value { +fn default_workspace_volume_claim_templates(storage_size: &str) -> serde_json::Value { + let size = if storage_size.is_empty() { + DEFAULT_WORKSPACE_STORAGE_SIZE + } else { + storage_size + }; serde_json::json!([{ "metadata": { "name": WORKSPACE_VOLUME_NAME @@ -919,28 +1035,70 @@ fn default_workspace_volume_claim_templates() -> serde_json::Value { "accessModes": ["ReadWriteOnce"], "resources": { "requests": { - "storage": WORKSPACE_DEFAULT_STORAGE + "storage": size } } } }]) } -#[allow(clippy::too_many_arguments)] +/// Parameters shared by `sandbox_to_k8s_spec` and `sandbox_template_to_k8s`. +struct SandboxPodParams<'a> { + default_image: &'a str, + image_pull_policy: &'a str, + supervisor_image: &'a str, + supervisor_image_pull_policy: &'a str, + supervisor_sideload_method: SupervisorSideloadMethod, + service_account_name: &'a str, + sandbox_id: &'a str, + sandbox_name: &'a str, + grpc_endpoint: &'a str, + ssh_socket_path: &'a str, + client_tls_secret_name: &'a str, + host_gateway_ip: &'a str, + enable_user_namespaces: bool, + workspace_default_storage_size: &'a str, + /// Lifetime (seconds) of the projected `ServiceAccount` token used + /// for the bootstrap `IssueSandboxToken` exchange. + sa_token_ttl_secs: i64, +} + +impl Default for SandboxPodParams<'_> { + fn default() -> Self { + Self { + default_image: "", + image_pull_policy: "", + supervisor_image: "", + supervisor_image_pull_policy: "", + supervisor_sideload_method: SupervisorSideloadMethod::default(), + service_account_name: DEFAULT_SANDBOX_SERVICE_ACCOUNT_NAME, + sandbox_id: "", + sandbox_name: "", + grpc_endpoint: "", + ssh_socket_path: "", + client_tls_secret_name: "", + host_gateway_ip: "", + enable_user_namespaces: false, + workspace_default_storage_size: DEFAULT_WORKSPACE_STORAGE_SIZE, + sa_token_ttl_secs: 3600, + } + } +} + +fn spec_pod_env(spec: Option<&SandboxSpec>) -> std::collections::HashMap { + let mut env = spec.map_or_else(Default::default, |s| s.environment.clone()); + if let Some(s) = spec.filter(|s| !s.log_level.is_empty()) { + env.insert( + openshell_core::sandbox_env::LOG_LEVEL.to_string(), + s.log_level.clone(), + ); + } + env +} + fn sandbox_to_k8s_spec( spec: Option<&SandboxSpec>, - default_image: &str, - image_pull_policy: &str, - supervisor_image: &str, - supervisor_image_pull_policy: &str, - sandbox_id: &str, - sandbox_name: &str, - grpc_endpoint: &str, - ssh_socket_path: &str, - ssh_handshake_secret: &str, - ssh_handshake_skew_secs: u64, - client_tls_secret_name: &str, - host_gateway_ip: &str, + params: &SandboxPodParams<'_>, ) -> serde_json::Value { let mut root = serde_json::Map::new(); @@ -956,36 +1114,11 @@ fn sandbox_to_k8s_spec( let inject_workspace = !user_has_vct; if let Some(spec) = spec { - if !spec.log_level.is_empty() { - root.insert("logLevel".to_string(), serde_json::json!(spec.log_level)); - } - if !spec.environment.is_empty() { - root.insert( - "environment".to_string(), - serde_json::json!(spec.environment), - ); - } + let pod_env = spec_pod_env(Some(spec)); if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s( - template, - spec.gpu, - default_image, - image_pull_policy, - supervisor_image, - supervisor_image_pull_policy, - sandbox_id, - sandbox_name, - grpc_endpoint, - ssh_socket_path, - ssh_handshake_secret, - ssh_handshake_skew_secs, - &spec.environment, - client_tls_secret_name, - host_gateway_ip, - inject_workspace, - ), + sandbox_template_to_k8s(template, spec.gpu, &pod_env, inject_workspace, params), ); if !template.agent_socket_path.is_empty() { root.insert( @@ -1006,33 +1139,21 @@ fn sandbox_to_k8s_spec( if inject_workspace { root.insert( "volumeClaimTemplates".to_string(), - default_workspace_volume_claim_templates(), + default_workspace_volume_claim_templates(params.workspace_default_storage_size), ); } // podTemplate is required by the Kubernetes CRD - ensure it's always present if !root.contains_key("podTemplate") { - let empty_env = std::collections::HashMap::new(); - let spec_env = spec.as_ref().map_or(&empty_env, |s| &s.environment); + let pod_env = spec_pod_env(spec); root.insert( "podTemplate".to_string(), sandbox_template_to_k8s( &SandboxTemplate::default(), - spec.as_ref().is_some_and(|s| s.gpu), - default_image, - image_pull_policy, - supervisor_image, - supervisor_image_pull_policy, - sandbox_id, - sandbox_name, - grpc_endpoint, - ssh_socket_path, - ssh_handshake_secret, - ssh_handshake_skew_secs, - spec_env, - client_tls_secret_name, - host_gateway_ip, + spec.is_some_and(|s| s.gpu), + &pod_env, inject_workspace, + params, ), ); } @@ -1042,31 +1163,40 @@ fn sandbox_to_k8s_spec( ) } -#[allow(clippy::too_many_arguments)] fn sandbox_template_to_k8s( template: &SandboxTemplate, gpu: bool, - default_image: &str, - image_pull_policy: &str, - supervisor_image: &str, - supervisor_image_pull_policy: &str, - sandbox_id: &str, - sandbox_name: &str, - grpc_endpoint: &str, - ssh_socket_path: &str, - ssh_handshake_secret: &str, - ssh_handshake_skew_secs: u64, spec_environment: &std::collections::HashMap, - client_tls_secret_name: &str, - host_gateway_ip: &str, inject_workspace: bool, + params: &SandboxPodParams<'_>, ) -> serde_json::Value { let mut metadata = serde_json::Map::new(); if !template.labels.is_empty() { metadata.insert("labels".to_string(), serde_json::json!(template.labels)); } - if let Some(annotations) = platform_config_struct(template, "annotations") { - metadata.insert("annotations".to_string(), annotations); + // Carry the sandbox UUID as a pod annotation so the gateway can resolve + // a projected SA token claim (pod name + uid) back to a sandbox identity + // when the supervisor calls `IssueSandboxToken` at startup. The gateway + // also verifies the pod's controlling Sandbox ownerReference against the + // live CR before accepting this annotation. Its K8s Role does NOT grant + // `patch pods`, so this annotation is effectively immutable post-create. + let mut pod_annotations = platform_config_struct(template, "annotations") + .and_then(|v| match v { + serde_json::Value::Object(map) => Some(map), + _ => None, + }) + .unwrap_or_default(); + if !params.sandbox_id.is_empty() { + pod_annotations.insert( + "openshell.io/sandbox-id".to_string(), + serde_json::Value::String(params.sandbox_id.to_string()), + ); + } + if !pod_annotations.is_empty() { + metadata.insert( + "annotations".to_string(), + serde_json::Value::Object(pod_annotations), + ); } let mut spec = serde_json::Map::new(); @@ -1076,21 +1206,55 @@ fn sandbox_template_to_k8s( serde_json::json!(runtime_class), ); } + if let Some(node_selector) = platform_config_struct(template, "node_selector") { + spec.insert("nodeSelector".to_string(), node_selector); + } + if let Some(tolerations) = platform_config_struct(template, "tolerations") { + spec.insert("tolerations".to_string(), tolerations); + } + + // Per-sandbox platform_config.host_users overrides the cluster-wide default. + let use_user_namespaces = platform_config_bool(template, "host_users") + .map_or(params.enable_user_namespaces, |host_users| !host_users); + + if use_user_namespaces { + spec.insert("hostUsers".to_string(), serde_json::json!(false)); + if gpu { + warn!( + "GPU sandbox with user namespaces enabled — \ + NVIDIA device plugin compatibility is unverified" + ); + } + } + + if !params.service_account_name.is_empty() { + spec.insert( + "serviceAccountName".to_string(), + serde_json::json!(params.service_account_name), + ); + } + + // Disable service account token auto-mounting for security hardening. + // Sandbox pods should not have access to the Kubernetes API by default. + spec.insert( + "automountServiceAccountToken".to_string(), + serde_json::json!(false), + ); let mut container = serde_json::Map::new(); container.insert("name".to_string(), serde_json::json!("agent")); // Use template image if provided, otherwise fall back to default let image = if template.image.is_empty() { - default_image + params.default_image } else { &template.image }; if !image.is_empty() { container.insert("image".to_string(), serde_json::json!(image)); - if !image_pull_policy.is_empty() { + if !params.image_pull_policy.is_empty() { container.insert( "imagePullPolicy".to_string(), - serde_json::json!(image_pull_policy), + serde_json::json!(params.image_pull_policy), ); } } @@ -1100,43 +1264,52 @@ fn sandbox_template_to_k8s( None, &template.environment, spec_environment, - sandbox_id, - sandbox_name, - grpc_endpoint, - ssh_socket_path, - ssh_handshake_secret, - ssh_handshake_skew_secs, - !client_tls_secret_name.is_empty(), + params.sandbox_id, + params.sandbox_name, + params.grpc_endpoint, + params.ssh_socket_path, + !params.client_tls_secret_name.is_empty(), ); container.insert("env".to_string(), serde_json::Value::Array(env)); - // The sandbox process needs SYS_ADMIN (for seccomp filter installation and - // network namespace creation), NET_ADMIN (for network namespace veth setup), - // SYS_PTRACE (for the CONNECT proxy to read /proc//fd/ of sandbox-user - // processes to resolve binary identity for network policy enforcement), - // and SYSLOG (for reading /dev/kmsg to surface bypass detection diagnostics). - // This mirrors the capabilities used by `mise run sandbox`. + let mut capabilities: Vec<&str> = vec!["SYS_ADMIN", "NET_ADMIN", "SYS_PTRACE", "SYSLOG"]; + if use_user_namespaces { + // In a user namespace the bounding set is reset. SETUID/SETGID are + // needed for the supervisor to drop privileges to the sandbox user. + // DAC_READ_SEARCH is needed for cross-UID /proc//fd/ access + // for process identity resolution in network policy enforcement. + capabilities.extend(["SETUID", "SETGID", "DAC_READ_SEARCH"]); + } container.insert( "securityContext".to_string(), serde_json::json!({ "capabilities": { - "add": ["SYS_ADMIN", "NET_ADMIN", "SYS_PTRACE", "SYSLOG"] + "add": capabilities } }), ); - // Mount client TLS secret for mTLS to the server. - if !client_tls_secret_name.is_empty() { - container.insert( - "volumeMounts".to_string(), - serde_json::json!([{ - "name": "openshell-client-tls", - "mountPath": "/etc/openshell-tls/client", - "readOnly": true - }]), - ); - } + // Mount client TLS secret for mTLS to the server, plus the projected + // ServiceAccount token used to bootstrap the sandbox's gateway JWT + // via `IssueSandboxToken`. + let mut volume_mounts: Vec = Vec::new(); + if !params.client_tls_secret_name.is_empty() { + volume_mounts.push(serde_json::json!({ + "name": "openshell-client-tls", + "mountPath": "/etc/openshell-tls/client", + "readOnly": true + })); + } + volume_mounts.push(serde_json::json!({ + "name": "openshell-sa-token", + "mountPath": "/var/run/secrets/openshell", + "readOnly": true, + })); + container.insert( + "volumeMounts".to_string(), + serde_json::Value::Array(volume_mounts), + ); if let Some(resources) = container_resources(template, gpu) { container.insert("resources".to_string(), resources); @@ -1148,22 +1321,38 @@ fn sandbox_template_to_k8s( // Add TLS secret volume. Mode 0400 (owner-read) prevents the // unprivileged sandbox user from reading the mTLS private key. - if !client_tls_secret_name.is_empty() { - spec.insert( - "volumes".to_string(), - serde_json::json!([{ - "name": "openshell-client-tls", - "secret": { "secretName": client_tls_secret_name, "defaultMode": 256 } - }]), - ); - } + let mut volumes: Vec = Vec::new(); + if !params.client_tls_secret_name.is_empty() { + volumes.push(serde_json::json!({ + "name": "openshell-client-tls", + "secret": { "secretName": params.client_tls_secret_name, "defaultMode": 256 } + })); + } + // Projected ServiceAccountToken volume — kubelet writes a short-lived + // audience-bound JWT into /var/run/secrets/openshell/token and rotates + // it automatically. The supervisor exchanges this for a gateway-minted + // JWT via `IssueSandboxToken` once at startup. + volumes.push(serde_json::json!({ + "name": "openshell-sa-token", + "projected": { + "sources": [{ + "serviceAccountToken": { + "audience": "openshell-gateway", + "expirationSeconds": params.sa_token_ttl_secs, + "path": "token" + } + }], + "defaultMode": 256 + } + })); + spec.insert("volumes".to_string(), serde_json::Value::Array(volumes)); // Add hostAliases so sandbox pods can reach the Docker host. - if !host_gateway_ip.is_empty() { + if !params.host_gateway_ip.is_empty() { spec.insert( "hostAliases".to_string(), serde_json::json!([{ - "ip": host_gateway_ip, + "ip": params.host_gateway_ip, "hostnames": ["host.docker.internal", "host.openshell.internal"] }]), ); @@ -1177,15 +1366,18 @@ fn sandbox_template_to_k8s( let mut result = serde_json::Value::Object(template_value); - // Side-load the supervisor binary via an init container that copies it - // from the supervisor image into a shared emptyDir volume. - apply_supervisor_sideload(&mut result, supervisor_image, supervisor_image_pull_policy); + apply_supervisor_sideload( + &mut result, + params.supervisor_image, + params.supervisor_image_pull_policy, + params.supervisor_sideload_method, + ); // Inject workspace persistence (init container + PVC volume mount) so // that /sandbox data survives pod rescheduling. Skipped when the user // provides custom volumeClaimTemplates to avoid conflicts. if inject_workspace { - apply_workspace_persistence(&mut result, image, image_pull_policy); + apply_workspace_persistence(&mut result, image, params.image_pull_policy); } result @@ -1207,10 +1399,21 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Vec { let mut env = existing_env.cloned().unwrap_or_default(); @@ -1265,8 +1466,6 @@ fn build_env_list( sandbox_name, grpc_endpoint, ssh_socket_path, - ssh_handshake_secret, - ssh_handshake_skew_secs, tls_enabled, ); env @@ -1283,45 +1482,56 @@ fn apply_env_map( // Required env vars are passed individually for clarity at call sites; grouping into a struct // would not improve readability for this internal helper. -#[allow(clippy::too_many_arguments)] fn apply_required_env( env: &mut Vec, sandbox_id: &str, sandbox_name: &str, grpc_endpoint: &str, ssh_socket_path: &str, - ssh_handshake_secret: &str, - ssh_handshake_skew_secs: u64, tls_enabled: bool, ) { - upsert_env(env, "OPENSHELL_SANDBOX_ID", sandbox_id); - upsert_env(env, "OPENSHELL_SANDBOX", sandbox_name); - upsert_env(env, "OPENSHELL_ENDPOINT", grpc_endpoint); - upsert_env(env, "OPENSHELL_SANDBOX_COMMAND", "sleep infinity"); - if !ssh_socket_path.is_empty() { - upsert_env(env, "OPENSHELL_SSH_SOCKET_PATH", ssh_socket_path); - } - upsert_env(env, "OPENSHELL_SSH_HANDSHAKE_SECRET", ssh_handshake_secret); + upsert_env(env, openshell_core::sandbox_env::SANDBOX_ID, sandbox_id); + upsert_env(env, openshell_core::sandbox_env::SANDBOX, sandbox_name); + upsert_env(env, openshell_core::sandbox_env::ENDPOINT, grpc_endpoint); upsert_env( env, - "OPENSHELL_SSH_HANDSHAKE_SKEW_SECS", - &ssh_handshake_skew_secs.to_string(), + openshell_core::sandbox_env::SANDBOX_COMMAND, + "sleep infinity", ); + if !ssh_socket_path.is_empty() { + upsert_env( + env, + openshell_core::sandbox_env::SSH_SOCKET_PATH, + ssh_socket_path, + ); + } // TLS cert paths for sandbox-to-server mTLS. Only set when TLS is enabled // and the client TLS secret is mounted into the sandbox pod. if tls_enabled { - upsert_env(env, "OPENSHELL_TLS_CA", "/etc/openshell-tls/client/ca.crt"); upsert_env( env, - "OPENSHELL_TLS_CERT", + openshell_core::sandbox_env::TLS_CA, + "/etc/openshell-tls/client/ca.crt", + ); + upsert_env( + env, + openshell_core::sandbox_env::TLS_CERT, "/etc/openshell-tls/client/tls.crt", ); upsert_env( env, - "OPENSHELL_TLS_KEY", + openshell_core::sandbox_env::TLS_KEY, "/etc/openshell-tls/client/tls.key", ); } + // Projected ServiceAccount token written by kubelet (see the volume + // definition in `sandbox_template_to_k8s`). The supervisor reads this + // and exchanges it for a gateway-minted JWT via `IssueSandboxToken`. + upsert_env( + env, + openshell_core::sandbox_env::K8S_SA_TOKEN_FILE, + "/var/run/secrets/openshell/token", + ); } fn upsert_env(env: &mut Vec, name: &str, value: &str) { @@ -1346,6 +1556,15 @@ fn platform_config_string(template: &SandboxTemplate, key: &str) -> Option Option { + let config = template.platform_config.as_ref()?; + let value = config.fields.get(key)?; + match value.kind.as_ref() { + Some(prost_types::value::Kind::BoolValue(b)) => Some(*b), + _ => None, + } +} + /// Extract a nested Struct value from the template's `platform_config`, /// converting it to `serde_json::Value`. fn platform_config_struct(template: &SandboxTemplate, key: &str) -> Option { @@ -1449,31 +1668,55 @@ fn condition_from_value(value: &serde_json::Value) -> Option { #[cfg(test)] mod tests { use super::*; + use openshell_core::progress::{ + PROGRESS_ACTIVE_DETAIL_KEY, PROGRESS_ACTIVE_STEP_KEY, PROGRESS_COMPLETE_LABEL_KEY, + PROGRESS_COMPLETE_STEP_KEY, + }; use prost_types::{Struct, Value, value::Kind}; #[test] - fn apply_required_env_always_injects_ssh_handshake_secret() { - let mut env = Vec::new(); - apply_required_env( - &mut env, - "sandbox-1", - "my-sandbox", - "https://endpoint:8080", - "0.0.0.0:2222", - "my-secret-value", - 300, - true, + fn kube_pulling_event_adds_image_progress_metadata() { + let mut metadata = std::collections::HashMap::new(); + + attach_kube_progress_metadata( + &mut metadata, + "Pulling", + "Pulling image \"ghcr.io/acme/sandbox:latest\"", ); - let secret_entry = env - .iter() - .find(|e| { - e.get("name").and_then(|v| v.as_str()) == Some("OPENSHELL_SSH_HANDSHAKE_SECRET") - }) - .expect("OPENSHELL_SSH_HANDSHAKE_SECRET must be present in env"); assert_eq!( - secret_entry.get("value").and_then(|v| v.as_str()), - Some("my-secret-value") + metadata.get(PROGRESS_ACTIVE_STEP_KEY).map(String::as_str), + Some(PROGRESS_STEP_PULLING_IMAGE) + ); + assert_eq!( + metadata.get(PROGRESS_ACTIVE_DETAIL_KEY).map(String::as_str), + Some("ghcr.io/acme/sandbox:latest") + ); + } + + #[test] + fn kube_pulled_event_adds_completed_image_progress_metadata() { + let mut metadata = std::collections::HashMap::new(); + + attach_kube_progress_metadata( + &mut metadata, + "Pulled", + "Successfully pulled image \"ghcr.io/acme/sandbox:latest\". Image size: 44040192 bytes.", + ); + + assert_eq!( + metadata.get(PROGRESS_COMPLETE_STEP_KEY).map(String::as_str), + Some(PROGRESS_STEP_PULLING_IMAGE) + ); + assert_eq!( + metadata + .get(PROGRESS_COMPLETE_LABEL_KEY) + .map(String::as_str), + Some("Image pulled (42 MB)") + ); + assert_eq!( + metadata.get(PROGRESS_ACTIVE_STEP_KEY).map(String::as_str), + Some(PROGRESS_STEP_STARTING_SANDBOX) ); } @@ -1493,7 +1736,12 @@ mod tests { } }); - apply_supervisor_sideload(&mut pod_template, "custom-image:latest", "IfNotPresent"); + apply_supervisor_sideload( + &mut pod_template, + "custom-image:latest", + "IfNotPresent", + SupervisorSideloadMethod::InitContainer, + ); let sc = &pod_template["spec"]["containers"][0]["securityContext"]; assert_eq!(sc["runAsUser"], 0, "runAsUser must be 0 for supervisor"); @@ -1517,7 +1765,12 @@ mod tests { } }); - apply_supervisor_sideload(&mut pod_template, "supervisor-image:latest", "IfNotPresent"); + apply_supervisor_sideload( + &mut pod_template, + "supervisor-image:latest", + "IfNotPresent", + SupervisorSideloadMethod::InitContainer, + ); let sc = &pod_template["spec"]["containers"][0]["securityContext"]; assert_eq!( @@ -1537,7 +1790,12 @@ mod tests { } }); - apply_supervisor_sideload(&mut pod_template, "supervisor-image:latest", "IfNotPresent"); + apply_supervisor_sideload( + &mut pod_template, + "supervisor-image:latest", + "IfNotPresent", + SupervisorSideloadMethod::InitContainer, + ); // Volume should be an emptyDir let volumes = pod_template["spec"]["volumes"] @@ -1559,8 +1817,8 @@ mod tests { assert_eq!(init_containers[0]["image"], "supervisor-image:latest"); assert_eq!(init_containers[0]["imagePullPolicy"], "IfNotPresent"); - // The supervisor image ships only the binary (no shell). The init - // container must invoke the binary directly with `copy-self `. + // The init container must invoke the binary directly with + // `copy-self ` rather than depending on shell utilities. let init_command = init_containers[0]["command"] .as_array() .expect("init container command should be set"); @@ -1573,7 +1831,7 @@ mod tests { ); assert!( !init_command.iter().any(|v| v == "sh"), - "init container must not depend on a shell (supervisor image ships only the binary)" + "init container must not depend on a shell" ); // Agent container command should be overridden to the emptyDir path @@ -1595,75 +1853,145 @@ mod tests { assert_eq!(mounts[0]["readOnly"], true); } - /// Regression test: TLS mount path must match env var paths. - /// The volume is mounted at a specific path and the env vars must point to - /// files within that same path, otherwise the sandbox will fail to start - /// with "No such file or directory" errors. #[test] - fn tls_env_vars_match_volume_mount_path() { - // The mount path used in pod template construction - const TLS_MOUNT_PATH: &str = "/etc/openshell-tls/client"; + fn supervisor_sideload_image_volume_injects_image_source_without_init_container() { + let mut pod_template = serde_json::json!({ + "spec": { + "containers": [{ + "name": "agent", + "image": "custom-image:latest" + }] + } + }); - // Build env with TLS enabled - let mut env = Vec::new(); - apply_required_env( - &mut env, - "sandbox-1", - "my-sandbox", - "https://endpoint:8080", - "0.0.0.0:2222", - "secret", - 300, - true, // tls_enabled + apply_supervisor_sideload( + &mut pod_template, + "supervisor-image:latest", + "IfNotPresent", + SupervisorSideloadMethod::ImageVolume, ); - // Extract the TLS-related env vars - let get_env = |name: &str| -> Option { - env.iter() - .find(|e| e.get("name").and_then(|v| v.as_str()) == Some(name)) - .and_then(|e| e.get("value").and_then(|v| v.as_str()).map(String::from)) - }; - - let tls_ca = get_env("OPENSHELL_TLS_CA").expect("OPENSHELL_TLS_CA must be set"); - let tls_cert = get_env("OPENSHELL_TLS_CERT").expect("OPENSHELL_TLS_CERT must be set"); - let tls_key = get_env("OPENSHELL_TLS_KEY").expect("OPENSHELL_TLS_KEY must be set"); - - // All TLS paths must be within the mount path - assert!( - tls_ca.starts_with(TLS_MOUNT_PATH), - "OPENSHELL_TLS_CA path '{tls_ca}' must start with mount path '{TLS_MOUNT_PATH}'" - ); + let volumes = pod_template["spec"]["volumes"] + .as_array() + .expect("volumes should exist"); + assert_eq!(volumes.len(), 1); + assert_eq!(volumes[0]["name"], SUPERVISOR_VOLUME_NAME); + assert_eq!(volumes[0]["image"]["reference"], "supervisor-image:latest"); + assert_eq!(volumes[0]["image"]["pullPolicy"], "IfNotPresent"); assert!( - tls_cert.starts_with(TLS_MOUNT_PATH), - "OPENSHELL_TLS_CERT path '{tls_cert}' must start with mount path '{TLS_MOUNT_PATH}'" + volumes[0]["emptyDir"].is_null(), + "image volume method must not use emptyDir" ); + assert!( - tls_key.starts_with(TLS_MOUNT_PATH), - "OPENSHELL_TLS_KEY path '{tls_key}' must start with mount path '{TLS_MOUNT_PATH}'" + pod_template["spec"]["initContainers"].is_null(), + "image volume method must not inject init containers" ); - } - #[test] - fn gpu_sandbox_adds_runtime_class_and_gpu_limit() { - let pod_template = sandbox_template_to_k8s( - &SandboxTemplate::default(), - true, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", + let command = pod_template["spec"]["containers"][0]["command"] + .as_array() + .expect("command should be set"); + assert_eq!( + command[0].as_str().unwrap(), + format!("{SUPERVISOR_MOUNT_PATH}/openshell-sandbox") + ); + + let sc = &pod_template["spec"]["containers"][0]["securityContext"]; + assert_eq!(sc["runAsUser"], 0); + + let mounts = pod_template["spec"]["containers"][0]["volumeMounts"] + .as_array() + .expect("volumeMounts should exist"); + assert_eq!(mounts[0]["name"], SUPERVISOR_VOLUME_NAME); + assert_eq!(mounts[0]["mountPath"], SUPERVISOR_MOUNT_PATH); + assert_eq!(mounts[0]["readOnly"], true); + } + + #[test] + fn supervisor_image_volume_omits_pull_policy_when_empty() { + let mut pod_template = serde_json::json!({ + "spec": { + "containers": [{ + "name": "agent", + "image": "custom-image:latest" + }] + } + }); + + apply_supervisor_sideload( + &mut pod_template, + "supervisor-image:latest", "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", + SupervisorSideloadMethod::ImageVolume, + ); + + let volume = &pod_template["spec"]["volumes"][0]; + assert_eq!(volume["image"]["reference"], "supervisor-image:latest"); + assert!( + volume["image"].get("pullPolicy").is_none(), + "pullPolicy should be omitted when empty" + ); + } + + /// Regression test: TLS mount path must match env var paths. + /// The volume is mounted at a specific path and the env vars must point to + /// files within that same path, otherwise the sandbox will fail to start + /// with "No such file or directory" errors. + #[test] + fn tls_env_vars_match_volume_mount_path() { + // The mount path used in pod template construction + const TLS_MOUNT_PATH: &str = "/etc/openshell-tls/client"; + + // Build env with TLS enabled + let mut env = Vec::new(); + apply_required_env( + &mut env, + "sandbox-1", + "my-sandbox", + "https://endpoint:8080", "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, + true, // tls_enabled ); + // Extract the TLS-related env vars + let get_env = |name: &str| -> Option { + env.iter() + .find(|e| e.get("name").and_then(|v| v.as_str()) == Some(name)) + .and_then(|e| e.get("value").and_then(|v| v.as_str()).map(String::from)) + }; + + let tls_ca = get_env("OPENSHELL_TLS_CA").expect("OPENSHELL_TLS_CA must be set"); + let tls_cert = get_env("OPENSHELL_TLS_CERT").expect("OPENSHELL_TLS_CERT must be set"); + let tls_key = get_env("OPENSHELL_TLS_KEY").expect("OPENSHELL_TLS_KEY must be set"); + + // All TLS paths must be within the mount path + assert!( + tls_ca.starts_with(TLS_MOUNT_PATH), + "OPENSHELL_TLS_CA path '{tls_ca}' must start with mount path '{TLS_MOUNT_PATH}'" + ); + assert!( + tls_cert.starts_with(TLS_MOUNT_PATH), + "OPENSHELL_TLS_CERT path '{tls_cert}' must start with mount path '{TLS_MOUNT_PATH}'" + ); + assert!( + tls_key.starts_with(TLS_MOUNT_PATH), + "OPENSHELL_TLS_KEY path '{tls_key}' must start with mount path '{TLS_MOUNT_PATH}'" + ); + } + + #[test] + fn gpu_sandbox_adds_runtime_class_and_gpu_limit() { + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + true, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + assert_eq!( pod_template["spec"]["runtimeClassName"], serde_json::Value::Null @@ -1689,24 +2017,16 @@ mod tests { ..SandboxTemplate::default() }; - let pod_template = sandbox_template_to_k8s( - &template, - true, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + true, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; assert_eq!( pod_template["spec"]["runtimeClassName"], @@ -1729,24 +2049,16 @@ mod tests { ..SandboxTemplate::default() }; - let pod_template = sandbox_template_to_k8s( - &template, - false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; assert_eq!( pod_template["spec"]["runtimeClassName"], @@ -1765,24 +2077,16 @@ mod tests { ..SandboxTemplate::default() }; - let pod_template = sandbox_template_to_k8s( - &template, - true, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + true, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; let limits = &pod_template["spec"]["containers"][0]["resources"]["limits"]; assert_eq!(limits["cpu"], serde_json::json!("2")); @@ -1792,26 +2096,51 @@ mod tests { ); } + #[test] + fn cpu_and_memory_limits_are_mirrored_to_requests() { + use openshell_core::proto::compute::v1::DriverResourceRequirements; + let template = SandboxTemplate { + resources: Some(DriverResourceRequirements { + cpu_limit: "500m".to_string(), + memory_limit: "2Gi".to_string(), + ..Default::default() + }), + ..SandboxTemplate::default() + }; + + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + let resources = &pod_template["spec"]["containers"][0]["resources"]; + assert_eq!(resources["limits"]["cpu"], serde_json::json!("500m")); + assert_eq!(resources["limits"]["memory"], serde_json::json!("2Gi")); + assert_eq!(resources["requests"]["cpu"], serde_json::json!("500m")); + assert_eq!(resources["requests"]["memory"], serde_json::json!("2Gi")); + } + #[test] fn host_aliases_injected_when_gateway_ip_set() { - let pod_template = sandbox_template_to_k8s( - &SandboxTemplate::default(), - false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "172.17.0.1", - true, - ); + let pod_template = { + let params = SandboxPodParams { + host_gateway_ip: "172.17.0.1", + ..Default::default() + }; + sandbox_template_to_k8s( + &SandboxTemplate::default(), + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; let host_aliases = pod_template["spec"]["hostAliases"] .as_array() @@ -1827,24 +2156,16 @@ mod tests { #[test] fn host_aliases_not_injected_when_gateway_ip_empty() { - let pod_template = sandbox_template_to_k8s( - &SandboxTemplate::default(), - false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; assert!( pod_template["spec"]["hostAliases"].is_null(), @@ -1855,24 +2176,19 @@ mod tests { #[test] fn tls_secret_volume_uses_restrictive_default_mode() { let template = SandboxTemplate::default(); - let pod_template = sandbox_template_to_k8s( - &template, - false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "my-tls-secret", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams { + client_tls_secret_name: "my-tls-secret", + ..Default::default() + }; + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; let volumes = pod_template["spec"]["volumes"] .as_array() @@ -2000,23 +2316,16 @@ mod tests { #[test] fn workspace_persistence_skipped_when_inject_workspace_false() { + let params = SandboxPodParams { + supervisor_sideload_method: SupervisorSideloadMethod::InitContainer, + ..SandboxPodParams::default() + }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, &std::collections::HashMap::new(), - "", - "", false, // user provided custom VCTs + ¶ms, ); // Only the supervisor init container should be present — no workspace init container @@ -2039,4 +2348,356 @@ mod tests { "workspace mount must NOT be present when inject_workspace is false" ); } + + // ----------------------------------------------------------------------- + // User namespace tests + // ----------------------------------------------------------------------- + + fn default_template_to_k8s(enable_user_namespaces: bool) -> serde_json::Value { + let params = SandboxPodParams { + enable_user_namespaces, + ..Default::default() + }; + sandbox_template_to_k8s( + &SandboxTemplate::default(), + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + } + + #[test] + fn user_namespaces_disabled_by_default() { + let pod_template = default_template_to_k8s(false); + assert!( + pod_template["spec"]["hostUsers"].is_null(), + "hostUsers must not be set when user namespaces are disabled" + ); + let caps = pod_template["spec"]["containers"][0]["securityContext"]["capabilities"]["add"] + .as_array() + .unwrap(); + assert_eq!(caps.len(), 4); + assert!(!caps.contains(&serde_json::json!("SETUID"))); + } + + #[test] + fn user_namespaces_enabled_by_cluster_default() { + let pod_template = default_template_to_k8s(true); + assert_eq!( + pod_template["spec"]["hostUsers"], + serde_json::json!(false), + "hostUsers must be false when user namespaces are enabled" + ); + } + + #[test] + fn user_namespaces_adds_extra_capabilities() { + let pod_template = default_template_to_k8s(true); + let caps = pod_template["spec"]["containers"][0]["securityContext"]["capabilities"]["add"] + .as_array() + .unwrap(); + assert!(caps.contains(&serde_json::json!("SYS_ADMIN"))); + assert!(caps.contains(&serde_json::json!("NET_ADMIN"))); + assert!(caps.contains(&serde_json::json!("SYS_PTRACE"))); + assert!(caps.contains(&serde_json::json!("SYSLOG"))); + assert!(caps.contains(&serde_json::json!("SETUID"))); + assert!(caps.contains(&serde_json::json!("SETGID"))); + assert!(caps.contains(&serde_json::json!("DAC_READ_SEARCH"))); + assert_eq!(caps.len(), 7); + } + + #[test] + fn user_namespaces_per_sandbox_override_enables() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "host_users".to_string(), + Value { + kind: Some(Kind::BoolValue(false)), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + let params = SandboxPodParams::default(); // cluster default is off + let pod_template = sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ); + + assert_eq!( + pod_template["spec"]["hostUsers"], + serde_json::json!(false), + "per-sandbox host_users: false must enable user namespaces" + ); + let caps = pod_template["spec"]["containers"][0]["securityContext"]["capabilities"]["add"] + .as_array() + .unwrap(); + assert!(caps.contains(&serde_json::json!("SETUID"))); + } + + #[test] + fn user_namespaces_per_sandbox_override_disables() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "host_users".to_string(), + Value { + kind: Some(Kind::BoolValue(true)), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + let params = SandboxPodParams { + enable_user_namespaces: true, // cluster default is on + ..Default::default() + }; + let pod_template = sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ); + + assert!( + pod_template["spec"]["hostUsers"].is_null(), + "per-sandbox host_users: true must disable user namespaces even when cluster default is on" + ); + let caps = pod_template["spec"]["containers"][0]["securityContext"]["capabilities"]["add"] + .as_array() + .unwrap(); + assert_eq!( + caps.len(), + 4, + "extra capabilities must not be added when user namespaces are disabled" + ); + } + + #[test] + fn automount_service_account_token_is_disabled() { + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["automountServiceAccountToken"], + serde_json::json!(false), + "service account token auto-mounting must be disabled for security hardening" + ); + } + + #[test] + fn sandbox_template_sets_configured_service_account_name() { + let params = SandboxPodParams { + service_account_name: "openshell-sandbox", + ..Default::default() + }; + let pod_template = sandbox_template_to_k8s( + &SandboxTemplate::default(), + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ); + + assert_eq!( + pod_template["spec"]["serviceAccountName"], + serde_json::json!("openshell-sandbox"), + "sandbox pods must run under the configured service account" + ); + assert_eq!( + pod_template["spec"]["automountServiceAccountToken"], + serde_json::json!(false), + "explicit service account selection must not re-enable default token automounting" + ); + } + + #[test] + fn platform_config_bool_extracts_value() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "my_bool".to_string(), + Value { + kind: Some(Kind::BoolValue(true)), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + assert_eq!(platform_config_bool(&template, "my_bool"), Some(true)); + assert_eq!(platform_config_bool(&template, "missing"), None); + } + + #[test] + fn platform_config_bool_returns_none_for_non_bool() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "a_string".to_string(), + Value { + kind: Some(Kind::StringValue("hello".to_string())), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + assert_eq!(platform_config_bool(&template, "a_string"), None); + } + + #[test] + fn log_level_propagates_as_env_var_to_sandbox_pod() { + let spec = SandboxSpec { + log_level: "debug".to_string(), + ..SandboxSpec::default() + }; + let cr = sandbox_to_k8s_spec(Some(&spec), &SandboxPodParams::default()); + let env = cr["spec"]["podTemplate"]["spec"]["containers"][0]["env"] + .as_array() + .unwrap(); + assert!( + env.iter() + .any(|e| e["name"] == "OPENSHELL_LOG_LEVEL" && e["value"] == "debug") + ); + assert!(cr["spec"].get("logLevel").is_none()); + } + + #[test] + fn node_selector_from_platform_config() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "node_selector".to_string(), + Value { + kind: Some(Kind::StructValue(Struct { + fields: std::iter::once(( + "gpu-pool".to_string(), + Value { + kind: Some(Kind::StringValue("true".to_string())), + }, + )) + .collect(), + })), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + false, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["nodeSelector"]["gpu-pool"], + serde_json::json!("true") + ); + } + + #[test] + fn tolerations_from_platform_config() { + let toleration = Struct { + fields: [ + ( + "key".to_string(), + Value { + kind: Some(Kind::StringValue("nvidia.com/gpu".to_string())), + }, + ), + ( + "operator".to_string(), + Value { + kind: Some(Kind::StringValue("Exists".to_string())), + }, + ), + ( + "effect".to_string(), + Value { + kind: Some(Kind::StringValue("NoSchedule".to_string())), + }, + ), + ] + .into_iter() + .collect(), + }; + + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "tolerations".to_string(), + Value { + kind: Some(Kind::ListValue(prost_types::ListValue { + values: vec![Value { + kind: Some(Kind::StructValue(toleration)), + }], + })), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + false, + ¶ms, + ) + }; + + let tolerations = pod_template["spec"]["tolerations"] + .as_array() + .expect("tolerations should be an array"); + assert_eq!(tolerations.len(), 1); + assert_eq!(tolerations[0]["key"], "nvidia.com/gpu"); + assert_eq!(tolerations[0]["operator"], "Exists"); + assert_eq!(tolerations[0]["effect"], "NoSchedule"); + } + + #[test] + fn default_workspace_vct_uses_provided_storage_size() { + let vct = default_workspace_volume_claim_templates("5Gi"); + let storage = &vct[0]["spec"]["resources"]["requests"]["storage"]; + assert_eq!(storage, "5Gi"); + } + + #[test] + fn default_workspace_vct_falls_back_to_const_when_empty() { + let vct = default_workspace_volume_claim_templates(""); + let storage = &vct[0]["spec"]["resources"]["requests"]["storage"]; + assert_eq!(storage, DEFAULT_WORKSPACE_STORAGE_SIZE); + } } diff --git a/crates/openshell-driver-kubernetes/src/grpc.rs b/crates/openshell-driver-kubernetes/src/grpc.rs index 51488f694..03f8bed7f 100644 --- a/crates/openshell-driver-kubernetes/src/grpc.rs +++ b/crates/openshell-driver-kubernetes/src/grpc.rs @@ -14,7 +14,7 @@ use openshell_core::proto::compute::v1::{ use std::pin::Pin; use tonic::{Request, Response, Status}; -use crate::{KubernetesComputeDriver, KubernetesDriverError}; +use crate::KubernetesComputeDriver; #[derive(Debug, Clone)] pub struct ComputeDriverService { @@ -103,7 +103,7 @@ impl ComputeDriver for ComputeDriverService { self.driver .create_sandbox(&sandbox) .await - .map_err(status_from_driver_error)?; + .map_err(|e| Status::from(openshell_core::ComputeDriverError::from(e)))?; Ok(Response::new(CreateSandboxResponse {})) } @@ -146,23 +146,18 @@ impl ComputeDriver for ComputeDriverService { } } -fn status_from_driver_error(err: KubernetesDriverError) -> Status { - match err { - KubernetesDriverError::AlreadyExists => Status::already_exists("sandbox already exists"), - KubernetesDriverError::Precondition(message) => Status::failed_precondition(message), - KubernetesDriverError::Message(message) => Status::internal(message), - } -} - #[cfg(test)] mod tests { - use super::*; + use crate::KubernetesDriverError; + use openshell_core::ComputeDriverError; + use tonic::Status; #[test] fn precondition_driver_errors_map_to_failed_precondition_status() { - let status = status_from_driver_error(KubernetesDriverError::Precondition( + let status: Status = ComputeDriverError::from(KubernetesDriverError::Precondition( "sandbox agent pod IP is not available".to_string(), - )); + )) + .into(); assert_eq!(status.code(), tonic::Code::FailedPrecondition); assert_eq!(status.message(), "sandbox agent pod IP is not available"); @@ -170,7 +165,7 @@ mod tests { #[test] fn already_exists_driver_errors_map_to_already_exists_status() { - let status = status_from_driver_error(KubernetesDriverError::AlreadyExists); + let status: Status = ComputeDriverError::from(KubernetesDriverError::AlreadyExists).into(); assert_eq!(status.code(), tonic::Code::AlreadyExists); assert_eq!(status.message(), "sandbox already exists"); diff --git a/crates/openshell-driver-kubernetes/src/lib.rs b/crates/openshell-driver-kubernetes/src/lib.rs index 54149fe83..b0a5ca957 100644 --- a/crates/openshell-driver-kubernetes/src/lib.rs +++ b/crates/openshell-driver-kubernetes/src/lib.rs @@ -5,6 +5,9 @@ pub mod config; pub mod driver; pub mod grpc; -pub use config::KubernetesComputeConfig; +pub use config::{ + DEFAULT_SANDBOX_SERVICE_ACCOUNT_NAME, DEFAULT_WORKSPACE_STORAGE_SIZE, KubernetesComputeConfig, + SupervisorSideloadMethod, +}; pub use driver::{KubernetesComputeDriver, KubernetesDriverError}; pub use grpc::ComputeDriverService; diff --git a/crates/openshell-driver-kubernetes/src/main.rs b/crates/openshell-driver-kubernetes/src/main.rs index 26d323f56..703659af3 100644 --- a/crates/openshell-driver-kubernetes/src/main.rs +++ b/crates/openshell-driver-kubernetes/src/main.rs @@ -10,7 +10,8 @@ use tracing_subscriber::EnvFilter; use openshell_core::VERSION; use openshell_core::proto::compute::v1::compute_driver_server::ComputeDriverServer; use openshell_driver_kubernetes::{ - ComputeDriverService, KubernetesComputeConfig, KubernetesComputeDriver, + ComputeDriverService, DEFAULT_SANDBOX_SERVICE_ACCOUNT_NAME, KubernetesComputeConfig, + KubernetesComputeDriver, SupervisorSideloadMethod, }; #[derive(Parser, Debug)] @@ -30,6 +31,13 @@ struct Args { #[arg(long, env = "OPENSHELL_SANDBOX_NAMESPACE", default_value = "default")] sandbox_namespace: String, + #[arg( + long, + env = "OPENSHELL_K8S_SANDBOX_SERVICE_ACCOUNT", + default_value = DEFAULT_SANDBOX_SERVICE_ACCOUNT_NAME + )] + sandbox_service_account: String, + #[arg(long, env = "OPENSHELL_SANDBOX_IMAGE")] sandbox_image: Option, @@ -46,12 +54,6 @@ struct Args { )] sandbox_ssh_socket_path: String, - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SECRET")] - ssh_handshake_secret: String, - - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SKEW_SECS", default_value_t = 300)] - ssh_handshake_skew_secs: u64, - #[arg(long, env = "OPENSHELL_CLIENT_TLS_SECRET_NAME")] client_tls_secret_name: Option, @@ -63,6 +65,23 @@ struct Args { #[arg(long, env = "OPENSHELL_SUPERVISOR_IMAGE_PULL_POLICY")] supervisor_image_pull_policy: Option, + + #[arg( + long, + env = "OPENSHELL_SUPERVISOR_SIDELOAD_METHOD", + default_value = "image-volume" + )] + supervisor_sideload_method: SupervisorSideloadMethod, + + #[arg(long, env = "OPENSHELL_ENABLE_USER_NAMESPACES")] + enable_user_namespaces: bool, + + /// Lifetime (seconds) of the projected `ServiceAccount` token + /// kubelet writes into each sandbox pod for the `IssueSandboxToken` + /// bootstrap exchange. Kubelet enforces a minimum of 600s; the + /// gateway clamps values outside `[600, 86400]`. Default 3600. + #[arg(long, env = "OPENSHELL_K8S_SA_TOKEN_TTL_SECS", default_value_t = 3600)] + sa_token_ttl_secs: i64, } #[tokio::main] @@ -76,18 +95,26 @@ async fn main() -> Result<()> { let driver = KubernetesComputeDriver::new(KubernetesComputeConfig { namespace: args.sandbox_namespace, + service_account_name: args.sandbox_service_account, default_image: args.sandbox_image.unwrap_or_default(), image_pull_policy: args.sandbox_image_pull_policy.unwrap_or_default(), supervisor_image: args .supervisor_image .unwrap_or_else(|| openshell_core::config::DEFAULT_SUPERVISOR_IMAGE.to_string()), supervisor_image_pull_policy: args.supervisor_image_pull_policy.unwrap_or_default(), + supervisor_sideload_method: args.supervisor_sideload_method, grpc_endpoint: args.grpc_endpoint.unwrap_or_default(), ssh_socket_path: args.sandbox_ssh_socket_path, - ssh_handshake_secret: args.ssh_handshake_secret, - ssh_handshake_skew_secs: args.ssh_handshake_skew_secs, client_tls_secret_name: args.client_tls_secret_name.unwrap_or_default(), host_gateway_ip: args.host_gateway_ip.unwrap_or_default(), + enable_user_namespaces: args.enable_user_namespaces, + workspace_default_storage_size: std::env::var( + "OPENSHELL_K8S_WORKSPACE_DEFAULT_STORAGE_SIZE", + ) + .unwrap_or_else(|_| { + openshell_driver_kubernetes::DEFAULT_WORKSPACE_STORAGE_SIZE.to_string() + }), + sa_token_ttl_secs: args.sa_token_ttl_secs, }) .await .into_diagnostic()?; diff --git a/crates/openshell-driver-podman/Cargo.toml b/crates/openshell-driver-podman/Cargo.toml index 51ac698de..6f2963d92 100644 --- a/crates/openshell-driver-podman/Cargo.toml +++ b/crates/openshell-driver-podman/Cargo.toml @@ -28,6 +28,7 @@ serde = { workspace = true } serde_json = { workspace = true } clap = { workspace = true } nix = { workspace = true } +rustix = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } thiserror = { workspace = true } diff --git a/crates/openshell-driver-podman/NETWORKING.md b/crates/openshell-driver-podman/NETWORKING.md index 87eef079c..2cb6e35d2 100644 --- a/crates/openshell-driver-podman/NETWORKING.md +++ b/crates/openshell-driver-podman/NETWORKING.md @@ -178,7 +178,7 @@ Namespace 2: Rootless Podman network namespace, managed by pasta Namespace 3: Inner sandbox netns, created by supervisor | veth pair, such as 10.200.0.1 <-> 10.200.0.2 - iptables forces ordinary traffic through proxy + nftables forces ordinary traffic through proxy user workload runs here ``` @@ -270,7 +270,7 @@ Container on the Podman bridge | user code runs here | - iptables rules: + nftables rules: ACCEPT -> proxy TCP ACCEPT -> loopback ACCEPT -> established/related @@ -337,7 +337,7 @@ User code in inner netns HTTP_PROXY points at the local sandbox proxy | 2. TCP connect to proxy - allowed by iptables as the only ordinary egress destination + allowed by nftables as the only ordinary egress destination | 3. HTTP CONNECT api.example.com:443 | @@ -357,9 +357,9 @@ Supervisor proxy in container netns The Podman driver auto-detects the callback endpoint scheme based on whether TLS client certificates are configured. When the RPM's auto-generated PKI is in -place, the endpoint is `https://host.containers.internal:8080` and the +place, the endpoint is `https://host.containers.internal:17670` and the supervisor connects with mTLS. Without TLS configuration, it falls back to -`http://host.containers.internal:8080`. +`http://host.containers.internal:`. ```text Supervisor in container netns @@ -382,10 +382,9 @@ Gateway 9. Same gRPC channel reused for RelayStream calls ``` -The gateway binds to `0.0.0.0` by default in the RPM packaging. mTLS prevents -unauthenticated access even though the gateway is reachable from the network. -Client certificates are auto-generated by `init-pki.sh` on first start and -bind-mounted into sandbox containers by the Podman driver. +The gateway binds to `127.0.0.1:17670` by default in the RPM packaging. Client +certificates are auto-generated by `openshell-gateway generate-certs` on first +start and bind-mounted into sandbox containers by the Podman driver. ## Differences from the Kubernetes Driver @@ -398,9 +397,9 @@ bind-mounted into sandbox containers by the Podman driver. | Port publishing | Not needed for relay | Ephemeral host port remains in the container spec for compatibility and debug paths. | | TLS | mTLS via Kubernetes secrets | mTLS via mounted client files, RPM defaults, or explicit configuration. | | DNS | Kubernetes CoreDNS | Podman bridge DNS through aardvark-dns when DNS is enabled. | -| Network policy | Kubernetes network policy for pod ingress plus supervisor policy | iptables inside inner sandbox netns plus supervisor policy. | +| Network policy | Kubernetes network policy for pod ingress plus supervisor policy | nftables inside inner sandbox netns plus supervisor policy. | | Supervisor delivery | Kubernetes driver managed pod image or template | OCI image volume mount. | -| Secrets | Kubernetes Secret volume and env vars | Podman `secret_env` for handshake secret, plus mounted TLS files. | +| Secrets | Kubernetes Secret volume and env vars | Mounted TLS client materials from a Podman secret. | Both drivers use the same reverse gRPC relay for SSH transport. The most important Podman-specific difference is network reachability: in rootless @@ -412,7 +411,7 @@ published ports, or the supervisor relay. | Port | Component | Purpose | |---|---|---| -| `8080` | Gateway | gRPC and HTTP multiplexed default server port. | +| `17670` | Gateway | Default local gRPC and HTTP multiplexed server port. | | `2222` | Sandbox | Container port mapping default for the SSH compatibility port. | | `3128` | Sandbox proxy | HTTP CONNECT proxy inside the sandbox network model. | | `0` | Host | Ephemeral host port requested for the container SSH compatibility port. | diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index d853bb5ea..dbf508c03 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -46,6 +46,7 @@ The container spec in `container.rs` sets these security-critical fields: | `no_new_privileges` | `true` | Prevents privilege escalation after exec. | | `seccomp_profile_path` | `unconfined` | The supervisor installs its own policy-aware BPF filter. A container-level profile can block Landlock/seccomp syscalls during setup. | | `mounts` | Private tmpfs at `/run/netns` | Lets the supervisor create named network namespaces in rootless Podman. | +| CDI GPU devices | Sandbox `gpu_device` value when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. | The restricted agent child does not retain these supervisor privileges. @@ -54,7 +55,7 @@ The restricted agent child does not retain these supervisor privileges. | Capability | Purpose | |---|---| | `SYS_ADMIN` | seccomp filter installation, namespace creation, and Landlock setup. | -| `NET_ADMIN` | Network namespace veth setup, IP address assignment, routes, and iptables. | +| `NET_ADMIN` | Network namespace veth setup, IP address assignment, routes, and nftables. | | `SYS_PTRACE` | Reading `/proc//exe` and walking process ancestry for binary identity. | | `SYSLOG` | Reading `/dev/kmsg` for bypass-detection diagnostics. | | `DAC_READ_SEARCH` | Reading `/proc//fd/` across UIDs so the proxy can resolve the binary responsible for a connection. | @@ -85,8 +86,8 @@ sequenceDiagram C->>C: entrypoint: /opt/openshell/bin/openshell-sandbox ``` -The `supervisor` target in `deploy/docker/Dockerfile.images` copies the -`openshell-sandbox` binary to `/openshell-sandbox` in the supervisor image. +The supervisor image from `deploy/docker/Dockerfile.supervisor` copies the static +`openshell-sandbox` binary to `/openshell-sandbox`. Mounting that image at `/opt/openshell/bin` makes the binary available as `/opt/openshell/bin/openshell-sandbox`. @@ -119,10 +120,10 @@ connection back to the gateway. On SELinux systems, the bind mounts include Podman's shared relabel option so the container process can read the files. The RPM packaging auto-generates a self-signed PKI on first start via -`init-pki.sh`. Client certs are placed in the CLI auto-discovery directory -(`~/.config/openshell/gateways/openshell/mtls/`) so the CLI connects with mTLS -without manual configuration. See `deploy/rpm/CONFIGURATION.md` for the full -RPM configuration reference. +`openshell-gateway generate-certs`. Client certs are placed in the CLI +auto-discovery directory (`~/.config/openshell/gateways/openshell/mtls/`) so +the CLI connects with mTLS without manual configuration. See +`deploy/rpm/CONFIGURATION.md` for the full RPM configuration reference. ## Network Model @@ -133,7 +134,7 @@ the supervisor for sandbox process isolation. ```mermaid graph TB subgraph Host - GW["Gateway Server
127.0.0.1:8080"] + GW["Gateway Server
127.0.0.1:17670"] PS["Podman Socket"] end @@ -196,12 +197,13 @@ The standalone `openshell-driver-podman` binary sets the same struct field from ## Credential Injection -The SSH handshake secret is injected via Podman's `secret_env` API rather than -a plaintext environment variable. +Sandboxes authenticate to the gateway via mTLS using client materials bind- +mounted into the container from a Podman secret. No shared per-request secret +is injected as an environment variable. | Credential | Mechanism | Visible in `inspect`? | Visible in `/proc//environ`? | |---|---|---|---| -| SSH handshake secret | Podman `secret_env`, created via secrets API and referenced by name | No | Yes, supervisor only, scrubbed from children | +| mTLS client cert/key | Bind-mounted file paths (`OPENSHELL_TLS_*` env vars point at them) | Yes (paths only) | Yes (paths only) | | Sandbox identity | Plaintext env var | Yes | Yes | | gRPC endpoint | Plaintext env var, override-protected | Yes | Yes | | Supervisor relay socket path | Plaintext env var, override-protected | Yes | Yes | @@ -214,13 +216,9 @@ via sandbox templates: - `OPENSHELL_SANDBOX_ID` - `OPENSHELL_ENDPOINT` - `OPENSHELL_SSH_SOCKET_PATH` -- `OPENSHELL_SSH_HANDSHAKE_SKEW_SECS` - `OPENSHELL_CONTAINER_IMAGE` - `OPENSHELL_SANDBOX_COMMAND` -The `PodmanComputeConfig::Debug` implementation redacts the handshake secret as -`[REDACTED]`. - ## Sandbox Lifecycle ### Creation Flow @@ -238,26 +236,23 @@ sequenceDiagram D->>P: pull_image(supervisor, "missing") D->>P: pull_image(sandbox_image, policy) - D->>P: create_secret(handshake) - Note over D: On failure below, rollback secret - D->>P: create_volume(workspace) - Note over D: On failure below, rollback volume + secret + Note over D: On failure below, rollback volume D->>P: create_container(spec) alt Conflict (409) - D->>P: remove_volume + remove_secret + D->>P: remove_volume D-->>GW: AlreadyExists end - Note over D: On failure below, rollback container + volume + secret + Note over D: On failure below, rollback container + volume D->>P: start_container D-->>GW: Ok ``` Each step rolls back previously-created resources on failure. The Conflict path -cleans up the volume and secret because they are keyed by the new sandbox's ID, -not the conflicting container's ID. +cleans up the volume because it is keyed by the new sandbox's ID, not the +conflicting container's ID. ### Readiness and Health @@ -280,11 +275,9 @@ the socket without the old marker or published-port signal. 4. Force-remove the container. 5. Remove workspace volume derived from the request `sandbox_id`, warning on failure and continuing. -6. Remove handshake secret derived from the request `sandbox_id`, warning on - failure and continuing. If the container is already gone during inspect or remove, the driver still -performs idempotent volume and secret cleanup using the request `sandbox_id` and +performs idempotent volume cleanup using the request `sandbox_id` and returns `Ok(false)` for the container-delete result. This prevents leaked Podman resources after out-of-band container removal or label drift. @@ -296,14 +289,11 @@ Podman resources after out-of-band container removal or label drift. | `OPENSHELL_SANDBOX_IMAGE` | `--sandbox-image` | From gateway config | Default OCI image for sandboxes. | | `OPENSHELL_SANDBOX_IMAGE_PULL_POLICY` | `--sandbox-image-pull-policy` | `missing` | Pull policy: `always`, `missing`, `never`, or `newer`. | | `OPENSHELL_GRPC_ENDPOINT` | `--grpc-endpoint` | Auto-detected via `host.containers.internal` | Gateway gRPC endpoint for sandbox callbacks. | -| `OPENSHELL_GATEWAY_PORT` | `--gateway-port` | `8080` | Gateway port used for endpoint auto-detection by the standalone binary. | +| `OPENSHELL_GATEWAY_PORT` | `--gateway-port` | `17670` | Gateway port used for endpoint auto-detection by the standalone binary. | | `OPENSHELL_NETWORK_NAME` | `--network-name` | `openshell` | Podman bridge network name. | -| `OPENSHELL_SANDBOX_SSH_PORT` | `--sandbox-ssh-port` | `2222` | SSH compatibility port inside the container. | -| `OPENSHELL_SSH_HANDSHAKE_SECRET` | `--ssh-handshake-secret` | Required standalone, gateway-generated in-process | Shared secret for the NSSH1 handshake. | -| `OPENSHELL_SSH_HANDSHAKE_SKEW_SECS` | `--ssh-handshake-skew-secs` | `300` | Allowed timestamp skew for SSH handshake validation. | -| `OPENSHELL_SANDBOX_SSH_SOCKET_PATH` | `--sandbox-ssh-socket-path` | `/run/openshell/ssh.sock` | Standalone driver only: supervisor Unix socket path in `PodmanComputeConfig`. In-gateway Podman uses server `config.sandbox_ssh_socket_path`. | +| `OPENSHELL_SANDBOX_SSH_SOCKET_PATH` | `--sandbox-ssh-socket-path` | `/run/openshell/ssh.sock` | Supervisor Unix socket path in `PodmanComputeConfig`. | | `OPENSHELL_STOP_TIMEOUT` | `--stop-timeout` | `10` | Container stop timeout in seconds. | -| `OPENSHELL_SUPERVISOR_IMAGE` | `--supervisor-image` | `openshell/supervisor:latest` through the gateway, required standalone | OCI image containing the supervisor binary. | +| `OPENSHELL_SUPERVISOR_IMAGE` | `--supervisor-image` | `ghcr.io/nvidia/openshell/supervisor:latest` through the gateway, required standalone | OCI image containing the supervisor binary. | | `OPENSHELL_PODMAN_TLS_CA` | `--podman-tls-ca` | unset | Host path to the CA certificate mounted for sandbox mTLS. | | `OPENSHELL_PODMAN_TLS_CERT` | `--podman-tls-cert` | unset | Host path to the client certificate mounted for sandbox mTLS. | | `OPENSHELL_PODMAN_TLS_KEY` | `--podman-tls-key` | unset | Host path to the client private key mounted for sandbox mTLS. | @@ -351,4 +341,4 @@ matter compared to cluster or rootful runtimes: netns, proxy, and relay behavior shared by all drivers. - Container engine abstraction: `tasks/scripts/container-engine.sh` for build/deploy support across Docker and Podman. -- Supervisor image build: `deploy/docker/Dockerfile.images`. +- Supervisor image build: `deploy/docker/Dockerfile.supervisor`. diff --git a/crates/openshell-driver-podman/src/client.rs b/crates/openshell-driver-podman/src/client.rs index 69bfd69c0..834fb21a2 100644 --- a/crates/openshell-driver-podman/src/client.rs +++ b/crates/openshell-driver-podman/src/client.rs @@ -381,23 +381,6 @@ impl PodmanClient { } } - /// Perform a versioned HTTP request with a raw byte body (not JSON). - async fn request_raw( - &self, - method: hyper::Method, - path: &str, - content_type: &str, - body: Bytes, - ) -> Result<(hyper::StatusCode, Bytes), PodmanApiError> { - let req = Self::build_request( - method, - &format!("/{API_VERSION}{path}"), - Full::new(body), - Some(content_type), - ); - self.send_request(req, API_TIMEOUT).await - } - /// POST a JSON body and ignore 409 Conflict (resource already exists). async fn create_ignore_conflict(&self, path: &str, body: &Value) -> Result<(), PodmanApiError> { match self @@ -550,64 +533,6 @@ impl PodmanClient { Ok(gateway) } - // ── Secret operations ──────────────────────────────────────────────── - - /// Create a Podman secret with the given name and raw value. - /// - /// Idempotent: if a secret with the same name already exists it is - /// replaced (delete + recreate) so the value is always up-to-date. - pub async fn create_secret(&self, name: &str, value: &[u8]) -> Result<(), PodmanApiError> { - validate_name(name)?; - let encoded_name = url_encode(name); - let path = format!("/libpod/secrets/create?name={encoded_name}"); - let (status, bytes) = self - .request_raw( - hyper::Method::POST, - &path, - "application/octet-stream", - Bytes::copy_from_slice(value), - ) - .await?; - - match status.as_u16() { - 200 | 201 => Ok(()), - 409 => { - // Secret already exists — replace it. - self.remove_secret(name).await?; - let (status2, bytes2) = self - .request_raw( - hyper::Method::POST, - &path, - "application/octet-stream", - Bytes::copy_from_slice(value), - ) - .await?; - if status2.is_success() { - Ok(()) - } else { - Err(error_from_response(status2.as_u16(), &bytes2)) - } - } - _ => Err(error_from_response(status.as_u16(), &bytes)), - } - } - - /// Remove a Podman secret by name. Idempotent (not-found is ignored). - pub async fn remove_secret(&self, name: &str) -> Result<(), PodmanApiError> { - validate_name(name)?; - match self - .request_ok( - hyper::Method::DELETE, - &format!("/libpod/secrets/{name}"), - None, - ) - .await - { - Ok(()) | Err(PodmanApiError::NotFound(_)) => Ok(()), - Err(e) => Err(e), - } - } - // ── Image operations ──────────────────────────────────────────────── /// Pull an image if it is not already present locally. diff --git a/crates/openshell-driver-podman/src/config.rs b/crates/openshell-driver-podman/src/config.rs index d82b8d0b0..270804882 100644 --- a/crates/openshell-driver-podman/src/config.rs +++ b/crates/openshell-driver-podman/src/config.rs @@ -1,13 +1,13 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use openshell_core::config::{ - DEFAULT_NETWORK_NAME, DEFAULT_SSH_HANDSHAKE_SKEW_SECS, DEFAULT_SSH_PORT, - DEFAULT_STOP_TIMEOUT_SECS, DEFAULT_SUPERVISOR_IMAGE, -}; +use openshell_core::config::{DEFAULT_STOP_TIMEOUT_SECS, DEFAULT_SUPERVISOR_IMAGE}; use std::path::PathBuf; use std::str::FromStr; +/// Default Podman bridge network name. +pub const DEFAULT_NETWORK_NAME: &str = "openshell"; + /// Image pull policy for sandbox and supervisor images. /// /// Controls when the Podman driver fetches a newer copy of an OCI image @@ -61,7 +61,8 @@ impl FromStr for ImagePullPolicy { } } -#[derive(Clone)] +#[derive(Clone, serde::Serialize, serde::Deserialize)] +#[serde(default, deny_unknown_fields)] pub struct PodmanComputeConfig { /// Path to the Podman API Unix socket. /// Default: `$XDG_RUNTIME_DIR/podman/podman.sock` (Linux), @@ -88,12 +89,6 @@ pub struct PodmanComputeConfig { /// Name of the Podman bridge network. /// Created automatically if it does not exist. pub network_name: String, - /// SSH port inside the container. - pub ssh_port: u16, - /// Shared secret for the NSSH1 SSH handshake. - pub ssh_handshake_secret: String, - /// Maximum clock skew in seconds for SSH handshake timestamps. - pub ssh_handshake_skew_secs: u64, /// Container stop timeout in seconds (SIGTERM → SIGKILL). pub stop_timeout_secs: u32, /// OCI image containing the openshell-sandbox supervisor binary. @@ -172,7 +167,7 @@ impl PodmanComputeConfig { { std::env::var("XDG_RUNTIME_DIR").map_or_else( |_| { - let uid = nix::unistd::getuid(); + let uid = rustix::process::getuid().as_raw(); PathBuf::from(format!("/run/user/{uid}/podman/podman.sock")) }, |xdg| PathBuf::from(xdg).join("podman/podman.sock"), @@ -185,15 +180,12 @@ impl Default for PodmanComputeConfig { fn default() -> Self { Self { socket_path: Self::default_socket_path(), - default_image: String::new(), + default_image: openshell_core::image::default_sandbox_image(), image_pull_policy: ImagePullPolicy::default(), grpc_endpoint: String::new(), gateway_port: openshell_core::config::DEFAULT_SERVER_PORT, sandbox_ssh_socket_path: "/run/openshell/ssh.sock".to_string(), network_name: DEFAULT_NETWORK_NAME.to_string(), - ssh_port: DEFAULT_SSH_PORT, - ssh_handshake_secret: String::new(), - ssh_handshake_skew_secs: DEFAULT_SSH_HANDSHAKE_SKEW_SECS, stop_timeout_secs: DEFAULT_STOP_TIMEOUT_SECS, supervisor_image: DEFAULT_SUPERVISOR_IMAGE.to_string(), guest_tls_ca: None, @@ -213,9 +205,6 @@ impl std::fmt::Debug for PodmanComputeConfig { .field("gateway_port", &self.gateway_port) .field("sandbox_ssh_socket_path", &self.sandbox_ssh_socket_path) .field("network_name", &self.network_name) - .field("ssh_port", &self.ssh_port) - .field("ssh_handshake_secret", &"[REDACTED]") - .field("ssh_handshake_skew_secs", &self.ssh_handshake_skew_secs) .field("stop_timeout_secs", &self.stop_timeout_secs) .field("supervisor_image", &self.supervisor_image) .field("guest_tls_ca", &self.guest_tls_ca) @@ -254,7 +243,7 @@ mod tests { .unwrap_or_else(std::sync::PoisonError::into_inner); temp_env::with_vars([("XDG_RUNTIME_DIR", None::<&str>)], || { let path = PodmanComputeConfig::default_socket_path(); - let uid = nix::unistd::getuid(); + let uid = rustix::process::getuid().as_raw(); assert_eq!( path, PathBuf::from(format!("/run/user/{uid}/podman/podman.sock")) diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index 3c5df292f..e73619756 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -4,7 +4,7 @@ //! Container spec construction for the Podman driver. use crate::config::PodmanComputeConfig; -use openshell_core::config::CDI_GPU_DEVICE_ALL; +use openshell_core::gpu::cdi_gpu_device_ids; use openshell_core::proto::compute::v1::DriverSandbox; use serde::Serialize; use serde_json::Value; @@ -50,6 +50,7 @@ const VOLUME_PREFIX: &str = "openshell-sandbox-"; const TLS_CA_MOUNT_PATH: &str = "/etc/openshell/tls/client/ca.crt"; const TLS_CERT_MOUNT_PATH: &str = "/etc/openshell/tls/client/tls.crt"; const TLS_KEY_MOUNT_PATH: &str = "/etc/openshell/tls/client/tls.key"; +const SANDBOX_TOKEN_MOUNT_PATH: &str = "/etc/openshell/auth/sandbox.jwt"; /// Build a Podman container name from the sandbox name. #[must_use] @@ -63,15 +64,6 @@ pub fn volume_name(sandbox_id: &str) -> String { format!("{VOLUME_PREFIX}{sandbox_id}-workspace") } -/// Podman secret name prefix. -const SECRET_PREFIX: &str = "openshell-handshake-"; - -/// Build the Podman secret name for a sandbox's SSH handshake secret. -#[must_use] -pub fn secret_name(sandbox_id: &str) -> String { - format!("{SECRET_PREFIX}{sandbox_id}") -} - /// Truncate a container ID to 12 characters (standard short form). #[must_use] pub fn short_id(id: &str) -> String { @@ -252,7 +244,10 @@ fn build_env( // 1. User-supplied environment (lowest priority). if let Some(s) = spec { if !s.log_level.is_empty() { - env.insert("OPENSHELL_LOG_LEVEL".into(), s.log_level.clone()); + env.insert( + openshell_core::sandbox_env::LOG_LEVEL.into(), + s.log_level.clone(), + ); } for (k, v) in &s.environment { env.insert(k.clone(), v.clone()); @@ -265,30 +260,58 @@ fn build_env( } // 2. Required driver vars (highest priority -- always overwrite). - env.insert("OPENSHELL_SANDBOX".into(), sandbox.name.clone()); - env.insert("OPENSHELL_SANDBOX_ID".into(), sandbox.id.clone()); - env.insert("OPENSHELL_ENDPOINT".into(), config.grpc_endpoint.clone()); env.insert( - "OPENSHELL_SSH_SOCKET_PATH".into(), - config.sandbox_ssh_socket_path.clone(), + openshell_core::sandbox_env::SANDBOX.into(), + sandbox.name.clone(), + ); + env.insert( + openshell_core::sandbox_env::SANDBOX_ID.into(), + sandbox.id.clone(), + ); + env.insert( + openshell_core::sandbox_env::ENDPOINT.into(), + config.grpc_endpoint.clone(), ); - // NOTE: The SSH handshake secret is injected via a Podman secret - // (see the "secrets" field below) rather than a plaintext env var. - // This prevents exposure through `podman inspect`. env.insert( - "OPENSHELL_SSH_HANDSHAKE_SKEW_SECS".into(), - config.ssh_handshake_skew_secs.to_string(), + openshell_core::sandbox_env::SSH_SOCKET_PATH.into(), + config.sandbox_ssh_socket_path.clone(), ); env.insert("OPENSHELL_CONTAINER_IMAGE".into(), image.to_string()); - env.insert("OPENSHELL_SANDBOX_COMMAND".into(), "sleep infinity".into()); + env.insert( + openshell_core::sandbox_env::SANDBOX_COMMAND.into(), + "sleep infinity".into(), + ); // 3. TLS client cert paths (when mTLS is enabled). These point to // the container-side mount paths where the cert files are // bind-mounted from the host. if config.tls_enabled() { - env.insert("OPENSHELL_TLS_CA".into(), TLS_CA_MOUNT_PATH.into()); - env.insert("OPENSHELL_TLS_CERT".into(), TLS_CERT_MOUNT_PATH.into()); - env.insert("OPENSHELL_TLS_KEY".into(), TLS_KEY_MOUNT_PATH.into()); + env.insert( + openshell_core::sandbox_env::TLS_CA.into(), + TLS_CA_MOUNT_PATH.into(), + ); + env.insert( + openshell_core::sandbox_env::TLS_CERT.into(), + TLS_CERT_MOUNT_PATH.into(), + ); + env.insert( + openshell_core::sandbox_env::TLS_KEY.into(), + TLS_KEY_MOUNT_PATH.into(), + ); + } + + env.remove(openshell_core::sandbox_env::SANDBOX_TOKEN); + env.remove(openshell_core::sandbox_env::SANDBOX_TOKEN_FILE); + + // 4. Gateway-minted sandbox JWT. Keep the raw bearer out of container + // metadata; the supervisor reads it from a driver-owned bind mount. + if let Some(s) = spec + && !s.sandbox_token.is_empty() + { + env.insert( + openshell_core::sandbox_env::SANDBOX_TOKEN_FILE.into(), + SANDBOX_TOKEN_MOUNT_PATH.into(), + ); } env @@ -345,18 +368,28 @@ fn build_resource_limits(sandbox: &DriverSandbox) -> ResourceLimits { /// Build CDI GPU device list if GPU is requested. fn build_devices(sandbox: &DriverSandbox) -> Option> { - if sandbox.spec.as_ref().is_some_and(|s| s.gpu) { - Some(vec![LinuxDevice { - path: CDI_GPU_DEVICE_ALL.into(), - }]) - } else { - None - } + let spec = sandbox.spec.as_ref()?; + cdi_gpu_device_ids(spec.gpu, &spec.gpu_device).map(|device_ids| { + device_ids + .into_iter() + .map(|path| LinuxDevice { path }) + .collect() + }) } /// Build the Podman container creation JSON spec. +#[cfg(test)] #[must_use] pub fn build_container_spec(sandbox: &DriverSandbox, config: &PodmanComputeConfig) -> Value { + build_container_spec_with_token(sandbox, config, None) +} + +#[must_use] +pub fn build_container_spec_with_token( + sandbox: &DriverSandbox, + config: &PodmanComputeConfig, + token_host_path: Option<&std::path::Path>, +) -> Value { let image = resolve_image(sandbox, config); let name = container_name(&sandbox.name); let vol = volume_name(&sandbox.id); @@ -482,7 +515,8 @@ pub fn build_container_spec(sandbox: &DriverSandbox, config: &PodmanComputeConfi "CMD-SHELL".into(), format!( "test -e /var/run/openshell-ssh-ready || test -S {} || ss -tlnp | grep -q :{}", - config.sandbox_ssh_socket_path, config.ssh_port + config.sandbox_ssh_socket_path, + openshell_core::config::DEFAULT_SSH_PORT ), ], interval: 3_000_000_000, @@ -491,16 +525,7 @@ pub fn build_container_spec(sandbox: &DriverSandbox, config: &PodmanComputeConfi start_period: 5_000_000_000, }, resource_limits, - // Inject the SSH handshake secret via Podman's secret_env map so it - // does not appear in `podman inspect` output. The libpod SpecGenerator - // uses `secret_env` (map of env_var → secret_name) for env-type secrets, - // distinct from `secrets` which only handles file mounts under /run/secrets/. - // The secret is created by the driver before the container - // (see `PodmanComputeDriver::create_sandbox`). - secret_env: BTreeMap::from([( - "OPENSHELL_SSH_HANDSHAKE_SECRET".into(), - secret_name(&sandbox.id), - )]), + secret_env: BTreeMap::new(), stop_timeout: config.stop_timeout_secs, // Inject stable host aliases into /etc/hosts so sandbox containers can // reach services on the host. `host.openshell.internal` is the driver- @@ -563,6 +588,18 @@ pub fn build_container_spec(sandbox: &DriverSandbox, config: &PodmanComputeConfi options: ro, }); } + if let Some(path) = token_host_path { + let mut ro = vec!["ro".into(), "rbind".into()]; + if is_selinux_enabled() { + ro.push("z".into()); + } + m.push(Mount { + kind: "bind".into(), + source: path.display().to_string(), + destination: SANDBOX_TOKEN_MOUNT_PATH.into(), + options: ro, + }); + } m }, // Publish the SSH port with host_port=0 to get an ephemeral host port. @@ -570,7 +607,7 @@ pub fn build_container_spec(sandbox: &DriverSandbox, config: &PodmanComputeConfi // the host, so we must use the published host port on 127.0.0.1 instead. portmappings: vec![PortMapping { host_port: 0, - container_port: config.ssh_port, + container_port: openshell_core::config::DEFAULT_SSH_PORT, protocol: "tcp".into(), }], }; @@ -606,13 +643,18 @@ fn parse_cpu_to_microseconds(quantity: &str) -> Option { /// (decimal), as well as plain byte values. fn parse_memory_to_bytes(quantity: &str) -> Option { let suffixes: &[(&str, u64)] = &[ + ("Ei", 1024 * 1024 * 1024 * 1024 * 1024 * 1024), + ("Pi", 1024 * 1024 * 1024 * 1024 * 1024), ("Ti", 1024 * 1024 * 1024 * 1024), ("Gi", 1024 * 1024 * 1024), ("Mi", 1024 * 1024), ("Ki", 1024), + ("E", 1_000_000_000_000_000_000), + ("P", 1_000_000_000_000_000), ("T", 1_000_000_000_000), ("G", 1_000_000_000), ("M", 1_000_000), + ("K", 1_000), ("k", 1_000), ]; @@ -656,6 +698,7 @@ mod tests { fn parse_memory_decimal_suffixes() { assert_eq!(parse_memory_to_bytes("1G"), Some(1_000_000_000)); assert_eq!(parse_memory_to_bytes("500M"), Some(500_000_000)); + assert_eq!(parse_memory_to_bytes("1K"), Some(1_000)); } #[test] @@ -663,6 +706,37 @@ mod tests { assert_eq!(parse_memory_to_bytes("1048576"), Some(1_048_576)); } + #[test] + fn container_spec_applies_cpu_and_memory_limits() { + use openshell_core::proto::compute::v1::{ + DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, + }; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + template: Some(DriverSandboxTemplate { + resources: Some(DriverResourceRequirements { + cpu_limit: "500m".to_string(), + memory_limit: "2Gi".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert_eq!( + spec["resource_limits"]["cpu"]["quota"].as_u64(), + Some(50_000) + ); + assert_eq!( + spec["resource_limits"]["memory"]["limit"].as_u64(), + Some(2 * 1024 * 1024 * 1024) + ); + } + #[test] fn container_name_is_prefixed() { assert_eq!(container_name("my-sandbox"), "openshell-sandbox-my-sandbox"); @@ -676,17 +750,59 @@ mod tests { ); } - #[test] - fn secret_name_uses_id() { - assert_eq!(secret_name("abc-123"), "openshell-handshake-abc-123"); - } - #[test] fn short_id_truncates() { assert_eq!(short_id("abc123def456789"), "abc123def456"); assert_eq!(short_id("short"), "short"); } + #[test] + fn container_spec_omits_devices_without_gpu_request() { + let sandbox = test_sandbox("test-id", "test-name"); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert!(spec.get("devices").is_none()); + } + + #[test] + fn container_spec_maps_empty_gpu_request_to_all_cdi_device() { + use openshell_core::config::CDI_GPU_DEVICE_ALL; + use openshell_core::proto::compute::v1::DriverSandboxSpec; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + gpu: true, + ..Default::default() + }); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert_eq!( + spec["devices"][0]["path"].as_str(), + Some(CDI_GPU_DEVICE_ALL) + ); + } + + #[test] + fn container_spec_passes_explicit_cdi_device_id_through() { + use openshell_core::proto::compute::v1::DriverSandboxSpec; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + gpu: true, + gpu_device: "nvidia.com/gpu=0".to_string(), + ..Default::default() + }); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert_eq!( + spec["devices"][0]["path"].as_str(), + Some("nvidia.com/gpu=0") + ); + } + #[test] fn container_spec_includes_required_capabilities() { let sandbox = test_sandbox("test-id", "test-name"); @@ -735,34 +851,6 @@ mod tests { ); } - #[test] - fn container_spec_uses_secret_env_not_plaintext() { - let sandbox = test_sandbox("test-id", "test-name"); - let config = test_config(); - let spec = build_container_spec(&sandbox, &config); - - // The handshake secret must NOT appear in the plaintext env map. - let env_map = spec["env"].as_object().expect("env should be an object"); - assert!( - !env_map.contains_key("OPENSHELL_SSH_HANDSHAKE_SECRET"), - "handshake secret should not be in plaintext env" - ); - - // It should appear in secret_env (the libpod env-type secret map) instead. - let secret_env = spec["secret_env"] - .as_object() - .expect("secret_env should be an object"); - assert!( - secret_env.contains_key("OPENSHELL_SSH_HANDSHAKE_SECRET"), - "secret_env should map OPENSHELL_SSH_HANDSHAKE_SECRET to its secret name" - ); - assert_eq!( - secret_env["OPENSHELL_SSH_HANDSHAKE_SECRET"].as_str(), - Some("openshell-handshake-test-id"), - "secret_env value should be the Podman secret name for the sandbox" - ); - } - #[test] fn container_spec_sets_sandbox_name_in_env() { let sandbox = test_sandbox("test-id", "my-sandbox"); @@ -771,7 +859,9 @@ mod tests { let env_map = spec["env"].as_object().expect("env should be an object"); assert_eq!( - env_map.get("OPENSHELL_SANDBOX").and_then(|v| v.as_str()), + env_map + .get(openshell_core::sandbox_env::SANDBOX) + .and_then(|v| v.as_str()), Some("my-sandbox"), ); } @@ -961,7 +1051,6 @@ mod tests { default_image: "test-image:latest".to_string(), grpc_endpoint: "http://localhost:50051".to_string(), sandbox_ssh_socket_path: "/run/openshell/test-ssh.sock".to_string(), - ssh_handshake_secret: "test-secret-value".to_string(), ..PodmanComputeConfig::default() } } @@ -984,7 +1073,7 @@ mod tests { let vol = &image_volumes[0]; assert_eq!( vol["source"].as_str(), - Some("openshell/supervisor:latest"), + Some("ghcr.io/nvidia/openshell/supervisor:latest"), "image volume source should be the supervisor image" ); assert_eq!( @@ -1063,6 +1152,43 @@ mod tests { ); } + #[test] + fn container_spec_uses_token_file_mount_without_raw_token_env() { + use openshell_core::proto::compute::v1::DriverSandboxSpec; + + let mut sandbox = test_sandbox("token-id", "token-name"); + sandbox.spec = Some(DriverSandboxSpec { + sandbox_token: "secret.jwt.value".to_string(), + ..Default::default() + }); + let config = test_config(); + let token_path = std::path::Path::new("/host/token.jwt"); + + let spec = build_container_spec_with_token(&sandbox, &config, Some(token_path)); + + let env_map = spec["env"].as_object().expect("env should be an object"); + assert_eq!( + env_map + .get(openshell_core::sandbox_env::SANDBOX_TOKEN) + .and_then(|v| v.as_str()), + None + ); + assert_eq!( + env_map + .get(openshell_core::sandbox_env::SANDBOX_TOKEN_FILE) + .and_then(|v| v.as_str()), + Some("/etc/openshell/auth/sandbox.jwt") + ); + let mounts = spec["mounts"] + .as_array() + .expect("mounts should be an array"); + assert!(mounts.iter().any(|m| { + m["type"].as_str() == Some("bind") + && m["source"].as_str() == Some("/host/token.jwt") + && m["destination"].as_str() == Some("/etc/openshell/auth/sandbox.jwt") + })); + } + #[test] fn container_spec_omits_tls_without_config() { let sandbox = test_sandbox("notls-id", "notls-name"); diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index ad4d7a192..4e493ddc1 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -11,6 +11,8 @@ use crate::watcher::{ }; use openshell_core::ComputeDriverError; use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; +use std::path::PathBuf; +use std::time::Duration; use tracing::{info, warn}; impl From for ComputeDriverError { @@ -54,9 +56,71 @@ fn validated_container_name(sandbox_name: &str) -> Result Result { + openshell_core::driver_utils::sandbox_token_path("podman-sandbox-tokens", None, sandbox_id) + .map_err(|err| ComputeDriverError::Message(format!("resolve state dir failed: {err}"))) +} + +async fn write_sandbox_token_file( + sandbox: &DriverSandbox, +) -> Result, ComputeDriverError> { + let Some(spec) = sandbox.spec.as_ref() else { + return Ok(None); + }; + if spec.sandbox_token.is_empty() { + return Ok(None); + } + let path = sandbox_token_host_path(&sandbox.id)?; + if let Some(parent) = path.parent() { + openshell_core::paths::create_dir_restricted(parent).map_err(|err| { + ComputeDriverError::Message(format!( + "create sandbox token directory {} failed: {err}", + parent.display() + )) + })?; + } + tokio::fs::write(&path, format!("{}\n", spec.sandbox_token)) + .await + .map_err(|err| { + ComputeDriverError::Message(format!( + "write sandbox token file {} failed: {err}", + path.display() + )) + })?; + openshell_core::paths::set_file_owner_only(&path).map_err(|err| { + ComputeDriverError::Message(format!( + "restrict sandbox token file {} failed: {err}", + path.display() + )) + })?; + Ok(Some(path)) +} + +fn cleanup_sandbox_token_file(sandbox_id: &str) { + let Ok(path) = sandbox_token_host_path(sandbox_id) else { + return; + }; + if let Err(err) = std::fs::remove_file(&path) + && err.kind() != std::io::ErrorKind::NotFound + { + warn!( + sandbox_id = %sandbox_id, + path = %path.display(), + error = %err, + "Failed to remove Podman sandbox token file" + ); + } + if let Some(dir) = path.parent() { + let _ = std::fs::remove_dir(dir); + } +} + impl PodmanComputeDriver { /// Create a new driver, verifying the Podman socket is reachable. pub async fn new(mut config: PodmanComputeConfig) -> Result { + const MAX_PING_RETRIES: u32 = 5; + const PING_RETRY_DELAY: Duration = Duration::from_secs(2); + if !config.socket_path.exists() { if cfg!(target_os = "macos") { warn!( @@ -80,8 +144,27 @@ impl PodmanComputeDriver { let client = PodmanClient::new(config.socket_path.clone()); - // Verify connectivity. - client.ping().await?; + // Verify connectivity, retrying briefly to tolerate transient socket + // unavailability (e.g. podman.socket restarting after a package + // upgrade). The systemd unit uses Wants=podman.socket (not Requires), + // so the gateway may start while the socket is briefly re-activating. + let mut attempts = 0; + loop { + match client.ping().await { + Ok(()) => break, + Err(e) if attempts < MAX_PING_RETRIES => { + attempts += 1; + warn!( + attempt = attempts, + max_retries = MAX_PING_RETRIES, + error = %e, + "Podman socket not ready, retrying" + ); + tokio::time::sleep(PING_RETRY_DELAY).await; + } + Err(e) => return Err(e), + } + } // Verify cgroups v2, detect rootless mode, and log system info. match client.system_info().await { @@ -111,7 +194,7 @@ impl PodmanComputeDriver { // Rootless pre-flight: warn if subuid/subgid ranges look missing. // Not a hard error because some systems configure these via LDAP or // other mechanisms that /etc/subuid does not reflect. - if nix::unistd::getuid().as_raw() != 0 { + if rustix::process::getuid().as_raw() != 0 { check_subuid_range(); } @@ -169,14 +252,12 @@ impl PodmanComputeDriver { /// Report driver capabilities. pub fn capabilities(&self) -> Result { - let supports_gpu = Self::has_gpu_capacity(); - Ok(GetCapabilitiesResponse { - driver_name: "podman".to_string(), - driver_version: openshell_core::VERSION.to_string(), - default_image: self.config.default_image.clone(), - supports_gpu, - gpu_count: 0, - }) + Ok(openshell_core::driver_utils::build_capabilities_response( + "podman", + openshell_core::VERSION, + &self.config.default_image, + Self::has_gpu_capacity(), + )) } #[must_use] @@ -199,6 +280,10 @@ impl PodmanComputeDriver { sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); + Self::validate_gpu_request(gpu_requested) + } + + fn validate_gpu_request(gpu_requested: bool) -> Result<(), ComputeDriverError> { if gpu_requested && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), @@ -221,12 +306,11 @@ impl PodmanComputeDriver { } // Validate the composed container name early, before creating any - // resources (secret, volume), so we don't leave orphans when the - // name is invalid. + // resources (volume), so we don't leave orphans when the name is + // invalid. let name = validated_container_name(&sandbox.name)?; let vol_name = container::volume_name(&sandbox.id); - let sec_name = container::secret_name(&sandbox.id); info!( sandbox_id = %sandbox.id, @@ -237,15 +321,16 @@ impl PodmanComputeDriver { // 1a. Pull the supervisor image if needed. The supervisor binary // is shipped in a standalone OCI image and mounted into sandbox - // containers via Podman's type=image mount. Using "missing" - // policy so the image is only pulled once and then cached. + // containers via Podman's type=image mount. Refresh mutable tags + // like latest/dev, but avoid registry checks for pinned images. + let supervisor_pull_policy = supervisor_image_pull_policy(&self.config.supervisor_image); info!( image = %self.config.supervisor_image, - policy = "missing", + policy = supervisor_pull_policy, "Ensuring supervisor image" ); self.client - .pull_image(&self.config.supervisor_image, "missing") + .pull_image(&self.config.supervisor_image, supervisor_pull_policy) .await .map_err(ComputeDriverError::from)?; @@ -253,7 +338,7 @@ impl PodmanComputeDriver { let image = container::resolve_image(sandbox, &self.config); if image.is_empty() { return Err(ComputeDriverError::Precondition( - "no sandbox image configured: set --sandbox-image on the server \ + "no sandbox image configured: set default_image in [openshell.drivers.podman] \ or provide an image in the sandbox template" .to_string(), )); @@ -265,35 +350,38 @@ impl PodmanComputeDriver { .await .map_err(ComputeDriverError::from)?; - // 2. Create the SSH handshake secret via the Podman secrets API - // so it is not exposed in `podman inspect` output. - self.client - .create_secret(&sec_name, self.config.ssh_handshake_secret.as_bytes()) - .await - .map_err(ComputeDriverError::from)?; - - // 3. Create workspace volume. + // 2. Create workspace volume. if let Err(e) = self.client.create_volume(&vol_name).await { - let _ = self.client.remove_secret(&sec_name).await; return Err(ComputeDriverError::from(e)); } + let token_host_path = match write_sandbox_token_file(sandbox).await { + Ok(path) => path, + Err(e) => { + let _ = self.client.remove_volume(&vol_name).await; + return Err(e); + } + }; - // 4. Create container. - let spec = container::build_container_spec(sandbox, &self.config); + // 3. Create container. + let spec = container::build_container_spec_with_token( + sandbox, + &self.config, + token_host_path.as_deref(), + ); match self.client.create_container(&spec).await { Ok(_) => {} Err(PodmanApiError::Conflict(_)) => { - // Clean up the volume and secret we just created. They are - // keyed by *this* sandbox's ID, not the conflicting - // container's ID (which has the same name but a different - // ID), so they would be orphaned otherwise. + // Clean up the volume we just created. It is keyed by *this* + // sandbox's ID, not the conflicting container's ID (which + // has the same name but a different ID), so it would be + // orphaned otherwise. let _ = self.client.remove_volume(&vol_name).await; - let _ = self.client.remove_secret(&sec_name).await; + cleanup_sandbox_token_file(&sandbox.id); return Err(ComputeDriverError::AlreadyExists); } Err(e) => { let _ = self.client.remove_volume(&vol_name).await; - let _ = self.client.remove_secret(&sec_name).await; + cleanup_sandbox_token_file(&sandbox.id); return Err(ComputeDriverError::from(e)); } } @@ -307,7 +395,7 @@ impl PodmanComputeDriver { ); let _ = self.client.remove_container(&name).await; let _ = self.client.remove_volume(&vol_name).await; - let _ = self.client.remove_secret(&sec_name).await; + cleanup_sandbox_token_file(&sandbox.id); return Err(ComputeDriverError::from(e)); } @@ -386,15 +474,15 @@ impl PodmanComputeDriver { .await; // Remove container. If NotFound, the container was removed between - // inspect and here (TOCTOU race); proceed with volume/secret cleanup - // since those resources are idempotent to remove. + // inspect and here (TOCTOU race); proceed with volume cleanup + // since the workspace volume is idempotent to remove. let container_existed = match self.client.remove_container(&name).await { Ok(()) => true, Err(PodmanApiError::NotFound(_)) => false, Err(e) => return Err(ComputeDriverError::from(e)), }; - // Remove workspace volume and handshake secret. + // Remove workspace volume. let vol = container::volume_name(sandbox_id); if let Err(e) = self.client.remove_volume(&vol).await { warn!( @@ -405,16 +493,7 @@ impl PodmanComputeDriver { "Failed to remove workspace volume" ); } - let sec = container::secret_name(sandbox_id); - if let Err(e) = self.client.remove_secret(&sec).await { - warn!( - sandbox_id = %sandbox_id, - sandbox_name = %sandbox_name, - secret = %sec, - error = %e, - "Failed to remove handshake secret" - ); - } + cleanup_sandbox_token_file(sandbox_id); Ok(container_existed) } @@ -504,6 +583,31 @@ impl PodmanComputeDriver { } } +fn supervisor_image_pull_policy(image: &str) -> &'static str { + if supervisor_image_should_refresh(image) { + "newer" + } else { + "missing" + } +} + +fn supervisor_image_should_refresh(image: &str) -> bool { + matches!(supervisor_image_tag(image), Some("dev" | "latest")) +} + +fn supervisor_image_tag(image: &str) -> Option<&str> { + if image.contains('@') { + return None; + } + + let image_name = image.rsplit('/').next().unwrap_or(image); + image_name + .rsplit_once(':') + .map_or(Some("latest"), |(_, tag)| { + if tag.is_empty() { None } else { Some(tag) } + }) +} + /// Check whether the current user has subuid/subgid ranges configured. /// /// Rootless Podman requires entries in `/etc/subuid` and `/etc/subgid` for @@ -546,18 +650,9 @@ fn check_subuid_range() { #[cfg(test)] mod tests { use super::*; - use http_body_util::Full; - use hyper::body::Bytes; - use hyper::server::conn::http1; - use hyper::service::service_fn; - use hyper::{Response, StatusCode}; - use hyper_util::rt::TokioIo; - use std::collections::VecDeque; - use std::convert::Infallible; + use crate::test_utils::{StubResponse, spawn_podman_stub}; + use hyper::StatusCode; use std::path::PathBuf; - use std::sync::{Arc, Mutex}; - use std::time::{SystemTime, UNIX_EPOCH}; - use tokio::net::UnixListener; #[test] fn podman_driver_error_from_conflict() { @@ -643,30 +738,30 @@ mod tests { assert_eq!(cfg.grpc_endpoint, "https://gateway.internal:9000"); } - #[derive(Clone)] - struct StubResponse { - status: StatusCode, - body: String, - } - - impl StubResponse { - fn new(status: StatusCode, body: impl Into) -> Self { - Self { - status, - body: body.into(), - } - } - } - - fn unique_socket_path(test_name: &str) -> PathBuf { - let nanos = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("clock should be after unix epoch") - .as_nanos(); - PathBuf::from(format!( - "/tmp/openshell-podman-{test_name}-{}-{nanos}.sock", - std::process::id() - )) + #[test] + fn supervisor_pull_policy_refreshes_mutable_tags_only() { + assert_eq!( + supervisor_image_pull_policy("ghcr.io/nvidia/openshell/supervisor:dev"), + "newer" + ); + assert_eq!( + supervisor_image_pull_policy("ghcr.io/nvidia/openshell/supervisor:latest"), + "newer" + ); + assert_eq!( + supervisor_image_pull_policy("ghcr.io/nvidia/openshell/supervisor"), + "newer" + ); + assert_eq!( + supervisor_image_pull_policy( + "ghcr.io/nvidia/openshell/supervisor:0.0.47-dev.13-g57b71c68f" + ), + "missing" + ); + assert_eq!( + supervisor_image_pull_policy("ghcr.io/nvidia/openshell/supervisor@sha256:abc123"), + "missing" + ); } fn test_driver(socket_path: PathBuf) -> PodmanComputeDriver { @@ -682,78 +777,12 @@ mod tests { format!("/v5.0.0{path}") } - fn spawn_podman_stub( - test_name: &str, - responses: Vec, - ) -> ( - PathBuf, - Arc>>, - tokio::task::JoinHandle<()>, - ) { - let socket_path = unique_socket_path(test_name); - let _ = std::fs::remove_file(&socket_path); - let listener = UnixListener::bind(&socket_path).expect("test socket should bind"); - let request_log = Arc::new(Mutex::new(Vec::new())); - let response_queue = Arc::new(Mutex::new(VecDeque::from(responses))); - let expected = response_queue - .lock() - .expect("response queue lock should not be poisoned") - .len(); - let socket_path_for_task = socket_path.clone(); - let log_for_task = request_log.clone(); - let queue_for_task = response_queue; - let handle = tokio::spawn(async move { - for _ in 0..expected { - let (stream, _) = listener.accept().await.expect("test stub should accept"); - let log = log_for_task.clone(); - let queue = queue_for_task.clone(); - let result = http1::Builder::new() - .serve_connection( - TokioIo::new(stream), - service_fn(move |req| { - let log = log.clone(); - let queue = queue.clone(); - async move { - let path = req.uri().path_and_query().map_or_else( - || req.uri().path().to_string(), - |pq| pq.as_str().to_string(), - ); - log.lock() - .expect("request log lock should not be poisoned") - .push(format!("{} {}", req.method(), path)); - let response = queue - .lock() - .expect("response queue lock should not be poisoned") - .pop_front() - .expect("stub response should exist"); - Ok::<_, Infallible>( - Response::builder() - .status(response.status) - .body(Full::new(Bytes::from(response.body))) - .expect("stub response should build"), - ) - } - }), - ) - .await; - // The one-shot test client can close the Unix socket after the - // response, which Hyper reports as a shutdown error. Let the - // request log assertions below decide whether the stub served - // the expected API calls. - let _ = result; - } - let _ = std::fs::remove_file(&socket_path_for_task); - }); - (socket_path, request_log, handle) - } - #[tokio::test] async fn delete_sandbox_cleans_up_with_request_id_when_container_is_already_gone() { let sandbox_id = "sandbox-123"; let sandbox_name = "demo"; let container_name = container::container_name(sandbox_name); let volume_name = container::volume_name(sandbox_id); - let secret_name = container::secret_name(sandbox_id); let (socket_path, request_log, handle) = spawn_podman_stub( "delete-not-found", vec![ @@ -761,7 +790,6 @@ mod tests { StubResponse::new(StatusCode::NOT_FOUND, r#"{"message":"gone"}"#), StubResponse::new(StatusCode::NOT_FOUND, r#"{"message":"gone"}"#), StubResponse::new(StatusCode::NO_CONTENT, ""), - StubResponse::new(StatusCode::NO_CONTENT, ""), ], ); let driver = test_driver(socket_path.clone()); @@ -800,10 +828,6 @@ mod tests { "DELETE {}", api_path(&format!("/libpod/volumes/{volume_name}")) ), - format!( - "DELETE {}", - api_path(&format!("/libpod/secrets/{secret_name}")) - ), ] ); let _ = std::fs::remove_file(socket_path); @@ -815,7 +839,6 @@ mod tests { let sandbox_name = "demo"; let container_name = container::container_name(sandbox_name); let volume_name = container::volume_name(sandbox_id); - let secret_name = container::secret_name(sandbox_id); let inspect_body = serde_json::json!({ "Id": "container-id", "Name": format!("/{container_name}"), @@ -837,7 +860,6 @@ mod tests { StubResponse::new(StatusCode::NO_CONTENT, ""), StubResponse::new(StatusCode::NO_CONTENT, ""), StubResponse::new(StatusCode::NO_CONTENT, ""), - StubResponse::new(StatusCode::NO_CONTENT, ""), ], ); let driver = test_driver(socket_path.clone()); @@ -855,16 +877,10 @@ mod tests { .clone(); assert_eq!( requests[3..], - [ - format!( - "DELETE {}", - api_path(&format!("/libpod/volumes/{volume_name}")) - ), - format!( - "DELETE {}", - api_path(&format!("/libpod/secrets/{secret_name}")) - ), - ] + [format!( + "DELETE {}", + api_path(&format!("/libpod/volumes/{volume_name}")) + )] ); let _ = std::fs::remove_file(socket_path); } diff --git a/crates/openshell-driver-podman/src/grpc.rs b/crates/openshell-driver-podman/src/grpc.rs index df4c90d13..0c6015776 100644 --- a/crates/openshell-driver-podman/src/grpc.rs +++ b/crates/openshell-driver-podman/src/grpc.rs @@ -15,7 +15,6 @@ use std::pin::Pin; use tonic::{Request, Response, Status}; use crate::PodmanComputeDriver; -use openshell_core::ComputeDriverError; #[derive(Debug, Clone)] pub struct ComputeDriverService { @@ -38,7 +37,7 @@ impl ComputeDriver for ComputeDriverService { self.driver .capabilities() .map(Response::new) - .map_err(status_from_driver_error) + .map_err(Status::from) } async fn validate_sandbox_create( @@ -51,7 +50,7 @@ impl ComputeDriver for ComputeDriverService { .ok_or_else(|| Status::invalid_argument("sandbox is required"))?; self.driver .validate_sandbox_create(&sandbox) - .map_err(status_from_driver_error)?; + .map_err(Status::from)?; Ok(Response::new(ValidateSandboxCreateResponse {})) } @@ -68,7 +67,7 @@ impl ComputeDriver for ComputeDriverService { .driver .get_sandbox(&request.sandbox_name) .await - .map_err(status_from_driver_error)? + .map_err(Status::from)? .ok_or_else(|| Status::not_found("sandbox not found"))?; if !request.sandbox_id.is_empty() && request.sandbox_id != sandbox.id { @@ -86,11 +85,7 @@ impl ComputeDriver for ComputeDriverService { &self, _request: Request, ) -> Result, Status> { - let sandboxes = self - .driver - .list_sandboxes() - .await - .map_err(status_from_driver_error)?; + let sandboxes = self.driver.list_sandboxes().await.map_err(Status::from)?; Ok(Response::new(ListSandboxesResponse { sandboxes })) } @@ -105,7 +100,7 @@ impl ComputeDriver for ComputeDriverService { self.driver .create_sandbox(&sandbox) .await - .map_err(status_from_driver_error)?; + .map_err(Status::from)?; Ok(Response::new(CreateSandboxResponse {})) } @@ -120,7 +115,7 @@ impl ComputeDriver for ComputeDriverService { self.driver .stop_sandbox(&request.sandbox_name) .await - .map_err(status_from_driver_error)?; + .map_err(Status::from)?; Ok(Response::new(StopSandboxResponse {})) } @@ -139,7 +134,7 @@ impl ComputeDriver for ComputeDriverService { .driver .delete_sandbox(&request.sandbox_id, &request.sandbox_name) .await - .map_err(status_from_driver_error)?; + .map_err(Status::from)?; Ok(Response::new(DeleteSandboxResponse { deleted })) } @@ -150,46 +145,26 @@ impl ComputeDriver for ComputeDriverService { &self, _request: Request, ) -> Result, Status> { - let stream = self - .driver - .watch_sandboxes() - .await - .map_err(status_from_driver_error)?; + let stream = self.driver.watch_sandboxes().await.map_err(Status::from)?; let stream = stream.map(|item| item.map_err(|err| Status::internal(err.to_string()))); Ok(Response::new(Box::pin(stream))) } } -fn status_from_driver_error(err: ComputeDriverError) -> Status { - match err { - ComputeDriverError::AlreadyExists => Status::already_exists("sandbox already exists"), - ComputeDriverError::Precondition(message) => Status::failed_precondition(message), - ComputeDriverError::Message(message) => Status::internal(message), - } -} - #[cfg(test)] mod tests { use super::*; use crate::config::PodmanComputeConfig; use crate::container; - use http_body_util::Full; - use hyper::body::Bytes; - use hyper::server::conn::http1; - use hyper::service::service_fn; - use hyper::{Response as HyperResponse, StatusCode}; - use hyper_util::rt::TokioIo; - use std::collections::VecDeque; - use std::convert::Infallible; + use crate::test_utils::{StubResponse, spawn_podman_stub, unique_socket_path}; + use hyper::StatusCode; + use openshell_core::ComputeDriverError; use std::path::PathBuf; - use std::sync::{Arc, Mutex}; - use std::time::{SystemTime, UNIX_EPOCH}; #[test] fn precondition_driver_errors_map_to_failed_precondition_status() { - let status = status_from_driver_error(ComputeDriverError::Precondition( - "sandbox container is not running".to_string(), - )); + let status: Status = + ComputeDriverError::Precondition("sandbox container is not running".to_string()).into(); assert_eq!(status.code(), tonic::Code::FailedPrecondition); assert_eq!(status.message(), "sandbox container is not running"); @@ -197,102 +172,10 @@ mod tests { #[test] fn already_exists_driver_errors_map_to_already_exists_status() { - let status = status_from_driver_error(ComputeDriverError::AlreadyExists); + let status: Status = ComputeDriverError::AlreadyExists.into(); assert_eq!(status.code(), tonic::Code::AlreadyExists); } - #[derive(Clone)] - struct StubResponse { - status: StatusCode, - body: String, - } - - impl StubResponse { - fn new(status: StatusCode, body: impl Into) -> Self { - Self { - status, - body: body.into(), - } - } - } - - fn unique_socket_path(test_name: &str) -> PathBuf { - let nanos = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("clock should be after unix epoch") - .as_nanos(); - PathBuf::from(format!( - "/tmp/openshell-podman-grpc-{test_name}-{}-{nanos}.sock", - std::process::id() - )) - } - - fn spawn_podman_stub( - test_name: &str, - responses: Vec, - ) -> ( - PathBuf, - Arc>>, - tokio::task::JoinHandle<()>, - ) { - let socket_path = unique_socket_path(test_name); - let _ = std::fs::remove_file(&socket_path); - let listener = - tokio::net::UnixListener::bind(&socket_path).expect("test socket should bind"); - let request_log = Arc::new(Mutex::new(Vec::new())); - let response_queue = Arc::new(Mutex::new(VecDeque::from(responses))); - let expected = response_queue - .lock() - .expect("response queue lock should not be poisoned") - .len(); - let socket_path_for_task = socket_path.clone(); - let log_for_task = request_log.clone(); - let queue_for_task = response_queue; - let handle = tokio::spawn(async move { - for _ in 0..expected { - let (stream, _) = listener.accept().await.expect("test stub should accept"); - let log = log_for_task.clone(); - let queue = queue_for_task.clone(); - let result = http1::Builder::new() - .serve_connection( - TokioIo::new(stream), - service_fn(move |req| { - let log = log.clone(); - let queue = queue.clone(); - async move { - let path = req.uri().path_and_query().map_or_else( - || req.uri().path().to_string(), - |pq| pq.as_str().to_string(), - ); - log.lock() - .expect("request log lock should not be poisoned") - .push(format!("{} {}", req.method(), path)); - let response = queue - .lock() - .expect("response queue lock should not be poisoned") - .pop_front() - .expect("stub response should exist"); - Ok::<_, Infallible>( - HyperResponse::builder() - .status(response.status) - .body(Full::new(Bytes::from(response.body))) - .expect("stub response should build"), - ) - } - }), - ) - .await; - // The one-shot test client can close the Unix socket after the - // response, which Hyper reports as a shutdown error. Let the - // request log assertions below decide whether the stub served - // the expected API calls. - let _ = result; - } - let _ = std::fs::remove_file(&socket_path_for_task); - }); - (socket_path, request_log, handle) - } - fn test_service(socket_path: PathBuf) -> ComputeDriverService { let config = PodmanComputeConfig { socket_path, @@ -348,7 +231,6 @@ mod tests { let sandbox_name = "demo"; let container_name = container::container_name(sandbox_name); let volume_name = container::volume_name(sandbox_id); - let secret_name = container::secret_name(sandbox_id); let (socket_path, request_log, handle) = spawn_podman_stub( "forward-id", vec![ @@ -356,7 +238,6 @@ mod tests { StubResponse::new(StatusCode::NOT_FOUND, r#"{"message":"gone"}"#), StubResponse::new(StatusCode::NOT_FOUND, r#"{"message":"gone"}"#), StubResponse::new(StatusCode::NO_CONTENT, ""), - StubResponse::new(StatusCode::NO_CONTENT, ""), ], ); let service = test_service(socket_path.clone()); @@ -404,10 +285,6 @@ mod tests { "DELETE {}", api_path(&format!("/libpod/volumes/{volume_name}")) ), - format!( - "DELETE {}", - api_path(&format!("/libpod/secrets/{secret_name}")) - ), ] ); let _ = std::fs::remove_file(socket_path); diff --git a/crates/openshell-driver-podman/src/lib.rs b/crates/openshell-driver-podman/src/lib.rs index 630deaee1..5847a10ea 100644 --- a/crates/openshell-driver-podman/src/lib.rs +++ b/crates/openshell-driver-podman/src/lib.rs @@ -6,6 +6,8 @@ pub mod config; pub(crate) mod container; pub mod driver; pub mod grpc; +#[cfg(test)] +pub(crate) mod test_utils; pub(crate) mod watcher; pub use config::PodmanComputeConfig; diff --git a/crates/openshell-driver-podman/src/main.rs b/crates/openshell-driver-podman/src/main.rs index 9095915dd..5a0227ef6 100644 --- a/crates/openshell-driver-podman/src/main.rs +++ b/crates/openshell-driver-podman/src/main.rs @@ -9,12 +9,9 @@ use tracing::info; use tracing_subscriber::EnvFilter; use openshell_core::VERSION; -use openshell_core::config::{ - DEFAULT_NETWORK_NAME, DEFAULT_SSH_HANDSHAKE_SKEW_SECS, DEFAULT_SSH_PORT, - DEFAULT_STOP_TIMEOUT_SECS, -}; +use openshell_core::config::DEFAULT_STOP_TIMEOUT_SECS; use openshell_core::proto::compute::v1::compute_driver_server::ComputeDriverServer; -use openshell_driver_podman::config::ImagePullPolicy; +use openshell_driver_podman::config::{DEFAULT_NETWORK_NAME, ImagePullPolicy}; use openshell_driver_podman::{ComputeDriverService, PodmanComputeConfig, PodmanComputeDriver}; #[derive(Parser)] @@ -70,15 +67,6 @@ struct Args { #[arg(long, env = "OPENSHELL_NETWORK_NAME", default_value = DEFAULT_NETWORK_NAME)] network_name: String, - #[arg(long, env = "OPENSHELL_SANDBOX_SSH_PORT", default_value_t = DEFAULT_SSH_PORT)] - sandbox_ssh_port: u16, - - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SECRET")] - ssh_handshake_secret: String, - - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SKEW_SECS", default_value_t = DEFAULT_SSH_HANDSHAKE_SKEW_SECS)] - ssh_handshake_skew_secs: u64, - /// Container stop timeout in seconds (SIGTERM → SIGKILL). #[arg(long, env = "OPENSHELL_STOP_TIMEOUT", default_value_t = DEFAULT_STOP_TIMEOUT_SECS)] stop_timeout: u32, @@ -121,9 +109,6 @@ async fn main() -> Result<()> { gateway_port: args.gateway_port, sandbox_ssh_socket_path: args.sandbox_ssh_socket_path, network_name: args.network_name, - ssh_port: args.sandbox_ssh_port, - ssh_handshake_secret: args.ssh_handshake_secret, - ssh_handshake_skew_secs: args.ssh_handshake_skew_secs, stop_timeout_secs: args.stop_timeout, supervisor_image: args.supervisor_image, guest_tls_ca: args.podman_tls_ca, diff --git a/crates/openshell-driver-podman/src/test_utils.rs b/crates/openshell-driver-podman/src/test_utils.rs new file mode 100644 index 000000000..94794bc22 --- /dev/null +++ b/crates/openshell-driver-podman/src/test_utils.rs @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared test helpers for openshell-driver-podman unit tests. + +use http_body_util::Full; +use hyper::StatusCode; +use hyper::body::Bytes; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper_util::rt::TokioIo; +use std::collections::VecDeque; +use std::convert::Infallible; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::net::UnixListener; + +/// A canned HTTP response for the Podman stub server. +#[derive(Clone)] +pub struct StubResponse { + pub status: StatusCode, + pub body: String, +} + +impl StubResponse { + pub fn new(status: StatusCode, body: impl Into) -> Self { + Self { + status, + body: body.into(), + } + } +} + +/// Generate a unique Unix socket path for a test. +/// +/// Uses the current PID and nanosecond timestamp to avoid collisions between +/// concurrent test runs. +pub fn unique_socket_path(test_name: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after unix epoch") + .as_nanos(); + PathBuf::from(format!( + "/tmp/openshell-podman-{test_name}-{}-{nanos}.sock", + std::process::id() + )) +} + +/// Spawn a Unix-socket HTTP stub that serves the given `responses` in order. +/// +/// Returns: +/// - the socket path (already bound and listening) +/// - a shared log of `"METHOD /path"` strings, one per request received +/// - a join handle that resolves once all expected requests have been served +pub fn spawn_podman_stub( + test_name: &str, + responses: Vec, +) -> ( + PathBuf, + Arc>>, + tokio::task::JoinHandle<()>, +) { + let socket_path = unique_socket_path(test_name); + let _ = std::fs::remove_file(&socket_path); + let listener = UnixListener::bind(&socket_path).expect("test socket should bind"); + let request_log = Arc::new(Mutex::new(Vec::new())); + let response_queue = Arc::new(Mutex::new(VecDeque::from(responses))); + let expected = response_queue + .lock() + .expect("response queue lock should not be poisoned") + .len(); + let socket_path_for_task = socket_path.clone(); + let log_for_task = request_log.clone(); + let queue_for_task = response_queue; + let handle = tokio::spawn(async move { + for _ in 0..expected { + let (stream, _) = listener.accept().await.expect("test stub should accept"); + let log = log_for_task.clone(); + let queue = queue_for_task.clone(); + let result = http1::Builder::new() + .serve_connection( + TokioIo::new(stream), + service_fn(move |req| { + let log = log.clone(); + let queue = queue.clone(); + async move { + let path = req.uri().path_and_query().map_or_else( + || req.uri().path().to_string(), + |pq| pq.as_str().to_string(), + ); + log.lock() + .expect("request log lock should not be poisoned") + .push(format!("{} {}", req.method(), path)); + let response = queue + .lock() + .expect("response queue lock should not be poisoned") + .pop_front() + .expect("stub response should exist"); + Ok::<_, Infallible>( + hyper::Response::builder() + .status(response.status) + .body(Full::new(Bytes::from(response.body))) + .expect("stub response should build"), + ) + } + }), + ) + .await; + // The one-shot test client can close the Unix socket after the + // response, which Hyper reports as a shutdown error. Let the + // request log assertions below decide whether the stub served + // the expected API calls. + let _ = result; + } + let _ = std::fs::remove_file(&socket_path_for_task); + }); + (socket_path, request_log, handle) +} diff --git a/crates/openshell-driver-vm/Cargo.toml b/crates/openshell-driver-vm/Cargo.toml index c13d904a6..0006f1f35 100644 --- a/crates/openshell-driver-vm/Cargo.toml +++ b/crates/openshell-driver-vm/Cargo.toml @@ -25,6 +25,7 @@ openshell-vfio = { path = "../openshell-vfio" } bollard = { version = "0.20", features = ["ssh"] } tokio = { workspace = true } tonic = { workspace = true, features = ["transport"] } +prost = { workspace = true } prost-types = { workspace = true } futures = { workspace = true } tokio-stream = { workspace = true, features = ["net"] } @@ -38,6 +39,7 @@ serde = { workspace = true } serde_json = { workspace = true } oci-client = "0.16" libc = "0.2" +rustix = { workspace = true } libloading = "0.8" tar = "0.4" flate2 = "1" diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index a3bdf9822..8da0b96a4 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -2,7 +2,7 @@ > Status: Experimental. The VM compute driver is under active development and the interface still has VM-specific plumbing that will be generalized. -Standalone libkrun-backed [`ComputeDriver`](../../proto/compute_driver.proto) for OpenShell. The gateway spawns this binary as a subprocess, talks to it over a Unix domain socket with the `openshell.compute.v1.ComputeDriver` gRPC surface, and lets it manage per-sandbox microVMs. The runtime (libkrun + libkrunfw + gvproxy) and the sandbox supervisor are embedded directly in the binary; each sandbox guest rootfs is derived from a configured container image at create time. +Standalone libkrun-backed [`ComputeDriver`](../../proto/compute_driver.proto) for OpenShell. The gateway spawns this binary as a subprocess, talks to it over a Unix domain socket with the `openshell.compute.v1.ComputeDriver` gRPC surface, and lets it manage per-sandbox microVMs. The runtime (libkrun + libkrunfw + gvproxy), guest OCI unpacker, and sandbox supervisor are embedded directly in the binary; each sandbox boots from a cached immutable bootstrap ext4 root disk plus a per-sandbox writable overlay disk. When the requested sandbox image differs from the bootstrap image, the driver prepares a read-only image ext4 disk inside a bootstrap VM and mounts that unpacked rootfs as the sandbox lowerdir. ## How it fits together @@ -35,15 +35,15 @@ Sandbox guests execute `/opt/openshell/bin/openshell-sandbox` as PID 1 inside th mise run gateway:vm ``` -First run takes a few minutes while `mise run vm:setup` stages libkrun/libkrunfw/gvproxy and `mise run vm:supervisor` builds the bundled guest supervisor. Subsequent runs are cached. +First run takes a few minutes while `mise run vm:setup` stages libkrun/libkrunfw/gvproxy/umoci and `mise run vm:supervisor` builds the bundled guest supervisor. Subsequent runs are cached. By default `mise run gateway:vm`: - Listens on plaintext HTTP at `127.0.0.1:18081`. - Registers the CLI gateway `vm-dev` by writing `~/.config/openshell/gateways/vm-dev/metadata.json`. It does not modify the workspace `.env`. - Persists the gateway SQLite DB under `.cache/gateway-vm/gateway.db`. -- Places the VM driver state (per-sandbox rootfs + `compute-driver.sock`) under `/tmp/openshell-vm-driver-$USER-vm-dev/` so the AF_UNIX socket path stays under macOS `SUN_LEN`. -- Passes `--driver-dir $PWD/target/debug` so the freshly built `openshell-driver-vm` is used instead of an older installed copy from `~/.local/libexec/openshell`, `/usr/libexec/openshell`, or `/usr/local/libexec`. +- Places the VM driver state (per-sandbox `overlay.ext4`, image cache, and `run/compute-driver.sock`) under `/tmp/openshell-vm-driver-$USER-vm-dev/` so the AF_UNIX socket path stays under macOS `SUN_LEN`. +- Writes `.cache/gateway-vm/gateway.toml` with `[openshell.drivers.vm].driver_dir = "$PWD/target/debug"` so the freshly built `openshell-driver-vm` is used instead of an older installed copy from `~/.local/libexec/openshell`, `/usr/libexec/openshell`, or `/usr/local/libexec`. For GPU passthrough (VFIO), pass `-- --gpu` and run with root privileges: @@ -75,6 +75,9 @@ mise run gateway:vm # custom sandbox image OPENSHELL_SANDBOX_IMAGE=ghcr.io/example/sandbox:latest mise run gateway:vm + +# custom bootstrap image for the VM runtime used to prepare/boot target images +OPENSHELL_VM_BOOTSTRAP_IMAGE=ghcr.io/example/bootstrap:latest mise run gateway:vm ``` Teardown: @@ -104,36 +107,51 @@ codesign \ # 4. Start the gateway with the VM driver mkdir -p /tmp/openshell-vm-driver-$USER-vm-dev .cache/gateway-vm +cat > .cache/gateway-vm/gateway.toml < \ - --grpc-endpoint http://host.containers.internal:18081 \ - --port 18081 \ - --vm-driver-state-dir /tmp/openshell-vm-driver-$USER-vm-dev + --port 18081 ``` -The gateway resolves `openshell-driver-vm` in this order: `--driver-dir`, conventional install locations (`~/.local/libexec/openshell`, `/usr/libexec/openshell`, `/usr/local/libexec/openshell`, `/usr/local/libexec`), then a sibling of the gateway binary. +The gateway resolves `openshell-driver-vm` in this order: `[openshell.drivers.vm].driver_dir`, conventional install locations (`~/.local/libexec/openshell`, `/usr/libexec/openshell`, `/usr/local/libexec/openshell`, `/usr/local/libexec`), then a sibling of the gateway binary. -## Flags +## Gateway And Driver Configuration -| Flag | Env var | Default | Purpose | -|---|---|---|---| -| `--drivers vm` | `OPENSHELL_DRIVERS` | `kubernetes` | Select the VM compute driver. | -| `--grpc-endpoint URL` | `OPENSHELL_GRPC_ENDPOINT` | — | Required. URL the sandbox guest dials to reach the gateway. Use `http://host.containers.internal:` (or `host.docker.internal` / `host.openshell.internal`) so traffic flows through gvproxy's host-loopback NAT (HostIP `192.168.127.254` → host `127.0.0.1`). Loopback URLs like `http://127.0.0.1:` are rewritten automatically by the driver. The bare gateway IP (`192.168.127.1`) only carries gvproxy's own services and will not reach host-bound ports. | -| `--vm-driver-state-dir DIR` | `OPENSHELL_VM_DRIVER_STATE_DIR` | `target/openshell-vm-driver` | Per-sandbox rootfs, console logs, and the `compute-driver.sock` UDS. | -| `--driver-dir DIR` | `OPENSHELL_DRIVER_DIR` | unset | Override the directory searched for `openshell-driver-vm`. | -| `--vm-driver-vcpus N` | `OPENSHELL_VM_DRIVER_VCPUS` | `2` | vCPUs per sandbox. | -| `--vm-driver-mem-mib N` | `OPENSHELL_VM_DRIVER_MEM_MIB` | `2048` | Memory per sandbox, in MiB. | -| `--vm-krun-log-level N` | `OPENSHELL_VM_KRUN_LOG_LEVEL` | `1` | libkrun verbosity (0–5). | -| `--vm-tls-ca PATH` | `OPENSHELL_VM_TLS_CA` | — | CA cert for the guest's mTLS client bundle. Required when `--grpc-endpoint` uses `https://`. | -| `--vm-tls-cert PATH` | `OPENSHELL_VM_TLS_CERT` | — | Guest client certificate. | -| `--vm-tls-key PATH` | `OPENSHELL_VM_TLS_KEY` | — | Guest client private key. | +Select the VM driver with `--drivers vm`, `OPENSHELL_DRIVERS=vm`, or `compute_drivers = ["vm"]` in `[openshell.gateway]`. Configure VM-specific settings in `[openshell.drivers.vm]`. -See [`openshell-gateway --help`](../openshell-server/src/cli.rs) for the full flag surface shared with the Kubernetes driver. +| Configuration key | Default | Purpose | +|---|---|---| +| `grpc_endpoint` | empty | Required. URL the sandbox guest dials to reach the gateway. Use `http://host.containers.internal:` (or `host.docker.internal` / `host.openshell.internal`) so traffic flows through gvproxy's host-loopback NAT (HostIP `192.168.127.254` → host `127.0.0.1`). Loopback URLs like `http://127.0.0.1:` are rewritten automatically by the driver. The bare gateway IP (`192.168.127.1`) only carries gvproxy's own services and will not reach host-bound ports. | +| `state_dir` | `target/openshell-vm-driver` | Per-sandbox overlay disks, console logs, image cache, and private `run/compute-driver.sock` UDS. | +| `driver_dir` | unset | Override the directory searched for `openshell-driver-vm`. | +| `default_image` | OpenShell base image | Sandbox image used when a create request omits one. | +| `bootstrap_image` | unset | VM runtime image used as the immutable bootstrap root disk. Defaults to the sandbox image when unset. | +| `vcpus` | `2` | vCPUs per sandbox. | +| `mem_mib` | `2048` | Memory per sandbox, in MiB. | +| `overlay_disk_mib` | `4096` | Sparse writable overlay disk size per sandbox, in MiB. | +| `krun_log_level` | `1` | libkrun verbosity (0-5). | +| `guest_tls_ca` | unset | CA cert for the guest's mTLS client bundle. Required when `grpc_endpoint` uses `https://`. | +| `guest_tls_cert` | unset | Guest client certificate. | +| `guest_tls_key` | unset | Guest client private key. | + +See [`openshell-gateway --help`](../openshell-server/src/cli.rs) for the gateway process flag surface. ## Verifying the gateway @@ -145,7 +163,38 @@ The gateway is auto-registered by `mise run gateway:vm`. In another terminal: ./scripts/bin/openshell sandbox connect demo ``` -First sandbox takes 10–30 seconds to boot (image fetch/prepare/cache + libkrun + guest init). If `--from` is omitted, the VM driver uses the gateway's configured default sandbox image. Without either `--from` or `--sandbox-image`, VM sandbox creation fails. Subsequent creates reuse the prepared sandbox rootfs. +First sandbox takes 10–30 seconds to boot (image fetch/prepare/cache + libkrun + guest init). If `--from` is omitted, the VM driver uses the gateway's configured default sandbox image. Without either `--from` or `--sandbox-image`, VM sandbox creation fails. Subsequent creates reuse the prepared image cache and create only a sparse per-sandbox `overlay.ext4` before boot. + +`CreateSandbox` accepts the sandbox quickly and continues VM provisioning in the +background. The driver publishes platform events for image resolution, cache +hits/misses, layer pulls, rootfs preparation, overlay creation, and VM launcher +startup so the CLI can show progress through the existing sandbox watch stream. + +The VM driver keeps two image caches. The bootstrap cache is a controlled +`rootfs.ext4` used to boot the guest init and OpenShell supervisor. The prepared +image cache is used when the requested sandbox image differs from the bootstrap +image: the host downloads registry layers into a valid OCI layout, attaches that +payload to a temporary bootstrap VM, and guest init runs `umoci raw unpack` onto +Linux-owned ext4 storage. The resulting disk is cached under +`/images//rootfs.ext4` and attached read-only to later +sandboxes. Local Docker images are still exported as rootfs tar archives and +prepared inside the bootstrap VM. Set `OPENSHELL_VM_IMAGE_PULL_CONCURRENCY` to +tune registry layer download parallelism (default `4`, maximum `16`). + +Each sandbox gets its own sparse writable +`/sandboxes//overlay.ext4`. Guest init mounts overlayfs as `/` +with the prepared image rootfs as lowerdir when present, otherwise the bootstrap +rootfs is used directly. Writes to `/sandbox` and other mutable paths land in +the overlay while cached image disks remain unchanged. The overlay disk must be +large enough to hold the compressed payload, unpacked rootfs, and sandbox writes +during the first prepare. + +The driver also writes the accepted `DriverSandbox` launch request to +`/sandboxes//sandbox.pb`. If the gateway restarts, it starts a +new VM driver process; that process scans the sandbox state directories, +restarts each persisted VM launcher, and preserves any existing `overlay.ext4` +instead of cloning a fresh overlay template. If a restart happened before the +overlay was created, the driver creates it during the resume attempt. ## Logs and debugging @@ -156,19 +205,42 @@ RUST_LOG=openshell_server=debug,openshell_driver_vm=debug \ mise run gateway:vm ``` -The VM guest's serial console is appended to `//console.log`. The `compute-driver.sock` lives at `/compute-driver.sock`; the gateway removes it on clean shutdown via `ManagedDriverProcess::drop`. +The VM guest's serial console is appended to `//console.log`. Sandbox IDs must match `[A-Za-z0-9._-]{1,128}` before the driver uses them in host paths. The gateway-owned compute-driver socket lives at `/run/compute-driver.sock`; OpenShell creates `run/` with owner-only permissions, removes same-owner stale sockets, and the gateway removes the socket on clean shutdown via `ManagedDriverProcess::drop`. UDS clients must match the driver UID and provide the expected gateway process PID by default. Standalone same-UID UDS mode requires the explicit `--allow-same-uid-peer` development flag. TCP mode is disabled by default because it is unauthenticated; use `--allow-unauthenticated-tcp --bind-address 127.0.0.1:50061` only for local development. + +## Host-side nftables rules + +The VM driver creates a per-VM nftables table on the host (`openshell_vm_vmtap_`) with three chains. These rules serve two purposes: NAT infrastructure (required for VM connectivity) and defense-in-depth host isolation. Primary security enforcement — proxy-only egress and bypass detection — is handled by the sandbox supervisor's own nftables rules inside the VM guest. + +**`postrouting` (NAT):** Masquerades outbound VM traffic so it can be routed from the VM's private subnet to the external network. This chain handles forwarded traffic (VM → internet), not traffic destined for the host. + +**`forward` (defense-in-depth):** Accepts all outbound traffic from the VM (security enforcement happens guest-side) and accepts established/related response traffic back to the VM. Drops unsolicited inbound connections to the VM from the broader network. This chain handles forwarded traffic only — packets transiting the host between the TAP interface and other interfaces. + +**`input` (defense-in-depth):** Accepts traffic from the VM to the gateway port on the host. Drops all other traffic from the VM destined for the host itself. This limits what a compromised guest can reach on the host to the gateway service only. + +The `input` and `postrouting` chains handle different traffic paths: `input` covers packets addressed to the host (VM → host), while `postrouting` covers packets the host is forwarding on behalf of the VM (VM → internet). A packet from the VM goes through one path or the other, never both. + +All chains use `policy accept`, so non-TAP traffic is unaffected. Because nftables evaluates multiple base chains on the same hook independently, host firewalls interact with these rules as follows: + +- **Open host (no other firewall):** Our chains are the only filter. The defense-in-depth drop rules block unsolicited inbound and non-gateway host access. Non-TAP traffic passes through. +- **Restrictive host firewall (e.g. firewalld):** The host firewall's chains may additionally drop TAP traffic that our chains accept. A `drop` verdict from any chain is final — our `accept` cannot override it. If VM connectivity fails, verify that the host firewall allows forwarding and input for `vmtap-*` interfaces. + +Each table is created atomically via `nft -f` on VM start and torn down atomically via `nft delete table` when the VM is destroyed. ## Prerequisites - macOS on Apple Silicon, or Linux on aarch64/x86_64 with KVM - Rust toolchain +- e2fsprogs (`mke2fs` or `mkfs.ext4`, plus `debugfs`) for root and overlay disk image creation and QEMU environment injection - Guest-supervisor cross-compile toolchain (needed on macOS, and on Linux when host arch ≠ guest arch): - Matching rustup target: `rustup target add aarch64-unknown-linux-gnu` (or `x86_64-unknown-linux-gnu` for an amd64 guest) - `cargo install --locked cargo-zigbuild` and `brew install zig` (or distro equivalent). `vm:supervisor` uses `cargo zigbuild` to cross-compile the in-VM `openshell-sandbox` supervisor binary. - [mise](https://mise.jdx.dev/) task runner -- Docker-compatible socket on the local CLI/gateway host when using +- Docker or Podman socket on the local CLI/gateway host when using `openshell sandbox create --from ./Dockerfile` or `--from ./dir`; the CLI - builds the image and the VM driver exports it via the local Docker daemon + builds the image and the VM driver exports it via the local container engine. + Docker is tried first; if unavailable, the driver falls back to the Podman + socket. On Linux, enable the Podman API socket with + `systemctl --user start podman.socket` - `gh` CLI (used by `mise run vm:setup` to download pre-built runtime artifacts) ## Releases @@ -180,13 +252,17 @@ The VM guest's serial console is appended to `//console.l - runtime tarballs: the rolling `vm-runtime` release, rebuilt on demand by `release-vm-kernel.yml` -On Linux amd64 and arm64, `install-dev.sh` installs the Debian package from the -selected `OPENSHELL_VERSION` release tag. That package includes -`openshell-gateway` and `openshell-driver-vm`, but leaves +On Debian-family Linux amd64 and arm64 systems, `install.sh` installs the +Debian package from the selected `OPENSHELL_VERSION` release tag. That package +includes `openshell-gateway` and `openshell-driver-vm`, but leaves `OPENSHELL_DRIVERS` unset so the gateway uses its normal runtime auto-detection. Set `OPENSHELL_DRIVERS=vm` to force the VM driver. -On Apple Silicon macOS, `install-dev.sh` stages the generated `openshell.rb` +On RPM-family Linux x86_64 and aarch64 systems, `install.sh` installs the +`openshell` and `openshell-gateway` RPM packages from the selected release tag. +The RPM gateway package is configured for the Podman driver. + +On Apple Silicon macOS, `install.sh` stages the generated `openshell.rb` formula from the selected release in the `nvidia/openshell` Homebrew tap. Homebrew installs `openshell`, `openshell-gateway`, and `openshell-driver-vm`, ad-hoc signs the driver with the Hypervisor entitlement diff --git a/crates/openshell-driver-vm/build.rs b/crates/openshell-driver-vm/build.rs index 6ea845dc9..176359054 100644 --- a/crates/openshell-driver-vm/build.rs +++ b/crates/openshell-driver-vm/build.rs @@ -21,6 +21,7 @@ fn main() { "libkrunfw.5.dylib.zst", "gvproxy.zst", "openshell-sandbox.zst", + "umoci.zst", ] { println!("cargo:rerun-if-changed={dir}/{name}"); } @@ -35,7 +36,10 @@ fn main() { "linux" => ("libkrun.so", "libkrunfw.so.5"), _ => { println!("cargo:warning=VM runtime not available for {target_os}-{target_arch}"); - generate_stub_resources(&out_dir, &["libkrun", "libkrunfw", "openshell-sandbox.zst"]); + generate_stub_resources( + &out_dir, + &["libkrun", "libkrunfw", "openshell-sandbox.zst", "umoci.zst"], + ); return; } }; @@ -52,6 +56,7 @@ fn main() { &format!("{libkrunfw_name}.zst"), "gvproxy.zst", "openshell-sandbox.zst", + "umoci.zst", ], ); return; @@ -70,6 +75,7 @@ fn main() { &format!("{libkrunfw_name}.zst"), "gvproxy.zst", "openshell-sandbox.zst", + "umoci.zst", ], ); return; @@ -86,6 +92,7 @@ fn main() { "openshell-sandbox.zst".to_string(), "openshell-sandbox.zst".to_string(), ), + ("umoci.zst".to_string(), "umoci.zst".to_string()), ]; let mut all_found = true; @@ -128,6 +135,7 @@ fn main() { &format!("{libkrunfw_name}.zst"), "gvproxy.zst", "openshell-sandbox.zst", + "umoci.zst", ], ); } diff --git a/crates/openshell-driver-vm/runtime/README.md b/crates/openshell-driver-vm/runtime/README.md index 17dc8dab7..11aab67f4 100644 --- a/crates/openshell-driver-vm/runtime/README.md +++ b/crates/openshell-driver-vm/runtime/README.md @@ -11,8 +11,8 @@ runtime/ openshell.kconfig ``` -`openshell-driver-vm` embeds libkrun, libkrunfw, gvproxy, and the bundled -`openshell-sandbox` supervisor. +`openshell-driver-vm` embeds libkrun, libkrunfw, gvproxy, umoci for guest-side +OCI image unpacking, and the bundled `openshell-sandbox` supervisor. ## Why @@ -27,7 +27,7 @@ VM sandboxes can run the same supervisor enforcement path as other backends. |---|---|---| | `tasks/scripts/vm/build-libkrun.sh` | Linux | Builds libkrunfw and libkrun from source with the custom kernel config | | `tasks/scripts/vm/build-libkrun-macos.sh` | macOS | Builds portable libkrunfw and libkrun from a prebuilt `kernel.c` | -| `tasks/scripts/vm/package-vm-runtime.sh` | Any | Packages `vm-runtime-.tar.zst` with libraries, gvproxy, and provenance | +| `tasks/scripts/vm/package-vm-runtime.sh` | Any | Packages `vm-runtime-.tar.zst` with libraries, gvproxy, umoci, and provenance | | `tasks/scripts/vm/download-kernel-runtime.sh` | Any | Downloads runtime tarballs from the `vm-runtime` release and stages compressed files | ## Local Flow @@ -62,8 +62,9 @@ publish the driver binary next to `openshell-gateway`. ## Provenance `package-vm-runtime.sh` writes `provenance.json` into each runtime tarball with -the platform, libkrunfw commit, kernel version, GitHub SHA, and build time. The -driver logs this metadata when it extracts and loads a runtime bundle. +the platform, libkrunfw commit, kernel version, gvproxy and umoci versions, +GitHub SHA, and build time. The driver logs this metadata when it extracts and +loads a runtime bundle. The release workflow also publishes GitHub artifact attestations for each runtime tarball. Verify a downloaded runtime with: diff --git a/crates/openshell-driver-vm/runtime/kernel/openshell.kconfig b/crates/openshell-driver-vm/runtime/kernel/openshell.kconfig index b5f0330af..e8d826c53 100644 --- a/crates/openshell-driver-vm/runtime/kernel/openshell.kconfig +++ b/crates/openshell-driver-vm/runtime/kernel/openshell.kconfig @@ -8,6 +8,13 @@ # # See also: check-vm-capabilities.sh for runtime verification. +# ── Root disk transport and filesystem ───────────────────────────────── +CONFIG_BLOCK=y +CONFIG_BLK_DEV=y +CONFIG_VIRTIO_BLK=y +CONFIG_EXT4_FS=y +CONFIG_EXT4_USE_FOR_EXT2=y + # ── Network Namespaces (required for pod isolation) ───────────────────── CONFIG_NET_NS=y CONFIG_NAMESPACES=y @@ -79,6 +86,7 @@ CONFIG_NFT_NUMGEN=y CONFIG_NFT_FIB_IPV4=y CONFIG_NFT_FIB_IPV6=y CONFIG_NFT_LIMIT=y +CONFIG_NFT_LOG=y CONFIG_NFT_REDIR=y CONFIG_NFT_TPROXY=y diff --git a/crates/openshell-driver-vm/runtime/pins.env b/crates/openshell-driver-vm/runtime/pins.env index b526947df..d05c66e83 100644 --- a/crates/openshell-driver-vm/runtime/pins.env +++ b/crates/openshell-driver-vm/runtime/pins.env @@ -33,6 +33,10 @@ COMMUNITY_SANDBOX_IMAGE="${COMMUNITY_SANDBOX_IMAGE:-ghcr.io/nvidia/openshell-com # Repo: https://github.com/containers/gvisor-tap-vsock GVPROXY_VERSION="${GVPROXY_VERSION:-v0.8.8}" +# ── umoci (guest OCI unpacker) ────────────────────────────────────────── +# Repo: https://github.com/opencontainers/umoci +UMOCI_VERSION="${UMOCI_VERSION:-v0.6.0}" + # ── libkrunfw upstream (commit-pinned) ───────────────────────────────── # Repo: https://github.com/containers/libkrunfw # Pinned: 2026-03-27 (main branch HEAD at time of pinning) diff --git a/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh b/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh index b61fd4900..8725984f9 100644 --- a/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh +++ b/crates/openshell-driver-vm/scripts/openshell-vm-sandbox-init.sh @@ -9,12 +9,6 @@ set -euo pipefail -# Source QEMU-injected environment variables if present. -if [ -f /srv/openshell-env.sh ]; then - # shellcheck source=/dev/null - source /srv/openshell-env.sh -fi - BOOT_START=$(date +%s%3N 2>/dev/null || date +%s) # gvisor-tap-vsock subnet layout: # 192.168.127.1 — gateway: gvproxy's DNS / DHCP / HTTP API. Does NOT @@ -23,14 +17,15 @@ BOOT_START=$(date +%s%3N 2>/dev/null || date +%s) # gvproxy's TCP/UDP/ICMP forwarder. Use this address # (or any of the host.* hostnames below) to reach a # service the host is listening on. -# The host.containers.internal / host.docker.internal DNS records served -# by gvproxy's embedded resolver point at 192.168.127.254. We mirror that -# in /etc/hosts so the supervisor can reach the gateway even when -# gvproxy's DNS is not in resolv.conf (e.g. DHCP failed and we fell -# back to 8.8.8.8). +# The host.openshell.internal / host.containers.internal / +# host.docker.internal DNS records served by gvproxy's embedded resolver +# point at 192.168.127.254. We mirror that in /etc/hosts so the supervisor +# can reach the gateway even when gvproxy's DNS is not in resolv.conf +# (e.g. DHCP failed and we fell back to 8.8.8.8). GVPROXY_GATEWAY_IP="192.168.127.1" GVPROXY_HOST_LOOPBACK_IP="192.168.127.254" GATEWAY_IP="$GVPROXY_GATEWAY_IP" +SANDBOX_OWNER_NORMALIZED_MARKER="/opt/openshell/.sandbox-owner-normalized" GPU_ENABLED="${GPU_ENABLED:-false}" VM_NET_IP="${VM_NET_IP:-}" @@ -44,6 +39,251 @@ ts() { printf "[%d.%03ds] %s\n" $((elapsed / 1000)) $((elapsed % 1000)) "$*" } +mount_initial_fs() { + mount -t proc proc /proc 2>/dev/null || true + mount -t sysfs sysfs /sys 2>/dev/null || true + mount -t tmpfs tmpfs /tmp 2>/dev/null || true + mount -t tmpfs tmpfs /run 2>/dev/null || true + mount -t devtmpfs devtmpfs /dev 2>/dev/null || true +} + +bind_mount_into_newroot() { + local source="$1" + local target="/newroot${source}" + + mkdir -p "$target" 2>/dev/null || true + mount --rbind "$source" "$target" 2>/dev/null \ + || mount --bind "$source" "$target" 2>/dev/null \ + || true +} + +root_path() { + local path="$1" + printf '%s%s\n' "${ROOT_PREFIX:-}" "$path" +} + +sandbox_owner() { + sandbox_owner_from_passwd "$(root_path /etc/passwd)" +} + +sandbox_owner_for_root() { + local root="$1" + sandbox_owner_from_passwd "$root/etc/passwd" +} + +sandbox_owner_from_passwd() { + local passwd_path name uid gid rest + passwd_path="$1" + if [ -f "$passwd_path" ]; then + while IFS=: read -r name _ uid gid rest; do + _="${rest:-}" + if [ "$name" = "sandbox" ] \ + && [[ "$uid" =~ ^[0-9]+$ ]] \ + && [[ "$gid" =~ ^[0-9]+$ ]]; then + printf '%s:%s\n' "$uid" "$gid" + return + fi + done < "$passwd_path" + fi + + printf '10001:10001\n' +} + +source_overlay_env_if_present() { + local env_file="/overlay/upper/srv/openshell-env.sh" + if [ -f "$env_file" ]; then + # shellcheck source=/dev/null + source "$env_file" + fi +} + +ensure_target_runtime() { + local image_root="$1" + + mkdir -p \ + "$image_root/srv" \ + "$image_root/opt/openshell/bin" \ + "$image_root/sandbox" \ + "$image_root/etc" + + cp /srv/openshell-vm-sandbox-init.sh "$image_root/srv/openshell-vm-sandbox-init.sh" + chmod 0755 "$image_root/srv/openshell-vm-sandbox-init.sh" + + if [ -x /opt/openshell/bin/openshell-sandbox ]; then + cp /opt/openshell/bin/openshell-sandbox "$image_root/opt/openshell/bin/openshell-sandbox" + chmod 0755 "$image_root/opt/openshell/bin/openshell-sandbox" + fi + + touch "$image_root/etc/passwd" "$image_root/etc/group" "$image_root/etc/shadow" "$image_root/etc/gshadow" + if ! grep -q '^sandbox:' "$image_root/etc/group" 2>/dev/null; then + printf 'sandbox:x:10001:\n' >> "$image_root/etc/group" + fi + if ! grep -q '^sandbox:' "$image_root/etc/gshadow" 2>/dev/null; then + printf 'sandbox:!::\n' >> "$image_root/etc/gshadow" + fi + if ! grep -q '^sandbox:' "$image_root/etc/passwd" 2>/dev/null; then + printf 'sandbox:x:10001:10001:OpenShell Sandbox:/sandbox:/bin/sh\n' >> "$image_root/etc/passwd" + fi + if ! grep -q '^sandbox:' "$image_root/etc/shadow" 2>/dev/null; then + printf 'sandbox:!:20123:0:99999:7:::\n' >> "$image_root/etc/shadow" + fi + local owner + local owner_normalized=0 + owner="$(sandbox_owner_for_root "$image_root")" + if chown -R "$owner" "$image_root/sandbox" 2>/dev/null; then + owner_normalized=1 + elif chown -R 10001:10001 "$image_root/sandbox" 2>/dev/null; then + owner_normalized=1 + fi + chmod 0755 "$image_root/sandbox" + if [ "$owner_normalized" -eq 1 ]; then + mkdir -p "$image_root/opt/openshell" + printf '1\n' > "$image_root${SANDBOX_OWNER_NORMALIZED_MARKER}" + fi +} + +prepare_guest_image_rootfs() { + local payload_dir="/overlay/config/openshell-image" + local image_root="/overlay/image-rootfs" + local partial_root="/overlay/image-rootfs.partial" + local source + + [ -d "$payload_dir" ] || return 0 + + source="$(cat "$payload_dir/source" 2>/dev/null || true)" + ts "preparing sandbox image rootfs in guest (${source:-unknown})" + + rm -rf "$image_root" "$partial_root" + + case "$source" in + local-docker) + mkdir -p "$image_root" + tar -xpf "$payload_dir/source-rootfs.tar" -C "$image_root" + ;; + oci-layout) + if [ ! -x /opt/openshell/bin/umoci ]; then + ts "FATAL: umoci not found in VM bootstrap image" + exit 1 + fi + /opt/openshell/bin/umoci raw unpack \ + --image "$payload_dir/oci:openshell" \ + "$partial_root" + if [ ! -d "$partial_root/rootfs" ]; then + ts "FATAL: umoci unpack did not produce rootfs directory" + exit 1 + fi + mv "$partial_root/rootfs" "$image_root" + rm -rf "$partial_root" + ;; + *) + ts "FATAL: unknown guest image payload source: ${source:-missing}" + exit 1 + ;; + esac + + ensure_target_runtime "$image_root" + if [ -f "$payload_dir/identity" ]; then + cp "$payload_dir/identity" "$image_root/.openshell-rootfs-variant" + fi + rm -rf "$payload_dir" +} + +exec_supervisor_in_newroot() { + local chroot_bin + local bootstrap="/.openshell-bootstrap" + local supervisor="${bootstrap}/opt/openshell/bin/openshell-sandbox" + local loader + local lib_path + + for chroot_bin in /usr/sbin/chroot /usr/bin/chroot /sbin/chroot /bin/chroot; do + [ -x "$chroot_bin" ] || continue + + if [ -x "/newroot${supervisor}" ]; then + for loader in \ + "${bootstrap}/lib/x86_64-linux-gnu/ld-linux-x86-64.so.2" \ + "${bootstrap}/usr/lib/x86_64-linux-gnu/ld-linux-x86-64.so.2" \ + "${bootstrap}/lib64/ld-linux-x86-64.so.2" \ + "${bootstrap}/lib/ld-linux-x86-64.so.2" \ + "${bootstrap}/lib/aarch64-linux-gnu/ld-linux-aarch64.so.1" \ + "${bootstrap}/usr/lib/aarch64-linux-gnu/ld-linux-aarch64.so.1" \ + "${bootstrap}/lib/ld-linux-aarch64.so.1" \ + "${bootstrap}/lib64/ld-linux-aarch64.so.1"; do + if [ -x "/newroot${loader}" ]; then + lib_path="${bootstrap}/lib:${bootstrap}/lib64:${bootstrap}/usr/lib:${bootstrap}/usr/lib64:${bootstrap}/lib/aarch64-linux-gnu:${bootstrap}/lib/x86_64-linux-gnu:${bootstrap}/usr/lib/aarch64-linux-gnu:${bootstrap}/usr/lib/x86_64-linux-gnu" + exec "$chroot_bin" /newroot "$loader" --library-path "$lib_path" "$supervisor" --workdir /sandbox + fi + done + exec "$chroot_bin" /newroot "$supervisor" --workdir /sandbox + fi + + if [ -x /newroot/opt/openshell/bin/openshell-sandbox ]; then + exec "$chroot_bin" /newroot /opt/openshell/bin/openshell-sandbox --workdir /sandbox + fi + done + + ts "FATAL: unable to exec openshell-sandbox in guest rootfs" + exit 1 +} + +setup_overlay_root() { + ts "setting up writable overlay root" + mount_initial_fs + + if [ ! -b /dev/vdb ]; then + ts "FATAL: writable overlay disk /dev/vdb not found" + exit 1 + fi + + mkdir -p /overlay /lower /newroot /image-cache + mount -o remount,ro / 2>/dev/null || true + mount -t ext4 -o rw /dev/vdb /overlay + mkdir -p /overlay/upper /overlay/work + source_overlay_env_if_present + + if [ "${OPENSHELL_VM_INIT_MODE:-sandbox}" = "image-prep" ]; then + prepare_guest_image_rootfs + sync + ts "image-prep complete" + exit 0 + fi + + mount --bind / /lower + mount -o remount,bind,ro /lower 2>/dev/null || true + + local lower_root="/lower" + if [ -b /dev/vdc ]; then + mount -t ext4 -o ro /dev/vdc /image-cache + if [ -d /image-cache/image-rootfs ]; then + lower_root="/image-cache/image-rootfs" + ts "using prepared image rootfs lowerdir" + else + ts "FATAL: prepared image disk missing /image-rootfs" + exit 1 + fi + fi + + mount -t overlay overlay \ + -o lowerdir="$lower_root",upperdir=/overlay/upper,workdir=/overlay/work \ + /newroot + mkdir -p /newroot/.openshell-bootstrap + mount --bind /lower /newroot/.openshell-bootstrap + + # GPU setup runs against the bootstrap runtime and its mounted /dev, /proc, + # and /run before those filesystems are mirrored into the target root. + if [ "${GPU_ENABLED}" = "true" ]; then + setup_gpu || ts "WARNING: GPU init failed; continuing without GPU" + fi + + bind_mount_into_newroot /proc + bind_mount_into_newroot /sys + bind_mount_into_newroot /tmp + bind_mount_into_newroot /dev + bind_mount_into_newroot /run + + ROOT_PREFIX="/newroot" + run_post_overlay_setup +} + parse_endpoint() { local endpoint="$1" local scheme rest authority path host port @@ -110,13 +350,23 @@ ensure_host_gateway_aliases() { # gateway IP only listens on gvproxy's own service ports (DNS:53, DHCP, # HTTP API:80). Pinning host.containers.internal to the gateway IP # silently breaks guest→host port reachability for arbitrary ports. - local hosts_tmp="/tmp/openshell-hosts.$$" local host_aliases="host.openshell.internal host.containers.internal host.docker.internal" local gateway_aliases="gateway.containers.internal" local filter='(^|[[:space:]])(host\.openshell\.internal|host\.containers\.internal|host\.docker\.internal|gateway\.containers\.internal)([[:space:]]|$)' - if [ -f /etc/hosts ]; then - grep -vE "$filter" /etc/hosts > "$hosts_tmp" || true + write_host_gateway_aliases "$(root_path /etc/hosts)" "$(root_path "/tmp/openshell-hosts.$$.tmp")" || true + if [ -n "${ROOT_PREFIX:-}" ]; then + write_host_gateway_aliases "/etc/hosts" "/tmp/openshell-hosts.$$.tmp" || true + fi +} + +write_host_gateway_aliases() { + local hosts_path="$1" + local hosts_tmp="$2" + mkdir -p "$(dirname "$hosts_path")" 2>/dev/null || true + mkdir -p "$(dirname "$hosts_tmp")" 2>/dev/null || true + if [ -f "$hosts_path" ]; then + grep -vE "$filter" "$hosts_path" > "$hosts_tmp" || true else : > "$hosts_tmp" fi @@ -133,7 +383,11 @@ ensure_host_gateway_aliases() { # TAP networking: gateway and host are both reachable at GATEWAY_IP. printf '%s %s %s\n' "$GATEWAY_IP" "$host_aliases" "$gateway_aliases" >> "$hosts_tmp" fi - cat "$hosts_tmp" > /etc/hosts + if ! cat "$hosts_tmp" > "$hosts_path" 2>/dev/null; then + rm -f "$hosts_tmp" + ts "WARNING: could not update ${hosts_path}" + return 1 + fi rm -f "$hosts_tmp" } @@ -165,7 +419,12 @@ rewrite_openshell_endpoint_if_needed() { if [ "${GATEWAY_IP}" != "${GVPROXY_GATEWAY_IP}" ]; then fallback_ip="$GATEWAY_IP" fi - for candidate in host.openshell.internal host.containers.internal host.docker.internal "$fallback_ip"; do + local candidates="host.openshell.internal host.containers.internal host.docker.internal" + if [ "$scheme" != "https" ]; then + candidates="${candidates} ${fallback_ip}" + fi + + for candidate in $candidates; do if [ "$candidate" = "$host" ]; then continue fi @@ -181,6 +440,11 @@ rewrite_openshell_endpoint_if_needed() { fi done + if [ "$scheme" = "https" ]; then + ts "WARNING: could not preflight HTTPS OpenShell endpoint ${host}:${port}; preserving hostname for TLS verification" + return 0 + fi + ts "WARNING: could not reach OpenShell endpoint ${host}:${port}" } @@ -239,7 +503,8 @@ setup_gpu() { return 1 fi - # Stage GSP firmware from virtiofs to tmpfs to avoid slow FUSE reads + # Stage GSP firmware to tmpfs so module loading reads it from a stable + # early-boot path. if [ -d /lib/firmware/nvidia ]; then ts "staging GPU firmware to tmpfs" mkdir -p /run/firmware/nvidia @@ -273,26 +538,74 @@ setup_gpu() { fi } -mount -t proc proc /proc 2>/dev/null & -mount -t sysfs sysfs /sys 2>/dev/null & -mount -t tmpfs tmpfs /tmp 2>/dev/null & -mount -t tmpfs tmpfs /run 2>/dev/null & -mount -t devtmpfs devtmpfs /dev 2>/dev/null & -wait - -mkdir -p /dev/pts /dev/shm /sys/fs/cgroup -mount -t devpts devpts /dev/pts 2>/dev/null & -mount -t tmpfs tmpfs /dev/shm 2>/dev/null & -mount -t cgroup2 cgroup2 /sys/fs/cgroup 2>/dev/null & -wait - -hostname openshell-sandbox-vm 2>/dev/null || true -ip link set lo up 2>/dev/null || true - -# GPU initialization (before networking so nvidia-smi output is visible early) -if [ "${GPU_ENABLED}" = "true" ]; then - setup_gpu || ts "WARNING: GPU init failed; continuing without GPU" -fi +setup_sandbox_workdir() { + local sandbox_dir + local owner + local current_owner + sandbox_dir="$(root_path /sandbox)" + owner="$(sandbox_owner)" + mkdir -p "$sandbox_dir" + current_owner="$(stat -c '%u:%g' "$sandbox_dir" 2>/dev/null || true)" + if [ "$current_owner" != "$owner" ] \ + || [ ! -f "$(root_path "$SANDBOX_OWNER_NORMALIZED_MARKER")" ]; then + if ! chown -R "$owner" "$sandbox_dir" 2>/dev/null; then + chown -R 10001:10001 "$sandbox_dir" + fi + fi + chmod 0755 "$sandbox_dir" + ts "prepared /sandbox ownership (${owner})" +} + +configure_hostname() { + local sandbox_hostname="${OPENSHELL_SANDBOX:-openshell-sandbox-vm}" + sandbox_hostname="$(printf '%s' "$sandbox_hostname" | tr -c 'A-Za-z0-9.-' '-')" + sandbox_hostname="$(printf '%s' "$sandbox_hostname" | sed 's/^[.-][.-]*//; s/[.-][.-]*$//')" + sandbox_hostname="$(printf '%.63s' "$sandbox_hostname")" + if [ -z "$sandbox_hostname" ]; then + sandbox_hostname="openshell-sandbox-vm" + fi + + hostname "$sandbox_hostname" 2>/dev/null || true + printf '%s\n' "$sandbox_hostname" >"$(root_path /etc/hostname)" 2>/dev/null || true + ts "hostname=${sandbox_hostname}" +} + +run_post_overlay_setup() { + # Source QEMU-injected environment variables if present. The file lives in + # the overlay upperdir so the cached bootstrap rootfs remains immutable. + local env_file + env_file="$(root_path /srv/openshell-env.sh)" + if [ -f "$env_file" ]; then + # shellcheck source=/dev/null + source "$env_file" + fi + + if [ -z "${ROOT_PREFIX:-}" ]; then + mount -t proc proc /proc 2>/dev/null & + mount -t sysfs sysfs /sys 2>/dev/null & + mount -t tmpfs tmpfs /tmp 2>/dev/null & + mount -t tmpfs tmpfs /run 2>/dev/null & + mount -t devtmpfs devtmpfs /dev 2>/dev/null & + wait + fi + + mkdir -p "$(root_path /dev/pts)" "$(root_path /dev/shm)" "$(root_path /sys/fs/cgroup)" + mount -t devpts devpts "$(root_path /dev/pts)" 2>/dev/null & + mount -t tmpfs tmpfs "$(root_path /dev/shm)" 2>/dev/null & + mount -t cgroup2 cgroup2 "$(root_path /sys/fs/cgroup)" 2>/dev/null & + wait + + # Allow nftables LOG rules to work in non-init network namespaces. + # Without this, the kernel's nf_log_syslog silently suppresses output + # from the sandbox's network namespace. + if [ -f /proc/sys/net/netfilter/nf_log_all_netns ]; then + echo 1 > /proc/sys/net/netfilter/nf_log_all_netns 2>/dev/null || true + fi + + setup_sandbox_workdir + + configure_hostname + ip link set lo up 2>/dev/null || true # Networking: use TAP static config if VM_NET_IP is set (QEMU path), # otherwise fall back to gvproxy DHCP on eth0 (libkrun path). @@ -335,10 +648,10 @@ if [ -n "${VM_NET_IP}" ] && [ -n "${VM_NET_GW}" ]; then fi if [ -n "${VM_NET_DNS}" ]; then - echo "nameserver ${VM_NET_DNS}" > /etc/resolv.conf - elif [ ! -s /etc/resolv.conf ]; then - echo "nameserver 8.8.8.8" > /etc/resolv.conf - echo "nameserver 8.8.4.4" >> /etc/resolv.conf + echo "nameserver ${VM_NET_DNS}" > "$(root_path /etc/resolv.conf)" + elif [ ! -s "$(root_path /etc/resolv.conf)" ]; then + echo "nameserver 8.8.8.8" > "$(root_path /etc/resolv.conf)" + echo "nameserver 8.8.4.4" >> "$(root_path /etc/resolv.conf)" fi ensure_host_gateway_aliases @@ -347,10 +660,9 @@ elif ip link show eth0 >/dev/null 2>&1; then ip link set eth0 up 2>/dev/null || true if command -v udhcpc >/dev/null 2>&1; then - UDHCPC_SCRIPT="/usr/share/udhcpc/default.script" - if [ ! -f "$UDHCPC_SCRIPT" ]; then - UDHCPC_SCRIPT="/run/openshell-udhcpc.script" - cat > "$UDHCPC_SCRIPT" <<'DHCP_SCRIPT' + UDHCPC_SCRIPT="$(root_path /run/openshell-udhcpc.script)" + mkdir -p "$(dirname "$UDHCPC_SCRIPT")" + cat > "$UDHCPC_SCRIPT" <<'DHCP_SCRIPT' #!/bin/sh case "$1" in bound|renew) @@ -360,18 +672,20 @@ case "$1" in ip route add default via "$router" dev "$interface" fi if [ -n "$dns" ]; then - : > /etc/resolv.conf + resolv_conf="${OPENSHELL_RESOLV_CONF:-/etc/resolv.conf}" + mkdir -p "$(dirname "$resolv_conf")" 2>/dev/null || true + : > "$resolv_conf" 2>/dev/null || true for d in $dns; do - echo "nameserver $d" >> /etc/resolv.conf + echo "nameserver $d" >> "$resolv_conf" 2>/dev/null || true done fi ;; esac DHCP_SCRIPT - chmod +x "$UDHCPC_SCRIPT" - fi + chmod +x "$UDHCPC_SCRIPT" - if ! udhcpc -i eth0 -f -q -n -T 1 -t 3 -A 1 -s "$UDHCPC_SCRIPT" 2>&1; then + if ! OPENSHELL_RESOLV_CONF="$(root_path /etc/resolv.conf)" \ + udhcpc -i eth0 -f -q -n -T 1 -t 3 -A 1 -s "$UDHCPC_SCRIPT" 2>&1; then ts "WARNING: DHCP failed, falling back to static config" ip addr add 192.168.127.2/24 dev eth0 2>/dev/null || true ip route add default via "$GVPROXY_GATEWAY_IP" 2>/dev/null || true @@ -382,9 +696,9 @@ DHCP_SCRIPT ip route add default via "$GVPROXY_GATEWAY_IP" 2>/dev/null || true fi - if [ ! -s /etc/resolv.conf ]; then - echo "nameserver 8.8.8.8" > /etc/resolv.conf - echo "nameserver 8.8.4.4" >> /etc/resolv.conf + if [ ! -s "$(root_path /etc/resolv.conf)" ]; then + echo "nameserver 8.8.8.8" > "$(root_path /etc/resolv.conf)" + echo "nameserver 8.8.4.4" >> "$(root_path /etc/resolv.conf)" fi ensure_host_gateway_aliases @@ -395,6 +709,26 @@ fi export HOME=/sandbox export USER=sandbox +# Fix /sandbox ownership. The host-side CLI extracts OCI layers as a non-root +# user (e.g. UID 501 on macOS), so /sandbox may be owned by the host UID. +# +# On macOS (Hypervisor.framework), guest root has real root privileges and +# chown succeeds. On Linux non-root hosts with virtiofs, guest root maps to +# the host user, so chown is denied — this is non-fatal because the +# supervisor's own filesystem preparation handles the paths that matter. +if [ -d /sandbox ]; then + _sb_uid=$(id -u sandbox 2>/dev/null || true) + _sb_gid=$(id -g sandbox 2>/dev/null || true) + if [ -n "$_sb_uid" ] && [ -n "$_sb_gid" ]; then + _cur_uid=$(stat -c '%u' /sandbox 2>/dev/null || true) + if [ -n "$_cur_uid" ] && [ "$_cur_uid" != "$_sb_uid" ]; then + ts "fixing /sandbox ownership (was uid=${_cur_uid}, setting to sandbox=${_sb_uid}:${_sb_gid})" + chown -R "${_sb_uid}:${_sb_gid}" /sandbox 2>/dev/null || \ + ts "chown /sandbox denied (virtiofs rootless host), continuing" + fi + fi +fi + rewrite_openshell_endpoint_if_needed # Log supervisor connectivity state for debugging stuck-in-Provisioning issues @@ -416,4 +750,15 @@ if [ -n "${OPENSHELL_SANDBOX_ID:-}" ]; then fi ts "starting openshell-sandbox supervisor" +if [ "${ROOT_PREFIX:-}" = "/newroot" ]; then + exec_supervisor_in_newroot +fi exec /opt/openshell/bin/openshell-sandbox --workdir /sandbox +} + +if [ "${1:-}" != "--post-overlay" ]; then + setup_overlay_root +fi + +shift || true +run_post_overlay_setup diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 92cab23af..f09f1ebc3 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -5,22 +5,29 @@ use crate::gpu::{ GpuInventory, SubnetAllocator, allocate_vsock_cid, mac_from_sandbox_id, tap_device_name, }; use crate::rootfs::{ - create_rootfs_archive_from_dir, extract_rootfs_archive_to, - prepare_sandbox_rootfs_from_image_root, sandbox_guest_init_path, + clone_or_copy_sparse_file, create_ext4_image_from_dir_with_size, create_rootfs_image_from_dir, + extract_rootfs_archive_to, prepare_sandbox_rootfs_from_image_root, sandbox_guest_init_path, + set_rootfs_image_file_mode, write_rootfs_image_file, }; use bollard::Docker; use bollard::errors::Error as BollardError; use bollard::models::ContainerCreateBody; use bollard::query_parameters::{CreateContainerOptionsBuilder, RemoveContainerOptionsBuilder}; use flate2::read::GzDecoder; -use futures::{Stream, StreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use nix::errno::Errno; use nix::sys::signal::{Signal, kill}; use nix::unistd::Pid; use oci_client::client::{Client as OciClient, ClientConfig}; -use oci_client::manifest::{ImageIndexEntry, OciDescriptor}; +use oci_client::manifest::{ + ImageIndexEntry, OCI_IMAGE_MEDIA_TYPE, OciDescriptor, OciImageManifest, +}; use oci_client::secrets::RegistryAuth; use oci_client::{Reference, RegistryOperation}; +use openshell_core::progress::{ + PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, + mark_progress_active, mark_progress_complete, mark_progress_detail, +}; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, @@ -32,13 +39,15 @@ use openshell_core::proto::compute::v1::{ compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_vfio::SysfsRoot; +use prost::Message; use sha2::{Digest, Sha256}; use std::collections::{HashMap, HashSet}; use std::fs; use std::io::Read; use std::net::Ipv4Addr; +#[cfg(unix)] use std::os::unix::fs::PermissionsExt; -use std::path::{Path, PathBuf}; +use std::path::{Component, Path, PathBuf}; use std::pin::Pin; use std::process::Stdio; use std::sync::Arc; @@ -47,6 +56,7 @@ use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::process::{Child, Command}; use tokio::sync::{Mutex, broadcast, mpsc}; +use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; use tracing::{info, warn}; @@ -56,6 +66,9 @@ const DRIVER_NAME: &str = "openshell-driver-vm"; const WATCH_BUFFER: usize = 256; const DEFAULT_VCPUS: u8 = 2; const DEFAULT_MEM_MIB: u32 = 2048; +const DEFAULT_OVERLAY_DISK_MIB: u64 = 4096; +const DEFAULT_REGISTRY_LAYER_DOWNLOAD_CONCURRENCY: usize = 4; +const MAX_REGISTRY_LAYER_DOWNLOAD_CONCURRENCY: usize = 16; /// gvproxy host-loopback IP — gvproxy's TCP/UDP/ICMP forwarder NAT-rewrites /// this destination to the host's `127.0.0.1` and dials out from the host /// process. This is the only address that transparently reaches host-bound @@ -78,21 +91,31 @@ const OPENSHELL_HOST_GATEWAY_ALIAS: &str = "host.openshell.internal"; /// resolves even when gvproxy's DNS is not in resolv.conf; /// * keeping a recognisable hostname makes log messages clearer than a bare /// 192.168.127.254 reference; -/// * `host.docker.internal` works the same way for Docker-flavoured tooling. +/// * package-managed gateway certificates include this SAN for guest mTLS. /// /// Both names ultimately route through the gvproxy NAT path on /// `GVPROXY_HOST_LOOPBACK_IP` — they do **not** go through the gateway IP. -const GVPROXY_HOST_LOOPBACK_ALIAS: &str = "host.containers.internal"; +const GVPROXY_HOST_LOOPBACK_ALIAS: &str = OPENSHELL_HOST_GATEWAY_ALIAS; const GUEST_SSH_SOCKET_PATH: &str = "/run/openshell/ssh.sock"; -const GUEST_TLS_DIR: &str = "/opt/openshell/tls"; const GUEST_TLS_CA_PATH: &str = "/opt/openshell/tls/ca.crt"; const GUEST_TLS_CERT_PATH: &str = "/opt/openshell/tls/tls.crt"; const GUEST_TLS_KEY_PATH: &str = "/opt/openshell/tls/tls.key"; +const GUEST_SANDBOX_TOKEN_PATH: &str = "/opt/openshell/auth/sandbox.jwt"; const IMAGE_CACHE_ROOT_DIR: &str = "images"; -const IMAGE_CACHE_ROOTFS_ARCHIVE: &str = "rootfs.tar"; +const IMAGE_CACHE_ROOTFS_IMAGE: &str = "rootfs.ext4"; +const OVERLAY_TEMPLATE_CACHE_DIR: &str = "overlay-templates"; +const OVERLAY_TEMPLATE_CACHE_LAYOUT_VERSION: &str = "sandbox-overlay-ext4-v1"; +const SANDBOX_OVERLAY_IMAGE: &str = "overlay.ext4"; +const SANDBOX_REQUEST_FILE: &str = "sandbox.pb"; +const GUEST_IMAGE_CONFIG_DIR: &str = "openshell-image"; +const GUEST_IMAGE_OCI_LAYOUT_DIR: &str = "oci"; +const GUEST_IMAGE_OCI_REF: &str = "openshell"; const IMAGE_EXPORT_ROOTFS_ARCHIVE: &str = "source-rootfs.tar"; +const BOOTSTRAP_IMAGE_CACHE_LAYOUT_VERSION: &str = "sandbox-bootstrap-rootfs-ext4-v2"; +const PREPARED_IMAGE_CACHE_LAYOUT_VERSION: &str = "sandbox-prepared-rootfs-ext4-umoci-v2"; const IMAGE_IDENTITY_FILE: &str = "image-identity"; const IMAGE_REFERENCE_FILE: &str = "image-reference"; +const IMAGE_PREP_INIT_MODE: &str = "image-prep"; static IMAGE_CACHE_BUILD_COUNTER: AtomicU64 = AtomicU64::new(0); #[derive(Debug, Clone)] @@ -102,18 +125,45 @@ struct VmDriverTlsPaths { key: PathBuf, } +#[derive(Debug, Clone)] +struct RuntimeImagePlan { + root_disk: PathBuf, + image_disk: Option, + image_identity: String, + bootstrap_image_identity: String, +} + +#[derive(Debug, Clone)] +struct PreparedImageDisk { + image_identity: String, + disk_path: PathBuf, +} + +#[derive(Debug, Clone)] +struct GuestImagePayload { + image_ref: String, + image_identity: String, + source: GuestImagePayloadSource, +} + +#[derive(Debug, Clone)] +enum GuestImagePayloadSource { + RegistryOciLayout { layout_dir: PathBuf }, + LocalDocker { rootfs_archive: PathBuf }, +} + #[derive(Debug, Clone)] pub struct VmDriverConfig { pub openshell_endpoint: String, pub state_dir: PathBuf, pub launcher_bin: Option, pub default_image: String, - pub ssh_handshake_secret: String, - pub ssh_handshake_skew_secs: u64, + pub bootstrap_image: String, pub log_level: String, pub krun_log_level: u32, pub vcpus: u8, pub mem_mib: u32, + pub overlay_disk_mib: u64, pub guest_tls_ca: Option, pub guest_tls_cert: Option, pub guest_tls_key: Option, @@ -129,12 +179,12 @@ impl Default for VmDriverConfig { state_dir: PathBuf::from("target/openshell-vm-driver"), launcher_bin: None, default_image: String::new(), - ssh_handshake_secret: String::new(), - ssh_handshake_skew_secs: 300, + bootstrap_image: String::new(), log_level: "info".to_string(), krun_log_level: 1, vcpus: DEFAULT_VCPUS, mem_mib: DEFAULT_MEM_MIB, + overlay_disk_mib: DEFAULT_OVERLAY_DISK_MIB, guest_tls_ca: None, guest_tls_cert: None, guest_tls_key: None, @@ -224,12 +274,19 @@ struct VmProcess { deleting: bool, } -#[derive(Debug)] struct SandboxRecord { snapshot: Sandbox, state_dir: PathBuf, - process: Arc>, + process: Option>>, + provisioning_task: Option>, gpu_bdf: Option, + deleting: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OverlayPreparation { + Fresh, + PreserveExisting, } #[derive(Clone)] @@ -260,14 +317,12 @@ impl VmDriver { } let state_root = sandboxes_root_dir(&config.state_dir); - tokio::fs::create_dir_all(&state_root) - .await - .map_err(|err| { - format!( - "failed to create state dir '{}': {err}", - state_root.display() - ) - })?; + create_private_dir_all(&state_root).await.map_err(|err| { + format!( + "failed to create state dir '{}': {err}", + state_root.display() + ) + })?; let image_cache_root = image_cache_root_dir(&config.state_dir); tokio::fs::create_dir_all(&image_cache_root) .await @@ -303,7 +358,7 @@ impl VmDriver { ))); let (events, _) = broadcast::channel(WATCH_BUFFER); - Ok(Self { + let driver = Self { config, launcher_bin, registry: Arc::new(Mutex::new(HashMap::new())), @@ -311,7 +366,9 @@ impl VmDriver { events, gpu_inventory, subnet_allocator, - }) + }; + driver.restore_persisted_sandboxes().await; + Ok(driver) } #[must_use] @@ -354,16 +411,7 @@ impl VmDriver { ); validate_vm_sandbox(sandbox, self.config.gpu_enabled)?; - if self.registry.lock().await.contains_key(&sandbox.id) { - return Err(Status::already_exists("sandbox already exists")); - } - - let spec = sandbox.spec.as_ref(); - let is_gpu = spec.is_some_and(|s| s.gpu); - let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); - - let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id); - let rootfs = state_dir.join("rootfs"); + let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id)?; let image_ref = self.resolved_sandbox_image(sandbox).ok_or_else(|| { Status::failed_precondition( "vm sandboxes require template.image or a configured default sandbox image", @@ -373,20 +421,52 @@ impl VmDriver { sandbox_id = %sandbox.id, image_ref = %image_ref, state_dir = %state_dir.display(), - "vm driver: resolved image ref, preparing rootfs" + "vm driver: resolved image ref, preparing disks" ); - tokio::fs::create_dir_all(&state_dir) - .await - .map_err(|err| Status::internal(format!("create state dir failed: {err}")))?; + let snapshot = sandbox_snapshot(sandbox, provisioning_condition(), false); + { + let mut registry = self.registry.lock().await; + if registry.contains_key(&sandbox.id) { + return Err(Status::already_exists("sandbox already exists")); + } + registry.insert( + sandbox.id.clone(), + SandboxRecord { + snapshot: snapshot.clone(), + state_dir: state_dir.clone(), + process: None, + provisioning_task: None, + gpu_bdf: None, + deleting: false, + }, + ); + } + + let tls_paths = match self.config.tls_paths() { + Ok(paths) => paths, + Err(err) => { + let mut registry = self.registry.lock().await; + registry.remove(&sandbox.id); + return Err(Status::failed_precondition(err)); + } + }; + + if let Err(err) = create_private_dir_all(&state_dir).await { + let mut registry = self.registry.lock().await; + registry.remove(&sandbox.id); + return Err(Status::internal(format!("create state dir failed: {err}"))); + } + + if let Err(err) = write_sandbox_request(&state_dir, sandbox).await { + let mut registry = self.registry.lock().await; + registry.remove(&sandbox.id); + let _ = tokio::fs::remove_dir_all(&state_dir).await; + return Err(Status::internal(format!( + "write sandbox resume metadata failed: {err}" + ))); + } - let tls_paths = self - .config - .tls_paths() - .map_err(Status::failed_precondition)?; - // Mirror the K8s `Scheduled` event so the CLI can complete the - // "Requesting sandbox" step and switch the spinner over to the - // image-pull phase before we block on the registry. self.publish_platform_event( sandbox.id.clone(), platform_event( @@ -396,85 +476,173 @@ impl VmDriver { format!("Sandbox accepted by vm driver to image \"{image_ref}\""), ), ); + self.publish_snapshot(snapshot); + + let driver = self.clone(); + let sandbox_for_task = sandbox.clone(); + let sandbox_id = sandbox.id.clone(); + let image_ref_for_task = image_ref.clone(); + let state_dir_for_task = state_dir.clone(); + let task = tokio::spawn(async move { + driver + .provision_sandbox( + sandbox_for_task, + image_ref_for_task, + state_dir_for_task, + tls_paths, + OverlayPreparation::Fresh, + ) + .await; + }); + + let mut registry = self.registry.lock().await; + if let Some(record) = registry.get_mut(&sandbox_id) { + if record.deleting { + task.abort(); + } else { + record.provisioning_task = Some(task); + } + } else { + task.abort(); + } + + Ok(CreateSandboxResponse {}) + } - let image_identity = match self - .prepare_runtime_rootfs(&sandbox.id, &image_ref, &rootfs) + async fn provision_sandbox( + &self, + sandbox: Sandbox, + image_ref: String, + state_dir: PathBuf, + tls_paths: Option, + overlay_preparation: OverlayPreparation, + ) { + let sandbox_id = sandbox.id.clone(); + if let Err(err) = self + .provision_sandbox_inner( + sandbox, + image_ref, + state_dir.clone(), + tls_paths, + overlay_preparation, + ) .await { - Ok(image_identity) => { - info!( - sandbox_id = %sandbox.id, - image_identity = %image_identity, - "vm driver: rootfs prepared" - ); - image_identity - } - Err(err) => { - warn!( - sandbox_id = %sandbox.id, - error = %err.message(), - "vm driver: rootfs preparation failed" - ); - let _ = tokio::fs::remove_dir_all(&state_dir).await; - return Err(err); + if err.code() == tonic::Code::Cancelled { + if overlay_preparation == OverlayPreparation::Fresh { + let _ = tokio::fs::remove_dir_all(&state_dir).await; + } + return; } - }; - if let Some(tls_paths) = tls_paths.as_ref() - && let Err(err) = prepare_guest_tls_materials(&rootfs, tls_paths).await + + warn!( + sandbox_id = %sandbox_id, + error = %err.message(), + "vm driver: sandbox provisioning failed" + ); + self.fail_provisioning( + &sandbox_id, + &state_dir, + "ProvisioningFailed", + err.message(), + overlay_preparation == OverlayPreparation::Fresh, + ) + .await; + } + } + + #[allow(clippy::result_large_err)] + async fn provision_sandbox_inner( + &self, + sandbox: Sandbox, + image_ref: String, + state_dir: PathBuf, + tls_paths: Option, + overlay_preparation: OverlayPreparation, + ) -> Result<(), Status> { + self.ensure_provisioning_active(&sandbox.id).await?; + self.publish_platform_event( + sandbox.id.clone(), + platform_event( + "vm", + "Normal", + "ResolvingImage", + format!("Resolving VM sandbox image \"{image_ref}\""), + ), + ); + + let image_plan = self.prepare_runtime_images(&sandbox.id, &image_ref).await?; + let image_identity = image_plan.image_identity.clone(); + self.ensure_provisioning_active(&sandbox.id).await?; + info!( + sandbox_id = %sandbox.id, + image_identity = %image_identity, + bootstrap_image_identity = %image_plan.bootstrap_image_identity, + image_disk = image_plan.image_disk.as_ref().map(|path| path.display().to_string()).unwrap_or_default(), + "vm driver: sandbox root disk plan resolved" + ); + let disk_paths = sandbox_runtime_disk_paths(&state_dir); + let root_disk = image_plan.root_disk; + let image_disk = image_plan.image_disk; + let overlay_disk = disk_paths.overlay_disk; + + self.publish_platform_event( + sandbox.id.clone(), + platform_event( + "vm", + "Normal", + "PreparingOverlay", + "Preparing writable VM overlay disk".to_string(), + ), + ); + if let Err(err) = self + .prepare_runtime_overlay( + &overlay_disk, + tls_paths.as_ref(), + sandbox + .spec + .as_ref() + .map(|spec| spec.sandbox_token.as_str()) + .filter(|token| !token.is_empty()), + overlay_preparation, + ) + .await { - let _ = tokio::fs::remove_dir_all(&state_dir).await; return Err(Status::internal(format!( - "prepare guest TLS materials failed: {err}" + "prepare guest overlay disk failed: {err}" ))); } + self.ensure_provisioning_active(&sandbox.id).await?; if let Err(err) = write_sandbox_image_metadata(&state_dir, &image_ref, &image_identity).await { - let _ = tokio::fs::remove_dir_all(&state_dir).await; return Err(Status::internal(format!( "write sandbox image metadata failed: {err}" ))); } + let spec = sandbox.spec.as_ref(); + let is_gpu = spec.is_some_and(|s| s.gpu); + let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); let gpu_bdf = if is_gpu { - let inventory = self - .gpu_inventory - .as_ref() - .ok_or_else(|| Status::internal("GPU inventory not initialized"))?; - match inventory - .lock() - .map_err(|e| Status::internal(format!("GPU inventory lock poisoned: {e}"))) - .and_then(|mut inv| { - inv.assign(&sandbox.id, gpu_device) - .map_err(Status::failed_precondition) - }) { - Ok(assignment) => { - tracing::info!( - sandbox_id = %sandbox.id, - bdf = %assignment.bdf, - gpu_name = %assignment.name, - iommu_group = assignment.iommu_group, - "assigned GPU to sandbox" - ); - Some(assignment.bdf) - } - Err(err) => { - let _ = tokio::fs::remove_dir_all(&state_dir).await; - return Err(err); - } - } + Some(self.assign_gpu_to_record(&sandbox.id, gpu_device).await?) } else { None }; let console_output = state_dir.join("rootfs-console.log"); let mut command = Command::new(&self.launcher_bin); + command.kill_on_drop(true); command.stdin(Stdio::null()); command.stdout(Stdio::inherit()); command.stderr(Stdio::inherit()); command.arg("--internal-run-vm"); - command.arg("--vm-rootfs").arg(&rootfs); + command.arg("--vm-root-disk").arg(&root_disk); + command.arg("--vm-overlay-disk").arg(&overlay_disk); + if let Some(image_disk) = &image_disk { + command.arg("--vm-image-disk").arg(image_disk); + } command.arg("--vm-exec").arg(sandbox_guest_init_path()); command.arg("--vm-workdir").arg("/"); command.arg("--vm-console-output").arg(&console_output); @@ -494,7 +662,6 @@ impl VmDriver { Ok(s) => s, Err(err) => { self.release_gpu_and_subnet(&sandbox.id); - let _ = tokio::fs::remove_dir_all(&state_dir).await; return Err(err); } }; @@ -539,12 +706,13 @@ impl VmDriver { .arg(self.config.mem_mib.to_string()); None }; + self.ensure_provisioning_active(&sandbox.id).await?; command .arg("--vm-krun-log-level") .arg(self.config.krun_log_level.to_string()); - for env in build_guest_environment(sandbox, &self.config, endpoint_override.as_deref()) { + for env in build_guest_environment(&sandbox, &self.config, endpoint_override.as_deref()) { command.arg("--vm-env").arg(env); } @@ -565,7 +733,6 @@ impl VmDriver { if gpu_bdf.is_some() { self.release_gpu_and_subnet(&sandbox.id); } - let _ = tokio::fs::remove_dir_all(&state_dir).await; return Err(Status::internal(format!( "failed to launch vm helper '{}': {err}", self.launcher_bin.display() @@ -575,35 +742,49 @@ impl VmDriver { info!( sandbox_id = %sandbox.id, launcher_pid = child.id().unwrap_or(0), - "vm driver: launcher spawned" - ); - // Mirror the K8s `Started` event so the CLI can complete the - // "Starting sandbox" step. The supervisor-ready transition still - // promotes the sandbox to `Ready` separately. - self.publish_platform_event( - sandbox.id.clone(), - platform_event("vm", "Normal", "Started", "Started VM launcher".to_string()), + "vm driver: launcher spawned" ); - let snapshot = sandbox_snapshot(sandbox, provisioning_condition(), false); let process = Arc::new(Mutex::new(VmProcess { child, deleting: false, })); + let mut process_to_stop = None; + let mut snapshot_to_publish = None; { let mut registry = self.registry.lock().await; - registry.insert( - sandbox.id.clone(), - SandboxRecord { - snapshot: snapshot.clone(), - state_dir: state_dir.clone(), - process: process.clone(), - gpu_bdf: gpu_bdf.clone(), - }, - ); + match registry.get_mut(&sandbox.id) { + Some(record) if !record.deleting => { + record.process = Some(process.clone()); + record.gpu_bdf.clone_from(&gpu_bdf); + record.provisioning_task = None; + snapshot_to_publish = Some(record.snapshot.clone()); + } + _ => { + process_to_stop = Some(process.clone()); + } + } + } + + if let Some(process) = process_to_stop { + { + let mut process = process.lock().await; + process.deleting = true; + terminate_vm_process(&mut process.child) + .await + .map_err(|err| Status::internal(format!("failed to stop vm: {err}")))?; + } + self.release_gpu_and_subnet(&sandbox.id); + return Err(Status::cancelled("sandbox provisioning cancelled")); } - self.publish_snapshot(snapshot.clone()); + self.publish_platform_event( + sandbox.id.clone(), + platform_event("vm", "Normal", "Started", "Started VM launcher".to_string()), + ); + if let Some(snapshot) = snapshot_to_publish { + self.publish_snapshot(snapshot); + } tokio::spawn({ let driver = self.clone(); let sandbox_id = sandbox.id.clone(); @@ -612,7 +793,7 @@ impl VmDriver { } }); - Ok(CreateSandboxResponse {}) + Ok(()) } pub async fn delete_sandbox( @@ -620,37 +801,40 @@ impl VmDriver { sandbox_id: &str, sandbox_name: &str, ) -> Result { - let record = { + if !sandbox_id.is_empty() { + validate_sandbox_id(sandbox_id)?; + } + + let record_id = { let registry = self.registry.lock().await; - if let Some((id, record)) = registry.get_key_value(sandbox_id) { - Some(( - id.clone(), - record.state_dir.clone(), - record.process.clone(), - record.gpu_bdf.clone(), - )) + if let Some((id, _record)) = registry.get_key_value(sandbox_id) { + Some(id.clone()) } else { - let matched_id = registry + registry .iter() .find(|(_, record)| record.snapshot.name == sandbox_name) - .map(|(id, _)| id.clone()); - matched_id.and_then(|id| { - registry.get(&id).map(|record| { - ( - id, - record.state_dir.clone(), - record.process.clone(), - record.gpu_bdf.clone(), - ) - }) - }) + .map(|(id, _)| id.clone()) } }; - let Some((record_id, state_dir, process, gpu_bdf)) = record else { + let Some(record_id) = record_id else { return Ok(DeleteSandboxResponse { deleted: false }); }; + let (state_dir, process, gpu_bdf, provisioning_task) = { + let mut registry = self.registry.lock().await; + let Some(record) = registry.get_mut(&record_id) else { + return Ok(DeleteSandboxResponse { deleted: false }); + }; + record.deleting = true; + ( + record.state_dir.clone(), + record.process.clone(), + record.gpu_bdf.clone(), + record.provisioning_task.take(), + ) + }; + if let Some(snapshot) = self .set_snapshot_condition(&record_id, deleting_condition(), true) .await @@ -658,7 +842,11 @@ impl VmDriver { self.publish_snapshot(snapshot); } - { + if let Some(task) = provisioning_task { + task.abort(); + } + + if let Some(process) = process { let mut process = process.lock().await; process.deleting = true; terminate_vm_process(&mut process.child) @@ -670,13 +858,7 @@ impl VmDriver { self.release_gpu_and_subnet(&record_id); } - if let Err(err) = tokio::fs::remove_dir_all(&state_dir).await - && err.kind() != std::io::ErrorKind::NotFound - { - return Err(Status::internal(format!( - "failed to remove state dir: {err}" - ))); - } + remove_sandbox_state_dir(&self.config.state_dir, &state_dir).await?; { let mut registry = self.registry.lock().await; @@ -692,6 +874,10 @@ impl VmDriver { sandbox_id: &str, sandbox_name: &str, ) -> Result, Status> { + if !sandbox_id.is_empty() { + validate_sandbox_id(sandbox_id)?; + } + let registry = self.registry.lock().await; let sandbox = if sandbox_id.is_empty() { registry @@ -716,66 +902,390 @@ impl VmDriver { snapshots } - fn release_gpu_and_subnet(&self, sandbox_id: &str) { - if let Some(inventory) = self.gpu_inventory.as_ref() - && let Ok(mut inv) = inventory.lock() - { - inv.release(sandbox_id); - } - if let Ok(mut alloc) = self.subnet_allocator.lock() { - alloc.release(sandbox_id); - } - } + async fn restore_persisted_sandboxes(&self) { + let state_root = sandboxes_root_dir(&self.config.state_dir); + let mut entries = match tokio::fs::read_dir(&state_root).await { + Ok(entries) => entries, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return, + Err(err) => { + warn!( + state_root = %state_root.display(), + error = %err, + "vm driver: failed to scan persisted sandboxes" + ); + return; + } + }; - async fn prepare_runtime_rootfs( - &self, - sandbox_id: &str, - image_ref: &str, - rootfs: &Path, - ) -> Result { - let image_identity = self - .ensure_cached_image_rootfs_archive(sandbox_id, image_ref) - .await?; - let archive_path = image_cache_rootfs_archive(&self.config.state_dir, &image_identity); - let rootfs_dest = rootfs.to_path_buf(); - tokio::task::spawn_blocking(move || extract_rootfs_archive_to(&archive_path, &rootfs_dest)) - .await - .map_err(|err| Status::internal(format!("sandbox rootfs extraction panicked: {err}")))? - .map_err(|err| Status::internal(format!("extract sandbox rootfs failed: {err}")))?; + loop { + let entry = match entries.next_entry().await { + Ok(Some(entry)) => entry, + Ok(None) => break, + Err(err) => { + warn!( + state_root = %state_root.display(), + error = %err, + "vm driver: failed to continue scanning persisted sandboxes" + ); + break; + } + }; + let state_dir = entry.path(); + let is_dir = match entry.file_type().await { + Ok(file_type) => file_type.is_dir(), + Err(err) => { + warn!( + state_dir = %state_dir.display(), + error = %err, + "vm driver: failed to inspect persisted sandbox state dir" + ); + continue; + } + }; + if !is_dir { + continue; + } - Ok(image_identity) - } + let request_path = state_dir.join(SANDBOX_REQUEST_FILE); + let sandbox = match read_sandbox_request(&request_path).await { + Ok(sandbox) => sandbox, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => continue, + Err(err) => { + warn!( + state_dir = %state_dir.display(), + error = %err, + "vm driver: failed to read persisted sandbox request" + ); + continue; + } + }; - fn resolved_sandbox_image(&self, sandbox: &Sandbox) -> Option { - requested_sandbox_image(sandbox) - .map(ToOwned::to_owned) - .or_else(|| { - let image = self.config.default_image.trim(); - (!image.is_empty()).then(|| image.to_string()) - }) + if let Err(status) = + validate_restored_sandbox_state(&self.config.state_dir, &state_dir, &sandbox) + { + warn!( + sandbox_id = %sandbox.id, + state_dir = %state_dir.display(), + error = %status.message(), + "vm driver: ignoring invalid persisted sandbox state" + ); + continue; + } + + self.restore_persisted_sandbox(sandbox, state_dir).await; + } } - async fn ensure_cached_image_rootfs_archive( - &self, - sandbox_id: &str, - image_ref: &str, - ) -> Result { - if let Some((docker, image_identity)) = self.resolve_local_docker_image(image_ref).await? { + async fn restore_persisted_sandbox(&self, sandbox: Sandbox, state_dir: PathBuf) { + let Some(image_ref) = self.resolved_sandbox_image(&sandbox) else { + warn!( + sandbox_id = %sandbox.id, + sandbox_name = %sandbox.name, + "vm driver: cannot restore persisted sandbox without image" + ); + return; + }; + let tls_paths = match self.config.tls_paths() { + Ok(paths) => paths, + Err(err) => { + warn!( + sandbox_id = %sandbox.id, + sandbox_name = %sandbox.name, + error = %err, + "vm driver: cannot restore persisted sandbox TLS configuration" + ); + return; + } + }; + + let snapshot = sandbox_snapshot(&sandbox, provisioning_condition(), false); + { + let mut registry = self.registry.lock().await; + if registry.contains_key(&sandbox.id) { + return; + } + registry.insert( + sandbox.id.clone(), + SandboxRecord { + snapshot: snapshot.clone(), + state_dir: state_dir.clone(), + process: None, + provisioning_task: None, + gpu_bdf: None, + deleting: false, + }, + ); + } + + self.publish_platform_event( + sandbox.id.clone(), + platform_event( + "vm", + "Normal", + "Restoring", + "Restoring persisted VM sandbox after driver restart".to_string(), + ), + ); + self.publish_snapshot(snapshot); + + let driver = self.clone(); + let sandbox_id = sandbox.id.clone(); + let task = tokio::spawn(async move { + driver + .provision_sandbox( + sandbox, + image_ref, + state_dir, + tls_paths, + OverlayPreparation::PreserveExisting, + ) + .await; + }); + + let mut registry = self.registry.lock().await; + if let Some(record) = registry.get_mut(&sandbox_id) { + if record.deleting { + task.abort(); + } else { + record.provisioning_task = Some(task); + } + } else { + task.abort(); + } + } + + fn release_gpu_and_subnet(&self, sandbox_id: &str) { + if let Some(inventory) = self.gpu_inventory.as_ref() + && let Ok(mut inv) = inventory.lock() + { + inv.release(sandbox_id); + } + if let Ok(mut alloc) = self.subnet_allocator.lock() { + alloc.release(sandbox_id); + } + } + + async fn ensure_provisioning_active(&self, sandbox_id: &str) -> Result<(), Status> { + let registry = self.registry.lock().await; + match registry.get(sandbox_id) { + Some(record) if !record.deleting => Ok(()), + _ => Err(Status::cancelled("sandbox provisioning cancelled")), + } + } + + async fn assign_gpu_to_record( + &self, + sandbox_id: &str, + gpu_device: &str, + ) -> Result { + let mut registry = self.registry.lock().await; + match registry.get_mut(sandbox_id) { + Some(record) if !record.deleting => {} + _ => return Err(Status::cancelled("sandbox provisioning cancelled")), + } + + let inventory = self + .gpu_inventory + .as_ref() + .ok_or_else(|| Status::internal("GPU inventory not initialized"))?; + let assignment = inventory + .lock() + .map_err(|e| Status::internal(format!("GPU inventory lock poisoned: {e}")))? + .assign(sandbox_id, gpu_device) + .map_err(Status::failed_precondition)?; + + let record = registry + .get_mut(sandbox_id) + .expect("sandbox record exists while registry lock is held"); + record.gpu_bdf = Some(assignment.bdf.clone()); + tracing::info!( + sandbox_id = %sandbox_id, + bdf = %assignment.bdf, + gpu_name = %assignment.name, + iommu_group = assignment.iommu_group, + "assigned GPU to sandbox" + ); + Ok(assignment.bdf) + } + + async fn fail_provisioning( + &self, + sandbox_id: &str, + state_dir: &Path, + reason: &str, + message: &str, + remove_state: bool, + ) { + self.release_gpu_and_subnet(sandbox_id); + let snapshot = { + let mut registry = self.registry.lock().await; + let Some(record) = registry.get_mut(sandbox_id) else { + return; + }; + if record.deleting { + return; + } + record.process = None; + record.provisioning_task = None; + record.gpu_bdf = None; + record.snapshot.status = Some(status_with_condition( + &record.snapshot, + error_condition(reason, message), + false, + )); + Some(record.snapshot.clone()) + }; + + if remove_state { + let _ = tokio::fs::remove_dir_all(state_dir).await; + } + self.publish_platform_event( + sandbox_id.to_string(), + platform_event( + "vm", + "Warning", + reason, + format!("VM provisioning failed: {message}"), + ), + ); + if let Some(snapshot) = snapshot { + self.publish_snapshot(snapshot); + } + } + + async fn prepare_runtime_images( + &self, + sandbox_id: &str, + image_ref: &str, + ) -> Result { + let bootstrap_image_ref = self.bootstrap_image_ref(image_ref); + let bootstrap_image_identity = self + .ensure_cached_bootstrap_rootfs_image(sandbox_id, &bootstrap_image_ref) + .await?; + let root_disk = image_cache_rootfs_image(&self.config.state_dir, &bootstrap_image_identity); + + if image_ref.trim() == bootstrap_image_ref.trim() { + return Ok(RuntimeImagePlan { + root_disk, + image_disk: None, + image_identity: bootstrap_image_identity.clone(), + bootstrap_image_identity, + }); + } + + let prepared = self + .ensure_prepared_image_disk(sandbox_id, image_ref, &root_disk) + .await?; + Ok(RuntimeImagePlan { + root_disk, + image_disk: Some(prepared.disk_path), + image_identity: prepared.image_identity, + bootstrap_image_identity, + }) + } + + fn bootstrap_image_ref(&self, sandbox_image_ref: &str) -> String { + let configured = self.config.bootstrap_image.trim(); + if !configured.is_empty() { + return configured.to_string(); + } + let default = self.config.default_image.trim(); + if !default.is_empty() { + return default.to_string(); + } + sandbox_image_ref.to_string() + } + + async fn prepare_runtime_overlay( + &self, + overlay_disk: &Path, + tls_paths: Option<&VmDriverTlsPaths>, + sandbox_token: Option<&str>, + preparation: OverlayPreparation, + ) -> Result<(), String> { + let tls_materials = match tls_paths { + Some(paths) => Some(read_guest_tls_materials(paths).await?), + None => None, + }; + let sandbox_token = sandbox_token.map(str::to_string); + let overlay_disk = overlay_disk.to_path_buf(); + let overlay_size_bytes = self + .config + .overlay_disk_mib + .checked_mul(1024 * 1024) + .ok_or_else(|| { + format!( + "overlay disk size {} MiB is too large", + self.config.overlay_disk_mib + ) + })?; + + let template_path = overlay_template_image(&self.config.state_dir, overlay_size_bytes); + if !overlay_template_image_ready(&template_path, overlay_size_bytes).await? { + let _cache_guard = self.image_cache_lock.lock().await; + let template_path = template_path.clone(); + tokio::task::spawn_blocking(move || { + ensure_sandbox_overlay_template_image(&template_path, overlay_size_bytes) + }) + .await + .map_err(|err| format!("overlay template preparation panicked: {err}"))??; + } + + tokio::task::spawn_blocking(move || { + prepare_sandbox_overlay_image( + &template_path, + &overlay_disk, + tls_materials.as_ref(), + sandbox_token.as_deref(), + preparation, + overlay_size_bytes, + ) + }) + .await + .map_err(|err| format!("overlay image preparation panicked: {err}"))? + } + + fn resolved_sandbox_image(&self, sandbox: &Sandbox) -> Option { + requested_sandbox_image(sandbox) + .map(ToOwned::to_owned) + .or_else(|| { + let image = self.config.default_image.trim(); + (!image.is_empty()).then(|| image.to_string()) + }) + } + + async fn ensure_cached_bootstrap_rootfs_image( + &self, + sandbox_id: &str, + image_ref: &str, + ) -> Result { + if let Some((engine, image_identity)) = + self.resolve_local_container_image(image_ref).await? + { return self - .ensure_cached_local_image_rootfs_archive( + .ensure_cached_local_image_rootfs_image( sandbox_id, image_ref, - &docker, + &engine, &image_identity, ) .await; } - info!(image_ref = %image_ref, "vm driver: ensuring cached image rootfs archive (registry)"); + info!(image_ref = %image_ref, "vm driver: ensuring cached root disk image (registry)"); let reference = parse_registry_reference(image_ref)?; let client = registry_client(); let auth = registry_auth(image_ref)?; info!(image_ref = %image_ref, "vm driver: authenticating with registry"); + self.publish_vm_progress( + sandbox_id, + "AuthenticatingRegistry", + format!("Authenticating registry access for image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "registry".to_string()), + ]), + ); client .auth(&reference, &auth, RegistryOperation::Pull) .await @@ -785,7 +1295,16 @@ impl VmDriver { )) })?; info!(image_ref = %image_ref, "vm driver: fetching manifest digest"); - let image_identity = client + self.publish_vm_progress( + sandbox_id, + "FetchingManifest", + format!("Fetching manifest for image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "registry".to_string()), + ]), + ); + let source_image_identity = client .fetch_manifest_digest(&reference, &auth) .await .map_err(|err| { @@ -795,15 +1314,14 @@ impl VmDriver { })?; info!( image_ref = %image_ref, - image_identity = %image_identity, + image_identity = %source_image_identity, "vm driver: manifest digest resolved" ); - let archive_path = image_cache_rootfs_archive(&self.config.state_dir, &image_identity); + let image_identity = bootstrap_image_cache_identity(&source_image_identity); + let image_path = image_cache_rootfs_image(&self.config.state_dir, &image_identity); - // Mirror the K8s `Pulling` event so the CLI flips to the - // image-pull spinner with the image name as detail. We emit it - // for cache hits too and immediately follow with `Pulled` so the - // spinner step still advances cleanly. + // Emit a driver progress hint for cache hits too and immediately + // follow with `Pulled` so the image step still advances cleanly. self.publish_platform_event( sandbox_id.to_string(), platform_event( @@ -814,37 +1332,79 @@ impl VmDriver { ), ); - if tokio::fs::metadata(&archive_path).await.is_ok() { + if tokio::fs::metadata(&image_path).await.is_ok() { info!( image_identity = %image_identity, - archive_path = %archive_path.display(), - "vm driver: image rootfs archive cache hit (no build needed)" + image_path = %image_path.display(), + "vm driver: root disk image cache hit (no build needed)" + ); + self.publish_vm_progress( + sandbox_id, + "CacheHit", + format!("Using cached VM root disk for image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "registry".to_string()), + ("cache_hit".to_string(), "true".to_string()), + ("image_identity".to_string(), image_identity.clone()), + ]), ); - self.publish_pulled_event(sandbox_id, image_ref, &archive_path) + self.publish_pulled_event(sandbox_id, image_ref, &image_path) .await; return Ok(image_identity); } info!( image_identity = %image_identity, - "vm driver: image rootfs archive cache miss, acquiring build lock" + "vm driver: root disk image cache miss, acquiring build lock" + ); + self.publish_vm_progress( + sandbox_id, + "CacheMiss", + format!("Preparing VM root disk cache for image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "registry".to_string()), + ("cache_hit".to_string(), "false".to_string()), + ("image_identity".to_string(), image_identity.clone()), + ]), + ); + self.publish_vm_progress( + sandbox_id, + "WaitingForImageCacheLock", + "Waiting for VM image cache build lock".to_string(), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_identity".to_string(), image_identity.clone()), + ]), ); let _cache_guard = self.image_cache_lock.lock().await; info!( image_identity = %image_identity, "vm driver: build lock acquired" ); - if tokio::fs::metadata(&archive_path).await.is_ok() { + if tokio::fs::metadata(&image_path).await.is_ok() { info!( image_identity = %image_identity, - "vm driver: image rootfs archive cache hit after lock (built by another task)" + "vm driver: root disk image cache hit after lock (built by another task)" + ); + self.publish_vm_progress( + sandbox_id, + "CacheHit", + format!("Using cached VM root disk for image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "registry".to_string()), + ("cache_hit".to_string(), "true".to_string()), + ("image_identity".to_string(), image_identity.clone()), + ]), ); - self.publish_pulled_event(sandbox_id, image_ref, &archive_path) + self.publish_pulled_event(sandbox_id, image_ref, &image_path) .await; return Ok(image_identity); } - self.build_cached_registry_image_rootfs_archive( + self.build_cached_registry_image_rootfs_image( sandbox_id, &client, &reference, @@ -853,36 +1413,35 @@ impl VmDriver { &image_identity, ) .await?; - self.publish_pulled_event(sandbox_id, image_ref, &archive_path) + self.publish_pulled_event(sandbox_id, image_ref, &image_path) .await; Ok(image_identity) } - async fn resolve_local_docker_image( + async fn resolve_local_container_image( &self, image_ref: &str, ) -> Result, Status> { let required_local_image = is_openshell_local_build_image_ref(image_ref); - let docker = match Docker::connect_with_local_defaults() { - Ok(docker) => docker, - Err(err) if required_local_image => { + let engine = match connect_local_container_engine().await { + Some(engine) => engine, + None if required_local_image => { return Err(Status::failed_precondition(format!( - "failed to connect to local Docker daemon for locally built sandbox image '{image_ref}': {err}" + "no container engine (Docker/Podman) available for locally built sandbox image '{image_ref}'" ))); } - Err(err) => { + None => { warn!( image_ref = %image_ref, - error = %err, - "vm driver: local Docker daemon unavailable, falling back to registry" + "vm driver: no local container engine available, falling back to registry" ); return Ok(None); } }; - match docker.inspect_image(image_ref).await { + match engine.inspect_image(image_ref).await { Ok(inspect) => { - if let Some(message) = local_docker_image_platform_mismatch( + if let Some(message) = local_image_platform_mismatch( image_ref, inspect.os.as_deref(), inspect.architecture.as_deref(), @@ -893,98 +1452,570 @@ impl VmDriver { warn!( image_ref = %image_ref, %message, - "vm driver: local Docker image platform mismatch, falling back to registry" + "vm driver: local container image platform mismatch, falling back to registry" ); return Ok(None); } - let image_identity = - inspect - .id - .filter(|id| !id.trim().is_empty()) - .ok_or_else(|| { - Status::failed_precondition(format!( - "local Docker image '{image_ref}' inspect response has no image ID" - )) - })?; - info!( - image_ref = %image_ref, - image_identity = %image_identity, - "vm driver: resolved image from local Docker daemon" - ); - Ok(Some((docker, image_identity))) - } - Err(err) if is_docker_not_found_error(&err) && required_local_image => { - Err(Status::failed_precondition(format!( - "locally built sandbox image '{image_ref}' is not present in the local Docker daemon" - ))) - } - Err(err) if is_docker_not_found_error(&err) => Ok(None), - Err(err) if required_local_image => Err(Status::failed_precondition(format!( - "failed to inspect locally built sandbox image '{image_ref}': {err}" - ))), - Err(err) => { - warn!( - image_ref = %image_ref, - error = %err, - "vm driver: local Docker image inspection failed, falling back to registry" - ); - Ok(None) - } + let image_identity = inspect.id.filter(|id| !id.trim().is_empty()).ok_or_else( + || { + Status::failed_precondition(format!( + "local container image '{image_ref}' inspect response has no image ID" + )) + }, + )?; + info!( + image_ref = %image_ref, + image_identity = %image_identity, + "vm driver: resolved image from local container engine" + ); + Ok(Some((engine, image_identity))) + } + Err(err) if is_docker_not_found_error(&err) && required_local_image => { + Err(Status::failed_precondition(format!( + "locally built sandbox image '{image_ref}' is not present in the local container engine" + ))) + } + Err(err) if is_docker_not_found_error(&err) => Ok(None), + Err(err) if required_local_image => Err(Status::failed_precondition(format!( + "failed to inspect locally built sandbox image '{image_ref}': {err}" + ))), + Err(err) => { + warn!( + image_ref = %image_ref, + error = %err, + "vm driver: local container image inspection failed, falling back to registry" + ); + Ok(None) + } + } + } + + async fn ensure_cached_local_image_rootfs_image( + &self, + sandbox_id: &str, + image_ref: &str, + docker: &Docker, + image_identity: &str, + ) -> Result { + let cache_identity = bootstrap_image_cache_identity(image_identity); + let image_path = image_cache_rootfs_image(&self.config.state_dir, &cache_identity); + + self.publish_platform_event( + sandbox_id.to_string(), + platform_event( + "vm", + "Normal", + "Pulling", + format!("Pulling image \"{image_ref}\""), + ), + ); + + if tokio::fs::metadata(&image_path).await.is_ok() { + self.publish_vm_progress( + sandbox_id, + "CacheHit", + format!("Using cached VM root disk for local image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "local_docker".to_string()), + ("cache_hit".to_string(), "true".to_string()), + ("image_identity".to_string(), cache_identity.clone()), + ]), + ); + self.publish_pulled_event(sandbox_id, image_ref, &image_path) + .await; + return Ok(cache_identity); + } + + self.publish_vm_progress( + sandbox_id, + "CacheMiss", + format!("Preparing VM root disk cache for local image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "local_docker".to_string()), + ("cache_hit".to_string(), "false".to_string()), + ("image_identity".to_string(), cache_identity.clone()), + ]), + ); + self.publish_vm_progress( + sandbox_id, + "WaitingForImageCacheLock", + "Waiting for VM image cache build lock".to_string(), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_identity".to_string(), cache_identity.clone()), + ]), + ); + let _cache_guard = self.image_cache_lock.lock().await; + if tokio::fs::metadata(&image_path).await.is_ok() { + self.publish_vm_progress( + sandbox_id, + "CacheHit", + format!("Using cached VM root disk for local image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "local_docker".to_string()), + ("cache_hit".to_string(), "true".to_string()), + ("image_identity".to_string(), cache_identity.clone()), + ]), + ); + self.publish_pulled_event(sandbox_id, image_ref, &image_path) + .await; + return Ok(cache_identity); + } + + self.build_cached_local_image_rootfs_image(sandbox_id, docker, image_ref, &cache_identity) + .await?; + self.publish_pulled_event(sandbox_id, image_ref, &image_path) + .await; + Ok(cache_identity) + } + + async fn ensure_prepared_image_disk( + &self, + sandbox_id: &str, + image_ref: &str, + bootstrap_root_disk: &Path, + ) -> Result { + if let Some((docker, image_identity)) = + self.resolve_local_container_image(image_ref).await? + { + return self + .ensure_prepared_local_image_disk( + sandbox_id, + image_ref, + &docker, + &image_identity, + bootstrap_root_disk, + ) + .await; + } + + self.ensure_prepared_registry_image_disk(sandbox_id, image_ref, bootstrap_root_disk) + .await + } + + async fn ensure_prepared_local_image_disk( + &self, + sandbox_id: &str, + image_ref: &str, + docker: &Docker, + image_identity: &str, + bootstrap_root_disk: &Path, + ) -> Result { + let cache_identity = prepared_image_cache_identity(image_identity); + let image_path = image_cache_rootfs_image(&self.config.state_dir, &cache_identity); + + if tokio::fs::metadata(&image_path).await.is_ok() { + self.publish_prepared_cache_hit(sandbox_id, image_ref, "local_docker", &cache_identity); + return Ok(PreparedImageDisk { + image_identity: cache_identity, + disk_path: image_path, + }); + } + + self.publish_prepared_cache_miss(sandbox_id, image_ref, "local_docker", &cache_identity); + let _cache_guard = self.image_cache_lock.lock().await; + if tokio::fs::metadata(&image_path).await.is_ok() { + self.publish_prepared_cache_hit(sandbox_id, image_ref, "local_docker", &cache_identity); + return Ok(PreparedImageDisk { + image_identity: cache_identity, + disk_path: image_path, + }); + } + + let staging_dir = image_cache_staging_dir(&self.config.state_dir, &cache_identity); + let rootfs_archive = staging_dir.join(IMAGE_EXPORT_ROOTFS_ARCHIVE); + self.reset_image_staging_dir(&staging_dir).await?; + + self.publish_vm_progress( + sandbox_id, + "ExportingRootfs", + format!("Exporting rootfs from local image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "local_docker".to_string()), + ("image_identity".to_string(), cache_identity.clone()), + ]), + ); + if let Err(err) = + export_local_image_rootfs_to_path(docker, image_ref, &rootfs_archive).await + { + let _ = tokio::fs::remove_dir_all(&staging_dir).await; + return Err(err); + } + + let payload = GuestImagePayload { + image_ref: image_ref.to_string(), + image_identity: cache_identity.clone(), + source: GuestImagePayloadSource::LocalDocker { rootfs_archive }, + }; + self.build_prepared_image_disk( + sandbox_id, + image_ref, + "local_docker", + &cache_identity, + bootstrap_root_disk, + &staging_dir, + &payload, + ) + .await?; + + Ok(PreparedImageDisk { + image_identity: cache_identity, + disk_path: image_path, + }) + } + + async fn ensure_prepared_registry_image_disk( + &self, + sandbox_id: &str, + image_ref: &str, + bootstrap_root_disk: &Path, + ) -> Result { + let reference = parse_registry_reference(image_ref)?; + let client = registry_client(); + let auth = registry_auth(image_ref)?; + + self.publish_vm_progress( + sandbox_id, + "AuthenticatingRegistry", + format!("Authenticating registry access for image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "registry".to_string()), + ]), + ); + client + .auth(&reference, &auth, RegistryOperation::Pull) + .await + .map_err(|err| { + Status::failed_precondition(format!( + "failed to authenticate registry access for vm sandbox image '{image_ref}': {err}" + )) + })?; + + self.publish_vm_progress( + sandbox_id, + "FetchingManifest", + format!("Fetching manifest for image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "registry".to_string()), + ]), + ); + let source_image_identity = client + .fetch_manifest_digest(&reference, &auth) + .await + .map_err(|err| { + Status::failed_precondition(format!( + "failed to resolve vm sandbox image '{image_ref}': {err}" + )) + })?; + let cache_identity = prepared_image_cache_identity(&source_image_identity); + let image_path = image_cache_rootfs_image(&self.config.state_dir, &cache_identity); + + if tokio::fs::metadata(&image_path).await.is_ok() { + self.publish_prepared_cache_hit(sandbox_id, image_ref, "registry", &cache_identity); + return Ok(PreparedImageDisk { + image_identity: cache_identity, + disk_path: image_path, + }); + } + + self.publish_prepared_cache_miss(sandbox_id, image_ref, "registry", &cache_identity); + let _cache_guard = self.image_cache_lock.lock().await; + if tokio::fs::metadata(&image_path).await.is_ok() { + self.publish_prepared_cache_hit(sandbox_id, image_ref, "registry", &cache_identity); + return Ok(PreparedImageDisk { + image_identity: cache_identity, + disk_path: image_path, + }); + } + + let staging_dir = image_cache_staging_dir(&self.config.state_dir, &cache_identity); + self.reset_image_staging_dir(&staging_dir).await?; + let layout_dir = staging_dir.join(GUEST_IMAGE_OCI_LAYOUT_DIR); + + let (manifest, _) = client + .pull_image_manifest(&reference, &auth) + .await + .map_err(|err| { + Status::failed_precondition(format!( + "failed to pull vm sandbox image manifest '{image_ref}': {err}" + )) + })?; + tokio::fs::create_dir_all(oci_layout_blobs_dir(&layout_dir)) + .await + .map_err(|err| Status::internal(format!("create guest OCI layout failed: {err}")))?; + + download_registry_descriptor_blob_file( + &client, + &reference, + image_ref, + &layout_dir, + &manifest.config, + "config", + ) + .await?; + + let total_layers = manifest.layers.len(); + let total_bytes: i64 = manifest.layers.iter().map(|layer| layer.size.max(0)).sum(); + futures::stream::iter(manifest.layers.iter().cloned().enumerate()) + .map(|(index, layer)| { + let client = client.clone(); + let reference = reference.clone(); + let layout_dir = layout_dir.clone(); + async move { + self.publish_registry_layer_progress( + sandbox_id, + image_ref, + &layer, + index, + total_layers, + total_bytes, + ); + download_registry_descriptor_blob_file( + &client, + &reference, + image_ref, + &layout_dir, + &layer, + &format!("layer {}", index + 1), + ) + .await + } + }) + .buffer_unordered(registry_layer_download_concurrency()) + .try_collect::>() + .await?; + + write_oci_layout_for_manifest(&layout_dir, GUEST_IMAGE_OCI_REF, &manifest) + .map_err(|err| Status::internal(format!("write OCI layout failed: {err}")))?; + + let payload = GuestImagePayload { + image_ref: image_ref.to_string(), + image_identity: cache_identity.clone(), + source: GuestImagePayloadSource::RegistryOciLayout { layout_dir }, + }; + self.build_prepared_image_disk( + sandbox_id, + image_ref, + "registry", + &cache_identity, + bootstrap_root_disk, + &staging_dir, + &payload, + ) + .await?; + + Ok(PreparedImageDisk { + image_identity: cache_identity, + disk_path: image_path, + }) + } + + async fn reset_image_staging_dir(&self, staging_dir: &Path) -> Result<(), Status> { + tokio::fs::create_dir_all(image_cache_root_dir(&self.config.state_dir)) + .await + .map_err(|err| Status::internal(format!("create image cache dir failed: {err}")))?; + if tokio::fs::metadata(staging_dir).await.is_ok() { + tokio::fs::remove_dir_all(staging_dir) + .await + .map_err(|err| { + Status::internal(format!( + "remove stale image cache staging dir failed: {err}" + )) + })?; + } + tokio::fs::create_dir_all(staging_dir).await.map_err(|err| { + Status::internal(format!("create image cache staging dir failed: {err}")) + }) + } + + #[allow(clippy::too_many_arguments)] + async fn build_prepared_image_disk( + &self, + sandbox_id: &str, + image_ref: &str, + image_source: &str, + image_identity: &str, + bootstrap_root_disk: &Path, + staging_dir: &Path, + payload: &GuestImagePayload, + ) -> Result<(), Status> { + let cache_dir = image_cache_dir(&self.config.state_dir, image_identity); + let image_path = image_cache_rootfs_image(&self.config.state_dir, image_identity); + let prepared_image = staging_dir.join(IMAGE_CACHE_ROOTFS_IMAGE); + tokio::fs::create_dir_all(&cache_dir).await.map_err(|err| { + Status::internal(format!("create prepared image cache dir failed: {err}")) + })?; + + let payload_for_size = payload.clone(); + let min_size = self + .config + .overlay_disk_mib + .checked_mul(1024 * 1024) + .ok_or_else(|| Status::internal("prepared image disk size overflow"))?; + let image_size = tokio::task::spawn_blocking(move || { + prepared_image_disk_size_bytes(&payload_for_size, min_size) + }) + .await + .map_err(|err| { + Status::internal(format!("prepared image size calculation panicked: {err}")) + })? + .map_err(Status::internal)?; + + let payload_for_disk = payload.clone(); + let prepared_image_for_disk = prepared_image.clone(); + self.publish_vm_progress( + sandbox_id, + "CreatingRootDisk", + "Formatting prepared VM image disk".to_string(), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), image_source.to_string()), + ("image_identity".to_string(), image_identity.to_string()), + ]), + ); + tokio::task::spawn_blocking(move || { + create_image_prep_disk(&prepared_image_for_disk, image_size, &payload_for_disk) + }) + .await + .map_err(|err| Status::internal(format!("prepared image disk build panicked: {err}")))? + .map_err(Status::failed_precondition)?; + + self.publish_vm_progress( + sandbox_id, + "PreparingRootfs", + format!("Preparing VM image rootfs for \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), image_source.to_string()), + ("image_identity".to_string(), image_identity.to_string()), + ]), + ); + if let Err(err) = self + .run_image_prep_vm(bootstrap_root_disk, &prepared_image, staging_dir) + .await + { + let _ = tokio::fs::remove_dir_all(staging_dir).await; + return Err(err); + } + + if tokio::fs::metadata(&image_path).await.is_ok() { + let _ = tokio::fs::remove_dir_all(staging_dir).await; + return Ok(()); + } + tokio::fs::rename(&prepared_image, &image_path) + .await + .map_err(|err| Status::internal(format!("store prepared image disk failed: {err}")))?; + let _ = tokio::fs::remove_dir_all(staging_dir).await; + Ok(()) + } + + async fn run_image_prep_vm( + &self, + bootstrap_root_disk: &Path, + prep_disk: &Path, + run_dir: &Path, + ) -> Result<(), Status> { + let console_output = run_dir.join("image-prep-console.log"); + let mut command = Command::new(&self.launcher_bin); + command.kill_on_drop(true); + command.stdin(Stdio::null()); + command.stdout(Stdio::inherit()); + command.stderr(Stdio::inherit()); + command.arg("--internal-run-vm"); + command.arg("--vm-root-disk").arg(bootstrap_root_disk); + command.arg("--vm-overlay-disk").arg(prep_disk); + command.arg("--vm-exec").arg(sandbox_guest_init_path()); + command.arg("--vm-workdir").arg("/"); + command.arg("--vm-console-output").arg(&console_output); + command.arg("--vm-vcpus").arg(self.config.vcpus.to_string()); + command + .arg("--vm-mem-mib") + .arg(self.config.mem_mib.to_string()); + command + .arg("--vm-krun-log-level") + .arg(self.config.krun_log_level.to_string()); + command + .arg("--vm-env") + .arg(format!("OPENSHELL_VM_INIT_MODE={IMAGE_PREP_INIT_MODE}")); + + let mut child = command + .spawn() + .map_err(|err| Status::internal(format!("failed to run image-prep vm: {err}")))?; + let status = child + .wait() + .await + .map_err(|err| Status::internal(format!("failed to wait for image-prep vm: {err}")))?; + if status.success() { + return Ok(()); } + let console = tokio::fs::read_to_string(&console_output) + .await + .unwrap_or_default(); + Err(Status::failed_precondition(format!( + "image-prep vm exited with status {status}: {console}" + ))) } - async fn ensure_cached_local_image_rootfs_archive( + fn publish_prepared_cache_hit( &self, sandbox_id: &str, image_ref: &str, - docker: &Docker, + image_source: &str, image_identity: &str, - ) -> Result { - let archive_path = image_cache_rootfs_archive(&self.config.state_dir, image_identity); - - self.publish_platform_event( - sandbox_id.to_string(), - platform_event( - "vm", - "Normal", - "Pulling", - format!("Pulling image \"{image_ref}\""), - ), + ) { + self.publish_vm_progress( + sandbox_id, + "CacheHit", + format!("Using cached prepared VM image disk for \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), image_source.to_string()), + ("cache_hit".to_string(), "true".to_string()), + ("image_identity".to_string(), image_identity.to_string()), + ]), ); + } - if tokio::fs::metadata(&archive_path).await.is_ok() { - self.publish_pulled_event(sandbox_id, image_ref, &archive_path) - .await; - return Ok(image_identity.to_string()); - } - - let _cache_guard = self.image_cache_lock.lock().await; - if tokio::fs::metadata(&archive_path).await.is_ok() { - self.publish_pulled_event(sandbox_id, image_ref, &archive_path) - .await; - return Ok(image_identity.to_string()); - } - - self.build_cached_local_image_rootfs_archive(docker, image_ref, image_identity) - .await?; - self.publish_pulled_event(sandbox_id, image_ref, &archive_path) - .await; - Ok(image_identity.to_string()) + fn publish_prepared_cache_miss( + &self, + sandbox_id: &str, + image_ref: &str, + image_source: &str, + image_identity: &str, + ) { + self.publish_vm_progress( + sandbox_id, + "CacheMiss", + format!("Preparing VM image disk cache for \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), image_source.to_string()), + ("cache_hit".to_string(), "false".to_string()), + ("image_identity".to_string(), image_identity.to_string()), + ]), + ); } - async fn build_cached_local_image_rootfs_archive( + async fn build_cached_local_image_rootfs_image( &self, + sandbox_id: &str, docker: &Docker, image_ref: &str, image_identity: &str, ) -> Result<(), Status> { let cache_dir = image_cache_dir(&self.config.state_dir, image_identity); - let archive_path = image_cache_rootfs_archive(&self.config.state_dir, image_identity); + let image_path = image_cache_rootfs_image(&self.config.state_dir, image_identity); let staging_dir = image_cache_staging_dir(&self.config.state_dir, image_identity); let exported_rootfs = staging_dir.join(IMAGE_EXPORT_ROOTFS_ARCHIVE); let prepared_rootfs = staging_dir.join("rootfs"); - let prepared_archive = staging_dir.join(IMAGE_CACHE_ROOTFS_ARCHIVE); + let prepared_image = staging_dir.join(IMAGE_CACHE_ROOTFS_IMAGE); tokio::fs::create_dir_all(image_cache_root_dir(&self.config.state_dir)) .await @@ -1008,6 +2039,16 @@ impl VmDriver { Status::internal(format!("create image cache staging dir failed: {err}")) })?; + self.publish_vm_progress( + sandbox_id, + "ExportingRootfs", + format!("Exporting rootfs from local image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "local_docker".to_string()), + ("image_identity".to_string(), image_identity.to_string()), + ]), + ); if let Err(err) = export_local_image_rootfs_to_path(docker, image_ref, &exported_rootfs).await { @@ -1019,37 +2060,70 @@ impl VmDriver { let image_identity_owned = image_identity.to_string(); let exported_rootfs_for_build = exported_rootfs.clone(); let prepared_rootfs_for_build = prepared_rootfs.clone(); - let prepared_archive_for_build = prepared_archive.clone(); - let build_result = tokio::task::spawn_blocking(move || { - prepare_exported_rootfs_archive( - &image_ref_owned, - &image_identity_owned, - &exported_rootfs_for_build, + self.publish_vm_progress( + sandbox_id, + "PreparingRootfs", + format!("Preparing VM rootfs for local image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "local_docker".to_string()), + ("image_identity".to_string(), image_identity.to_string()), + ]), + ); + let prepare_result = tokio::task::spawn_blocking(move || { + extract_rootfs_archive_to(&exported_rootfs_for_build, &prepared_rootfs_for_build)?; + prepare_sandbox_rootfs_from_image_root( &prepared_rootfs_for_build, - &prepared_archive_for_build, + &image_identity_owned, ) + .map_err(|err| { + format!("vm sandbox image '{image_ref_owned}' is not base-compatible: {err}") + }) }) .await .map_err(|err| Status::internal(format!("local image preparation panicked: {err}")))?; + if let Err(err) = prepare_result { + let _ = tokio::fs::remove_dir_all(&staging_dir).await; + return Err(Status::failed_precondition(err)); + } + + self.publish_vm_progress( + sandbox_id, + "CreatingRootDisk", + "Formatting VM root disk image".to_string(), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "local_docker".to_string()), + ("image_identity".to_string(), image_identity.to_string()), + ]), + ); + let prepared_rootfs_for_build = prepared_rootfs.clone(); + let prepared_image_for_build = prepared_image.clone(); + let build_result = tokio::task::spawn_blocking(move || { + create_rootfs_image_from_dir(&prepared_rootfs_for_build, &prepared_image_for_build) + }) + .await + .map_err(|err| Status::internal(format!("rootfs image build panicked: {err}")))?; + if let Err(err) = build_result { let _ = tokio::fs::remove_dir_all(&staging_dir).await; return Err(Status::failed_precondition(err)); } - if tokio::fs::metadata(&archive_path).await.is_ok() { + if tokio::fs::metadata(&image_path).await.is_ok() { let _ = tokio::fs::remove_dir_all(&staging_dir).await; return Ok(()); } - tokio::fs::rename(&prepared_archive, &archive_path) + tokio::fs::rename(&prepared_image, &image_path) .await - .map_err(|err| Status::internal(format!("store cached image rootfs failed: {err}")))?; + .map_err(|err| Status::internal(format!("store cached rootfs image failed: {err}")))?; let _ = tokio::fs::remove_dir_all(&staging_dir).await; Ok(()) } - async fn build_cached_registry_image_rootfs_archive( + async fn build_cached_registry_image_rootfs_image( &self, sandbox_id: &str, client: &OciClient, @@ -1059,10 +2133,10 @@ impl VmDriver { image_identity: &str, ) -> Result<(), Status> { let cache_dir = image_cache_dir(&self.config.state_dir, image_identity); - let archive_path = image_cache_rootfs_archive(&self.config.state_dir, image_identity); + let image_path = image_cache_rootfs_image(&self.config.state_dir, image_identity); let staging_dir = image_cache_staging_dir(&self.config.state_dir, image_identity); let prepared_rootfs = staging_dir.join("rootfs"); - let prepared_archive = staging_dir.join(IMAGE_CACHE_ROOTFS_ARCHIVE); + let prepared_image = staging_dir.join(IMAGE_CACHE_ROOTFS_IMAGE); tokio::fs::create_dir_all(image_cache_root_dir(&self.config.state_dir)) .await @@ -1113,52 +2187,88 @@ impl VmDriver { } info!( image_ref = %image_ref, - "vm driver: image layers pulled, preparing rootfs archive" + "vm driver: image layers pulled, preparing rootfs image" ); let image_ref_owned = image_ref.to_string(); let image_identity_owned = image_identity.to_string(); let prepared_rootfs_for_build = prepared_rootfs.clone(); - let prepared_archive_for_build = prepared_archive.clone(); - let build_result = tokio::task::spawn_blocking(move || { + self.publish_vm_progress( + sandbox_id, + "PreparingRootfs", + format!("Preparing VM rootfs for image \"{image_ref}\""), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "registry".to_string()), + ("image_identity".to_string(), image_identity.to_string()), + ]), + ); + let prepare_result = tokio::task::spawn_blocking(move || { prepare_sandbox_rootfs_from_image_root( &prepared_rootfs_for_build, &image_identity_owned, ) .map_err(|err| { format!("vm sandbox image '{image_ref_owned}' is not base-compatible: {err}") - })?; - create_rootfs_archive_from_dir(&prepared_rootfs_for_build, &prepared_archive_for_build) + }) }) .await .map_err(|err| Status::internal(format!("image rootfs preparation panicked: {err}")))?; + if let Err(err) = prepare_result { + warn!( + image_ref = %image_ref, + error = %err, + "vm driver: rootfs preparation failed" + ); + let _ = tokio::fs::remove_dir_all(&staging_dir).await; + return Err(Status::failed_precondition(err)); + } + + self.publish_vm_progress( + sandbox_id, + "CreatingRootDisk", + "Formatting VM root disk image".to_string(), + HashMap::from([ + ("image_ref".to_string(), image_ref.to_string()), + ("image_source".to_string(), "registry".to_string()), + ("image_identity".to_string(), image_identity.to_string()), + ]), + ); + let prepared_rootfs_for_build = prepared_rootfs.clone(); + let prepared_image_for_build = prepared_image.clone(); + let build_result = tokio::task::spawn_blocking(move || { + create_rootfs_image_from_dir(&prepared_rootfs_for_build, &prepared_image_for_build) + }) + .await + .map_err(|err| Status::internal(format!("image rootfs build panicked: {err}")))?; + if let Err(err) = build_result { warn!( image_ref = %image_ref, error = %err, - "vm driver: rootfs archive build failed" + "vm driver: rootfs image build failed" ); let _ = tokio::fs::remove_dir_all(&staging_dir).await; return Err(Status::failed_precondition(err)); } - if tokio::fs::metadata(&archive_path).await.is_ok() { + if tokio::fs::metadata(&image_path).await.is_ok() { info!( image_identity = %image_identity, - "vm driver: another task wrote archive while we were building, discarding ours" + "vm driver: another task wrote image while we were building, discarding ours" ); let _ = tokio::fs::remove_dir_all(&staging_dir).await; return Ok(()); } - tokio::fs::rename(&prepared_archive, &archive_path) + tokio::fs::rename(&prepared_image, &image_path) .await - .map_err(|err| Status::internal(format!("store cached image rootfs failed: {err}")))?; + .map_err(|err| Status::internal(format!("store cached rootfs image failed: {err}")))?; info!( image_identity = %image_identity, - archive_path = %archive_path.display(), - "vm driver: image rootfs archive committed to cache" + image_path = %image_path.display(), + "vm driver: root disk image committed to cache" ); let _ = tokio::fs::remove_dir_all(&staging_dir).await; Ok(()) @@ -1180,7 +2290,10 @@ impl VmDriver { let Some(record) = registry.get(&sandbox_id) else { return; }; - record.process.clone() + let Some(process) = record.process.as_ref() else { + return; + }; + process.clone() }; let exit_status = { @@ -1291,6 +2404,19 @@ impl VmDriver { )), }); } + + fn publish_vm_progress( + &self, + sandbox_id: &str, + reason: &str, + message: String, + metadata: HashMap, + ) { + let mut event = platform_event("vm", "Normal", reason, message); + event.metadata = metadata; + attach_vm_progress_metadata(&mut event); + self.publish_platform_event(sandbox_id.to_string(), event); + } } #[tonic::async_trait] @@ -1436,9 +2562,8 @@ impl ComputeDriver for VmDriver { } #[cfg(target_os = "linux")] -#[allow(unsafe_code)] // libc::geteuid is a thin syscall wrapper fn check_gpu_privileges() -> Result<(), String> { - if unsafe { libc::geteuid() } != 0 { + if !rustix::process::geteuid().is_root() { return Err( "GPU support requires root privileges for VFIO bind/unbind and TAP networking. \ Run with sudo or ensure CAP_SYS_ADMIN + CAP_NET_ADMIN capabilities are set." @@ -1452,6 +2577,8 @@ fn check_gpu_privileges() -> Result<(), String> { // gRPC API surface, so boxing here would diverge from every other handler. #[allow(clippy::result_large_err)] fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Status> { + validate_sandbox_id(&sandbox.id)?; + let spec = sandbox .spec .as_ref() @@ -1478,11 +2605,32 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu "vm sandboxes do not support template.platform_config", )); } - if template.resources.is_some() { - return Err(Status::failed_precondition( - "vm sandboxes do not support template.resources", - )); - } + } + Ok(()) +} + +#[allow(clippy::result_large_err)] +fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { + if sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox id is required")); + } + if sandbox_id.len() > 128 { + return Err(Status::invalid_argument( + "sandbox id exceeds maximum length (128 bytes)", + )); + } + if matches!(sandbox_id, "." | "..") { + return Err(Status::invalid_argument( + "sandbox id must match [A-Za-z0-9._-]{1,128}", + )); + } + if !sandbox_id + .bytes() + .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'_' | b'-')) + { + return Err(Status::invalid_argument( + "sandbox id must match [A-Za-z0-9._-]{1,128}", + )); } Ok(()) } @@ -1496,11 +2644,58 @@ fn parse_registry_reference(image_ref: &str) -> Result { }) } +/// Try to connect to a local container engine (Docker or Podman). +/// +/// Tries Docker first (`connect_with_local_defaults`, which respects +/// `DOCKER_HOST`). If Docker is unavailable, falls back to the Podman +/// socket, which exposes a Docker-compatible API. +async fn connect_local_container_engine() -> Option { + if let Ok(docker) = Docker::connect_with_local_defaults() + && docker.ping().await.is_ok() + { + return Some(docker); + } + + let podman_socket = podman_socket_path(); + if podman_socket.exists() + && let Ok(docker) = + Docker::connect_with_unix(podman_socket.to_str()?, 120, bollard::API_DEFAULT_VERSION) + && docker.ping().await.is_ok() + { + info!( + socket = %podman_socket.display(), + "vm driver: connected to Podman (Docker-compatible API)" + ); + return Some(docker); + } + + None +} + +/// Podman user socket path for the current platform. +fn podman_socket_path() -> PathBuf { + #[cfg(target_os = "macos")] + { + let home = std::env::var("HOME").unwrap_or_default(); + PathBuf::from(home).join(".local/share/containers/podman/machine/podman.sock") + } + #[cfg(target_os = "linux")] + { + std::env::var("XDG_RUNTIME_DIR").map_or_else( + |_| { + let uid = nix::unistd::getuid(); + PathBuf::from(format!("/run/user/{uid}/podman/podman.sock")) + }, + |xdg| PathBuf::from(xdg).join("podman/podman.sock"), + ) + } +} + fn is_openshell_local_build_image_ref(image_ref: &str) -> bool { image_ref.starts_with("openshell/sandbox-from:") } -fn local_docker_image_platform_mismatch( +fn local_image_platform_mismatch( image_ref: &str, actual_os: Option<&str>, actual_arch: Option<&str>, @@ -1603,19 +2798,6 @@ async fn export_local_image_rootfs_to_path( } } -fn prepare_exported_rootfs_archive( - image_ref: &str, - image_identity: &str, - exported_rootfs: &Path, - prepared_rootfs: &Path, - prepared_archive: &Path, -) -> Result<(), String> { - extract_rootfs_archive_to(exported_rootfs, prepared_rootfs)?; - prepare_sandbox_rootfs_from_image_root(prepared_rootfs, image_identity) - .map_err(|err| format!("vm sandbox image '{image_ref}' is not base-compatible: {err}"))?; - create_rootfs_archive_from_dir(prepared_rootfs, prepared_archive) -} - fn registry_client() -> OciClient { OciClient::new(ClientConfig { platform_resolver: Some(Box::new(linux_platform_resolver)), @@ -1733,76 +2915,105 @@ impl VmDriver { let total_layers = manifest.layers.len(); let total_bytes: i64 = manifest.layers.iter().map(|layer| layer.size.max(0)).sum(); - for (index, layer) in manifest.layers.iter().enumerate() { - // Emit a per-layer progress event so the CLI can show - // "Layer 3/8 (12.4 MB)" as detail under the spinner. - let mut metadata = HashMap::new(); - metadata.insert("layer_index".to_string(), (index + 1).to_string()); - metadata.insert("layer_total".to_string(), total_layers.to_string()); - metadata.insert("layer_digest".to_string(), layer.digest.clone()); - metadata.insert("layer_size_bytes".to_string(), layer.size.to_string()); - metadata.insert("image_ref".to_string(), image_ref.to_string()); - if total_bytes > 0 { - metadata.insert("image_size_bytes".to_string(), total_bytes.to_string()); - } - let mut event = platform_event( - "vm", - "Normal", - "PullingLayer", - format!( - "Pulling layer {}/{} ({} bytes) for image \"{image_ref}\"", - index + 1, + let mut layers = futures::stream::iter(manifest.layers.iter().cloned().enumerate()) + .map(|(index, layer)| async move { + self.publish_registry_layer_progress( + sandbox_id, + image_ref, + &layer, + index, total_layers, - layer.size - ), - ); - event.metadata = metadata; - self.publish_platform_event(sandbox_id.to_string(), event); - - pull_registry_layer( - client, - reference, - image_ref, - staging_dir, - rootfs, - layer, - index, - ) + total_bytes, + ); + download_registry_layer_blob( + client, + reference, + image_ref, + staging_dir, + layer, + index, + ) + .await + }) + .buffer_unordered(registry_layer_download_concurrency()) + .try_collect::>() .await?; + layers.sort_by_key(|layer| layer.index); + + for layer in &layers { + apply_registry_layer_blob(image_ref, rootfs, layer).await?; } Ok(()) } - /// Emit a `Pulled` platform event with a message that mirrors the - /// kubelet's `Successfully pulled image ... Image size: N bytes.` - /// format so the CLI's `extract_image_size` parser works unchanged. - async fn publish_pulled_event(&self, sandbox_id: &str, image_ref: &str, archive_path: &Path) { - let size_suffix = tokio::fs::metadata(archive_path).await.map_or_else( + fn publish_registry_layer_progress( + &self, + sandbox_id: &str, + image_ref: &str, + layer: &OciDescriptor, + index: usize, + total_layers: usize, + total_bytes: i64, + ) { + let mut metadata = HashMap::new(); + metadata.insert("layer_index".to_string(), (index + 1).to_string()); + metadata.insert("layer_total".to_string(), total_layers.to_string()); + metadata.insert("layer_digest".to_string(), layer.digest.clone()); + metadata.insert("layer_size_bytes".to_string(), layer.size.to_string()); + metadata.insert("image_ref".to_string(), image_ref.to_string()); + if total_bytes > 0 { + metadata.insert("image_size_bytes".to_string(), total_bytes.to_string()); + } + let mut event = platform_event( + "vm", + "Normal", + "PullingLayer", + format!( + "Pulling layer {}/{} ({} bytes) for image \"{image_ref}\"", + index + 1, + total_layers, + layer.size + ), + ); + event.metadata = metadata; + attach_vm_progress_metadata(&mut event); + self.publish_platform_event(sandbox_id.to_string(), event); + } + + /// Emit a `Pulled` platform event with progress metadata for the CLI. + async fn publish_pulled_event(&self, sandbox_id: &str, image_ref: &str, image_path: &Path) { + let mut metadata = HashMap::from([("image_ref".to_string(), image_ref.to_string())]); + let size_suffix = tokio::fs::metadata(image_path).await.map_or_else( |_| String::new(), - |meta| format!(" Image size: {} bytes.", meta.len()), + |meta| { + metadata.insert("image_size_bytes".to_string(), meta.len().to_string()); + format!(" Image size: {} bytes.", meta.len()) + }, ); - self.publish_platform_event( - sandbox_id.to_string(), - platform_event( - "vm", - "Normal", - "Pulled", - format!("Successfully pulled image \"{image_ref}\".{size_suffix}"), - ), + self.publish_vm_progress( + sandbox_id, + "Pulled", + format!("Successfully pulled image \"{image_ref}\".{size_suffix}"), + metadata, ); } } -async fn pull_registry_layer( +struct DownloadedRegistryLayer { + index: usize, + digest: String, + layer_root: PathBuf, +} + +async fn download_registry_layer_blob( client: &OciClient, reference: &Reference, image_ref: &str, staging_dir: &Path, - rootfs: &Path, - layer: &OciDescriptor, + layer: OciDescriptor, index: usize, -) -> Result<(), Status> { +) -> Result { let digest_component = sanitize_image_identity(&layer.digest); let blob_path = staging_dir .join("layers") @@ -1815,7 +3026,7 @@ async fn pull_registry_layer( .await .map_err(|err| Status::internal(format!("create layer blob failed: {err}")))?; client - .pull_blob(reference, layer, &mut file) + .pull_blob(reference, &layer, &mut file) .await .map_err(|err| { Status::failed_precondition(format!( @@ -1828,33 +3039,104 @@ async fn pull_registry_layer( .map_err(|err| Status::internal(format!("flush layer blob failed: {err}")))?; let blob_path_for_digest = blob_path.clone(); - let expected_digest = layer.digest.clone(); + let expected_digest = layer.digest.clone(); + tokio::task::spawn_blocking(move || { + verify_descriptor_digest(&blob_path_for_digest, &expected_digest) + }) + .await + .map_err(|err| Status::internal(format!("layer digest verification panicked: {err}")))? + .map_err(|err| { + Status::failed_precondition(format!( + "vm sandbox image layer verification failed for '{}': {err}", + layer.digest + )) + })?; + + let blob_path_for_unpack = blob_path.clone(); + let layer_root_for_unpack = layer_root.clone(); + let media_type = layer.media_type.clone(); + tokio::task::spawn_blocking(move || { + extract_layer_blob_to_dir(&blob_path_for_unpack, &media_type, &layer_root_for_unpack) + }) + .await + .map_err(|err| Status::internal(format!("layer extraction panicked: {err}")))? + .map_err(|err| { + Status::failed_precondition(format!( + "failed to extract layer '{}' for vm sandbox image '{image_ref}': {err}", + layer.digest + )) + })?; + + Ok(DownloadedRegistryLayer { + index, + digest: layer.digest, + layer_root, + }) +} + +async fn apply_registry_layer_blob( + image_ref: &str, + rootfs: &Path, + layer: &DownloadedRegistryLayer, +) -> Result<(), Status> { + let layer_root_for_unpack = layer.layer_root.clone(); + let rootfs_for_unpack = rootfs.to_path_buf(); + tokio::task::spawn_blocking(move || { + apply_layer_dir_to_rootfs(&layer_root_for_unpack, &rootfs_for_unpack) + }) + .await + .map_err(|err| Status::internal(format!("layer application panicked: {err}")))? + .map_err(|err| { + Status::failed_precondition(format!( + "failed to apply layer '{}' for vm sandbox image '{image_ref}': {err}", + layer.digest + )) + }) +} + +async fn download_registry_descriptor_blob_file( + client: &OciClient, + reference: &Reference, + image_ref: &str, + layout_dir: &Path, + descriptor: &OciDescriptor, + kind: &str, +) -> Result<(), Status> { + let blob_path = oci_layout_blob_path(layout_dir, &descriptor.digest) + .map_err(|err| Status::failed_precondition(format!("invalid {kind} digest: {err}")))?; + if let Some(parent) = blob_path.parent() { + tokio::fs::create_dir_all(parent) + .await + .map_err(|err| Status::internal(format!("create OCI blob dir failed: {err}")))?; + } + + let mut file = tokio::fs::File::create(&blob_path) + .await + .map_err(|err| Status::internal(format!("create OCI {kind} blob failed: {err}")))?; + client + .pull_blob(reference, descriptor, &mut file) + .await + .map_err(|err| { + Status::failed_precondition(format!( + "failed to download {kind} '{}' for vm sandbox image '{image_ref}': {err}", + descriptor.digest + )) + })?; + file.flush() + .await + .map_err(|err| Status::internal(format!("flush OCI {kind} blob failed: {err}")))?; + + let blob_path_for_digest = blob_path.clone(); + let expected_digest = descriptor.digest.clone(); tokio::task::spawn_blocking(move || { verify_descriptor_digest(&blob_path_for_digest, &expected_digest) }) .await - .map_err(|err| Status::internal(format!("layer digest verification panicked: {err}")))? - .map_err(|err| { - Status::failed_precondition(format!( - "vm sandbox image layer verification failed for '{}': {err}", - layer.digest - )) - })?; - - let blob_path_for_unpack = blob_path.clone(); - let layer_root_for_unpack = layer_root.clone(); - let rootfs_for_unpack = rootfs.to_path_buf(); - let media_type = layer.media_type.clone(); - tokio::task::spawn_blocking(move || { - extract_layer_blob_to_dir(&blob_path_for_unpack, &media_type, &layer_root_for_unpack)?; - apply_layer_dir_to_rootfs(&layer_root_for_unpack, &rootfs_for_unpack) - }) - .await - .map_err(|err| Status::internal(format!("layer extraction panicked: {err}")))? + .map_err(|err| Status::internal(format!("OCI {kind} digest verification panicked: {err}")))? .map_err(|err| { Status::failed_precondition(format!( - "failed to apply layer '{}' for vm sandbox image '{image_ref}': {err}", - layer.digest + "vm sandbox image {kind} verification failed for '{}': {err}", + descriptor.digest )) }) } @@ -1890,6 +3172,12 @@ fn compute_file_sha256_hex(path: &Path) -> Result { Ok(format!("{:x}", hasher.finalize())) } +fn compute_bytes_sha256_hex(bytes: &[u8]) -> String { + let mut hasher = Sha256::new(); + hasher.update(bytes); + format!("{:x}", hasher.finalize()) +} + fn extract_layer_blob_to_dir( blob_path: &Path, media_type: &str, @@ -2116,7 +3404,7 @@ fn merged_environment(sandbox: &Sandbox) -> HashMap { /// not the host's. Inside the guest we need a name that gvproxy will translate /// into the host's loopback address. /// -/// We rewrite to `host.containers.internal`, which gvproxy's embedded DNS resolves +/// We rewrite to `host.openshell.internal`, which gvproxy's embedded DNS resolves /// to the host-loopback IP `192.168.127.254`. gvproxy installs a default NAT entry /// rewriting that destination to the host's `127.0.0.1` and dialing out from the /// host process, so any port the host is listening on becomes reachable. The @@ -2176,148 +3464,760 @@ fn build_guest_environment( "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin".to_string(), ), ("TERM".to_string(), "xterm".to_string()), - ("OPENSHELL_ENDPOINT".to_string(), openshell_endpoint), - ("OPENSHELL_SANDBOX_ID".to_string(), sandbox.id.clone()), - ("OPENSHELL_SANDBOX".to_string(), sandbox.name.clone()), ( - "OPENSHELL_SSH_SOCKET_PATH".to_string(), - GUEST_SSH_SOCKET_PATH.to_string(), + openshell_core::sandbox_env::ENDPOINT.to_string(), + openshell_endpoint, ), ( - "OPENSHELL_SANDBOX_COMMAND".to_string(), - "tail -f /dev/null".to_string(), + openshell_core::sandbox_env::SANDBOX_ID.to_string(), + sandbox.id.clone(), + ), + ( + openshell_core::sandbox_env::SANDBOX.to_string(), + sandbox.name.clone(), ), ( - "OPENSHELL_LOG_LEVEL".to_string(), - sandbox_log_level(sandbox, &config.log_level), + openshell_core::sandbox_env::SSH_SOCKET_PATH.to_string(), + GUEST_SSH_SOCKET_PATH.to_string(), + ), + ( + openshell_core::sandbox_env::SANDBOX_COMMAND.to_string(), + "tail -f /dev/null".to_string(), ), ( - "OPENSHELL_SSH_HANDSHAKE_SECRET".to_string(), - config.ssh_handshake_secret.clone(), + openshell_core::sandbox_env::LOG_LEVEL.to_string(), + openshell_core::driver_utils::sandbox_log_level(sandbox, &config.log_level), ), ]); if config.requires_tls_materials() { environment.extend(HashMap::from([ ( - "OPENSHELL_TLS_CA".to_string(), + openshell_core::sandbox_env::TLS_CA.to_string(), GUEST_TLS_CA_PATH.to_string(), ), ( - "OPENSHELL_TLS_CERT".to_string(), + openshell_core::sandbox_env::TLS_CERT.to_string(), GUEST_TLS_CERT_PATH.to_string(), ), ( - "OPENSHELL_TLS_KEY".to_string(), + openshell_core::sandbox_env::TLS_KEY.to_string(), GUEST_TLS_KEY_PATH.to_string(), ), ])); } - environment.extend(merged_environment(sandbox)); + environment.extend(merged_environment(sandbox)); + environment.remove(openshell_core::sandbox_env::SANDBOX_TOKEN); + environment.remove(openshell_core::sandbox_env::SANDBOX_TOKEN_FILE); + if sandbox + .spec + .as_ref() + .is_some_and(|spec| !spec.sandbox_token.is_empty()) + { + environment.insert( + openshell_core::sandbox_env::SANDBOX_TOKEN_FILE.to_string(), + GUEST_SANDBOX_TOKEN_PATH.to_string(), + ); + } + + let mut pairs = environment.into_iter().collect::>(); + pairs.sort_by(|left, right| left.0.cmp(&right.0)); + pairs + .into_iter() + .map(|(key, value)| format!("{key}={value}")) + .collect() +} + +fn sandboxes_root_dir(root: &Path) -> PathBuf { + root.join("sandboxes") +} + +async fn create_private_dir_all(path: &Path) -> Result<(), std::io::Error> { + tokio::fs::create_dir_all(path).await?; + restrict_owner_only_dir(path).await +} + +#[cfg(unix)] +async fn restrict_owner_only_dir(path: &Path) -> Result<(), std::io::Error> { + tokio::fs::set_permissions(path, fs::Permissions::from_mode(0o700)).await +} + +#[cfg(not(unix))] +async fn restrict_owner_only_dir(_path: &Path) -> Result<(), std::io::Error> { + Ok(()) +} + +#[allow(clippy::result_large_err)] +fn sandbox_state_dir(root: &Path, sandbox_id: &str) -> Result { + validate_sandbox_id(sandbox_id)?; + Ok(sandboxes_root_dir(root).join(sandbox_id)) +} + +fn sandbox_overlay_image(state_dir: &Path) -> PathBuf { + state_dir.join(SANDBOX_OVERLAY_IMAGE) +} + +fn overlay_template_image(root: &Path, size_bytes: u64) -> PathBuf { + image_cache_root_dir(root) + .join(OVERLAY_TEMPLATE_CACHE_DIR) + .join(OVERLAY_TEMPLATE_CACHE_LAYOUT_VERSION) + .join(format!("{size_bytes}.ext4")) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SandboxRuntimeDiskPaths { + overlay_disk: PathBuf, +} + +fn sandbox_runtime_disk_paths(state_dir: &Path) -> SandboxRuntimeDiskPaths { + SandboxRuntimeDiskPaths { + overlay_disk: sandbox_overlay_image(state_dir), + } +} + +#[allow(clippy::result_large_err)] +fn validate_sandbox_state_dir(root: &Path, state_dir: &Path) -> Result<(), Status> { + let sandboxes_root = sandboxes_root_dir(root); + let relative = state_dir.strip_prefix(&sandboxes_root).map_err(|_| { + Status::internal(format!( + "refusing to use sandbox state path outside vm state root: {}", + state_dir.display() + )) + })?; + + let mut components = relative.components(); + match components.next() { + Some(Component::Normal(_)) => {} + _ => { + return Err(Status::internal(format!( + "refusing to use malformed sandbox state path: {}", + state_dir.display() + ))); + } + } + if components.next().is_some() { + return Err(Status::internal(format!( + "refusing to use nested sandbox state path: {}", + state_dir.display() + ))); + } + + Ok(()) +} + +async fn remove_sandbox_state_dir(root: &Path, state_dir: &Path) -> Result<(), Status> { + validate_sandbox_state_dir(root, state_dir)?; + + let metadata = match tokio::fs::symlink_metadata(state_dir).await { + Ok(metadata) => metadata, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(()), + Err(err) => { + return Err(Status::internal(format!( + "failed to stat sandbox state dir: {err}" + ))); + } + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Status::internal(format!( + "refusing to remove symlinked sandbox state dir: {}", + state_dir.display() + ))); + } + if !file_type.is_dir() { + return Err(Status::internal(format!( + "sandbox state path is not a directory: {}", + state_dir.display() + ))); + } + + tokio::fs::remove_dir_all(state_dir) + .await + .map_err(|err| Status::internal(format!("failed to remove state dir: {err}"))) +} + +fn image_cache_root_dir(root: &Path) -> PathBuf { + root.join(IMAGE_CACHE_ROOT_DIR) +} + +fn image_cache_dir(root: &Path, image_identity: &str) -> PathBuf { + image_cache_root_dir(root).join(sanitize_image_identity(image_identity)) +} + +fn image_cache_rootfs_image(root: &Path, image_identity: &str) -> PathBuf { + image_cache_dir(root, image_identity).join(IMAGE_CACHE_ROOTFS_IMAGE) +} + +fn image_cache_staging_dir(root: &Path, image_identity: &str) -> PathBuf { + image_cache_root_dir(root).join(format!( + "{}.staging-{}", + sanitize_image_identity(image_identity), + unique_image_cache_suffix() + )) +} + +fn oci_layout_blobs_dir(layout_dir: &Path) -> PathBuf { + layout_dir.join("blobs").join("sha256") +} + +fn oci_layout_blob_path(layout_dir: &Path, digest: &str) -> Result { + let hex = sha256_digest_hex(digest)?; + Ok(oci_layout_blobs_dir(layout_dir).join(hex)) +} + +fn sha256_digest_hex(digest: &str) -> Result<&str, String> { + let Some((algorithm, hex)) = digest.split_once(':') else { + return Err(format!("digest '{digest}' is missing an algorithm")); + }; + if algorithm != "sha256" { + return Err(format!("unsupported digest algorithm '{algorithm}'")); + } + if hex.is_empty() || !hex.chars().all(|ch| ch.is_ascii_hexdigit()) { + return Err(format!("digest '{digest}' is not a valid sha256 digest")); + } + Ok(hex) +} + +fn write_oci_layout_for_manifest( + layout_dir: &Path, + ref_name: &str, + manifest: &OciImageManifest, +) -> Result<(), String> { + fs::create_dir_all(oci_layout_blobs_dir(layout_dir)) + .map_err(|err| format!("create OCI layout blobs dir failed: {err}"))?; + + fs::write( + layout_dir.join("oci-layout"), + br#"{"imageLayoutVersion":"1.0.0"}"#, + ) + .map_err(|err| format!("write OCI layout marker failed: {err}"))?; + + let manifest_bytes = serde_json::to_vec(manifest) + .map_err(|err| format!("serialize OCI manifest failed: {err}"))?; + let manifest_digest = format!("sha256:{}", compute_bytes_sha256_hex(&manifest_bytes)); + let manifest_blob = oci_layout_blob_path(layout_dir, &manifest_digest) + .map_err(|err| format!("compute OCI manifest blob path failed: {err}"))?; + fs::write(&manifest_blob, &manifest_bytes) + .map_err(|err| format!("write OCI manifest blob failed: {err}"))?; + + let media_type = manifest + .media_type + .clone() + .unwrap_or_else(|| OCI_IMAGE_MEDIA_TYPE.to_string()); + let index = serde_json::json!({ + "schemaVersion": 2, + "manifests": [ + { + "mediaType": media_type, + "digest": manifest_digest, + "size": manifest_bytes.len(), + "annotations": { + "org.opencontainers.image.ref.name": ref_name + } + } + ] + }); + let index_bytes = serde_json::to_vec_pretty(&index) + .map_err(|err| format!("serialize OCI index failed: {err}"))?; + fs::write(layout_dir.join("index.json"), index_bytes) + .map_err(|err| format!("write OCI index failed: {err}"))?; + + Ok(()) +} + +fn bootstrap_image_cache_identity(image_identity: &str) -> String { + format!("{BOOTSTRAP_IMAGE_CACHE_LAYOUT_VERSION}:{image_identity}") +} + +fn prepared_image_cache_identity(image_identity: &str) -> String { + format!("{PREPARED_IMAGE_CACHE_LAYOUT_VERSION}:{image_identity}") +} + +fn registry_layer_download_concurrency() -> usize { + let value = std::env::var("OPENSHELL_VM_IMAGE_PULL_CONCURRENCY").ok(); + registry_layer_download_concurrency_value(value.as_deref()) +} + +fn registry_layer_download_concurrency_value(value: Option<&str>) -> usize { + value + .and_then(|value| value.parse::().ok()) + .filter(|value| *value > 0) + .map_or(DEFAULT_REGISTRY_LAYER_DOWNLOAD_CONCURRENCY, |value| { + value.min(MAX_REGISTRY_LAYER_DOWNLOAD_CONCURRENCY) + }) +} + +fn sanitize_image_identity(image_identity: &str) -> String { + image_identity + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.' { + ch + } else { + '-' + } + }) + .collect() +} + +fn unique_image_cache_suffix() -> String { + let counter = IMAGE_CACHE_BUILD_COUNTER.fetch_add(1, Ordering::Relaxed); + format!("{}-{counter}", openshell_core::time::now_ms()) +} + +async fn write_sandbox_image_metadata( + state_dir: &Path, + image_ref: &str, + image_identity: &str, +) -> Result<(), std::io::Error> { + tokio::fs::write( + state_dir.join(IMAGE_IDENTITY_FILE), + format!("{image_identity}\n"), + ) + .await?; + tokio::fs::write( + state_dir.join(IMAGE_REFERENCE_FILE), + format!("{image_ref}\n"), + ) + .await?; + + Ok(()) +} + +async fn write_sandbox_request(state_dir: &Path, sandbox: &Sandbox) -> Result<(), std::io::Error> { + restrict_owner_only_dir(state_dir).await?; + write_private_file( + &state_dir.join(SANDBOX_REQUEST_FILE), + sandbox.encode_to_vec(), + ) + .await +} + +async fn read_sandbox_request(path: &Path) -> Result { + let bytes = tokio::fs::read(path).await?; + Sandbox::decode(bytes.as_slice()).map_err(|err| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("decode persisted sandbox request: {err}"), + ) + }) +} + +async fn write_private_file(path: &Path, bytes: Vec) -> Result<(), std::io::Error> { + tokio::fs::write(path, bytes).await?; + restrict_owner_read_write(path).await +} + +#[cfg(unix)] +async fn restrict_owner_read_write(path: &Path) -> Result<(), std::io::Error> { + tokio::fs::set_permissions(path, fs::Permissions::from_mode(0o600)).await +} + +#[cfg(not(unix))] +async fn restrict_owner_read_write(_path: &Path) -> Result<(), std::io::Error> { + Ok(()) +} + +#[allow(clippy::result_large_err)] +fn validate_restored_sandbox_state( + root: &Path, + state_dir: &Path, + sandbox: &Sandbox, +) -> Result<(), Status> { + validate_sandbox_id(&sandbox.id)?; + validate_sandbox_state_dir(root, state_dir)?; + let Some(dir_name) = state_dir.file_name().and_then(|name| name.to_str()) else { + return Err(Status::internal(format!( + "sandbox state path has no valid directory name: {}", + state_dir.display() + ))); + }; + if dir_name != sandbox.id { + return Err(Status::internal(format!( + "sandbox state dir '{}' does not match persisted sandbox id '{}'", + dir_name, sandbox.id + ))); + } + Ok(()) +} + +#[derive(Debug, Clone)] +struct GuestTlsMaterials { + ca: Vec, + cert: Vec, + key: Vec, +} + +async fn read_guest_tls_materials(paths: &VmDriverTlsPaths) -> Result { + let ca = tokio::fs::read(&paths.ca) + .await + .map_err(|err| format!("read {}: {err}", paths.ca.display()))?; + let cert = tokio::fs::read(&paths.cert) + .await + .map_err(|err| format!("read {}: {err}", paths.cert.display()))?; + let key = tokio::fs::read(&paths.key) + .await + .map_err(|err| format!("read {}: {err}", paths.key.display()))?; + Ok(GuestTlsMaterials { ca, cert, key }) +} + +async fn overlay_template_image_ready(path: &Path, size_bytes: u64) -> Result { + match tokio::fs::metadata(path).await { + Ok(metadata) => Ok(metadata.is_file() && metadata.len() == size_bytes), + Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false), + Err(err) => Err(format!("stat overlay template {}: {err}", path.display())), + } +} + +fn ensure_sandbox_overlay_template_image( + template_path: &Path, + size_bytes: u64, +) -> Result<(), String> { + if let Ok(metadata) = fs::metadata(template_path) + && metadata.is_file() + && metadata.len() == size_bytes + { + return Ok(()); + } + + let parent = template_path.parent().ok_or_else(|| { + format!( + "overlay template path has no parent: {}", + template_path.display() + ) + })?; + fs::create_dir_all(parent).map_err(|err| { + format!( + "create overlay template cache dir {}: {err}", + parent.display() + ) + })?; + + let staging_image = parent.join(format!( + ".{}.staging-{}-{}", + template_path + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or("overlay-template.ext4"), + std::process::id(), + openshell_core::time::now_ms() + )); + + let result = (|| { + create_empty_sandbox_overlay_image(&staging_image, size_bytes)?; + fs::rename(&staging_image, template_path).map_err(|err| { + format!( + "move overlay template {} to {}: {err}", + staging_image.display(), + template_path.display() + ) + }) + })(); + + if result.is_err() { + let _ = fs::remove_file(&staging_image); + } + result +} + +fn create_empty_sandbox_overlay_image(overlay_disk: &Path, size_bytes: u64) -> Result<(), String> { + let staging_dir = overlay_staging_dir(overlay_disk); + if staging_dir.exists() { + fs::remove_dir_all(&staging_dir) + .map_err(|err| format!("remove stale overlay staging dir: {err}"))?; + } + + let result = (|| { + fs::create_dir_all(staging_dir.join("upper")) + .map_err(|err| format!("create overlay upper dir: {err}"))?; + fs::create_dir_all(staging_dir.join("work")) + .map_err(|err| format!("create overlay work dir: {err}"))?; + fs::create_dir_all(staging_dir.join("config")) + .map_err(|err| format!("create overlay config dir: {err}"))?; + + create_ext4_image_from_dir_with_size(&staging_dir, overlay_disk, size_bytes) + })(); + + let _ = fs::remove_dir_all(&staging_dir); + result +} + +fn create_sandbox_overlay_image_from_template( + template_path: &Path, + overlay_disk: &Path, + tls_materials: Option<&GuestTlsMaterials>, + sandbox_token: Option<&str>, +) -> Result<(), String> { + clone_or_copy_sparse_file(template_path, overlay_disk)?; + if let Some(tls) = tls_materials { + inject_guest_tls_materials(overlay_disk, tls)?; + } + if let Some(token) = sandbox_token { + inject_guest_sandbox_token(overlay_disk, token)?; + } + Ok(()) +} + +fn prepare_sandbox_overlay_image( + template_path: &Path, + overlay_disk: &Path, + tls_materials: Option<&GuestTlsMaterials>, + sandbox_token: Option<&str>, + preparation: OverlayPreparation, + expected_size_bytes: u64, +) -> Result<(), String> { + if preparation == OverlayPreparation::PreserveExisting { + match fs::metadata(overlay_disk) { + Ok(metadata) if metadata.is_file() && metadata.len() == expected_size_bytes => { + if let Some(tls) = tls_materials { + inject_guest_tls_materials(overlay_disk, tls)?; + } + if let Some(token) = sandbox_token { + inject_guest_sandbox_token(overlay_disk, token)?; + } + return Ok(()); + } + Ok(metadata) if metadata.is_file() => { + return Err(format!( + "existing overlay disk '{}' has size {}, expected {}", + overlay_disk.display(), + metadata.len(), + expected_size_bytes + )); + } + Ok(_) => { + return Err(format!( + "existing overlay path '{}' is not a file", + overlay_disk.display() + )); + } + Err(err) if err.kind() == std::io::ErrorKind::NotFound => {} + Err(err) => { + return Err(format!( + "stat overlay disk {}: {err}", + overlay_disk.display() + )); + } + } + } - let mut pairs = environment.into_iter().collect::>(); - pairs.sort_by(|left, right| left.0.cmp(&right.0)); - pairs - .into_iter() - .map(|(key, value)| format!("{key}={value}")) - .collect() + create_sandbox_overlay_image_from_template( + template_path, + overlay_disk, + tls_materials, + sandbox_token, + ) } -fn sandbox_log_level(sandbox: &Sandbox, default_level: &str) -> String { - sandbox - .spec - .as_ref() - .map(|spec| spec.log_level.as_str()) - .filter(|level| !level.is_empty()) - .unwrap_or(default_level) - .to_string() +fn inject_guest_tls_materials( + overlay_disk: &Path, + materials: &GuestTlsMaterials, +) -> Result<(), String> { + write_rootfs_image_file( + overlay_disk, + &overlay_upper_path(GUEST_TLS_CA_PATH), + &materials.ca, + )?; + write_rootfs_image_file( + overlay_disk, + &overlay_upper_path(GUEST_TLS_CERT_PATH), + &materials.cert, + )?; + let key_path = overlay_upper_path(GUEST_TLS_KEY_PATH); + write_rootfs_image_file(overlay_disk, &key_path, &materials.key)?; + set_rootfs_image_file_mode(overlay_disk, &key_path, 0o600) } -fn sandboxes_root_dir(root: &Path) -> PathBuf { - root.join("sandboxes") +fn inject_guest_sandbox_token(overlay_disk: &Path, token: &str) -> Result<(), String> { + let token_path = overlay_upper_path(GUEST_SANDBOX_TOKEN_PATH); + write_rootfs_image_file(overlay_disk, &token_path, format!("{token}\n").as_bytes())?; + set_rootfs_image_file_mode(overlay_disk, &token_path, 0o600) } -fn sandbox_state_dir(root: &Path, sandbox_id: &str) -> PathBuf { - sandboxes_root_dir(root).join(sandbox_id) +fn overlay_upper_path(guest_path: &str) -> String { + format!("/upper/{}", guest_path.trim_start_matches('/')) } -fn image_cache_root_dir(root: &Path) -> PathBuf { - root.join(IMAGE_CACHE_ROOT_DIR) -} +fn create_image_prep_disk( + image_path: &Path, + size_bytes: u64, + payload: &GuestImagePayload, +) -> Result<(), String> { + let staging_dir = overlay_staging_dir(image_path); + if staging_dir.exists() { + fs::remove_dir_all(&staging_dir) + .map_err(|err| format!("remove stale image-prep staging dir: {err}"))?; + } -fn image_cache_dir(root: &Path, image_identity: &str) -> PathBuf { - image_cache_root_dir(root).join(sanitize_image_identity(image_identity)) + let result = (|| { + fs::create_dir_all(staging_dir.join("upper").join("srv")) + .map_err(|err| format!("create image-prep env dir: {err}"))?; + fs::create_dir_all(staging_dir.join("work")) + .map_err(|err| format!("create image-prep work dir: {err}"))?; + fs::create_dir_all(staging_dir.join("config")) + .map_err(|err| format!("create image-prep config dir: {err}"))?; + stage_guest_image_payload(&staging_dir, payload)?; + create_ext4_image_from_dir_with_size(&staging_dir, image_path, size_bytes) + })(); + + let _ = fs::remove_dir_all(&staging_dir); + result } -fn image_cache_rootfs_archive(root: &Path, image_identity: &str) -> PathBuf { - image_cache_dir(root, image_identity).join(IMAGE_CACHE_ROOTFS_ARCHIVE) -} +fn stage_guest_image_payload( + staging_dir: &Path, + payload: &GuestImagePayload, +) -> Result<(), String> { + let image_dir = staging_dir.join("config").join(GUEST_IMAGE_CONFIG_DIR); + fs::create_dir_all(&image_dir).map_err(|err| { + format!( + "create guest image config dir {}: {err}", + image_dir.display() + ) + })?; + fs::write(image_dir.join("ref"), payload.image_ref.as_bytes()) + .map_err(|err| format!("write guest image ref: {err}"))?; + fs::write( + image_dir.join("identity"), + payload.image_identity.as_bytes(), + ) + .map_err(|err| format!("write guest image identity: {err}"))?; -fn image_cache_staging_dir(root: &Path, image_identity: &str) -> PathBuf { - image_cache_root_dir(root).join(format!( - "{}.staging-{}", - sanitize_image_identity(image_identity), - unique_image_cache_suffix() - )) + match &payload.source { + GuestImagePayloadSource::RegistryOciLayout { layout_dir } => { + fs::write(image_dir.join("source"), b"oci-layout") + .map_err(|err| format!("write guest image source: {err}"))?; + copy_dir_recursive(layout_dir, &image_dir.join(GUEST_IMAGE_OCI_LAYOUT_DIR))?; + } + GuestImagePayloadSource::LocalDocker { rootfs_archive } => { + fs::write(image_dir.join("source"), b"local-docker") + .map_err(|err| format!("write guest image source: {err}"))?; + let dest = image_dir.join(IMAGE_EXPORT_ROOTFS_ARCHIVE); + fs::copy(rootfs_archive, &dest).map_err(|err| { + format!( + "copy guest image rootfs archive {} to {}: {err}", + rootfs_archive.display(), + dest.display() + ) + })?; + } + } + + Ok(()) } -fn sanitize_image_identity(image_identity: &str) -> String { - image_identity - .chars() - .map(|ch| { - if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.' { - ch - } else { - '-' +fn copy_dir_recursive(source: &Path, dest: &Path) -> Result<(), String> { + fs::create_dir_all(dest).map_err(|err| format!("create {}: {err}", dest.display()))?; + for entry in fs::read_dir(source).map_err(|err| format!("read {}: {err}", source.display()))? { + let entry = entry.map_err(|err| format!("read {}: {err}", source.display()))?; + let source_path = entry.path(); + let dest_path = dest.join(entry.file_name()); + let metadata = fs::symlink_metadata(&source_path) + .map_err(|err| format!("stat {}: {err}", source_path.display()))?; + if metadata.file_type().is_dir() { + copy_dir_recursive(&source_path, &dest_path)?; + } else if metadata.file_type().is_file() { + if let Some(parent) = dest_path.parent() { + fs::create_dir_all(parent) + .map_err(|err| format!("create {}: {err}", parent.display()))?; } - }) - .collect() + fs::copy(&source_path, &dest_path).map_err(|err| { + format!( + "copy {} to {}: {err}", + source_path.display(), + dest_path.display() + ) + })?; + } else { + return Err(format!( + "unsupported payload entry type at {}", + source_path.display() + )); + } + } + Ok(()) } -fn unique_image_cache_suffix() -> String { - let counter = IMAGE_CACHE_BUILD_COUNTER.fetch_add(1, Ordering::Relaxed); - format!("{}-{counter}", current_time_ms()) +fn prepared_image_disk_size_bytes( + payload: &GuestImagePayload, + minimum_size_bytes: u64, +) -> Result { + let payload_size = match &payload.source { + GuestImagePayloadSource::RegistryOciLayout { layout_dir } => dir_size_bytes(layout_dir)?, + GuestImagePayloadSource::LocalDocker { rootfs_archive } => fs::metadata(rootfs_archive) + .map_err(|err| format!("stat {}: {err}", rootfs_archive.display()))? + .len(), + }; + let requested = payload_size + .saturating_mul(3) + .saturating_add(512 * 1024 * 1024); + Ok(minimum_size_bytes.max(requested)) } -async fn write_sandbox_image_metadata( - state_dir: &Path, - image_ref: &str, - image_identity: &str, -) -> Result<(), std::io::Error> { - tokio::fs::write( - state_dir.join(IMAGE_IDENTITY_FILE), - format!("{image_identity}\n"), - ) - .await?; - tokio::fs::write( - state_dir.join(IMAGE_REFERENCE_FILE), - format!("{image_ref}\n"), - ) - .await?; - - Ok(()) +fn dir_size_bytes(path: &Path) -> Result { + let metadata = + fs::symlink_metadata(path).map_err(|err| format!("stat {}: {err}", path.display()))?; + if metadata.file_type().is_file() { + return Ok(metadata.len()); + } + if metadata.file_type().is_symlink() { + return Ok(0); + } + let mut total = 0_u64; + for entry in fs::read_dir(path).map_err(|err| format!("read {}: {err}", path.display()))? { + let entry = entry.map_err(|err| format!("read {}: {err}", path.display()))?; + total = total.saturating_add(dir_size_bytes(&entry.path())?); + } + Ok(total) } -async fn prepare_guest_tls_materials( - rootfs: &Path, - paths: &VmDriverTlsPaths, -) -> Result<(), std::io::Error> { - let guest_tls_dir = rootfs.join(GUEST_TLS_DIR.trim_start_matches('/')); - tokio::fs::create_dir_all(&guest_tls_dir).await?; +#[cfg(test)] +fn stage_guest_tls_materials( + staging_dir: &Path, + materials: &GuestTlsMaterials, +) -> Result<(), String> { + let tls_dir = staging_dir + .join("upper") + .join(GUEST_TLS_CA_PATH.trim_start_matches('/')) + .parent() + .ok_or_else(|| "guest TLS CA path has no parent".to_string())? + .to_path_buf(); + fs::create_dir_all(&tls_dir) + .map_err(|err| format!("create guest TLS dir {}: {err}", tls_dir.display()))?; + + let ca_path = staging_dir + .join("upper") + .join(GUEST_TLS_CA_PATH.trim_start_matches('/')); + let cert_path = staging_dir + .join("upper") + .join(GUEST_TLS_CERT_PATH.trim_start_matches('/')); + let key_path = staging_dir + .join("upper") + .join(GUEST_TLS_KEY_PATH.trim_start_matches('/')); + fs::write(&ca_path, &materials.ca) + .map_err(|err| format!("write guest TLS CA {}: {err}", ca_path.display()))?; + fs::write(&cert_path, &materials.cert) + .map_err(|err| format!("write guest TLS cert {}: {err}", cert_path.display()))?; + fs::write(&key_path, &materials.key) + .map_err(|err| format!("write guest TLS key {}: {err}", key_path.display()))?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + + fs::set_permissions(&key_path, fs::Permissions::from_mode(0o600)) + .map_err(|err| format!("chmod guest TLS key {}: {err}", key_path.display()))?; + } - copy_guest_tls_material(&paths.ca, &guest_tls_dir.join("ca.crt"), 0o644).await?; - copy_guest_tls_material(&paths.cert, &guest_tls_dir.join("tls.crt"), 0o644).await?; - copy_guest_tls_material(&paths.key, &guest_tls_dir.join("tls.key"), 0o600).await?; Ok(()) } -async fn copy_guest_tls_material( - source: &Path, - dest: &Path, - mode: u32, -) -> Result<(), std::io::Error> { - tokio::fs::copy(source, dest).await?; - tokio::fs::set_permissions(dest, fs::Permissions::from_mode(mode)).await?; - Ok(()) +fn overlay_staging_dir(overlay_disk: &Path) -> PathBuf { + let parent = overlay_disk.parent().unwrap_or_else(|| Path::new(".")); + parent.join(format!( + ".openshell-overlay-staging-{}-{}", + std::process::id(), + openshell_core::time::now_ms() + )) } async fn terminate_vm_process(child: &mut Child) -> Result<(), std::io::Error> { @@ -2403,40 +4303,205 @@ fn error_condition(reason: &str, message: &str) -> SandboxCondition { } fn platform_event(source: &str, event_type: &str, reason: &str, message: String) -> PlatformEvent { - PlatformEvent { - timestamp_ms: current_time_ms(), + let mut event = PlatformEvent { + timestamp_ms: openshell_core::time::now_ms(), source: source.to_string(), r#type: event_type.to_string(), reason: reason.to_string(), message, metadata: HashMap::new(), + }; + attach_vm_progress_metadata(&mut event); + event +} + +fn attach_vm_progress_metadata(event: &mut PlatformEvent) { + if event.source != "vm" { + return; + } + + match event.reason.as_str() { + "Scheduled" => { + mark_progress_complete( + &mut event.metadata, + PROGRESS_STEP_REQUESTING_SANDBOX, + "Sandbox allocated", + ); + mark_progress_active(&mut event.metadata, PROGRESS_STEP_PULLING_IMAGE); + } + "Pulling" => { + mark_progress_active(&mut event.metadata, PROGRESS_STEP_PULLING_IMAGE); + if let Some(image_ref) = event.metadata.get("image_ref").cloned() { + mark_progress_detail(&mut event.metadata, image_ref); + } else if let Some(image_ref) = pulling_image_from_message(&event.message) { + mark_progress_detail(&mut event.metadata, image_ref); + } + } + "Pulled" => { + let label = pulled_label(event); + mark_progress_complete(&mut event.metadata, PROGRESS_STEP_PULLING_IMAGE, label); + mark_progress_active(&mut event.metadata, PROGRESS_STEP_STARTING_SANDBOX); + } + "PullingLayer" => { + if let Some(detail) = pulling_layer_detail(&event.metadata) { + mark_progress_detail(&mut event.metadata, detail); + } + } + "ResolvingImage" => mark_progress_detail(&mut event.metadata, "Resolving image"), + "AuthenticatingRegistry" => { + mark_progress_detail(&mut event.metadata, "Authenticating registry"); + } + "FetchingManifest" => mark_progress_detail(&mut event.metadata, "Fetching image manifest"), + "CacheHit" => mark_progress_detail(&mut event.metadata, "Using cached root disk"), + "CacheMiss" => mark_progress_detail(&mut event.metadata, "Preparing image cache"), + "WaitingForImageCacheLock" => { + mark_progress_detail(&mut event.metadata, "Waiting for image cache lock"); + } + "ExportingRootfs" => { + mark_progress_detail(&mut event.metadata, "Exporting local image rootfs"); + } + "PreparingRootfs" => mark_progress_detail(&mut event.metadata, "Preparing rootfs"), + "CreatingRootDisk" => mark_progress_detail(&mut event.metadata, "Formatting root disk"), + "PreparingOverlay" => mark_progress_detail(&mut event.metadata, "Preparing overlay disk"), + "Started" => mark_progress_detail(&mut event.metadata, "Waiting for VM supervisor"), + _ => {} } } -fn current_time_ms() -> i64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map_or(0, |duration| { - i64::try_from(duration.as_millis()).unwrap_or(i64::MAX) - }) +fn pulling_image_from_message(message: &str) -> Option { + let image = message + .strip_prefix("Pulling image ") + .map(str::trim) + .map(|value| value.trim_matches('"'))?; + (!image.is_empty()).then(|| image.to_string()) +} + +fn pulled_label(event: &PlatformEvent) -> String { + event + .metadata + .get("image_size_bytes") + .and_then(|value| value.parse::().ok()) + .map_or_else( + || "Image pulled".to_string(), + |bytes| format!("Image pulled ({})", format_bytes(bytes)), + ) +} + +fn pulling_layer_detail(metadata: &HashMap) -> Option { + let index = metadata.get("layer_index")?; + let total = metadata.get("layer_total")?; + let size = metadata + .get("layer_size_bytes") + .and_then(|value| value.parse::().ok()) + .map(format_bytes); + Some(size.map_or_else( + || format!("Layer {index}/{total}"), + |size| format!("Layer {index}/{total} ({size})"), + )) +} + +fn format_bytes(bytes: u64) -> String { + const KB: u64 = 1024; + const MB: u64 = 1024 * KB; + const GB: u64 = 1024 * MB; + + if bytes >= GB { + #[allow(clippy::cast_precision_loss)] + let gb = bytes as f64 / GB as f64; + format!("{gb:.1} GB") + } else if bytes >= MB { + format!("{} MB", bytes / MB) + } else if bytes >= KB { + format!("{} KB", bytes / KB) + } else { + format!("{bytes} B") + } } #[cfg(test)] mod tests { use super::*; use crate::gpu::{SubnetAllocator, allocate_vsock_cid, mac_from_sandbox_id, tap_device_name}; + use openshell_core::progress::{ + PROGRESS_ACTIVE_DETAIL_KEY, PROGRESS_ACTIVE_STEP_KEY, PROGRESS_COMPLETE_LABEL_KEY, + PROGRESS_COMPLETE_STEP_KEY, + }; use openshell_core::proto::compute::v1::{ DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, }; use prost_types::{Struct, Value, value::Kind}; use std::fs; + use std::path::Path; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{SystemTime, UNIX_EPOCH}; use tonic::Code; + #[test] + fn vm_pulling_layer_event_adds_progress_detail_metadata() { + let mut event = platform_event( + "vm", + "Normal", + "PullingLayer", + "Pulling layer 3/8 for image".to_string(), + ); + event.metadata = HashMap::from([ + ("layer_index".to_string(), "3".to_string()), + ("layer_total".to_string(), "8".to_string()), + ("layer_size_bytes".to_string(), "44040192".to_string()), + ]); + + attach_vm_progress_metadata(&mut event); + + assert_eq!( + event + .metadata + .get(PROGRESS_ACTIVE_DETAIL_KEY) + .map(String::as_str), + Some("Layer 3/8 (42 MB)") + ); + } + + #[test] + fn vm_pulled_event_adds_completed_image_progress_metadata() { + let mut event = platform_event( + "vm", + "Normal", + "Pulled", + "Successfully pulled image".to_string(), + ); + event + .metadata + .insert("image_size_bytes".to_string(), "44040192".to_string()); + + attach_vm_progress_metadata(&mut event); + + assert_eq!( + event + .metadata + .get(PROGRESS_COMPLETE_STEP_KEY) + .map(String::as_str), + Some(PROGRESS_STEP_PULLING_IMAGE) + ); + assert_eq!( + event + .metadata + .get(PROGRESS_COMPLETE_LABEL_KEY) + .map(String::as_str), + Some("Image pulled (42 MB)") + ); + assert_eq!( + event + .metadata + .get(PROGRESS_ACTIVE_STEP_KEY) + .map(String::as_str), + Some(PROGRESS_STEP_STARTING_SANDBOX) + ); + } + #[test] fn validate_vm_sandbox_rejects_gpu_when_not_enabled() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: true, ..Default::default() @@ -2452,6 +4517,7 @@ mod tests { #[test] fn validate_vm_sandbox_accepts_gpu_when_enabled() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: true, ..Default::default() @@ -2464,6 +4530,7 @@ mod tests { #[test] fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: false, gpu_device: "0000:2d:00.0".to_string(), @@ -2480,6 +4547,7 @@ mod tests { #[test] fn validate_vm_sandbox_rejects_platform_config() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { template: Some(SandboxTemplate { platform_config: Some(Struct { @@ -2504,18 +4572,213 @@ mod tests { } #[test] - fn validate_vm_sandbox_accepts_template_image() { - let sandbox = Sandbox { - spec: Some(SandboxSpec { - template: Some(SandboxTemplate { - image: "ghcr.io/example/sandbox:latest".to_string(), - ..Default::default() - }), - ..Default::default() - }), - ..Default::default() - }; - validate_vm_sandbox(&sandbox, false).expect("template.image should be accepted"); + fn validate_vm_sandbox_accepts_template_image() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + template: Some(SandboxTemplate { + image: "ghcr.io/example/sandbox:latest".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, false).expect("template.image should be accepted"); + } + + #[test] + fn validate_vm_sandbox_accepts_template_resources_as_noop() { + use openshell_core::proto::compute::v1::DriverResourceRequirements; + + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + template: Some(SandboxTemplate { + resources: Some(DriverResourceRequirements { + cpu_limit: "2".to_string(), + memory_limit: "4Gi".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, false) + .expect("template.resources should be accepted and ignored"); + } + + #[test] + fn validate_vm_sandbox_rejects_path_unsafe_ids() { + let mut unsafe_ids = [ + "", + ".", + "..", + "../escape", + "/tmp/escape", + "nested/path", + "nested\\path", + "bad\nid", + "bad id", + "unicodé", + ] + .into_iter() + .map(str::to_string) + .collect::>(); + unsafe_ids.push("a".repeat(129)); + + for sandbox_id in unsafe_ids { + let sandbox = Sandbox { + id: sandbox_id.clone(), + spec: Some(SandboxSpec { + template: Some(SandboxTemplate { + image: "ghcr.io/example/sandbox:latest".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, false) + .expect_err("path-unsafe sandbox id should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument, "id={sandbox_id:?}"); + assert!(err.message().contains("sandbox id"), "id={sandbox_id:?}"); + } + } + + #[test] + fn sandbox_state_dir_rejects_path_unsafe_ids() { + let err = sandbox_state_dir(Path::new("/tmp/openshell-vm"), "../escape") + .expect_err("path traversal should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + } + + #[test] + fn sandbox_runtime_disk_paths_use_per_sandbox_overlay() { + let driver_state = Path::new("/tmp/openshell-vm"); + let state_dir = driver_state.join("sandboxes").join("sandbox-123"); + + let disks = sandbox_runtime_disk_paths(&state_dir); + + assert_eq!(disks.overlay_disk, state_dir.join(SANDBOX_OVERLAY_IMAGE)); + } + + #[test] + fn overlay_template_image_is_keyed_by_size_and_layout() { + let path = overlay_template_image(Path::new("/tmp/openshell-vm"), 4 * 1024 * 1024); + + assert_eq!( + path, + Path::new("/tmp/openshell-vm") + .join(IMAGE_CACHE_ROOT_DIR) + .join(OVERLAY_TEMPLATE_CACHE_DIR) + .join(OVERLAY_TEMPLATE_CACHE_LAYOUT_VERSION) + .join("4194304.ext4") + ); + } + + #[tokio::test] + async fn sandbox_request_metadata_round_trips_for_resume() { + let base = unique_temp_dir(); + let state_dir = base.join("sandboxes").join("sandbox-123"); + std::fs::create_dir_all(&state_dir).unwrap(); + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + name: "resume-sandbox".to_string(), + namespace: "vm-dev".to_string(), + spec: Some(SandboxSpec { + environment: HashMap::from([("KEY".to_string(), "value".to_string())]), + template: Some(SandboxTemplate { + image: "ghcr.io/example/sandbox:latest".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + + write_sandbox_request(&state_dir, &sandbox) + .await + .expect("write sandbox request"); + let restored = read_sandbox_request(&state_dir.join(SANDBOX_REQUEST_FILE)) + .await + .expect("read sandbox request"); + + assert_eq!(restored, sandbox); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + + let dir_mode = std::fs::metadata(&state_dir).unwrap().permissions().mode() & 0o777; + let file_mode = std::fs::metadata(state_dir.join(SANDBOX_REQUEST_FILE)) + .unwrap() + .permissions() + .mode() + & 0o777; + assert_eq!(dir_mode, 0o700); + assert_eq!(file_mode, 0o600); + } + validate_restored_sandbox_state(&base, &state_dir, &restored) + .expect("restored state should validate"); + + let _ = std::fs::remove_dir_all(base); + } + + #[test] + fn prepare_sandbox_overlay_preserves_existing_overlay_on_resume() { + let base = unique_temp_dir(); + std::fs::create_dir_all(&base).unwrap(); + let template = base.join("template.ext4"); + let overlay = base.join("overlay.ext4"); + std::fs::write(&template, b"fresh-overlay").unwrap(); + std::fs::write(&overlay, b"saved-overlay").unwrap(); + + prepare_sandbox_overlay_image( + &template, + &overlay, + None, + None, + OverlayPreparation::PreserveExisting, + "saved-overlay".len() as u64, + ) + .expect("preserve existing overlay"); + + assert_eq!(std::fs::read(&overlay).unwrap(), b"saved-overlay"); + + let _ = std::fs::remove_dir_all(base); + } + + #[test] + fn prepare_sandbox_overlay_creates_missing_overlay_on_resume() { + let base = unique_temp_dir(); + std::fs::create_dir_all(&base).unwrap(); + let template = base.join("template.ext4"); + let overlay = base.join("overlay.ext4"); + std::fs::write(&template, b"fresh-overlay").unwrap(); + + prepare_sandbox_overlay_image( + &template, + &overlay, + None, + None, + OverlayPreparation::PreserveExisting, + "fresh-overlay".len() as u64, + ) + .expect("create missing overlay"); + + assert_eq!(std::fs::read(&overlay).unwrap(), b"fresh-overlay"); + + let _ = std::fs::remove_dir_all(base); + } + + #[test] + fn overlay_upper_path_targets_overlay_upperdir() { + assert_eq!( + overlay_upper_path(GUEST_TLS_KEY_PATH), + "/upper/opt/openshell/tls/tls.key" + ); } #[test] @@ -2629,6 +4892,76 @@ mod tests { assert!(driver.resolved_sandbox_image(&sandbox).is_none()); } + #[test] + fn bootstrap_image_ref_prefers_explicit_bootstrap_image() { + let driver = VmDriver { + config: VmDriverConfig { + default_image: "openshell/sandbox:default".to_string(), + bootstrap_image: "openshell/sandbox-bootstrap:latest".to_string(), + ..Default::default() + }, + launcher_bin: PathBuf::from("/tmp/openshell-driver-vm"), + registry: Arc::new(Mutex::new(HashMap::new())), + image_cache_lock: Arc::new(Mutex::new(())), + events: broadcast::channel(WATCH_BUFFER).0, + gpu_inventory: None, + subnet_allocator: Arc::new(std::sync::Mutex::new(SubnetAllocator::new( + Ipv4Addr::new(10, 0, 128, 0), + 17, + ))), + }; + + assert_eq!( + driver.bootstrap_image_ref("ghcr.io/example/app:latest"), + "openshell/sandbox-bootstrap:latest" + ); + } + + #[test] + fn bootstrap_image_ref_falls_back_to_default_image() { + let driver = VmDriver { + config: VmDriverConfig { + default_image: "openshell/sandbox:default".to_string(), + ..Default::default() + }, + launcher_bin: PathBuf::from("/tmp/openshell-driver-vm"), + registry: Arc::new(Mutex::new(HashMap::new())), + image_cache_lock: Arc::new(Mutex::new(())), + events: broadcast::channel(WATCH_BUFFER).0, + gpu_inventory: None, + subnet_allocator: Arc::new(std::sync::Mutex::new(SubnetAllocator::new( + Ipv4Addr::new(10, 0, 128, 0), + 17, + ))), + }; + + assert_eq!( + driver.bootstrap_image_ref("ghcr.io/example/app:latest"), + "openshell/sandbox:default" + ); + } + + #[test] + fn bootstrap_image_ref_falls_back_to_requested_image() { + let driver = VmDriver { + config: VmDriverConfig::default(), + launcher_bin: PathBuf::from("/tmp/openshell-driver-vm"), + registry: Arc::new(Mutex::new(HashMap::new())), + image_cache_lock: Arc::new(Mutex::new(())), + events: broadcast::channel(WATCH_BUFFER).0, + gpu_inventory: None, + subnet_allocator: Arc::new(std::sync::Mutex::new(SubnetAllocator::new( + Ipv4Addr::new(10, 0, 128, 0), + 17, + ))), + }; + + assert_eq!( + driver.bootstrap_image_ref("ghcr.io/example/app:latest"), + "ghcr.io/example/app:latest" + ); + } + #[test] fn merged_environment_prefers_spec_values() { let sandbox = Sandbox { @@ -2654,12 +4987,11 @@ mod tests { fn build_guest_environment_sets_supervisor_defaults() { let config = VmDriverConfig { openshell_endpoint: "http://127.0.0.1:8080".to_string(), - ssh_handshake_secret: "secret".to_string(), ..Default::default() }; let sandbox = Sandbox { id: "sandbox-123".to_string(), - name: "sandbox-123".to_string(), + name: "breezy-rhinoceros".to_string(), spec: Some(SandboxSpec::default()), ..Default::default() }; @@ -2670,20 +5002,48 @@ mod tests { "OPENSHELL_ENDPOINT=http://{GVPROXY_HOST_LOOPBACK_ALIAS}:8080/" ))); assert!(env.contains(&"OPENSHELL_SANDBOX_ID=sandbox-123".to_string())); + assert!(env.contains(&"OPENSHELL_SANDBOX=breezy-rhinoceros".to_string())); assert!(env.contains(&format!( "OPENSHELL_SSH_SOCKET_PATH={GUEST_SSH_SOCKET_PATH}" ))); - assert!( - env.contains(&"OPENSHELL_SSH_HANDSHAKE_SECRET=secret".to_string()), - "SSH handshake secret must be passed to the guest" - ); + } + + #[test] + fn build_guest_environment_uses_token_file_without_raw_token_env() { + let config = VmDriverConfig { + openshell_endpoint: "http://127.0.0.1:8080".to_string(), + ..Default::default() + }; + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + name: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + sandbox_token: "secret.jwt.value".to_string(), + environment: HashMap::from([( + openshell_core::sandbox_env::SANDBOX_TOKEN.to_string(), + "user-provided-token".to_string(), + )]), + ..Default::default() + }), + ..Default::default() + }; + + let env = build_guest_environment(&sandbox, &config, None); + + assert!(!env.iter().any(|v| v.starts_with(&format!( + "{}=", + openshell_core::sandbox_env::SANDBOX_TOKEN + )))); + assert!(env.contains(&format!( + "{}={GUEST_SANDBOX_TOKEN_PATH}", + openshell_core::sandbox_env::SANDBOX_TOKEN_FILE + ))); } #[test] fn build_guest_environment_uses_endpoint_override_for_tap() { let config = VmDriverConfig { openshell_endpoint: "http://127.0.0.1:8080".to_string(), - ssh_handshake_secret: "secret".to_string(), ..Default::default() }; let sandbox = Sandbox { @@ -2779,9 +5139,9 @@ mod tests { } #[test] - fn local_docker_image_platform_mismatch_checks_guest_platform() { + fn local_image_platform_mismatch_checks_guest_platform() { assert!( - local_docker_image_platform_mismatch( + local_image_platform_mismatch( "openshell/sandbox-from:123", Some("linux"), Some(linux_oci_arch()), @@ -2789,7 +5149,7 @@ mod tests { .is_none() ); - let err = local_docker_image_platform_mismatch( + let err = local_image_platform_mismatch( "openshell/sandbox-from:123", Some("linux"), Some("wrong-arch"), @@ -2798,7 +5158,7 @@ mod tests { assert!(err.contains("wrong-arch")); assert!(err.contains(linux_oci_arch())); - let err = local_docker_image_platform_mismatch("openshell/sandbox-from:123", None, None) + let err = local_image_platform_mismatch("openshell/sandbox-from:123", None, None) .expect("unknown platform should be reported"); assert!(err.contains("unknown/unknown")); } @@ -2886,7 +5246,6 @@ mod tests { fn build_guest_environment_includes_tls_paths_for_https_endpoint() { let config = VmDriverConfig { openshell_endpoint: "https://127.0.0.1:8443".to_string(), - ssh_handshake_secret: "secret".to_string(), guest_tls_ca: Some(PathBuf::from("/host/ca.crt")), guest_tls_cert: Some(PathBuf::from("/host/tls.crt")), guest_tls_key: Some(PathBuf::from("/host/tls.key")), @@ -2919,9 +5278,14 @@ mod tests { #[tokio::test] async fn delete_sandbox_keeps_registry_entry_when_cleanup_fails() { + let base = unique_temp_dir(); + let driver_state = base.join("driver-state"); let (events, _) = broadcast::channel(WATCH_BUFFER); let driver = VmDriver { - config: VmDriverConfig::default(), + config: VmDriverConfig { + state_dir: driver_state.clone(), + ..Default::default() + }, launcher_bin: PathBuf::from("openshell-driver-vm"), registry: Arc::new(Mutex::new(HashMap::new())), image_cache_lock: Arc::new(Mutex::new(())), @@ -2933,9 +5297,8 @@ mod tests { ))), }; - let base = unique_temp_dir(); - std::fs::create_dir_all(&base).unwrap(); - let state_file = base.join("state-file"); + let state_file = sandbox_state_dir(&driver_state, "sandbox-123").unwrap(); + std::fs::create_dir_all(state_file.parent().unwrap()).unwrap(); std::fs::write(&state_file, "not a directory").unwrap(); insert_test_record( @@ -2950,19 +5313,20 @@ mod tests { .delete_sandbox("sandbox-123", "sandbox-123") .await .expect_err("state dir cleanup should fail for a file path"); - assert!(err.message().contains("failed to remove state dir")); + assert!(err.message().contains("not a directory")); assert!(driver.registry.lock().await.contains_key("sandbox-123")); - let retry_state_dir = base.join("state-dir"); + std::fs::remove_file(&state_file).unwrap(); + let retry_state_dir = sandbox_state_dir(&driver_state, "sandbox-123").unwrap(); std::fs::create_dir_all(&retry_state_dir).unwrap(); { let mut registry = driver.registry.lock().await; let record = registry.get_mut("sandbox-123").unwrap(); record.state_dir = retry_state_dir; - record.process = Arc::new(Mutex::new(VmProcess { + record.process = Some(Arc::new(Mutex::new(VmProcess { child: spawn_exited_child(), deleting: false, - })); + }))); } let response = driver @@ -2975,6 +5339,154 @@ mod tests { let _ = std::fs::remove_dir_all(base); } + #[tokio::test] + async fn delete_sandbox_cleans_provisioning_record_without_process() { + let base = unique_temp_dir(); + let driver_state = base.join("driver-state"); + let (events, _) = broadcast::channel(WATCH_BUFFER); + let driver = VmDriver { + config: VmDriverConfig { + state_dir: driver_state.clone(), + ..Default::default() + }, + launcher_bin: PathBuf::from("openshell-driver-vm"), + registry: Arc::new(Mutex::new(HashMap::new())), + image_cache_lock: Arc::new(Mutex::new(())), + events, + gpu_inventory: None, + subnet_allocator: Arc::new(std::sync::Mutex::new(SubnetAllocator::new( + Ipv4Addr::new(10, 0, 128, 0), + 17, + ))), + }; + + let state_dir = sandbox_state_dir(&driver_state, "sandbox-123").unwrap(); + std::fs::create_dir_all(&state_dir).unwrap(); + { + let mut registry = driver.registry.lock().await; + registry.insert( + "sandbox-123".to_string(), + SandboxRecord { + snapshot: Sandbox { + id: "sandbox-123".to_string(), + name: "sandbox-123".to_string(), + ..Default::default() + }, + state_dir: state_dir.clone(), + process: None, + provisioning_task: None, + gpu_bdf: None, + deleting: false, + }, + ); + } + + let response = driver + .delete_sandbox("sandbox-123", "sandbox-123") + .await + .expect("delete should handle accepted-but-not-started sandboxes"); + assert!(response.deleted); + assert!(!driver.registry.lock().await.contains_key("sandbox-123")); + assert!(!state_dir.exists()); + + let _ = std::fs::remove_dir_all(base); + } + + #[tokio::test] + async fn duplicate_create_keeps_existing_state_dir() { + let base = unique_temp_dir(); + let driver_state = base.join("driver-state"); + let (events, _) = broadcast::channel(WATCH_BUFFER); + let driver = VmDriver { + config: VmDriverConfig { + state_dir: driver_state.clone(), + default_image: "ghcr.io/example/sandbox:latest".to_string(), + ..Default::default() + }, + launcher_bin: PathBuf::from("openshell-driver-vm"), + registry: Arc::new(Mutex::new(HashMap::new())), + image_cache_lock: Arc::new(Mutex::new(())), + events, + gpu_inventory: None, + subnet_allocator: Arc::new(std::sync::Mutex::new(SubnetAllocator::new( + Ipv4Addr::new(10, 0, 128, 0), + 17, + ))), + }; + + let state_dir = sandbox_state_dir(&driver_state, "sandbox-123").unwrap(); + std::fs::create_dir_all(&state_dir).unwrap(); + std::fs::write(state_dir.join("overlay.ext4"), b"live overlay").unwrap(); + { + let mut registry = driver.registry.lock().await; + registry.insert( + "sandbox-123".to_string(), + SandboxRecord { + snapshot: Sandbox { + id: "sandbox-123".to_string(), + name: "sandbox-123".to_string(), + ..Default::default() + }, + state_dir: state_dir.clone(), + process: None, + provisioning_task: None, + gpu_bdf: None, + deleting: false, + }, + ); + } + + let err = driver + .create_sandbox(&Sandbox { + id: "sandbox-123".to_string(), + name: "sandbox-123".to_string(), + spec: Some(SandboxSpec::default()), + ..Default::default() + }) + .await + .expect_err("duplicate create should fail"); + + assert_eq!(err.code(), Code::AlreadyExists); + assert!(state_dir.join("overlay.ext4").exists()); + assert!(driver.registry.lock().await.contains_key("sandbox-123")); + + let _ = std::fs::remove_dir_all(base); + } + + #[tokio::test] + async fn remove_sandbox_state_dir_rejects_paths_outside_state_root() { + let base = unique_temp_dir(); + let state_root = base.join("driver-state"); + let outside = base.join("outside"); + std::fs::create_dir_all(&outside).unwrap(); + + let err = remove_sandbox_state_dir(&state_root, &outside) + .await + .expect_err("outside state paths should be rejected"); + assert!(err.message().contains("outside vm state root")); + + let _ = std::fs::remove_dir_all(base); + } + + #[cfg(unix)] + #[tokio::test] + async fn remove_sandbox_state_dir_rejects_symlinked_state_dir() { + let base = unique_temp_dir(); + let state_root = base.join("driver-state"); + let target = base.join("target"); + let state_dir = sandbox_state_dir(&state_root, "sandbox-123").unwrap(); + std::fs::create_dir_all(&target).unwrap(); + std::fs::create_dir_all(state_dir.parent().unwrap()).unwrap(); + std::os::unix::fs::symlink(&target, &state_dir).unwrap(); + + let err = remove_sandbox_state_dir(&state_root, &state_dir) + .await + .expect_err("symlinked state dir should be rejected"); + assert!(err.message().contains("symlinked sandbox state dir")); + + let _ = std::fs::remove_dir_all(base); + } + #[test] fn validate_openshell_endpoint_accepts_loopback_hosts() { validate_openshell_endpoint("http://127.0.0.1:8080") @@ -3005,54 +5517,85 @@ mod tests { } #[test] - fn prepare_exported_rootfs_archive_rewrites_docker_exported_rootfs() { + fn prepared_image_cache_identity_includes_rootfs_layout_version() { + assert_eq!( + prepared_image_cache_identity("sha256:local-image"), + "sandbox-prepared-rootfs-ext4-umoci-v2:sha256:local-image" + ); + } + + #[test] + fn bootstrap_image_cache_identity_includes_rootfs_layout_version() { + assert_eq!( + bootstrap_image_cache_identity("sha256:bootstrap-image"), + "sandbox-bootstrap-rootfs-ext4-v2:sha256:bootstrap-image" + ); + } + + #[test] + fn stage_guest_image_payload_copies_registry_oci_layout() { let base = unique_temp_dir(); - let source_rootfs = base.join("source-rootfs"); - let exported_rootfs = base.join("exported-rootfs.tar"); - let prepared_rootfs = base.join("prepared-rootfs"); - let prepared_archive = base.join("prepared-rootfs.tar"); - let extracted = base.join("extracted"); - - for path in [ - "bin/bash", - "bin/mount", - "bin/sed", - "sbin/ip", - "opt/openshell/bin/openshell-sandbox", - ] { - let path = source_rootfs.join(path); - fs::create_dir_all(path.parent().unwrap()).unwrap(); - fs::write(path, "").unwrap(); - } - - create_rootfs_archive_from_dir(&source_rootfs, &exported_rootfs).unwrap(); - prepare_exported_rootfs_archive( - "openshell/sandbox-from:123", - "sha256:local-image", - &exported_rootfs, - &prepared_rootfs, - &prepared_archive, + let staging_dir = base.join("staging"); + let layout_dir = base.join("layout"); + let blob_dir = layout_dir.join("blobs").join("sha256"); + fs::create_dir_all(&blob_dir).unwrap(); + fs::write( + layout_dir.join("oci-layout"), + r#"{"imageLayoutVersion":"1.0.0"}"#, + ) + .unwrap(); + fs::write(layout_dir.join("index.json"), "{}").unwrap(); + fs::write(blob_dir.join("abc"), "blob").unwrap(); + + stage_guest_image_payload( + &staging_dir, + &GuestImagePayload { + image_ref: "ghcr.io/example/app:latest".to_string(), + image_identity: prepared_image_cache_identity("sha256:abc"), + source: GuestImagePayloadSource::RegistryOciLayout { layout_dir }, + }, ) .unwrap(); - extract_rootfs_archive_to(&prepared_archive, &extracted).unwrap(); - assert!(extracted.join("srv/openshell-vm-sandbox-init.sh").is_file()); - assert!( - extracted - .join("opt/openshell/bin/openshell-sandbox") - .is_file() + let image_dir = staging_dir.join("config").join(GUEST_IMAGE_CONFIG_DIR); + assert_eq!( + fs::read_to_string(image_dir.join("source")).unwrap(), + "oci-layout" ); assert_eq!( - fs::read_to_string(extracted.join("opt/openshell/.rootfs-type")).unwrap(), - "sandbox\n" + fs::read_to_string(image_dir.join("ref")).unwrap(), + "ghcr.io/example/app:latest" ); - assert!( - fs::read_to_string(extracted.join(".openshell-rootfs-variant")) - .unwrap() - .contains("sha256:local-image") + assert_eq!( + fs::read_to_string( + image_dir + .join(GUEST_IMAGE_OCI_LAYOUT_DIR) + .join("blobs") + .join("sha256") + .join("abc") + ) + .unwrap(), + "blob" ); - let _ = fs::remove_dir_all(base); + let _ = std::fs::remove_dir_all(base); + } + + #[test] + fn registry_layer_download_concurrency_is_bounded() { + assert_eq!( + registry_layer_download_concurrency_value(None), + DEFAULT_REGISTRY_LAYER_DOWNLOAD_CONCURRENCY + ); + assert_eq!( + registry_layer_download_concurrency_value(Some("0")), + DEFAULT_REGISTRY_LAYER_DOWNLOAD_CONCURRENCY + ); + assert_eq!(registry_layer_download_concurrency_value(Some("8")), 8); + assert_eq!( + registry_layer_download_concurrency_value(Some("999")), + MAX_REGISTRY_LAYER_DOWNLOAD_CONCURRENCY + ); } #[test] @@ -3064,50 +5607,61 @@ mod tests { } #[tokio::test] - async fn prepare_guest_tls_materials_copies_bundle_into_rootfs() { + async fn read_guest_tls_materials_reports_missing_input() { let base = unique_temp_dir(); - let source_dir = base.join("source"); - let rootfs = base.join("rootfs"); - std::fs::create_dir_all(&source_dir).unwrap(); - std::fs::create_dir_all(&rootfs).unwrap(); - - let ca = source_dir.join("ca.crt"); - let cert = source_dir.join("tls.crt"); - let key = source_dir.join("tls.key"); - std::fs::write(&ca, "ca").unwrap(); - std::fs::write(&cert, "cert").unwrap(); - std::fs::write(&key, "key").unwrap(); - - prepare_guest_tls_materials( - &rootfs, - &VmDriverTlsPaths { - ca: ca.clone(), - cert: cert.clone(), - key: key.clone(), - }, - ) + let source_dir = base.join("missing-source"); + + let err = read_guest_tls_materials(&VmDriverTlsPaths { + ca: source_dir.join("ca.crt"), + cert: source_dir.join("tls.crt"), + key: source_dir.join("tls.key"), + }) .await - .unwrap(); + .expect_err("missing TLS materials should fail before image injection"); + + assert!(err.contains("ca.crt")); + + let _ = std::fs::remove_dir_all(base); + } + + #[cfg(unix)] + #[test] + fn stage_guest_tls_materials_places_files_in_overlay_upper_with_private_key_mode() { + use std::os::unix::fs::PermissionsExt as _; + + let base = unique_temp_dir(); + let materials = GuestTlsMaterials { + ca: b"ca".to_vec(), + cert: b"cert".to_vec(), + key: b"key".to_vec(), + }; + + stage_guest_tls_materials(&base, &materials).expect("stage TLS materials"); - let guest_dir = rootfs.join(GUEST_TLS_DIR.trim_start_matches('/')); assert_eq!( - std::fs::read_to_string(guest_dir.join("ca.crt")).unwrap(), - "ca" + fs::read( + base.join("upper") + .join(GUEST_TLS_CA_PATH.trim_start_matches('/')) + ) + .unwrap(), + b"ca" ); assert_eq!( - std::fs::read_to_string(guest_dir.join("tls.crt")).unwrap(), - "cert" + fs::read( + base.join("upper") + .join(GUEST_TLS_CERT_PATH.trim_start_matches('/')) + ) + .unwrap(), + b"cert" ); + let key_path = base + .join("upper") + .join(GUEST_TLS_KEY_PATH.trim_start_matches('/')); + assert_eq!(fs::read(&key_path).unwrap(), b"key"); assert_eq!( - std::fs::read_to_string(guest_dir.join("tls.key")).unwrap(), - "key" + fs::metadata(&key_path).unwrap().permissions().mode() & 0o777, + 0o600 ); - let key_mode = std::fs::metadata(guest_dir.join("tls.key")) - .unwrap() - .permissions() - .mode() - & 0o777; - assert_eq!(key_mode, 0o600); let _ = std::fs::remove_dir_all(base); } @@ -3195,8 +5749,10 @@ mod tests { SandboxRecord { snapshot: sandbox, state_dir, - process, + process: Some(process), + provisioning_task: None, gpu_bdf: None, + deleting: false, }, ); } diff --git a/crates/openshell-driver-vm/src/ffi.rs b/crates/openshell-driver-vm/src/ffi.rs index db5d3ec10..423ad6f05 100644 --- a/crates/openshell-driver-vm/src/ffi.rs +++ b/crates/openshell-driver-vm/src/ffi.rs @@ -29,7 +29,18 @@ type KrunInitLog = type KrunCreateCtx = unsafe extern "C" fn() -> i32; type KrunFreeCtx = unsafe extern "C" fn(ctx_id: u32) -> i32; type KrunSetVmConfig = unsafe extern "C" fn(ctx_id: u32, num_vcpus: u8, ram_mib: u32) -> i32; -type KrunSetRoot = unsafe extern "C" fn(ctx_id: u32, root_path: *const c_char) -> i32; +type KrunAddDisk = unsafe extern "C" fn( + ctx_id: u32, + block_id: *const c_char, + disk_path: *const c_char, + read_only: bool, +) -> i32; +type KrunSetRootDiskRemount = unsafe extern "C" fn( + ctx_id: u32, + device: *const c_char, + fstype: *const c_char, + options: *const c_char, +) -> i32; type KrunSetWorkdir = unsafe extern "C" fn(ctx_id: u32, workdir_path: *const c_char) -> i32; type KrunSetExec = unsafe extern "C" fn( ctx_id: u32, @@ -67,7 +78,8 @@ pub struct LibKrun { pub krun_create_ctx: KrunCreateCtx, pub krun_free_ctx: KrunFreeCtx, pub krun_set_vm_config: KrunSetVmConfig, - pub krun_set_root: KrunSetRoot, + pub krun_add_disk: KrunAddDisk, + pub krun_set_root_disk_remount: KrunSetRootDiskRemount, pub krun_set_workdir: KrunSetWorkdir, pub krun_set_exec: KrunSetExec, pub krun_set_console_output: KrunSetConsoleOutput, @@ -119,7 +131,12 @@ impl LibKrun { krun_create_ctx: load_symbol(library, b"krun_create_ctx\0", &libkrun_path)?, krun_free_ctx: load_symbol(library, b"krun_free_ctx\0", &libkrun_path)?, krun_set_vm_config: load_symbol(library, b"krun_set_vm_config\0", &libkrun_path)?, - krun_set_root: load_symbol(library, b"krun_set_root\0", &libkrun_path)?, + krun_add_disk: load_symbol(library, b"krun_add_disk\0", &libkrun_path)?, + krun_set_root_disk_remount: load_symbol( + library, + b"krun_set_root_disk_remount\0", + &libkrun_path, + )?, krun_set_workdir: load_symbol(library, b"krun_set_workdir\0", &libkrun_path)?, krun_set_exec: load_symbol(library, b"krun_set_exec\0", &libkrun_path)?, krun_set_console_output: load_symbol( diff --git a/crates/openshell-driver-vm/src/lib.rs b/crates/openshell-driver-vm/src/lib.rs index 194dde43c..5b2ddc2bc 100644 --- a/crates/openshell-driver-vm/src/lib.rs +++ b/crates/openshell-driver-vm/src/lib.rs @@ -5,6 +5,7 @@ pub mod driver; mod embedded_runtime; mod ffi; pub mod gpu; +mod nft_ruleset; pub mod procguard; mod rootfs; mod runtime; diff --git a/crates/openshell-driver-vm/src/main.rs b/crates/openshell-driver-vm/src/main.rs index 596e6c88d..8f662dc76 100644 --- a/crates/openshell-driver-vm/src/main.rs +++ b/crates/openshell-driver-vm/src/main.rs @@ -2,28 +2,39 @@ // SPDX-License-Identifier: Apache-2.0 use clap::Parser; +use futures::Stream; use miette::{IntoDiagnostic, Result}; use openshell_core::VERSION; use openshell_core::proto::compute::v1::compute_driver_server::ComputeDriverServer; #[cfg(target_os = "macos")] use openshell_driver_vm::{VM_RUNTIME_DIR_ENV, configured_runtime_dir}; use openshell_driver_vm::{VmBackend, VmDriver, VmDriverConfig, VmLaunchConfig, procguard, run_vm}; +use std::io; use std::net::SocketAddr; -use std::path::PathBuf; -use tokio::net::UnixListener; -use tokio_stream::wrappers::UnixListenerStream; +use std::os::unix::fs::{FileTypeExt, MetadataExt, PermissionsExt}; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::net::{UnixListener, UnixStream}; use tracing::info; use tracing_subscriber::EnvFilter; #[derive(Parser, Debug)] #[command(name = "openshell-driver-vm")] #[command(version = VERSION)] +#[allow(clippy::struct_excessive_bools)] struct Args { #[arg(long, hide = true, default_value_t = false)] internal_run_vm: bool, - #[arg(long, hide = true)] - vm_rootfs: Option, + #[arg(long = "vm-root-disk", hide = true, alias = "vm-rootfs")] + vm_root_disk: Option, + + #[arg(long = "vm-overlay-disk", hide = true)] + vm_overlay_disk: Option, + + #[arg(long = "vm-image-disk", hide = true)] + vm_image_disk: Option, #[arg(long, hide = true)] vm_exec: Option, @@ -46,15 +57,28 @@ struct Args { #[arg(long, hide = true, default_value_t = 1)] vm_krun_log_level: u32, + #[arg(long, env = "OPENSHELL_COMPUTE_DRIVER_BIND")] + bind_address: Option, + + #[arg(long, env = "OPENSHELL_COMPUTE_DRIVER_SOCKET")] + bind_socket: Option, + + #[arg(long, hide = true)] + expected_peer_pid: Option, + #[arg( long, - env = "OPENSHELL_COMPUTE_DRIVER_BIND", - default_value = "127.0.0.1:50061" + env = "OPENSHELL_COMPUTE_DRIVER_ALLOW_UNAUTHENTICATED_TCP", + default_value_t = false )] - bind_address: SocketAddr, + allow_unauthenticated_tcp: bool, - #[arg(long, env = "OPENSHELL_COMPUTE_DRIVER_SOCKET")] - bind_socket: Option, + #[arg( + long, + env = "OPENSHELL_COMPUTE_DRIVER_ALLOW_SAME_UID_PEER", + default_value_t = false + )] + allow_same_uid_peer: bool, #[arg(long, env = "OPENSHELL_LOG_LEVEL", default_value = "info")] log_level: String, @@ -65,6 +89,9 @@ struct Args { #[arg(long, env = "OPENSHELL_SANDBOX_IMAGE", default_value = "")] default_image: String, + #[arg(long, env = "OPENSHELL_VM_BOOTSTRAP_IMAGE", default_value = "")] + bootstrap_image: String, + #[arg( long, env = "OPENSHELL_VM_DRIVER_STATE_DIR", @@ -72,12 +99,6 @@ struct Args { )] state_dir: PathBuf, - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SECRET")] - ssh_handshake_secret: Option, - - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SKEW_SECS", default_value_t = 300)] - ssh_handshake_skew_secs: u64, - #[arg(long = "guest-tls-ca", env = "OPENSHELL_VM_TLS_CA")] guest_tls_ca: Option, @@ -96,6 +117,9 @@ struct Args { #[arg(long, env = "OPENSHELL_VM_DRIVER_MEM_MIB", default_value_t = 2048)] mem_mib: u32, + #[arg(long, env = "OPENSHELL_VM_OVERLAY_DISK_MIB", default_value_t = 4096)] + overlay_disk_mib: u64, + #[arg(long, env = "OPENSHELL_VM_GPU")] gpu: bool, @@ -154,6 +178,8 @@ async fn main() -> Result<()> { ) .init(); + let listen_mode = compute_driver_listen_mode(&args).map_err(|err| miette::miette!("{err}"))?; + // Arm procguard so that if the gateway is killed (SIGKILL or crash) // we also die. Without this the driver is reparented to init and // keeps its per-sandbox VM launchers alive forever. Launchers have @@ -170,18 +196,18 @@ async fn main() -> Result<()> { openshell_endpoint: args .openshell_endpoint .ok_or_else(|| miette::miette!("OPENSHELL_GRPC_ENDPOINT is required"))?, - state_dir: args.state_dir, + state_dir: args.state_dir.clone(), launcher_bin: None, - default_image: args.default_image, - ssh_handshake_secret: args.ssh_handshake_secret.unwrap_or_default(), - ssh_handshake_skew_secs: args.ssh_handshake_skew_secs, - log_level: args.log_level, + default_image: args.default_image.clone(), + bootstrap_image: args.bootstrap_image.clone(), + log_level: args.log_level.clone(), krun_log_level: args.krun_log_level, vcpus: args.vcpus, mem_mib: args.mem_mib, - guest_tls_ca: args.guest_tls_ca, - guest_tls_cert: args.guest_tls_cert, - guest_tls_key: args.guest_tls_key, + overlay_disk_mib: args.overlay_disk_mib, + guest_tls_ca: args.guest_tls_ca.clone(), + guest_tls_cert: args.guest_tls_cert.clone(), + guest_tls_key: args.guest_tls_key.clone(), gpu_enabled: args.gpu, gpu_mem_mib: args.gpu_mem_mib, gpu_vcpus: args.gpu_vcpus, @@ -189,40 +215,254 @@ async fn main() -> Result<()> { .await .map_err(|err| miette::miette!("{err}"))?; - if let Some(socket_path) = args.bind_socket { - if let Some(parent) = socket_path.parent() { - std::fs::create_dir_all(parent).into_diagnostic()?; + match listen_mode { + ComputeDriverListenMode::Unix { + socket_path, + expected_peer_pid, + } => { + prepare_compute_driver_socket(&socket_path).map_err(|err| miette::miette!("{err}"))?; + + info!(socket = %socket_path.display(), "Starting vm compute driver"); + let listener = UnixListener::bind(&socket_path).into_diagnostic()?; + restrict_socket_permissions(&socket_path).map_err(|err| miette::miette!("{err}"))?; + let result = tonic::transport::Server::builder() + .add_service(ComputeDriverServer::new(driver)) + .serve_with_incoming(AuthenticatedUnixIncoming::new(listener, expected_peer_pid)) + .await + .into_diagnostic(); + let _ = std::fs::remove_file(&socket_path); + result + } + ComputeDriverListenMode::Tcp(bind_address) => { + info!(address = %bind_address, "Starting unauthenticated dev vm compute driver"); + tonic::transport::Server::builder() + .add_service(ComputeDriverServer::new(driver)) + .serve(bind_address) + .await + .into_diagnostic() } - match std::fs::remove_file(&socket_path) { - Ok(()) => {} - Err(err) if err.kind() == std::io::ErrorKind::NotFound => {} - Err(err) => return Err(err).into_diagnostic(), + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ComputeDriverListenMode { + Unix { + socket_path: PathBuf, + expected_peer_pid: Option, + }, + Tcp(SocketAddr), +} + +fn compute_driver_listen_mode(args: &Args) -> std::result::Result { + if let Some(socket_path) = args.bind_socket.clone() { + if args.expected_peer_pid.is_none() && !args.allow_same_uid_peer { + return Err( + "--expected-peer-pid is required with --bind-socket; use --allow-same-uid-peer only for local development" + .to_string(), + ); + } + return Ok(ComputeDriverListenMode::Unix { + socket_path, + expected_peer_pid: args.expected_peer_pid, + }); + } + + if !args.allow_unauthenticated_tcp { + return Err( + "--bind-socket is required; unauthenticated TCP mode is disabled unless --allow-unauthenticated-tcp is set for local development" + .to_string(), + ); + } + + let Some(bind_address) = args.bind_address else { + return Err("--bind-address is required with --allow-unauthenticated-tcp".to_string()); + }; + + Ok(ComputeDriverListenMode::Tcp(bind_address)) +} + +fn prepare_compute_driver_socket(socket_path: &Path) -> std::result::Result<(), String> { + let Some(parent) = socket_path.parent() else { + return Err(format!( + "vm compute driver socket path '{}' has no parent directory", + socket_path.display() + )); + }; + let expected_uid = current_euid(); + prepare_private_socket_dir(parent, expected_uid)?; + remove_stale_socket(socket_path, expected_uid) +} + +fn current_euid() -> u32 { + rustix::process::geteuid().as_raw() +} + +fn prepare_private_socket_dir( + socket_dir: &Path, + expected_uid: u32, +) -> std::result::Result<(), String> { + std::fs::create_dir_all(socket_dir) + .map_err(|err| format!("create socket dir {}: {err}", socket_dir.display()))?; + let metadata = std::fs::symlink_metadata(socket_dir) + .map_err(|err| format!("stat socket dir {}: {err}", socket_dir.display()))?; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(format!( + "socket dir {} is a symlink; refusing to use it", + socket_dir.display() + )); + } + if !file_type.is_dir() { + return Err(format!( + "socket dir {} is not a directory", + socket_dir.display() + )); + } + if metadata.uid() != expected_uid { + return Err(format!( + "socket dir {} is owned by uid {} but current euid is {}", + socket_dir.display(), + metadata.uid(), + expected_uid + )); + } + std::fs::set_permissions(socket_dir, std::fs::Permissions::from_mode(0o700)) + .map_err(|err| format!("chmod socket dir {}: {err}", socket_dir.display())) +} + +fn remove_stale_socket(socket_path: &Path, expected_uid: u32) -> std::result::Result<(), String> { + let metadata = match std::fs::symlink_metadata(socket_path) { + Ok(metadata) => metadata, + Err(err) if err.kind() == io::ErrorKind::NotFound => return Ok(()), + Err(err) => return Err(format!("stat socket {}: {err}", socket_path.display())), + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(format!( + "socket {} is a symlink; refusing to remove it", + socket_path.display() + )); + } + if metadata.uid() != expected_uid { + return Err(format!( + "socket {} is owned by uid {} but current euid is {}", + socket_path.display(), + metadata.uid(), + expected_uid + )); + } + if !file_type.is_socket() { + return Err(format!( + "socket path {} exists but is not a Unix socket", + socket_path.display() + )); + } + std::fs::remove_file(socket_path) + .map_err(|err| format!("remove stale socket {}: {err}", socket_path.display())) +} + +fn restrict_socket_permissions(socket_path: &Path) -> std::result::Result<(), String> { + std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(0o600)) + .map_err(|err| format!("chmod socket {}: {err}", socket_path.display())) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct PeerCredentials { + uid: u32, + pid: Option, +} + +fn peer_credentials(stream: &UnixStream) -> std::result::Result { + let credentials = stream + .peer_cred() + .map_err(|err| format!("read peer credentials: {err}"))?; + Ok(PeerCredentials { + uid: credentials.uid(), + pid: credentials.pid(), + }) +} + +fn authorize_peer_credentials( + peer: PeerCredentials, + driver_uid: u32, + gateway_pid: Option, +) -> std::result::Result<(), String> { + if peer.uid != driver_uid { + return Err(format!( + "peer uid {} does not match current euid {}", + peer.uid, driver_uid + )); + } + let Some(gateway_pid) = gateway_pid else { + return Ok(()); + }; + let Some(peer_process_id) = peer.pid.and_then(|pid| u32::try_from(pid).ok()) else { + return Err(format!( + "peer pid is unavailable; expected gateway pid {gateway_pid}" + )); + }; + if peer_process_id != gateway_pid { + return Err(format!( + "peer pid {peer_process_id} does not match expected gateway pid {gateway_pid}" + )); + } + Ok(()) +} + +struct AuthenticatedUnixIncoming { + listener: UnixListener, + expected_uid: u32, + expected_peer_pid: Option, +} + +impl AuthenticatedUnixIncoming { + fn new(listener: UnixListener, expected_peer_pid: Option) -> Self { + Self { + listener, + expected_uid: current_euid(), + expected_peer_pid, } + } +} - info!(socket = %socket_path.display(), "Starting vm compute driver"); - let listener = UnixListener::bind(&socket_path).into_diagnostic()?; - let result = tonic::transport::Server::builder() - .add_service(ComputeDriverServer::new(driver)) - .serve_with_incoming(UnixListenerStream::new(listener)) - .await - .into_diagnostic(); - let _ = std::fs::remove_file(&socket_path); - result - } else { - info!(address = %args.bind_address, "Starting vm compute driver"); - tonic::transport::Server::builder() - .add_service(ComputeDriverServer::new(driver)) - .serve(args.bind_address) - .await - .into_diagnostic() +impl Stream for AuthenticatedUnixIncoming { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + loop { + match this.listener.poll_accept(cx) { + Poll::Ready(Ok((stream, _addr))) => { + let authorized = peer_credentials(&stream).and_then(|peer| { + authorize_peer_credentials(peer, this.expected_uid, this.expected_peer_pid) + }); + match authorized { + Ok(()) => return Poll::Ready(Some(Ok(stream))), + Err(err) => { + tracing::warn!( + error = %err, + "rejected vm compute driver UDS client" + ); + } + } + } + Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), + Poll::Pending => return Poll::Pending, + } + } } } fn build_vm_launch_config(args: &Args) -> std::result::Result { - let rootfs = args - .vm_rootfs + let root_disk = args + .vm_root_disk .clone() - .ok_or_else(|| "--vm-rootfs is required in internal VM mode".to_string())?; + .ok_or_else(|| "--vm-root-disk is required in internal VM mode".to_string())?; + let overlay_disk = args + .vm_overlay_disk + .clone() + .ok_or_else(|| "--vm-overlay-disk is required in internal VM mode".to_string())?; + let image_disk = args.vm_image_disk.clone(); let exec_path = args .vm_exec .clone() @@ -239,7 +479,9 @@ fn build_vm_launch_config(args: &Args) -> std::result::Result Result<()> { fn maybe_reexec_internal_vm_with_runtime_env() -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::{ + Args, ComputeDriverListenMode, PeerCredentials, authorize_peer_credentials, + compute_driver_listen_mode, + }; + use clap::Parser; + use std::path::PathBuf; + + #[test] + fn peer_authorization_accepts_matching_uid_and_pid() { + authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: Some(42), + }, + 1000, + Some(42), + ) + .unwrap(); + } + + #[test] + fn peer_authorization_rejects_wrong_pid() { + let err = authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: Some(7), + }, + 1000, + Some(42), + ) + .expect_err("wrong pid should be rejected"); + assert!(err.contains("does not match expected gateway pid")); + } + + #[test] + fn peer_authorization_rejects_wrong_uid() { + let err = authorize_peer_credentials( + PeerCredentials { + uid: 1001, + pid: Some(42), + }, + 1000, + Some(42), + ) + .expect_err("wrong uid should be rejected"); + assert!(err.contains("does not match current euid")); + } + + #[test] + fn peer_authorization_rejects_missing_pid_when_expected() { + let err = authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: None, + }, + 1000, + Some(42), + ) + .expect_err("missing pid should be rejected"); + assert!(err.contains("peer pid is unavailable")); + } + + #[test] + fn peer_authorization_accepts_matching_uid_without_expected_pid() { + authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: None, + }, + 1000, + None, + ) + .unwrap(); + } + + #[test] + fn listen_mode_rejects_default_tcp() { + let args = Args::parse_from(["openshell-driver-vm"]); + let err = compute_driver_listen_mode(&args).expect_err("default TCP should be disabled"); + assert!(err.contains("--bind-socket is required")); + } + + #[test] + fn listen_mode_rejects_bind_address_without_tcp_opt_in() { + let args = Args::parse_from(["openshell-driver-vm", "--bind-address", "127.0.0.1:50061"]); + let err = + compute_driver_listen_mode(&args).expect_err("TCP bind should require explicit opt-in"); + assert!(err.contains("--allow-unauthenticated-tcp")); + } + + #[test] + fn listen_mode_requires_bind_address_with_tcp_opt_in() { + let args = Args::parse_from(["openshell-driver-vm", "--allow-unauthenticated-tcp"]); + let err = + compute_driver_listen_mode(&args).expect_err("TCP opt-in should require an address"); + assert!(err.contains("--bind-address is required")); + } + + #[test] + fn listen_mode_accepts_explicit_unauthenticated_tcp() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--allow-unauthenticated-tcp", + "--bind-address", + "127.0.0.1:50061", + ]); + assert_eq!( + compute_driver_listen_mode(&args).unwrap(), + ComputeDriverListenMode::Tcp("127.0.0.1:50061".parse().unwrap()) + ); + } + + #[test] + fn listen_mode_requires_expected_peer_pid_for_uds() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--bind-socket", + "/tmp/compute-driver.sock", + ]); + let err = compute_driver_listen_mode(&args) + .expect_err("UDS should require gateway peer pid by default"); + assert!(err.contains("--expected-peer-pid is required")); + } + + #[test] + fn listen_mode_accepts_uds_with_expected_peer_pid() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--bind-socket", + "/tmp/compute-driver.sock", + "--expected-peer-pid", + "42", + ]); + assert_eq!( + compute_driver_listen_mode(&args).unwrap(), + ComputeDriverListenMode::Unix { + socket_path: PathBuf::from("/tmp/compute-driver.sock"), + expected_peer_pid: Some(42), + } + ); + } + + #[test] + fn listen_mode_accepts_explicit_same_uid_uds_dev_mode() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--bind-socket", + "/tmp/compute-driver.sock", + "--allow-same-uid-peer", + ]); + assert_eq!( + compute_driver_listen_mode(&args).unwrap(), + ComputeDriverListenMode::Unix { + socket_path: PathBuf::from("/tmp/compute-driver.sock"), + expected_peer_pid: None, + } + ); + } +} diff --git a/crates/openshell-driver-vm/src/nft_ruleset.rs b/crates/openshell-driver-vm/src/nft_ruleset.rs new file mode 100644 index 000000000..fe3e86c90 --- /dev/null +++ b/crates/openshell-driver-vm/src/nft_ruleset.rs @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::fmt::Write; + +/// Sanitize a TAP device name for use as an nftables table name suffix. +/// Assumes device names match `vmtap-[a-f0-9]+` (driver-controlled). +fn sanitize_table_name(device: &str) -> String { + device.replace('-', "_") +} + +/// Return the nftables table name for a TAP device. +pub fn teardown_table_name(device: &str) -> String { + format!("openshell_vm_{}", sanitize_table_name(device)) +} + +/// Generate the nftables ruleset for VM TAP networking. +pub fn generate_tap_ruleset(tap_device: &str, subnet: &str, gateway_port: u16) -> String { + let table_name = teardown_table_name(tap_device); + let mut ruleset = String::with_capacity(512); + + writeln!(ruleset, "table ip {table_name} {{").unwrap(); + writeln!(ruleset, " chain postrouting {{").unwrap(); + writeln!( + ruleset, + " type nat hook postrouting priority 100; policy accept;" + ) + .unwrap(); + writeln!(ruleset, " ip saddr {subnet} masquerade").unwrap(); + writeln!(ruleset, " }}").unwrap(); + writeln!(ruleset, " chain forward {{").unwrap(); + writeln!( + ruleset, + " type filter hook forward priority 0; policy accept;" + ) + .unwrap(); + writeln!(ruleset, " iifname \"{tap_device}\" accept").unwrap(); + writeln!( + ruleset, + " oifname \"{tap_device}\" ct state related,established accept" + ) + .unwrap(); + writeln!(ruleset, " oifname \"{tap_device}\" drop").unwrap(); + writeln!(ruleset, " }}").unwrap(); + writeln!(ruleset, " chain input {{").unwrap(); + writeln!( + ruleset, + " type filter hook input priority 0; policy accept;" + ) + .unwrap(); + writeln!( + ruleset, + " iifname \"{tap_device}\" tcp dport {gateway_port} accept" + ) + .unwrap(); + writeln!(ruleset, " iifname \"{tap_device}\" drop").unwrap(); + writeln!(ruleset, " }}").unwrap(); + writeln!(ruleset, "}}").unwrap(); + + ruleset +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generates_tap_setup_ruleset() { + let ruleset = generate_tap_ruleset("vmtap-abcd", "10.0.128.0/30", 8080); + assert!(ruleset.contains("table ip openshell_vm_vmtap_abcd {")); + assert!(ruleset.contains("type nat hook postrouting priority 100; policy accept;")); + assert!(ruleset.contains("ip saddr 10.0.128.0/30 masquerade")); + assert!(ruleset.contains("type filter hook forward priority 0; policy accept;")); + assert!(ruleset.contains("iifname \"vmtap-abcd\" accept")); + assert!(ruleset.contains("oifname \"vmtap-abcd\" ct state related,established accept")); + assert!(ruleset.contains("oifname \"vmtap-abcd\" drop")); + assert!(ruleset.contains("type filter hook input priority 0; policy accept;")); + assert!(ruleset.contains("iifname \"vmtap-abcd\" tcp dport 8080 accept")); + } + + #[test] + fn table_name_sanitizes_device_name() { + let ruleset = generate_tap_ruleset("vmtap-abc-123", "10.0.128.0/30", 8080); + assert!(ruleset.contains("table ip openshell_vm_vmtap_abc_123 {")); + } + + #[test] + fn teardown_command_targets_correct_table() { + let cmd = teardown_table_name("vmtap-abcd"); + assert_eq!(cmd, "openshell_vm_vmtap_abcd"); + } +} diff --git a/crates/openshell-driver-vm/src/rootfs.rs b/crates/openshell-driver-vm/src/rootfs.rs index e498bd779..904ed8cd3 100644 --- a/crates/openshell-driver-vm/src/rootfs.rs +++ b/crates/openshell-driver-vm/src/rootfs.rs @@ -3,13 +3,24 @@ use std::fs; use std::fs::File; -use std::io::{BufWriter, Cursor}; -use std::path::Path; +#[cfg(test)] +use std::io::BufWriter; +use std::io::{Cursor, Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::sync::atomic::{AtomicU64, Ordering}; const SUPERVISOR: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/openshell-sandbox.zst")); +const UMOCI: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/umoci.zst")); const ROOTFS_VARIANT_MARKER: &str = ".openshell-rootfs-variant"; const SANDBOX_GUEST_INIT_PATH: &str = "/srv/openshell-vm-sandbox-init.sh"; const SANDBOX_SUPERVISOR_PATH: &str = "/opt/openshell/bin/openshell-sandbox"; +const SANDBOX_UMOCI_PATH: &str = "/opt/openshell/bin/umoci"; +const SANDBOX_OWNER_NORMALIZED_MARKER: &str = "/opt/openshell/.sandbox-owner-normalized"; +const ROOTFS_IMAGE_MIN_SIZE_BYTES: u64 = 512 * 1024 * 1024; +const ROOTFS_IMAGE_MIN_HEADROOM_BYTES: u64 = 256 * 1024 * 1024; +const EXT4_IMAGE_MIN_HEADROOM_BYTES: u64 = 16 * 1024 * 1024; +static INJECTION_COUNTER: AtomicU64 = AtomicU64::new(0); pub const fn sandbox_guest_init_path() -> &'static str { SANDBOX_GUEST_INIT_PATH @@ -44,6 +55,7 @@ pub fn extract_rootfs_archive_to(archive_path: &Path, dest: &Path) -> Result<(), .map_err(|e| format!("extract rootfs tarball into {}: {e}", dest.display())) } +#[cfg(test)] pub fn create_rootfs_archive_from_dir(source: &Path, archive_path: &Path) -> Result<(), String> { if let Some(parent) = archive_path.parent() { fs::create_dir_all(parent).map_err(|e| format!("create {}: {e}", parent.display()))?; @@ -65,6 +77,205 @@ pub fn create_rootfs_archive_from_dir(source: &Path, archive_path: &Path) -> Res .map_err(|e| format!("finalize {}: {e}", archive_path.display())) } +pub fn create_rootfs_image_from_dir(source: &Path, image_path: &Path) -> Result<(), String> { + let image_size = rootfs_image_size_bytes(source)?; + create_ext4_image_from_dir_with_size(source, image_path, image_size)?; + if let Err(err) = normalize_sandbox_owner_in_rootfs_image(source, image_path) { + let _ = fs::remove_file(image_path); + return Err(err); + } + Ok(()) +} + +pub fn create_ext4_image_from_dir_with_size( + source: &Path, + image_path: &Path, + image_size: u64, +) -> Result<(), String> { + if let Some(parent) = image_path.parent() { + fs::create_dir_all(parent).map_err(|e| format!("create {}: {e}", parent.display()))?; + } + if image_path.exists() { + fs::remove_file(image_path) + .map_err(|e| format!("remove old rootfs image {}: {e}", image_path.display()))?; + } + + let required_size = ext4_image_min_size_bytes(source)?; + if image_size < required_size { + return Err(format!( + "ext4 image size {} bytes is too small for {} (requires at least {} bytes)", + image_size, + source.display(), + required_size + )); + } + + let image = File::create(image_path) + .map_err(|e| format!("create rootfs image {}: {e}", image_path.display()))?; + image + .set_len(image_size) + .map_err(|e| format!("size rootfs image {}: {e}", image_path.display()))?; + drop(image); + + if let Err(err) = format_ext4_image_from_dir(source, image_path) { + let _ = fs::remove_file(image_path); + return Err(err); + } + + Ok(()) +} + +pub fn clone_or_copy_sparse_file(source: &Path, dest: &Path) -> Result<(), String> { + if let Some(parent) = dest.parent() { + fs::create_dir_all(parent).map_err(|e| format!("create {}: {e}", parent.display()))?; + } + if dest.exists() { + fs::remove_file(dest).map_err(|e| format!("remove old file {}: {e}", dest.display()))?; + } + + let clone_error = match try_clone_file(source, dest) { + Ok(()) => return Ok(()), + Err(err) => { + let _ = fs::remove_file(dest); + err + } + }; + + copy_sparse_file(source, dest).map_err(|copy_error| { + format!( + "clone {} to {} failed ({clone_error}); sparse copy failed: {copy_error}", + source.display(), + dest.display() + ) + }) +} + +pub fn write_rootfs_image_file( + image_path: &Path, + guest_path: &str, + contents: &[u8], +) -> Result<(), String> { + ensure_rootfs_image_parent_dirs(image_path, guest_path); + + let tmp_path = temporary_injection_path(image_path); + fs::write(&tmp_path, contents).map_err(|e| format!("write {}: {e}", tmp_path.display()))?; + let Some(quoted_guest_path) = debugfs_quote_absolute_path(guest_path) else { + let _ = fs::remove_file(&tmp_path); + return Err(format!("invalid debugfs guest path '{guest_path}'")); + }; + let Some(quoted_tmp_path) = debugfs_quote_argument(&tmp_path.to_string_lossy()) else { + let _ = fs::remove_file(&tmp_path); + return Err(format!( + "invalid debugfs injection path '{}'", + tmp_path.display() + )); + }; + let _ = run_debugfs(image_path, &format!("rm {quoted_guest_path}")); + let result = run_debugfs( + image_path, + &format!("write {quoted_tmp_path} {quoted_guest_path}"), + ); + let _ = fs::remove_file(&tmp_path); + result +} + +pub fn set_rootfs_image_file_mode( + image_path: &Path, + guest_path: &str, + mode: u32, +) -> Result<(), String> { + let regular_file_mode = 0o100_000 | (mode & 0o7777); + let Some(quoted_guest_path) = debugfs_quote_absolute_path(guest_path) else { + return Err(format!("invalid debugfs guest path '{guest_path}'")); + }; + run_debugfs( + image_path, + &format!("set_inode_field {quoted_guest_path} mode 0{regular_file_mode:o}"), + ) +} + +#[cfg(target_os = "macos")] +fn try_clone_file(source: &Path, dest: &Path) -> Result<(), String> { + let output = Command::new("cp") + .arg("-c") + .arg(source) + .arg(dest) + .output() + .map_err(|e| format!("run cp -c: {e}"))?; + if output.status.success() { + return Ok(()); + } + Err(format!( + "cp -c failed with status {}\nstdout: {}\nstderr: {}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )) +} + +#[cfg(target_os = "linux")] +fn try_clone_file(source: &Path, dest: &Path) -> Result<(), String> { + let output = Command::new("cp") + .arg("--reflink=auto") + .arg("--sparse=always") + .arg(source) + .arg(dest) + .output() + .map_err(|e| format!("run cp --reflink=auto: {e}"))?; + if output.status.success() { + return Ok(()); + } + Err(format!( + "cp --reflink=auto --sparse=always failed with status {}\nstdout: {}\nstderr: {}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )) +} + +#[cfg(not(any(target_os = "macos", target_os = "linux")))] +fn try_clone_file(_source: &Path, _dest: &Path) -> Result<(), String> { + Err("no platform clone command available".to_string()) +} + +fn copy_sparse_file(source: &Path, dest: &Path) -> Result<(), String> { + const BUFFER_SIZE: usize = 1024 * 1024; + + let mut source_file = + File::open(source).map_err(|e| format!("open {}: {e}", source.display()))?; + let mut dest_file = + File::create(dest).map_err(|e| format!("create {}: {e}", dest.display()))?; + let mut buffer = vec![0_u8; BUFFER_SIZE]; + let mut size = 0_u64; + + loop { + let read = source_file + .read(&mut buffer) + .map_err(|e| format!("read {}: {e}", source.display()))?; + if read == 0 { + break; + } + + if buffer[..read].iter().all(|byte| *byte == 0) { + let skip = + i64::try_from(read).map_err(|_| format!("sparse copy chunk too large: {read}"))?; + dest_file + .seek(SeekFrom::Current(skip)) + .map_err(|e| format!("seek {}: {e}", dest.display()))?; + } else { + dest_file + .write_all(&buffer[..read]) + .map_err(|e| format!("write {}: {e}", dest.display()))?; + } + size += read as u64; + } + + dest_file + .set_len(size) + .map_err(|e| format!("size {}: {e}", dest.display())) +} + +#[cfg(test)] fn append_rootfs_tree_to_archive( builder: &mut tar::Builder>, source: &Path, @@ -119,6 +330,7 @@ fn append_rootfs_tree_to_archive( Ok(()) } +#[cfg(test)] fn append_symlink_to_archive( builder: &mut tar::Builder>, source_path: &Path, @@ -159,21 +371,37 @@ fn prepare_sandbox_rootfs(rootfs: &Path) -> Result<(), String> { } ensure_supervisor_binary(rootfs)?; + ensure_umoci_binary(rootfs)?; let opt_dir = rootfs.join("opt/openshell"); fs::create_dir_all(&opt_dir).map_err(|e| format!("create {}: {e}", opt_dir.display()))?; fs::write(opt_dir.join(".rootfs-type"), "sandbox\n") .map_err(|e| format!("write sandbox rootfs marker: {e}"))?; ensure_sandbox_guest_user(rootfs)?; + create_sandbox_mountpoint(&rootfs.join("sandbox"))?; + create_sandbox_mountpoint(&rootfs.join("image-cache"))?; + create_sandbox_mountpoint(&rootfs.join("lower"))?; + create_sandbox_mountpoint(&rootfs.join("overlay"))?; + create_sandbox_mountpoint(&rootfs.join("newroot"))?; Ok(()) } pub fn validate_sandbox_rootfs(rootfs: &Path) -> Result<(), String> { require_rootfs_path(rootfs, SANDBOX_GUEST_INIT_PATH)?; - require_rootfs_path(rootfs, "/opt/openshell/bin/openshell-sandbox")?; + require_rootfs_path(rootfs, SANDBOX_SUPERVISOR_PATH)?; + require_rootfs_path(rootfs, SANDBOX_UMOCI_PATH)?; require_any_rootfs_path(rootfs, &["/bin/bash"])?; require_any_rootfs_path(rootfs, &["/bin/mount", "/usr/bin/mount"])?; + require_any_rootfs_path( + rootfs, + &[ + "/usr/sbin/chroot", + "/usr/bin/chroot", + "/sbin/chroot", + "/bin/chroot", + ], + )?; require_any_rootfs_path( rootfs, &["/sbin/ip", "/usr/sbin/ip", "/bin/ip", "/usr/bin/ip"], @@ -182,6 +410,348 @@ pub fn validate_sandbox_rootfs(rootfs: &Path) -> Result<(), String> { Ok(()) } +fn create_sandbox_mountpoint(path: &Path) -> Result<(), String> { + fs::create_dir_all(path).map_err(|e| format!("create {}: {e}", path.display()))?; + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + + fs::set_permissions(path, fs::Permissions::from_mode(0o755)) + .map_err(|e| format!("chmod {}: {e}", path.display()))?; + } + Ok(()) +} + +fn rootfs_image_size_bytes(source: &Path) -> Result { + let used = directory_size_bytes(source)?; + let headroom = (used / 4).max(ROOTFS_IMAGE_MIN_HEADROOM_BYTES); + let size = (used + headroom).max(ROOTFS_IMAGE_MIN_SIZE_BYTES); + Ok(round_up_to_mib(size)) +} + +fn ext4_image_min_size_bytes(source: &Path) -> Result { + let used = directory_size_bytes(source)?; + Ok(round_up_to_mib(used + EXT4_IMAGE_MIN_HEADROOM_BYTES)) +} + +fn directory_size_bytes(path: &Path) -> Result { + let metadata = + fs::symlink_metadata(path).map_err(|e| format!("stat {}: {e}", path.display()))?; + if metadata.file_type().is_file() || metadata.file_type().is_symlink() { + return Ok(metadata.len()); + } + if !metadata.file_type().is_dir() { + return Ok(0); + } + + let mut size = 4096; + for entry in fs::read_dir(path).map_err(|e| format!("read {}: {e}", path.display()))? { + let entry = entry.map_err(|e| format!("read {}: {e}", path.display()))?; + size += directory_size_bytes(&entry.path())?; + } + Ok(size) +} + +fn round_up_to_mib(bytes: u64) -> u64 { + const MIB: u64 = 1024 * 1024; + bytes.div_ceil(MIB) * MIB +} + +fn format_ext4_image_from_dir(source: &Path, image_path: &Path) -> Result<(), String> { + let mut last_error = None; + for tool in ["mke2fs", "mkfs.ext4"] { + for candidate in e2fs_tool_candidates(tool) { + let label = candidate.display().to_string(); + let output = Command::new(&candidate) + .arg("-q") + .arg("-F") + .arg("-t") + .arg("ext4") + .arg("-E") + .arg("root_owner=0:0") + .arg("-d") + .arg(source) + .arg(image_path) + .output(); + match output { + Ok(output) if output.status.success() => return Ok(()), + Ok(output) => { + last_error = Some(format!( + "{label} failed with status {}\nstdout: {}\nstderr: {}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )); + } + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + last_error = Some(format!("{label} not found")); + } + Err(err) => { + last_error = Some(format!("run {label}: {err}")); + } + } + } + } + Err(format!( + "failed to create ext4 rootfs image from {}: {}. Install e2fsprogs (mke2fs/mkfs.ext4) and retry", + source.display(), + last_error.unwrap_or_else(|| "no ext4 formatter found".to_string()) + )) +} + +fn ensure_rootfs_image_parent_dirs(image_path: &Path, guest_path: &str) { + let Some(parent) = Path::new(guest_path).parent() else { + return; + }; + let mut current = String::new(); + for component in parent.components() { + let part = component.as_os_str().to_string_lossy(); + if part == "/" || part.is_empty() { + continue; + } + current.push('/'); + current.push_str(&part); + let _ = run_debugfs(image_path, &format!("mkdir {current}")); + } +} + +fn normalize_sandbox_owner_in_rootfs_image(source: &Path, image_path: &Path) -> Result<(), String> { + let sandbox_dir = source.join("sandbox"); + if !sandbox_dir.exists() { + return Ok(()); + } + + let Some((uid, gid)) = sandbox_guest_user_ids(source)? else { + return Ok(()); + }; + + let mut commands = Vec::new(); + if !collect_sandbox_owner_commands(&sandbox_dir, "/sandbox", uid, gid, &mut commands)? { + return Ok(()); + } + if commands.is_empty() { + return Ok(()); + } + + run_debugfs_batch(image_path, &commands)?; + write_rootfs_image_file(image_path, SANDBOX_OWNER_NORMALIZED_MARKER, b"1\n") +} + +fn collect_sandbox_owner_commands( + source_path: &Path, + guest_path: &str, + uid: u32, + gid: u32, + commands: &mut Vec, +) -> Result { + let metadata = fs::symlink_metadata(source_path).map_err(|e| { + format!( + "stat {} for rootfs ownership normalization: {e}", + source_path.display() + ) + })?; + if metadata.file_type().is_symlink() { + return Ok(true); + } + + let Some(quoted_guest_path) = debugfs_quote_absolute_path(guest_path) else { + return Ok(false); + }; + commands.push(format!("set_inode_field {quoted_guest_path} uid {uid}")); + commands.push(format!("set_inode_field {quoted_guest_path} gid {gid}")); + + if !metadata.is_dir() { + return Ok(true); + } + + let mut entries = fs::read_dir(source_path) + .map_err(|e| { + format!( + "read {} for rootfs ownership normalization: {e}", + source_path.display() + ) + })? + .collect::, _>>() + .map_err(|e| { + format!( + "read {} entry for rootfs ownership normalization: {e}", + source_path.display() + ) + })?; + entries.sort_by_key(fs::DirEntry::file_name); + + for entry in entries { + let file_name = entry.file_name(); + let Some(file_name) = file_name.to_str() else { + return Ok(false); + }; + let child_guest_path = format!("{guest_path}/{file_name}"); + if !collect_sandbox_owner_commands(&entry.path(), &child_guest_path, uid, gid, commands)? { + return Ok(false); + } + } + + Ok(true) +} + +fn debugfs_quote_absolute_path(path: &str) -> Option { + if path.is_empty() || !path.starts_with('/') { + return None; + } + + debugfs_quote_argument(path) +} + +fn debugfs_quote_argument(argument: &str) -> Option { + if argument.is_empty() { + return None; + } + + let mut quoted = String::with_capacity(argument.len() + 2); + quoted.push('"'); + for ch in argument.chars() { + match ch { + '\0' | '\n' | '\r' => return None, + '\\' => quoted.push_str("\\\\"), + '"' => quoted.push_str("\\\""), + _ => quoted.push(ch), + } + } + quoted.push('"'); + Some(quoted) +} + +fn sandbox_guest_user_ids(rootfs: &Path) -> Result, String> { + let passwd_path = rootfs.join("etc/passwd"); + if !passwd_path.exists() { + return Ok(None); + } + + let passwd = fs::read_to_string(&passwd_path) + .map_err(|e| format!("read {}: {e}", passwd_path.display()))?; + for line in passwd.lines() { + let mut parts = line.split(':'); + if parts.next() != Some("sandbox") { + continue; + } + let _password = parts.next(); + let uid = parts + .next() + .ok_or_else(|| format!("sandbox entry in {} is missing uid", passwd_path.display()))? + .parse::() + .map_err(|e| format!("sandbox uid in {} is invalid: {e}", passwd_path.display()))?; + let gid = parts + .next() + .ok_or_else(|| format!("sandbox entry in {} is missing gid", passwd_path.display()))? + .parse::() + .map_err(|e| format!("sandbox gid in {} is invalid: {e}", passwd_path.display()))?; + return Ok(Some((uid, gid))); + } + + Ok(None) +} + +fn run_debugfs_batch(image_path: &Path, commands: &[String]) -> Result<(), String> { + let command_path = temporary_injection_path(image_path); + let mut contents = commands.join("\n"); + contents.push('\n'); + fs::write(&command_path, contents) + .map_err(|e| format!("write {}: {e}", command_path.display()))?; + + let result = run_debugfs_batch_file(image_path, &command_path); + let _ = fs::remove_file(&command_path); + result +} + +fn run_debugfs_batch_file(image_path: &Path, command_path: &Path) -> Result<(), String> { + let mut last_error = None; + for candidate in e2fs_tool_candidates("debugfs") { + let label = candidate.display().to_string(); + let output = Command::new(&candidate) + .arg("-w") + .arg("-f") + .arg(command_path) + .arg(image_path) + .output(); + match output { + Ok(output) if output.status.success() => return Ok(()), + Ok(output) => { + last_error = Some(format!( + "{label} failed with status {}\nstdout: {}\nstderr: {}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )); + } + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + last_error = Some(format!("{label} not found")); + } + Err(err) => { + last_error = Some(format!("run {label}: {err}")); + } + } + } + Err(format!( + "debugfs batch {} failed for {}: {}. Install e2fsprogs (debugfs) and retry", + command_path.display(), + image_path.display(), + last_error.unwrap_or_else(|| "debugfs not found".to_string()) + )) +} + +fn run_debugfs(image_path: &Path, command: &str) -> Result<(), String> { + let mut last_error = None; + for candidate in e2fs_tool_candidates("debugfs") { + let label = candidate.display().to_string(); + let output = Command::new(&candidate) + .arg("-w") + .arg("-R") + .arg(command) + .arg(image_path) + .output(); + match output { + Ok(output) if output.status.success() => return Ok(()), + Ok(output) => { + last_error = Some(format!( + "{label} failed with status {}\nstdout: {}\nstderr: {}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )); + } + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + last_error = Some(format!("{label} not found")); + } + Err(err) => { + last_error = Some(format!("run {label}: {err}")); + } + } + } + Err(format!( + "debugfs command '{command}' failed for {}: {}. Install e2fsprogs (debugfs) and retry", + image_path.display(), + last_error.unwrap_or_else(|| "debugfs not found".to_string()) + )) +} + +fn e2fs_tool_candidates(tool: &str) -> Vec { + let mut candidates = vec![PathBuf::from(tool)]; + for root in ["/opt/homebrew/opt/e2fsprogs", "/usr/local/opt/e2fsprogs"] { + candidates.push(Path::new(root).join("sbin").join(tool)); + candidates.push(Path::new(root).join("bin").join(tool)); + } + candidates +} + +fn temporary_injection_path(image_path: &Path) -> PathBuf { + let n = INJECTION_COUNTER.fetch_add(1, Ordering::Relaxed); + let parent = image_path.parent().unwrap_or_else(|| Path::new(".")); + parent.join(format!( + ".openshell-rootfs-inject-{}-{n}", + std::process::id() + )) +} + fn ensure_sandbox_guest_user(rootfs: &Path) -> Result<(), String> { const SANDBOX_UID: u32 = 10001; const SANDBOX_GID: u32 = 10001; @@ -265,6 +835,36 @@ fn ensure_supervisor_binary(rootfs: &Path) -> Result<(), String> { Ok(()) } +fn ensure_umoci_binary(rootfs: &Path) -> Result<(), String> { + let path = rootfs.join(SANDBOX_UMOCI_PATH.trim_start_matches('/')); + if UMOCI.is_empty() { + if !path.exists() { + return Err( + "umoci not embedded. Build openshell-driver-vm with OPENSHELL_VM_RUNTIME_COMPRESSED_DIR set and run `mise run vm:setup` first" + .to_string(), + ); + } + } else { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).map_err(|e| format!("create {}: {e}", parent.display()))?; + } + + let umoci = + zstd::decode_all(Cursor::new(UMOCI)).map_err(|e| format!("decompress umoci: {e}"))?; + fs::write(&path, umoci).map_err(|e| format!("write {}: {e}", path.display()))?; + } + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + + fs::set_permissions(&path, fs::Permissions::from_mode(0o755)) + .map_err(|e| format!("chmod {}: {e}", path.display()))?; + } + + Ok(()) +} + fn require_rootfs_path(rootfs: &Path, relative: &str) -> Result<(), String> { let candidate = rootfs.join(relative.trim_start_matches('/')); if candidate.exists() { @@ -320,11 +920,7 @@ mod tests { fs::create_dir_all(rootfs.join("etc")).expect("create etc"); fs::create_dir_all(rootfs.join("opt/openshell/bin")).expect("create openshell bin"); fs::write(rootfs.join("opt/openshell/.initialized"), b"yes").expect("write initialized"); - fs::write( - rootfs.join("opt/openshell/bin/openshell-sandbox"), - b"sandbox", - ) - .expect("write openshell-sandbox"); + write_fake_runtime_binaries(&rootfs); fs::write( rootfs.join("etc/passwd"), "root:x:0:0:root:/root:/bin/bash\n", @@ -336,6 +932,7 @@ mod tests { fs::create_dir_all(rootfs.join("sbin")).expect("create sbin"); fs::write(rootfs.join("bin/bash"), b"bash").expect("write bash"); fs::write(rootfs.join("bin/mount"), b"mount").expect("write mount"); + fs::write(rootfs.join("bin/chroot"), b"chroot").expect("write chroot"); fs::write(rootfs.join("bin/sed"), b"sed").expect("write sed"); fs::write(rootfs.join("sbin/ip"), b"ip").expect("write ip"); @@ -343,7 +940,18 @@ mod tests { validate_sandbox_rootfs(&rootfs).expect("validate sandbox rootfs"); assert!(rootfs.join("srv/openshell-vm-sandbox-init.sh").is_file()); - assert!(!rootfs.join("sandbox").exists()); + assert!(rootfs.join("opt/openshell/bin/umoci").is_file()); + assert!(rootfs.join("sandbox").is_dir()); + assert!(rootfs.join("image-cache").is_dir()); + assert!(rootfs.join("lower").is_dir()); + assert!(rootfs.join("overlay").is_dir()); + assert!(rootfs.join("newroot").is_dir()); + assert!( + fs::read_dir(rootfs.join("sandbox")) + .expect("read sandbox") + .next() + .is_none() + ); assert!( fs::read_to_string(rootfs.join("etc/passwd")) .expect("read passwd") @@ -363,21 +971,18 @@ mod tests { } #[test] - fn prepare_sandbox_rootfs_preserves_image_workdir_contents() { + fn prepare_sandbox_rootfs_preserves_image_workdir_contents_in_rootfs() { let dir = unique_temp_dir(); let rootfs = dir.join("rootfs"); fs::create_dir_all(rootfs.join("opt/openshell/bin")).expect("create openshell bin"); - fs::write( - rootfs.join("opt/openshell/bin/openshell-sandbox"), - b"sandbox", - ) - .expect("write openshell-sandbox"); + write_fake_runtime_binaries(&rootfs); fs::create_dir_all(rootfs.join("sandbox")).expect("create sandbox workdir"); fs::write(rootfs.join("sandbox/app.py"), "print('hello')\n").expect("write app"); prepare_sandbox_rootfs(&rootfs).expect("prepare sandbox rootfs"); + assert!(rootfs.join("sandbox").is_dir()); assert_eq!( fs::read_to_string(rootfs.join("sandbox/app.py")).expect("read app"), "print('hello')\n" @@ -417,6 +1022,105 @@ mod tests { let _ = fs::remove_dir_all(&dir); } + #[test] + fn clone_or_copy_sparse_file_preserves_size_and_contents() { + let dir = unique_temp_dir(); + fs::create_dir_all(&dir).expect("create temp dir"); + let source = dir.join("source.bin"); + let dest = dir.join("dest.bin"); + + let mut source_file = File::create(&source).expect("create source"); + source_file.write_all(b"head").expect("write head"); + source_file + .seek(SeekFrom::Start(1024 * 1024 + 7)) + .expect("seek source"); + source_file.write_all(b"tail").expect("write tail"); + source_file + .set_len(2 * 1024 * 1024 + 3) + .expect("size source"); + drop(source_file); + + clone_or_copy_sparse_file(&source, &dest).expect("copy sparse file"); + + assert_eq!( + fs::metadata(&dest).expect("stat dest").len(), + 2 * 1024 * 1024 + 3 + ); + let mut dest_file = File::open(&dest).expect("open dest"); + let mut head = [0_u8; 4]; + dest_file.read_exact(&mut head).expect("read head"); + assert_eq!(&head, b"head"); + dest_file + .seek(SeekFrom::Start(1024 * 1024 + 7)) + .expect("seek dest"); + let mut tail = [0_u8; 4]; + dest_file.read_exact(&mut tail).expect("read tail"); + assert_eq!(&tail, b"tail"); + + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn sandbox_guest_user_ids_reads_existing_sandbox_user() { + let dir = unique_temp_dir(); + let rootfs = dir.join("rootfs"); + fs::create_dir_all(rootfs.join("etc")).expect("create etc"); + fs::write( + rootfs.join("etc/passwd"), + "root:x:0:0:root:/root:/bin/bash\nsandbox:x:998:997:Sandbox:/sandbox:/bin/sh\n", + ) + .expect("write passwd"); + + assert_eq!( + sandbox_guest_user_ids(&rootfs).expect("read sandbox user"), + Some((998, 997)) + ); + + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn collect_sandbox_owner_commands_quotes_guest_paths() { + let dir = unique_temp_dir(); + let sandbox_dir = dir.join("sandbox"); + fs::create_dir_all(sandbox_dir.join("dir with space")).expect("create sandbox tree"); + fs::write(sandbox_dir.join("dir with space/file.txt"), "hello\n").expect("write file"); + + let mut commands = Vec::new(); + assert!( + collect_sandbox_owner_commands(&sandbox_dir, "/sandbox", 998, 997, &mut commands) + .expect("collect commands") + ); + + assert!(commands.contains(&"set_inode_field \"/sandbox\" uid 998".to_string())); + assert!(commands.contains(&"set_inode_field \"/sandbox\" gid 997".to_string())); + assert!( + commands.contains( + &"set_inode_field \"/sandbox/dir with space/file.txt\" uid 998".to_string() + ) + ); + assert!( + commands.contains( + &"set_inode_field \"/sandbox/dir with space/file.txt\" gid 997".to_string() + ) + ); + + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn debugfs_quote_argument_quotes_source_paths_with_spaces() { + assert_eq!( + debugfs_quote_argument("/tmp/openshell state/.openshell-rootfs-inject-123-0"), + Some("\"/tmp/openshell state/.openshell-rootfs-inject-123-0\"".to_string()) + ); + assert_eq!( + debugfs_quote_argument("/tmp/path/with\\backslash/and\"quote"), + Some("\"/tmp/path/with\\\\backslash/and\\\"quote\"".to_string()) + ); + assert_eq!(debugfs_quote_argument("/tmp/bad\npath"), None); + } + fn unique_temp_dir() -> PathBuf { static COUNTER: AtomicU64 = AtomicU64::new(0); let nanos = SystemTime::now() @@ -429,4 +1133,13 @@ mod tests { std::process::id() )) } + + fn write_fake_runtime_binaries(rootfs: &Path) { + fs::write( + rootfs.join("opt/openshell/bin/openshell-sandbox"), + b"sandbox", + ) + .expect("write openshell-sandbox"); + fs::write(rootfs.join("opt/openshell/bin/umoci"), b"umoci").expect("write umoci"); + } } diff --git a/crates/openshell-driver-vm/src/runtime.rs b/crates/openshell-driver-vm/src/runtime.rs index 758808c8e..1ce6fb26b 100644 --- a/crates/openshell-driver-vm/src/runtime.rs +++ b/crates/openshell-driver-vm/src/runtime.rs @@ -10,7 +10,7 @@ use std::ptr; use std::sync::atomic::{AtomicI32, Ordering}; use std::time::{Duration, Instant}; -use crate::{embedded_runtime, ffi, procguard}; +use crate::{embedded_runtime, ffi, nft_ruleset, procguard, rootfs}; pub const VM_RUNTIME_DIR_ENV: &str = "OPENSHELL_VM_RUNTIME_DIR"; @@ -18,7 +18,7 @@ pub const VM_RUNTIME_DIR_ENV: &str = "OPENSHELL_VM_RUNTIME_DIR"; /// Used by the SIGTERM/SIGINT handler to forward signals to the VM. static CHILD_PID: AtomicI32 = AtomicI32::new(0); -/// PID of the helper process (gvproxy for libkrun, virtiofsd for QEMU). +/// PID of the helper process (gvproxy for libkrun; zero for QEMU). /// Zero when not running. Used by the SIGTERM/SIGINT handler and /// procguard cleanup callback to ensure the helper doesn't outlive the /// launcher (especially on macOS where `PR_SET_PDEATHSIG` is absent). @@ -45,7 +45,9 @@ const COMPAT_NET_FEATURES: u32 = NET_FEATURE_CSUM | NET_FEATURE_HOST_UFO; pub struct VmLaunchConfig { - pub rootfs: PathBuf, + pub root_disk: PathBuf, + pub overlay_disk: PathBuf, + pub image_disk: Option, pub vcpus: u8, pub mem_mib: u32, pub exec_path: String, @@ -96,12 +98,23 @@ fn run_qemu_vm(config: &VmLaunchConfig) -> Result<(), String> { .as_deref() .ok_or("host_ip is required for QEMU backend")?; - if !config.rootfs.is_dir() { + if !config.root_disk.is_file() { return Err(format!( - "rootfs directory not found: {}", - config.rootfs.display() + "root disk image not found: {}", + config.root_disk.display() )); } + if !config.overlay_disk.is_file() { + return Err(format!( + "overlay disk image not found: {}", + config.overlay_disk.display() + )); + } + if let Some(image_disk) = &config.image_disk + && !image_disk.is_file() + { + return Err(format!("image disk not found: {}", image_disk.display())); + } if let Err(err) = procguard::die_with_parent_cleanup(procguard_kill_children) { return Err(format!("procguard arm failed: {err}")); @@ -111,70 +124,13 @@ fn run_qemu_vm(config: &VmLaunchConfig) -> Result<(), String> { check_kvm_access()?; let guest_env = qemu_guest_env_vars(config, host_dns_server()); - write_guest_env_file(&config.rootfs, &guest_env)?; - - let rootfs_str = config.rootfs.to_str().ok_or("rootfs path not UTF-8")?; - let sandbox_dir = config.rootfs.parent().unwrap_or(&config.rootfs); - let sock_prefix = tap_device.trim_start_matches("vmtap-"); - let virtiofsd_sock_dir = PathBuf::from(format!("/tmp/ovm-qemu-{sock_prefix}")); - std::fs::create_dir_all(&virtiofsd_sock_dir) - .map_err(|e| format!("create virtiofsd sock dir: {e}"))?; - let virtiofsd_sock = virtiofsd_sock_dir.join("virtiofsd.sock"); - let shm_path = format!("/dev/shm/ovm-qemu-{sock_prefix}"); - - std::fs::create_dir_all(&shm_path).map_err(|e| format!("create shm dir: {e}"))?; + write_guest_env_file(&config.overlay_disk, &guest_env)?; let runtime_dir = qemu_runtime_dir()?; - let gw_port = config.gateway_port.unwrap_or(0); setup_tap_networking(tap_device, host_ip, gw_port)?; let mut tap_guard = TapGuard::new(tap_device.to_string(), host_ip.to_string(), gw_port); - let virtiofsd_log = sandbox_dir.join("virtiofsd.log"); - let virtiofsd_log_file = - std::fs::File::create(&virtiofsd_log).map_err(|e| format!("create virtiofsd log: {e}"))?; - - let virtiofsd_bin = { - let runtime_virtiofsd = runtime_dir.join("virtiofsd"); - if runtime_virtiofsd.is_file() { - runtime_virtiofsd - } else { - PathBuf::from("virtiofsd") - } - }; - - let mut virtiofsd_cmd = StdCommand::new(&virtiofsd_bin); - virtiofsd_cmd - .arg("--socket-path") - .arg(&virtiofsd_sock) - .arg("--shared-dir") - .arg(rootfs_str) - .arg("--cache=auto") - .stdin(Stdio::null()) - .stdout(Stdio::null()) - .stderr(virtiofsd_log_file); - - #[cfg(target_os = "linux")] - { - use nix::sys::signal::Signal; - use std::os::unix::process::CommandExt as _; - unsafe { - virtiofsd_cmd.pre_exec(|| { - nix::sys::prctl::set_pdeathsig(Signal::SIGKILL) - .map_err(|err| std::io::Error::other(format!("pdeathsig: {err}"))) - }); - } - } - - let virtiofsd_child = virtiofsd_cmd - .spawn() - .map_err(|e| format!("failed to start virtiofsd: {e}"))?; - let virtiofsd_pid = virtiofsd_child.id().cast_signed(); - GVPROXY_PID.store(virtiofsd_pid, Ordering::Relaxed); - let mut virtiofsd_guard = GvproxyGuard::new(virtiofsd_child); - - wait_for_path(&virtiofsd_sock, Duration::from_secs(5), "virtiofsd socket")?; - let vmlinux = runtime_dir.join("vmlinux"); if !vmlinux.is_file() { return Err(format!("VM kernel not found: {}", vmlinux.display())); @@ -198,20 +154,7 @@ fn run_qemu_vm(config: &VmLaunchConfig) -> Result<(), String> { .arg(&vmlinux) .arg("-append") .arg(&kernel_cmdline) - .arg("-chardev") - .arg(format!( - "socket,id=virtiofs,path={}", - virtiofsd_sock.display() - )) - .arg("-device") - .arg("vhost-user-fs-pci,chardev=virtiofs,tag=rootfs") - .arg("-object") - .arg(format!( - "memory-backend-memfd,id=mem,size={}M,share=on", - config.mem_mib - )) - .arg("-numa") - .arg("node,memdev=mem") + .args(qemu_disk_args(config)) .arg("-netdev") .arg(format!( "tap,id=net0,ifname={tap_device},script=no,downscript=no" @@ -263,15 +206,8 @@ fn run_qemu_vm(config: &VmLaunchConfig) -> Result<(), String> { .map_err(|e| format!("failed to wait for QEMU: {e}"))?; CHILD_PID.store(0, Ordering::Relaxed); - unsafe { - libc::kill(virtiofsd_pid, libc::SIGTERM); - } - virtiofsd_guard.disarm(); - GVPROXY_PID.store(0, Ordering::Relaxed); teardown_tap_networking(tap_device, host_ip, gw_port); tap_guard.disarm(); - let _ = std::fs::remove_dir_all(&shm_path); - let _ = std::fs::remove_dir_all(&virtiofsd_sock_dir); if status.success() { Ok(()) @@ -280,12 +216,42 @@ fn run_qemu_vm(config: &VmLaunchConfig) -> Result<(), String> { } } -/// Write environment variables into the rootfs so the guest init script -/// can source them. virtiofs shares the host rootfs directory into the guest. -fn write_guest_env_file(rootfs: &Path, env_vars: &[String]) -> Result<(), String> { - let srv_dir = rootfs.join("srv"); - std::fs::create_dir_all(&srv_dir).map_err(|e| format!("create /srv in rootfs: {e}"))?; - let env_file = srv_dir.join("openshell-env.sh"); +fn qemu_disk_args(config: &VmLaunchConfig) -> Vec { + let mut args = vec![ + "-drive".to_string(), + format!( + "file={},if=none,format=raw,id=rootfs,readonly=on", + config.root_disk.display() + ), + "-device".to_string(), + "virtio-blk-pci,drive=rootfs".to_string(), + "-drive".to_string(), + format!( + "file={},if=none,format=raw,id=overlay", + config.overlay_disk.display() + ), + "-device".to_string(), + "virtio-blk-pci,drive=overlay".to_string(), + ]; + if let Some(image_disk) = &config.image_disk { + args.extend([ + "-drive".to_string(), + format!( + "file={},if=none,format=raw,id=image,readonly=on", + image_disk.display() + ), + "-device".to_string(), + "virtio-blk-pci,drive=image".to_string(), + ]); + } + args +} + +/// Write environment variables into the overlay disk so the guest init script +/// can source them after the overlay root is mounted. QEMU does not provide a +/// `krun_set_exec` equivalent, so the launcher injects this small per-sandbox +/// file into the overlay upperdir before boot. +fn write_guest_env_file(overlay_disk: &Path, env_vars: &[String]) -> Result<(), String> { let mut content = String::new(); for var in env_vars { if let Some((key, value)) = var.split_once('=') { @@ -293,8 +259,11 @@ fn write_guest_env_file(rootfs: &Path, env_vars: &[String]) -> Result<(), String let _ = writeln!(content, "export {key}=\"{}\"", shell_escape(value)); } } - std::fs::write(&env_file, &content).map_err(|e| format!("write guest env file: {e}"))?; - Ok(()) + rootfs::write_rootfs_image_file( + overlay_disk, + "/upper/srv/openshell-env.sh", + content.as_bytes(), + ) } fn qemu_guest_env_vars(config: &VmLaunchConfig, dns_server: Option) -> Vec { @@ -331,9 +300,9 @@ fn shell_escape(s: &str) -> String { fn build_kernel_cmdline(config: &VmLaunchConfig) -> String { let mut parts = vec![ "console=ttyS0".to_string(), - "root=rootfs".to_string(), - "rootfstype=virtiofs".to_string(), - "rw".to_string(), + "root=/dev/vda".to_string(), + "rootfstype=ext4".to_string(), + "ro".to_string(), "panic=-1".to_string(), format!("init={}", config.exec_path), ]; @@ -444,6 +413,12 @@ fn setup_tap_networking(tap_device: &str, host_ip: &str, gateway_port: u16) -> R enable_ip_forwarding()?; let subnet = tap_subnet_from_host_ip(host_ip); + let table_name = nft_ruleset::teardown_table_name(tap_device); + + // Delete any stale nftables table from a previous driver run. + let _ = run_cmd("nft", &["delete", "table", "ip", &table_name]); + + // Clean up legacy iptables rules from older driver versions. let _ = run_cmd( "iptables", &[ @@ -457,27 +432,10 @@ fn setup_tap_networking(tap_device: &str, host_ip: &str, gateway_port: u16) -> R "MASQUERADE", ], ); - run_cmd( - "iptables", - &[ - "-t", - "nat", - "-A", - "POSTROUTING", - "-s", - &subnet, - "-j", - "MASQUERADE", - ], - )?; let _ = run_cmd( "iptables", &["-D", "FORWARD", "-i", tap_device, "-j", "ACCEPT"], ); - run_cmd( - "iptables", - &["-A", "FORWARD", "-i", tap_device, "-j", "ACCEPT"], - )?; let _ = run_cmd( "iptables", &[ @@ -493,25 +451,6 @@ fn setup_tap_networking(tap_device: &str, host_ip: &str, gateway_port: u16) -> R "ACCEPT", ], ); - run_cmd( - "iptables", - &[ - "-A", - "FORWARD", - "-o", - tap_device, - "-m", - "state", - "--state", - "RELATED,ESTABLISHED", - "-j", - "ACCEPT", - ], - )?; - // Allow guest → host traffic only to the gateway gRPC port. - // Previous versions accepted ALL inbound traffic from the TAP - // interface; scope to the specific port so the guest cannot reach - // other host services. let port_str = gateway_port.to_string(); let _ = run_cmd( "iptables", @@ -519,17 +458,24 @@ fn setup_tap_networking(tap_device: &str, host_ip: &str, gateway_port: u16) -> R "-D", "INPUT", "-i", tap_device, "-p", "tcp", "--dport", &port_str, "-j", "ACCEPT", ], ); - run_cmd( + let _ = run_cmd( "iptables", - &[ - "-A", "INPUT", "-i", tap_device, "-p", "tcp", "--dport", &port_str, "-j", "ACCEPT", - ], - )?; + &["-D", "INPUT", "-i", tap_device, "-j", "ACCEPT"], + ); + + // Load nftables ruleset atomically. + let ruleset = nft_ruleset::generate_tap_ruleset(tap_device, &subnet, gateway_port); + run_nft_stdin(&ruleset)?; Ok(()) } fn teardown_tap_networking(tap_device: &str, host_ip: &str, gateway_port: u16) { + // Delete the entire nftables table — single atomic operation. + let table_name = nft_ruleset::teardown_table_name(tap_device); + let _ = run_cmd("nft", &["delete", "table", "ip", &table_name]); + + // Clean up legacy iptables rules from older driver versions. let subnet = tap_subnet_from_host_ip(host_ip); let _ = run_cmd( "iptables", @@ -550,8 +496,6 @@ fn teardown_tap_networking(tap_device: &str, host_ip: &str, gateway_port: u16) { "iptables", &["-D", "FORWARD", "-i", tap_device, "-j", "ACCEPT"], ); - // Remove the port-scoped INPUT rule. Also try the legacy blanket - // rule so stale rules from older driver versions are cleaned up. if gateway_port > 0 { let port_str = gateway_port.to_string(); let _ = run_cmd( @@ -578,6 +522,7 @@ fn teardown_tap_networking(tap_device: &str, host_ip: &str, gateway_port: u16) { "MASQUERADE", ], ); + let _ = run_cmd("ip", &["link", "set", tap_device, "down"]); let _ = run_cmd("ip", &["tuntap", "del", "dev", tap_device, "mode", "tap"]); } @@ -614,6 +559,35 @@ fn run_cmd(cmd: &str, args: &[&str]) -> Result<(), String> { } } +fn run_nft_stdin(ruleset: &str) -> Result<(), String> { + use std::io::Write; + + let mut child = StdCommand::new("nft") + .args(["-f", "-"]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(|e| format!("failed to run nft: {e}"))?; + + if let Some(mut stdin) = child.stdin.take() { + stdin + .write_all(ruleset.as_bytes()) + .map_err(|e| format!("failed to write nft ruleset: {e}"))?; + } + + let output = child + .wait_with_output() + .map_err(|e| format!("failed to wait for nft: {e}"))?; + + if output.status.success() { + Ok(()) + } else { + let stderr = String::from_utf8_lossy(&output.stderr); + Err(format!("nft -f - failed: {stderr}")) + } +} + /// RAII guard that tears down TAP networking on drop. struct TapGuard { tap_device: String, @@ -674,12 +648,23 @@ fn procguard_kill_children() { } fn run_libkrun_vm(config: &VmLaunchConfig) -> Result<(), String> { - if !config.rootfs.is_dir() { + if !config.root_disk.is_file() { return Err(format!( - "rootfs directory not found: {}", - config.rootfs.display() + "root disk image not found: {}", + config.root_disk.display() )); } + if !config.overlay_disk.is_file() { + return Err(format!( + "overlay disk image not found: {}", + config.overlay_disk.display() + )); + } + if let Some(image_disk) = &config.image_disk + && !image_disk.is_file() + { + return Err(format!("image disk not found: {}", image_disk.display())); + } // Arm procguard first, BEFORE we spawn gvproxy or fork libkrun, so // that the launcher can't be orphaned during setup. The cleanup @@ -702,7 +687,11 @@ fn run_libkrun_vm(config: &VmLaunchConfig) -> Result<(), String> { let vm = VmContext::create(&runtime_dir, config.log_level)?; vm.set_vm_config(config.vcpus, config.mem_mib)?; - vm.set_root(&config.rootfs)?; + vm.set_disks( + &config.root_disk, + &config.overlay_disk, + config.image_disk.as_deref(), + )?; vm.set_workdir(&config.workdir)?; // Run gvproxy strictly as the guest's virtual NIC / DHCP / router. @@ -731,7 +720,7 @@ fn run_libkrun_vm(config: &VmLaunchConfig) -> Result<(), String> { // on its own service ports (DNS:53, DHCP, HTTP API:80). // // That network plane is also what the sandbox supervisor's - // per-sandbox netns (veth pair + iptables, see + // per-sandbox netns (veth pair + nftables, see // `openshell-sandbox/src/sandbox/linux/netns.rs`) branches off of; // libkrun's built-in TSI socket impersonation would not satisfy // those kernel-level primitives. @@ -749,12 +738,12 @@ fn run_libkrun_vm(config: &VmLaunchConfig) -> Result<(), String> { )); } - let sock_base = gvproxy_socket_base(&config.rootfs)?; + let sock_base = gvproxy_socket_base(&config.overlay_disk)?; let net_sock = sock_base.with_extension("v"); let _ = std::fs::remove_file(&net_sock); let _ = std::fs::remove_file(sock_base.with_extension("v-krun.sock")); - let run_dir = config.rootfs.parent().unwrap_or(&config.rootfs); + let run_dir = config.overlay_disk.parent().unwrap_or(&config.overlay_disk); let gvproxy_log = run_dir.join("gvproxy.log"); let gvproxy_log_file = std::fs::File::create(&gvproxy_log) .map_err(|e| format!("create gvproxy log {}: {e}", gvproxy_log.display()))?; @@ -1013,11 +1002,74 @@ impl VmContext { ) } - fn set_root(&self, rootfs: &Path) -> Result<(), String> { - let rootfs_c = path_to_cstring(rootfs)?; + fn set_disks( + &self, + root_disk: &Path, + overlay_disk: &Path, + image_disk: Option<&Path>, + ) -> Result<(), String> { + let root_disk_c = path_to_cstring(root_disk)?; + let block_id_c = CString::new("root").map_err(|e| format!("invalid block id: {e}"))?; check( - unsafe { (self.krun.krun_set_root)(self.ctx_id, rootfs_c.as_ptr()) }, - "krun_set_root", + unsafe { + (self.krun.krun_add_disk)( + self.ctx_id, + block_id_c.as_ptr(), + root_disk_c.as_ptr(), + true, + ) + }, + "krun_add_disk", + )?; + + let overlay_disk_c = path_to_cstring(overlay_disk)?; + let overlay_block_id_c = + CString::new("overlay").map_err(|e| format!("invalid block id: {e}"))?; + check( + unsafe { + (self.krun.krun_add_disk)( + self.ctx_id, + overlay_block_id_c.as_ptr(), + overlay_disk_c.as_ptr(), + false, + ) + }, + "krun_add_disk", + )?; + + if let Some(image_disk) = image_disk { + let image_disk_c = path_to_cstring(image_disk)?; + let image_block_id_c = + CString::new("image").map_err(|e| format!("invalid image block id: {e}"))?; + check( + unsafe { + (self.krun.krun_add_disk)( + self.ctx_id, + image_block_id_c.as_ptr(), + image_disk_c.as_ptr(), + true, + ) + }, + "krun_add_disk", + )?; + } + + let device_c = + CString::new("/dev/vda").map_err(|e| format!("invalid root disk device: {e}"))?; + let fstype_c = + CString::new("ext4").map_err(|e| format!("invalid root disk fstype: {e}"))?; + let options_c = + CString::new("ro").map_err(|e| format!("invalid root disk options: {e}"))?; + check( + unsafe { + (self.krun.krun_set_root_disk_remount)( + self.ctx_id, + device_c.as_ptr(), + fstype_c.as_ptr(), + options_c.as_ptr(), + ) + }, + "krun_set_root_disk_remount", ) } @@ -1234,8 +1286,8 @@ fn secure_socket_base(subdir: &str) -> Result { Ok(dir) } -fn gvproxy_socket_base(rootfs: &Path) -> Result { - Ok(secure_socket_base("osd-gv")?.join(hash_path_id(rootfs))) +fn gvproxy_socket_base(overlay_disk: &Path) -> Result { + Ok(secure_socket_base("osd-gv")?.join(hash_path_id(overlay_disk))) } fn install_signal_forwarding(pid: i32) { @@ -1342,7 +1394,9 @@ mod tests { fn qemu_config() -> VmLaunchConfig { VmLaunchConfig { - rootfs: PathBuf::from("/rootfs"), + root_disk: PathBuf::from("/rootfs.ext4"), + overlay_disk: PathBuf::from("/overlay.ext4"), + image_disk: None, vcpus: 2, mem_mib: 2048, exec_path: "/srv/openshell-vm-sandbox-init.sh".to_string(), @@ -1377,6 +1431,9 @@ mod tests { fn kernel_cmdline_keeps_guest_init_metadata_out_of_proc_cmdline() { let cmdline = build_kernel_cmdline(&qemu_config()); + assert!(cmdline.contains("root=/dev/vda")); + assert!(cmdline.contains("rootfstype=ext4")); + assert!(cmdline.contains(" ro")); assert!(cmdline.contains("ip=10.0.128.2::10.0.128.1:255.255.255.252:sandbox::off")); assert!(cmdline.contains("firmware_class.path=/lib/firmware")); assert!(!cmdline.contains("VM_NET_IP=")); @@ -1384,4 +1441,62 @@ mod tests { assert!(!cmdline.contains("VM_NET_DNS=")); assert!(!cmdline.contains("GPU_ENABLED=")); } + + #[test] + fn qemu_disk_args_attach_base_readonly_and_overlay_readwrite() { + let args = qemu_disk_args(&qemu_config()); + + assert!(args.contains(&"-drive".to_string())); + assert!( + args.contains( + &"file=/rootfs.ext4,if=none,format=raw,id=rootfs,readonly=on".to_string() + ) + ); + assert!(args.contains(&"virtio-blk-pci,drive=rootfs".to_string())); + assert!(args.contains(&"file=/overlay.ext4,if=none,format=raw,id=overlay".to_string())); + assert!( + !args + .iter() + .any(|arg| arg.contains("id=overlay,readonly=on")) + ); + assert!(args.contains(&"virtio-blk-pci,drive=overlay".to_string())); + } + + #[test] + fn qemu_disk_args_attach_prepared_image_readonly_when_present() { + let mut config = qemu_config(); + config.image_disk = Some(PathBuf::from("/image-rootfs.ext4")); + + let args = qemu_disk_args(&config); + + assert!(args.contains( + &"file=/image-rootfs.ext4,if=none,format=raw,id=image,readonly=on".to_string() + )); + assert!(args.contains(&"virtio-blk-pci,drive=image".to_string())); + } + + #[test] + fn gvproxy_socket_base_is_per_sandbox_overlay_path() { + let first = + gvproxy_socket_base(Path::new("/tmp/openshell-vm/sandboxes/first/overlay.ext4")) + .expect("first socket base"); + let second = + gvproxy_socket_base(Path::new("/tmp/openshell-vm/sandboxes/second/overlay.ext4")) + .expect("second socket base"); + + assert_ne!(first, second); + } + + #[test] + fn tap_subnet_from_host_ip_calculates_slash30_base() { + assert_eq!(tap_subnet_from_host_ip("10.0.128.1"), "10.0.128.0/30"); + assert_eq!(tap_subnet_from_host_ip("10.0.128.2"), "10.0.128.0/30"); + assert_eq!(tap_subnet_from_host_ip("10.0.128.5"), "10.0.128.4/30"); + } + + #[test] + fn tap_subnet_from_host_ip_handles_invalid_ip() { + let result = tap_subnet_from_host_ip("not-an-ip"); + assert_eq!(result, "not-an-ip/30"); + } } diff --git a/crates/openshell-ocsf/src/builders/base.rs b/crates/openshell-ocsf/src/builders/base.rs index 791e5a094..ac51f7dd7 100644 --- a/crates/openshell-ocsf/src/builders/base.rs +++ b/crates/openshell-ocsf/src/builders/base.rs @@ -31,21 +31,6 @@ impl<'a> BaseEventBuilder<'a> { } } - #[must_use] - pub fn severity(mut self, id: SeverityId) -> Self { - self.severity = id; - self - } - #[must_use] - pub fn status(mut self, id: StatusId) -> Self { - self.status = Some(id); - self - } - #[must_use] - pub fn message(mut self, msg: impl Into) -> Self { - self.message = Some(msg.into()); - self - } #[must_use] pub fn activity_name(mut self, name: impl Into) -> Self { self.activity_name = Some(name.into()); @@ -72,22 +57,18 @@ impl<'a> BaseEventBuilder<'a> { self.severity, self.ctx.metadata(&["container", "host"]), ); - if let Some(status) = self.status { - base.set_status(status); - } - if let Some(msg) = self.message { - base.set_message(msg); - } - base.set_device(self.ctx.device()); - base.set_container(self.ctx.container()); if !self.unmapped.is_empty() { base.unmapped = Some(serde_json::Value::Object(self.unmapped)); } + self.ctx + .apply_common_fields(&mut base, self.status, self.message); OcsfEvent::Base(BaseEvent { base }) } } +impl_builder_setters!(BaseEventBuilder); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-ocsf/src/builders/config.rs b/crates/openshell-ocsf/src/builders/config.rs index 8ff9cae2f..d8d40dd24 100644 --- a/crates/openshell-ocsf/src/builders/config.rs +++ b/crates/openshell-ocsf/src/builders/config.rs @@ -37,22 +37,6 @@ impl<'a> ConfigStateChangeBuilder<'a> { } } - #[must_use] - pub fn severity(mut self, id: SeverityId) -> Self { - self.severity = id; - self - } - #[must_use] - pub fn status(mut self, id: StatusId) -> Self { - self.status = Some(id); - self - } - #[must_use] - pub fn message(mut self, msg: impl Into) -> Self { - self.message = Some(msg.into()); - self - } - /// Set state with a custom label (OCSF `state_id` + display label). #[must_use] pub fn state(mut self, id: StateId, label: &str) -> Self { @@ -92,17 +76,11 @@ impl<'a> ConfigStateChangeBuilder<'a> { self.ctx .metadata(&["security_control", "container", "host"]), ); - if let Some(status) = self.status { - base.set_status(status); - } - if let Some(msg) = self.message { - base.set_message(msg); - } - base.set_device(self.ctx.device()); - base.set_container(self.ctx.container()); if !self.unmapped.is_empty() { base.unmapped = Some(serde_json::Value::Object(self.unmapped)); } + self.ctx + .apply_common_fields(&mut base, self.status, self.message); OcsfEvent::DeviceConfigStateChange(DeviceConfigStateChangeEvent { base, @@ -114,6 +92,8 @@ impl<'a> ConfigStateChangeBuilder<'a> { } } +impl_builder_setters!(ConfigStateChangeBuilder); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-ocsf/src/builders/finding.rs b/crates/openshell-ocsf/src/builders/finding.rs index 10d770f46..ed7290ea3 100644 --- a/crates/openshell-ocsf/src/builders/finding.rs +++ b/crates/openshell-ocsf/src/builders/finding.rs @@ -54,11 +54,6 @@ impl<'a> DetectionFindingBuilder<'a> { self } #[must_use] - pub fn severity(mut self, id: SeverityId) -> Self { - self.severity = id; - self - } - #[must_use] pub fn action(mut self, id: ActionId) -> Self { self.action = Some(id); self @@ -89,11 +84,6 @@ impl<'a> DetectionFindingBuilder<'a> { self } #[must_use] - pub fn message(mut self, msg: impl Into) -> Self { - self.message = Some(msg.into()); - self - } - #[must_use] pub fn log_source(mut self, source: impl Into) -> Self { self.log_source = Some(source.into()); self @@ -147,11 +137,7 @@ impl<'a> DetectionFindingBuilder<'a> { self.severity, metadata, ); - if let Some(msg) = self.message { - base.set_message(msg); - } - base.set_device(self.ctx.device()); - base.set_container(self.ctx.container()); + self.ctx.apply_common_fields(&mut base, None, self.message); OcsfEvent::DetectionFinding(DetectionFindingEvent { base, @@ -178,6 +164,8 @@ impl<'a> DetectionFindingBuilder<'a> { } } +impl_builder_setters!(DetectionFindingBuilder, no_status); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-ocsf/src/builders/http.rs b/crates/openshell-ocsf/src/builders/http.rs index 96919f281..0530d26c6 100644 --- a/crates/openshell-ocsf/src/builders/http.rs +++ b/crates/openshell-ocsf/src/builders/http.rs @@ -64,16 +64,6 @@ impl<'a> HttpActivityBuilder<'a> { self } #[must_use] - pub fn severity(mut self, id: SeverityId) -> Self { - self.severity = id; - self - } - #[must_use] - pub fn status(mut self, id: StatusId) -> Self { - self.status = Some(id); - self - } - #[must_use] pub fn http_request(mut self, req: HttpRequest) -> Self { self.http_request = Some(req); self @@ -104,11 +94,6 @@ impl<'a> HttpActivityBuilder<'a> { self } #[must_use] - pub fn message(mut self, msg: impl Into) -> Self { - self.message = Some(msg.into()); - self - } - #[must_use] pub fn status_detail(mut self, detail: impl Into) -> Self { self.status_detail = Some(detail.into()); self @@ -128,17 +113,11 @@ impl<'a> HttpActivityBuilder<'a> { self.ctx .metadata(&["security_control", "network_proxy", "container", "host"]), ); - if let Some(status) = self.status { - base.set_status(status); - } - if let Some(msg) = self.message { - base.set_message(msg); - } if let Some(detail) = self.status_detail { base.set_status_detail(detail); } - base.set_device(self.ctx.device()); - base.set_container(self.ctx.container()); + self.ctx + .apply_common_fields(&mut base, self.status, self.message); OcsfEvent::HttpActivity(HttpActivityEvent { base, @@ -157,6 +136,8 @@ impl<'a> HttpActivityBuilder<'a> { } } +impl_builder_setters!(HttpActivityBuilder); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-ocsf/src/builders/lifecycle.rs b/crates/openshell-ocsf/src/builders/lifecycle.rs index b0d3a6007..9b5095a7a 100644 --- a/crates/openshell-ocsf/src/builders/lifecycle.rs +++ b/crates/openshell-ocsf/src/builders/lifecycle.rs @@ -35,21 +35,6 @@ impl<'a> AppLifecycleBuilder<'a> { self.activity = id; self } - #[must_use] - pub fn severity(mut self, id: SeverityId) -> Self { - self.severity = id; - self - } - #[must_use] - pub fn status(mut self, id: StatusId) -> Self { - self.status = Some(id); - self - } - #[must_use] - pub fn message(mut self, msg: impl Into) -> Self { - self.message = Some(msg.into()); - self - } #[must_use] pub fn build(self) -> OcsfEvent { @@ -64,14 +49,8 @@ impl<'a> AppLifecycleBuilder<'a> { self.severity, self.ctx.metadata(&["container", "host"]), ); - if let Some(status) = self.status { - base.set_status(status); - } - if let Some(msg) = self.message { - base.set_message(msg); - } - base.set_device(self.ctx.device()); - base.set_container(self.ctx.container()); + self.ctx + .apply_common_fields(&mut base, self.status, self.message); OcsfEvent::ApplicationLifecycle(ApplicationLifecycleEvent { base, @@ -80,6 +59,8 @@ impl<'a> AppLifecycleBuilder<'a> { } } +impl_builder_setters!(AppLifecycleBuilder); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-ocsf/src/builders/mod.rs b/crates/openshell-ocsf/src/builders/mod.rs index 77004da4a..7f5cc4f9b 100644 --- a/crates/openshell-ocsf/src/builders/mod.rs +++ b/crates/openshell-ocsf/src/builders/mod.rs @@ -6,6 +6,55 @@ //! Each event class has a builder that takes a `SandboxContext` reference //! and provides chainable methods for setting event fields. +/// Generate the shared `severity`, `status`, and `message` setter methods that +/// appear identically on every OCSF event builder in this crate. +/// +/// # Usage +/// ```ignore +/// impl_builder_setters!(MyBuilder); // severity + status + message +/// impl_builder_setters!(MyBuilder, no_status); // severity + message only +/// ``` +macro_rules! impl_builder_setters { + ($builder:ident) => { + impl<'a> $builder<'a> { + /// Set the event severity. + #[must_use] + pub fn severity(mut self, id: $crate::enums::SeverityId) -> Self { + self.severity = id; + self + } + /// Set the overall event status. + #[must_use] + pub fn status(mut self, id: $crate::enums::StatusId) -> Self { + self.status = Some(id); + self + } + /// Set a human-readable event message. + #[must_use] + pub fn message(mut self, msg: impl Into) -> Self { + self.message = Some(msg.into()); + self + } + } + }; + ($builder:ident, no_status) => { + impl<'a> $builder<'a> { + /// Set the event severity. + #[must_use] + pub fn severity(mut self, id: $crate::enums::SeverityId) -> Self { + self.severity = id; + self + } + /// Set a human-readable event message. + #[must_use] + pub fn message(mut self, msg: impl Into) -> Self { + self.message = Some(msg.into()); + self + } + } + }; +} + mod base; mod config; mod finding; @@ -27,6 +76,8 @@ pub use ssh::SshActivityBuilder; use std::net::IpAddr; use crate::OCSF_VERSION; +use crate::enums::StatusId; +use crate::events::base_event::BaseEventData; use crate::objects::{Container, Device, Endpoint, Image, Metadata, Product}; /// Immutable context created once at sandbox startup. @@ -87,6 +138,24 @@ impl SandboxContext { pub fn proxy_endpoint(&self) -> Endpoint { Endpoint::from_ip(self.proxy_ip, self.proxy_port) } + + /// Apply the fields common to every builder's `build()` method onto `base`: + /// optional status, optional message, device info, and container info. + pub fn apply_common_fields( + &self, + base: &mut BaseEventData, + status: Option, + message: Option, + ) { + if let Some(s) = status { + base.set_status(s); + } + if let Some(m) = message { + base.set_message(m); + } + base.set_device(self.device()); + base.set_container(self.container()); + } } #[cfg(test)] diff --git a/crates/openshell-ocsf/src/builders/network.rs b/crates/openshell-ocsf/src/builders/network.rs index d0a79925b..42b9bc488 100644 --- a/crates/openshell-ocsf/src/builders/network.rs +++ b/crates/openshell-ocsf/src/builders/network.rs @@ -78,16 +78,6 @@ impl<'a> NetworkActivityBuilder<'a> { self } #[must_use] - pub fn severity(mut self, id: SeverityId) -> Self { - self.severity = id; - self - } - #[must_use] - pub fn status(mut self, id: StatusId) -> Self { - self.status = Some(id); - self - } - #[must_use] pub fn src_endpoint_addr(mut self, ip: IpAddr, port: u16) -> Self { self.src_endpoint = Some(Endpoint::from_ip(ip, port)); self @@ -118,11 +108,6 @@ impl<'a> NetworkActivityBuilder<'a> { self } #[must_use] - pub fn message(mut self, msg: impl Into) -> Self { - self.message = Some(msg.into()); - self - } - #[must_use] pub fn status_detail(mut self, detail: impl Into) -> Self { self.status_detail = Some(detail.into()); self @@ -166,20 +151,14 @@ impl<'a> NetworkActivityBuilder<'a> { metadata, ); - if let Some(status) = self.status { - base.set_status(status); - } - if let Some(msg) = self.message { - base.set_message(msg); - } if let Some(detail) = self.status_detail { base.set_status_detail(detail); } - base.set_device(self.ctx.device()); - base.set_container(self.ctx.container()); if let Some(unmapped) = self.unmapped { base.unmapped = Some(serde_json::Value::Object(unmapped)); } + self.ctx + .apply_common_fields(&mut base, self.status, self.message); OcsfEvent::NetworkActivity(NetworkActivityEvent { base, @@ -197,6 +176,8 @@ impl<'a> NetworkActivityBuilder<'a> { } } +impl_builder_setters!(NetworkActivityBuilder); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-ocsf/src/builders/process.rs b/crates/openshell-ocsf/src/builders/process.rs index 8ede8012c..d66790eae 100644 --- a/crates/openshell-ocsf/src/builders/process.rs +++ b/crates/openshell-ocsf/src/builders/process.rs @@ -48,16 +48,6 @@ impl<'a> ProcessActivityBuilder<'a> { self } #[must_use] - pub fn severity(mut self, id: SeverityId) -> Self { - self.severity = id; - self - } - #[must_use] - pub fn status(mut self, id: StatusId) -> Self { - self.status = Some(id); - self - } - #[must_use] pub fn action(mut self, id: ActionId) -> Self { self.action = Some(id); self @@ -87,11 +77,6 @@ impl<'a> ProcessActivityBuilder<'a> { self.exit_code = Some(code); self } - #[must_use] - pub fn message(mut self, msg: impl Into) -> Self { - self.message = Some(msg.into()); - self - } #[must_use] pub fn build(self) -> OcsfEvent { @@ -107,14 +92,8 @@ impl<'a> ProcessActivityBuilder<'a> { self.ctx .metadata(&["security_control", "container", "host"]), ); - if let Some(status) = self.status { - base.set_status(status); - } - if let Some(msg) = self.message { - base.set_message(msg); - } - base.set_device(self.ctx.device()); - base.set_container(self.ctx.container()); + self.ctx + .apply_common_fields(&mut base, self.status, self.message); OcsfEvent::ProcessActivity(ProcessActivityEvent { base, @@ -128,6 +107,8 @@ impl<'a> ProcessActivityBuilder<'a> { } } +impl_builder_setters!(ProcessActivityBuilder); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-ocsf/src/builders/ssh.rs b/crates/openshell-ocsf/src/builders/ssh.rs index 6df01f3d1..0df704f31 100644 --- a/crates/openshell-ocsf/src/builders/ssh.rs +++ b/crates/openshell-ocsf/src/builders/ssh.rs @@ -64,16 +64,6 @@ impl<'a> SshActivityBuilder<'a> { self } #[must_use] - pub fn severity(mut self, id: SeverityId) -> Self { - self.severity = id; - self - } - #[must_use] - pub fn status(mut self, id: StatusId) -> Self { - self.status = Some(id); - self - } - #[must_use] pub fn src_endpoint_addr(mut self, ip: IpAddr, port: u16) -> Self { self.src_endpoint = Some(Endpoint::from_ip(ip, port)); self @@ -88,11 +78,6 @@ impl<'a> SshActivityBuilder<'a> { self.actor = Some(Actor { process }); self } - #[must_use] - pub fn message(mut self, msg: impl Into) -> Self { - self.message = Some(msg.into()); - self - } /// Set auth type with a custom label (e.g., "NSSH1"). #[must_use] @@ -122,14 +107,8 @@ impl<'a> SshActivityBuilder<'a> { self.ctx .metadata(&["security_control", "container", "host"]), ); - if let Some(status) = self.status { - base.set_status(status); - } - if let Some(msg) = self.message { - base.set_message(msg); - } - base.set_device(self.ctx.device()); - base.set_container(self.ctx.container()); + self.ctx + .apply_common_fields(&mut base, self.status, self.message); OcsfEvent::SshActivity(SshActivityEvent { base, @@ -145,6 +124,8 @@ impl<'a> SshActivityBuilder<'a> { } } +impl_builder_setters!(SshActivityBuilder); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-ocsf/src/events/network_activity.rs b/crates/openshell-ocsf/src/events/network_activity.rs index 6cd125fdc..92450bbe8 100644 --- a/crates/openshell-ocsf/src/events/network_activity.rs +++ b/crates/openshell-ocsf/src/events/network_activity.rs @@ -11,7 +11,7 @@ use crate::objects::{Actor, ConnectionInfo, Endpoint, FirewallRule}; /// OCSF Network Activity Event [4001]. /// -/// Proxy CONNECT tunnel events and iptables-level bypass detection. +/// Proxy CONNECT tunnel events and nftables bypass detection. #[derive(Debug, Clone, PartialEq, Eq, Deserialize)] pub struct NetworkActivityEvent { /// Common base event fields. diff --git a/crates/openshell-ocsf/src/format/shorthand.rs b/crates/openshell-ocsf/src/format/shorthand.rs index 42b30fbae..0e50fc6c5 100644 --- a/crates/openshell-ocsf/src/format/shorthand.rs +++ b/crates/openshell-ocsf/src/format/shorthand.rs @@ -61,6 +61,7 @@ pub fn severity_tag(severity_id: u8) -> &'static str { /// Max length for the reason text in `[reason:...]` before truncation. const MAX_REASON_LEN: usize = 80; +const MAX_MESSAGE_LEN: usize = 120; /// Format a `[reason:...]` tag from `status_detail` (or `message` fallback) /// for denied events. Returns an empty string if neither field is set. @@ -80,6 +81,19 @@ fn reason_tag(base: &BaseEventData) -> String { } } +fn message_tag(base: &BaseEventData) -> String { + let text = base.message.as_deref().unwrap_or(""); + if text.is_empty() { + return String::new(); + } + let text = text.replace(['\n', '\r'], " "); + if text.len() > MAX_MESSAGE_LEN { + format!(" [msg:{}...]", &text[..MAX_MESSAGE_LEN]) + } else { + format!(" [msg:{text}]") + } +} + impl OcsfEvent { /// Produce the single-line shorthand for `openshell.log` and gRPC log push. /// @@ -140,7 +154,13 @@ impl OcsfEvent { (false, true) => format!(" {action}"), (false, false) => format!(" {action}{arrow}"), }; - format!("NET:{activity} {sev}{detail}{rule_ctx}{reason_ctx}") + let message_ctx = + if detail.is_empty() && rule_ctx.is_empty() && reason_ctx.is_empty() { + message_tag(&e.base) + } else { + String::new() + }; + format!("NET:{activity} {sev}{detail}{rule_ctx}{reason_ctx}{message_ctx}") } Self::HttpActivity(e) => { @@ -436,7 +456,7 @@ mod tests { actor: Some(Actor { process: Process::new("node", 1234), }), - firewall_rule: Some(FirewallRule::new("bypass-detect", "iptables")), + firewall_rule: Some(FirewallRule::new("bypass-detect", "nftables")), connection_info: Some(ConnectionInfo::new("tcp")), action: Some(ActionId::Denied), disposition: Some(DispositionId::Blocked), @@ -447,7 +467,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "NET:REFUSE [MED] DENIED node(1234) -> 93.184.216.34:443/tcp [policy:bypass-detect engine:iptables]" + "NET:REFUSE [MED] DENIED node(1234) -> 93.184.216.34:443/tcp [policy:bypass-detect engine:nftables]" ); } @@ -541,6 +561,33 @@ mod tests { ); } + #[test] + fn test_network_activity_shorthand_shows_message_when_no_key_fields() { + let event = OcsfEvent::NetworkActivity(NetworkActivityEvent { + base: { + let mut b = base(4001, "Network Activity", 4, "Network Activity", 1, "Open"); + b.set_message("relay open (channel_id=ch-42)"); + b + }, + src_endpoint: None, + dst_endpoint: None, + proxy_endpoint: None, + actor: None, + firewall_rule: None, + connection_info: None, + action: None, + disposition: None, + observation_point_id: None, + is_src_dst_assignment_known: None, + }); + + let shorthand = event.format_shorthand(); + assert_eq!( + shorthand, + "NET:OPEN [INFO] [msg:relay open (channel_id=ch-42)]" + ); + } + #[test] fn test_http_activity_shorthand_denied_shows_reason() { let mut b = base(4002, "HTTP Activity", 4, "Network Activity", 99, "Other"); diff --git a/crates/openshell-ocsf/src/objects/firewall_rule.rs b/crates/openshell-ocsf/src/objects/firewall_rule.rs index fa8829275..2e242225b 100644 --- a/crates/openshell-ocsf/src/objects/firewall_rule.rs +++ b/crates/openshell-ocsf/src/objects/firewall_rule.rs @@ -11,7 +11,7 @@ pub struct FirewallRule { /// Rule name (e.g., "default-egress", "bypass-detect"). pub name: String, - /// Rule type / engine (e.g., "mechanistic", "opa", "iptables"). + /// Rule type / engine (e.g., "mechanistic", "opa", "nftables"). /// /// Kept as `String` because this is a project-specific extension field /// (not OCSF-enumerated) with runtime-dynamic values from the policy engine. diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index f26136c6b..8936b85be 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -13,6 +13,7 @@ repository.workspace = true [dependencies] openshell-core = { path = "../openshell-core" } serde = { workspace = true } +serde_json = { workspace = true } serde_yml = { workspace = true } miette = { workspace = true } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 61df0aadb..8dbaf077c 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -27,7 +27,7 @@ use serde::{Deserialize, Serialize}; pub use compose::{ProviderPolicyLayer, compose_effective_policy, provider_rule_name}; pub use merge::{ PolicyMergeError, PolicyMergeOp, PolicyMergeResult, PolicyMergeWarning, generated_rule_name, - merge_policy, + merge_policy, policy_covers_rule, }; // --------------------------------------------------------------------------- @@ -120,6 +120,15 @@ struct NetworkEndpointDef { /// Defaults to false (strict). #[serde(default, skip_serializing_if = "std::ops::Not::not")] allow_encoded_slash: bool, + /// When true, client-to-server WebSocket text messages on this REST + /// endpoint rewrite credential placeholders after an allowed 101 upgrade. + /// Defaults to false. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + websocket_credential_rewrite: bool, + /// When true, supported textual REST request bodies rewrite credential + /// placeholders before forwarding upstream. Defaults to false. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + request_body_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] persisted_queries: String, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] @@ -317,6 +326,8 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { }) .collect(), allow_encoded_slash: e.allow_encoded_slash, + websocket_credential_rewrite: e.websocket_credential_rewrite, + request_body_credential_rewrite: e.request_body_credential_rewrite, persisted_queries: e.persisted_queries, graphql_persisted_queries: e .graphql_persisted_queries @@ -480,6 +491,8 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(), allow_encoded_slash: e.allow_encoded_slash, + websocket_credential_rewrite: e.websocket_credential_rewrite, + request_body_credential_rewrite: e.request_body_credential_rewrite, persisted_queries: e.persisted_queries.clone(), graphql_persisted_queries: e .graphql_persisted_queries @@ -545,6 +558,25 @@ pub fn serialize_sandbox_policy(policy: &SandboxPolicy) -> Result { .wrap_err("failed to serialize policy to YAML") } +/// Convert a proto sandbox policy into the canonical policy JSON representation. +/// +/// The shape mirrors the YAML schema used by [`serialize_sandbox_policy`], so +/// automation can use the same documented field names in either format. +pub fn sandbox_policy_to_json_value(policy: &SandboxPolicy) -> Result { + let json_repr = from_proto(policy); + serde_json::to_value(&json_repr) + .into_diagnostic() + .wrap_err("failed to serialize policy to JSON") +} + +/// Serialize a proto sandbox policy to a pretty-printed JSON string. +pub fn serialize_sandbox_policy_json(policy: &SandboxPolicy) -> Result { + let json_repr = sandbox_policy_to_json_value(policy)?; + serde_json::to_string_pretty(&json_repr) + .into_diagnostic() + .wrap_err("failed to serialize policy to JSON") +} + /// Load a sandbox policy from an explicit source. /// /// Resolution order: @@ -868,6 +900,30 @@ mod tests { ); } + /// Verify that JSON serialization uses the same canonical schema keys as YAML. + #[test] + fn serialized_json_uses_policy_schema_keys() { + let proto = parse_sandbox_policy( + r" +version: 1 +network_policies: + github: + endpoints: + - host: api.github.com + port: 443 + protocol: https + binaries: + - path: /usr/bin/curl +", + ) + .expect("parse failed"); + let json = sandbox_policy_to_json_value(&proto).expect("serialize failed"); + + assert_eq!(json["version"], serde_json::json!(1)); + assert!(json.get("filesystem").is_none()); + assert!(json.get("network_policies").is_some()); + } + /// Verify that `allowed_ips` survives the round-trip. #[test] fn round_trip_preserves_allowed_ips() { @@ -1656,6 +1712,80 @@ network_policies: assert_eq!(ep.deny_rules[0].fields, vec!["deleteRepository"]); } + #[test] + fn round_trip_preserves_websocket_credential_rewrite() { + let yaml = r" +version: 1 +network_policies: + discord_gateway: + name: discord_gateway + endpoints: + - host: gateway.example.com + port: 443 + protocol: rest + enforcement: enforce + access: full + websocket_credential_rewrite: true + binaries: + - path: /usr/bin/node +"; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let ep = &proto2.network_policies["discord_gateway"].endpoints[0]; + assert_eq!(ep.protocol, "rest"); + assert!(ep.websocket_credential_rewrite); + assert!(yaml_out.contains("websocket_credential_rewrite: true")); + } + + #[test] + fn round_trip_preserves_request_body_credential_rewrite() { + let yaml = r" +version: 1 +network_policies: + slack_api: + name: slack_api + endpoints: + - host: slack.com + port: 443 + protocol: rest + enforcement: enforce + access: read-write + request_body_credential_rewrite: true + binaries: + - path: /usr/bin/node +"; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let ep = &proto2.network_policies["slack_api"].endpoints[0]; + assert_eq!(ep.protocol, "rest"); + assert!(ep.request_body_credential_rewrite); + assert!(yaml_out.contains("request_body_credential_rewrite: true")); + } + + #[test] + fn websocket_credential_rewrite_defaults_false() { + let yaml = r" +version: 1 +network_policies: + gateway: + endpoints: + - host: gateway.example.com + port: 443 + protocol: rest + access: full + binaries: + - path: /usr/bin/node +"; + let proto = parse_sandbox_policy(yaml).expect("parse failed"); + let ep = &proto.network_policies["gateway"].endpoints[0]; + assert!(!ep.websocket_credential_rewrite); + assert!(!ep.request_body_credential_rewrite); + } + #[test] fn parse_rejects_unknown_fields_in_deny_rule() { let yaml = r" diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 7a5dec916..c01445b11 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -184,7 +184,7 @@ impl std::fmt::Display for PolicyMergeError { protocol, } => write!( f, - "endpoint {host}:{port} uses unsupported protocol '{protocol}'; this operation currently supports only protocol 'rest'" + "endpoint {host}:{port} uses unsupported protocol '{protocol}'; this operation currently supports only protocol 'rest' or 'websocket'" ), Self::EndpointHasNoAllowBase { host, port } => write!( f, @@ -207,6 +207,78 @@ pub struct PolicyMergeResult { pub changed: bool, } +/// Returns true iff `policy` semantically contains the rule an `AddRule` +/// merge of `proposed` would produce. +/// +/// "Contains" means: for every endpoint in `proposed`, some rule in +/// `policy.network_policies` has an endpoint with overlapping +/// host/path/port set AND containing every L7 allow (method/path) the +/// proposed endpoint requested, and that rule's binaries cover every +/// binary in `proposed`. +/// +/// The sandbox's `policy.local /wait` long-poll uses this to decide when +/// the local supervisor has actually loaded a policy that includes the +/// chunk the agent just had approved. A whole-policy hash compare is wrong +/// in both directions: it can wake the wait on unrelated reloads (false +/// wakeup) and can fail to wake when the supervisor reloaded between two +/// `/wait` calls (false sleep). This check is the property the agent +/// actually cares about — "is my rule in effect right now?". +/// +/// L4-vs-L7 split: endpoint overlap reuses `endpoints_overlap` so the +/// L4 surface (host/path/port) lines up with the `add_rule` merge — if +/// the gateway folded the chunk into an existing rule under a different +/// key, this check still returns true. The L7 layer is checked +/// separately because `endpoints_overlap` is intentionally L4-only: +/// without the L7 check, coverage would return true the instant the +/// supervisor reloaded *any* change to an overlapping endpoint, even +/// before the new method/path actually landed — exactly the false-wakeup +/// mode this fix exists to prevent, just one layer down. +pub fn policy_covers_rule(policy: &SandboxPolicy, proposed: &NetworkPolicyRule) -> bool { + if proposed.endpoints.is_empty() { + return false; + } + proposed.endpoints.iter().all(|target_endpoint| { + policy.network_policies.values().any(|rule| { + rule.endpoints.iter().any(|endpoint| { + endpoints_overlap(endpoint, target_endpoint) + && endpoint_l7_covers(endpoint, target_endpoint) + }) && proposed.binaries.iter().all(|target_binary| { + rule.binaries + .iter() + .any(|binary| binary.path == target_binary.path) + }) + }) + }) +} + +/// L7 coverage for a single endpoint match. If the proposed endpoint +/// declared explicit L7 allow rules (method+path), every one of them must +/// be present in the merged endpoint's `rules`. An empty `proposed.rules` +/// is treated as "L4-only" and returns true (the endpoint match alone is +/// sufficient). +/// +/// Conservative on access presets: if a merged endpoint uses +/// `access: read-write` instead of explicit rules, this returns false +/// even though the preset would permit the method at runtime. That +/// produces a one-cycle re-issue on the agent's side — preferable to a +/// false-positive coverage signal that lets the agent retry too early. +fn endpoint_l7_covers(merged: &NetworkEndpoint, proposed: &NetworkEndpoint) -> bool { + if proposed.rules.is_empty() { + return true; + } + proposed.rules.iter().all(|proposed_rule| { + let Some(proposed_allow) = proposed_rule.allow.as_ref() else { + return true; + }; + merged.rules.iter().any(|existing| { + existing.allow.as_ref().is_some_and(|existing_allow| { + existing_allow.method == proposed_allow.method + && existing_allow.path == proposed_allow.path + }) + }) + }) +} + pub fn merge_policy( policy: SandboxPolicy, operations: &[PolicyMergeOp], @@ -265,7 +337,7 @@ fn apply_operation( port: *port, } })?; - ensure_rest_endpoint(endpoint, host, *port)?; + ensure_method_path_endpoint(endpoint, host, *port)?; if endpoint.access.is_empty() && endpoint.rules.is_empty() { return Err(PolicyMergeError::EndpointHasNoAllowBase { host: host.clone(), @@ -281,7 +353,7 @@ fn apply_operation( port: *port, } })?; - ensure_rest_endpoint(endpoint, host, *port)?; + ensure_method_path_endpoint(endpoint, host, *port)?; expand_existing_access(endpoint, host, *port, warnings)?; append_unique_l7_rules(&mut endpoint.rules, rules); } @@ -462,6 +534,9 @@ fn merge_endpoint( append_unique_deny_rules(&mut existing.deny_rules, &incoming.deny_rules); append_unique_strings(&mut existing.allowed_ips, &incoming.allowed_ips); + existing.allow_encoded_slash |= incoming.allow_encoded_slash; + existing.websocket_credential_rewrite |= incoming.websocket_credential_rewrite; + existing.request_body_credential_rewrite |= incoming.request_body_credential_rewrite; normalize_endpoint(existing); Ok(()) } @@ -568,7 +643,7 @@ fn endpoint_matches_host_port(endpoint: &NetworkEndpoint, host: &str, port: u32) endpoint.host.eq_ignore_ascii_case(host) && canonical_ports(endpoint).contains(&port) } -fn ensure_rest_endpoint( +fn ensure_method_path_endpoint( endpoint: &NetworkEndpoint, host: &str, port: u32, @@ -579,7 +654,7 @@ fn ensure_rest_endpoint( port, }); } - if endpoint.protocol != "rest" { + if !matches!(endpoint.protocol.as_str(), "rest" | "websocket") { return Err(PolicyMergeError::UnsupportedEndpointProtocol { host: host.to_string(), port, @@ -600,12 +675,13 @@ fn expand_existing_access( } let access = endpoint.access.clone(); - let expanded = - expand_access_preset(&access).ok_or_else(|| PolicyMergeError::UnsupportedAccessPreset { + let expanded = expand_access_preset(&endpoint.protocol, &access).ok_or_else(|| { + PolicyMergeError::UnsupportedAccessPreset { host: host.to_string(), port, access: access.clone(), - })?; + } + })?; endpoint.access.clear(); append_unique_l7_rules(&mut endpoint.rules, &expanded); warnings.push(PolicyMergeWarning::ExpandedAccessPreset { @@ -616,11 +692,13 @@ fn expand_existing_access( Ok(()) } -fn expand_access_preset(access: &str) -> Option> { - let methods = match access { - "read-only" => vec!["GET", "HEAD", "OPTIONS"], - "read-write" => vec!["GET", "HEAD", "OPTIONS", "POST", "PUT", "PATCH"], - "full" => vec!["*"], +fn expand_access_preset(protocol: &str, access: &str) -> Option> { + let methods = match (protocol, access) { + (_, "full") => vec!["*"], + ("websocket", "read-only") => vec!["GET"], + ("websocket", "read-write") => vec!["GET", "WEBSOCKET_TEXT"], + (_, "read-only") => vec!["GET", "HEAD", "OPTIONS"], + (_, "read-write") => vec!["GET", "HEAD", "OPTIONS", "POST", "PUT", "PATCH"], _ => return None, }; @@ -776,6 +854,7 @@ mod tests { use super::{ PolicyMergeError, PolicyMergeOp, PolicyMergeWarning, generated_rule_name, merge_policy, + policy_covers_rule, }; use crate::restrictive_default_policy; use openshell_core::proto::{ @@ -870,6 +949,96 @@ mod tests { assert_eq!(rule.binaries.len(), 2); } + #[test] + fn add_rule_merges_websocket_credential_rewrite_flag() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let incoming = NetworkPolicyRule { + name: "incoming".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + ..Default::default() + }; + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_realtime_example_com_443".to_string(), + rule: incoming, + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["existing"].endpoints[0]; + assert!(endpoint.websocket_credential_rewrite); + } + + #[test] + fn add_rule_merges_request_body_credential_rewrite_flag() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let incoming = NetworkPolicyRule { + name: "incoming".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + ..Default::default() + }; + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_slack_com_443".to_string(), + rule: incoming, + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["existing"].endpoints[0]; + assert!(endpoint.request_body_credential_rewrite); + } + #[test] fn add_allow_expands_access_preset() { let mut policy = restrictive_default_policy(); @@ -909,7 +1078,92 @@ mod tests { } #[test] - fn add_deny_requires_rest_protocol() { + fn add_allow_expands_websocket_access_preset() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "realtime".to_string(), + NetworkPolicyRule { + name: "realtime".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddAllowRules { + host: "realtime.example.com".to_string(), + port: 443, + rules: vec![rest_rule("WEBSOCKET_TEXT", "/rooms/private/**")], + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["realtime"].endpoints[0]; + assert!(endpoint.access.is_empty()); + assert_eq!(endpoint.rules.len(), 3); + assert!(endpoint.rules.contains(&rest_rule("GET", "**"))); + assert!(endpoint.rules.contains(&rest_rule("WEBSOCKET_TEXT", "**"))); + assert!( + endpoint + .rules + .contains(&rest_rule("WEBSOCKET_TEXT", "/rooms/private/**")) + ); + assert!(!endpoint.rules.contains(&rest_rule("POST", "**"))); + assert!(result.warnings.iter().any(|warning| matches!( + warning, + PolicyMergeWarning::ExpandedAccessPreset { access, .. } if access == "read-write" + ))); + } + + #[test] + fn add_deny_accepts_websocket_protocol() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "realtime".to_string(), + NetworkPolicyRule { + name: "realtime".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddDenyRules { + host: "realtime.example.com".to_string(), + port: 443, + deny_rules: vec![L7DenyRule { + method: "WEBSOCKET_TEXT".to_string(), + path: "/admin/**".to_string(), + ..Default::default() + }], + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["realtime"].endpoints[0]; + assert_eq!(endpoint.deny_rules.len(), 1); + assert_eq!(endpoint.deny_rules[0].method, "WEBSOCKET_TEXT"); + assert_eq!(endpoint.deny_rules[0].path, "/admin/**"); + } + + #[test] + fn add_deny_rejects_unsupported_protocol() { let mut policy = restrictive_default_policy(); policy.network_policies.insert( "db".to_string(), @@ -1006,6 +1260,298 @@ mod tests { assert!(!result.policy.network_policies.contains_key("github")); } + #[test] + fn policy_covers_rule_returns_true_when_merged_rule_present() { + let proposed = NetworkPolicyRule { + name: "agent_proposed".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + let merged = merge_policy( + restrictive_default_policy(), + &[PolicyMergeOp::AddRule { + rule_name: "allow_api_github_com_443".to_string(), + rule: proposed.clone(), + }], + ) + .expect("merge should succeed"); + + assert!(policy_covers_rule(&merged.policy, &proposed)); + } + + #[test] + fn policy_covers_rule_returns_false_when_unrelated_rule_present() { + let proposed = NetworkPolicyRule { + name: "agent_proposed".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + // Merge an *unrelated* rule for a different host. The proposed rule + // for api.github.com is still not present — this is John's + // "false-wakeup" case: an unrelated policy reload must not signal + // that the agent's rule is loaded. + let merged = merge_policy( + restrictive_default_policy(), + &[PolicyMergeOp::AddRule { + rule_name: "allow_api_example_com_443".to_string(), + rule: rule_with_endpoint("unrelated", "api.example.com", 443), + }], + ) + .expect("merge should succeed"); + + assert!(!policy_covers_rule(&merged.policy, &proposed)); + } + + #[test] + fn policy_covers_rule_handles_merge_into_existing_endpoint() { + // The merge logic folds a new rule into an existing rule when their + // endpoints overlap, even under a different network_policies key. + // Coverage must survive that fold — name-keyed checks would miss it. + let proposed = NetworkPolicyRule { + name: "agent_proposed".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "preexisting_github".to_string(), + NetworkPolicyRule { + name: "preexisting_github".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![NetworkBinary { + path: "/usr/bin/git".to_string(), + ..Default::default() + }], + }, + ); + + let merged = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_api_github_com_443".to_string(), + rule: proposed.clone(), + }], + ) + .expect("merge should succeed"); + + assert!( + !merged + .policy + .network_policies + .contains_key("allow_api_github_com_443"), + "proposed rule should have been folded into the existing key" + ); + assert!(policy_covers_rule(&merged.policy, &proposed)); + } + + #[test] + fn policy_covers_rule_returns_false_when_binary_missing() { + let proposed = NetworkPolicyRule { + name: "agent_proposed".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + // Endpoint exists in the policy but with a *different* binary. The + // agent's retry would still be denied; reload coverage should + // reflect that. + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![NetworkBinary { + path: "/usr/bin/git".to_string(), + ..Default::default() + }], + }, + ); + + assert!(!policy_covers_rule(&policy, &proposed)); + } + + #[test] + fn policy_covers_rule_returns_false_for_empty_proposed_endpoints() { + // Defensive: a rule with no endpoints carries no signal we can match + // on, so coverage is never true. + let proposed = NetworkPolicyRule::default(); + let policy = restrictive_default_policy(); + assert!(!policy_covers_rule(&policy, &proposed)); + } + + #[test] + fn policy_covers_rule_returns_false_when_proposed_l7_method_not_loaded() { + // John's false-wakeup mode at L7: the supervisor has an + // overlapping endpoint loaded (e.g. read-only GET), but the + // chunk's proposed PUT method is not in the merged endpoint's + // rules yet. Coverage must NOT return true here, or the agent + // retries the PUT and hits another policy_denied. + let proposed = NetworkPolicyRule { + name: "agent_put".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + rules: vec![rest_rule("PUT", "/repos/foo/bar/contents/x.md")], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing_readonly".to_string(), + NetworkPolicyRule { + name: "existing_readonly".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + rules: vec![rest_rule("GET", "/repos/foo/bar/contents/x.md")], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + + assert!( + !policy_covers_rule(&policy, &proposed), + "endpoint overlaps but L7 PUT not loaded yet; must not signal coverage" + ); + } + + #[test] + fn policy_covers_rule_returns_true_after_l7_merge_lands() { + // Same setup as above, but with the proposed L7 rule merged in. + // Coverage must now return true. + let proposed = NetworkPolicyRule { + name: "agent_put".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + rules: vec![rest_rule("PUT", "/repos/foo/bar/contents/x.md")], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + rules: vec![ + rest_rule("GET", "/repos/foo/bar/contents/x.md"), + rest_rule("PUT", "/repos/foo/bar/contents/x.md"), + ], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + + assert!(policy_covers_rule(&policy, &proposed)); + } + + #[test] + fn policy_covers_rule_returns_true_for_l4_only_proposed_when_endpoint_present() { + // A chunk that targets a non-REST surface (no L7 rules) needs + // only the L4 endpoint match to be considered covered. Empty + // proposed.rules must not be treated as "no method matches". + let proposed = NetworkPolicyRule { + name: "ssh_clone".to_string(), + endpoints: vec![NetworkEndpoint { + host: "github.com".to_string(), + port: 22, + ports: vec![22], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/git".to_string(), + ..Default::default() + }], + }; + + let merged = merge_policy( + restrictive_default_policy(), + &[PolicyMergeOp::AddRule { + rule_name: "allow_github_com_22".to_string(), + rule: proposed.clone(), + }], + ) + .expect("merge should succeed"); + + assert!(policy_covers_rule(&merged.policy, &proposed)); + } + + #[test] + fn policy_covers_rule_treats_empty_proposed_binaries_as_any_binary() { + // A proposed rule with no binaries is the "any binary" shape. + // The merged rule keeps its own binaries; coverage holds iff + // endpoint and (vacuously satisfied) binary set match. Document + // the semantics so a future reader doesn't flip it accidentally. + let proposed = NetworkPolicyRule { + name: "any_binary_rule".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![], + }; + + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + + assert!( + policy_covers_rule(&policy, &proposed), + "empty proposed binaries should match any merged binary set" + ); + } + #[test] fn add_rule_without_existing_match_inserts_requested_key() { let policy = restrictive_default_policy(); diff --git a/crates/openshell-providers/src/discovery.rs b/crates/openshell-providers/src/discovery.rs index 8c10bbf7e..79d6fb091 100644 --- a/crates/openshell-providers/src/discovery.rs +++ b/crates/openshell-providers/src/discovery.rs @@ -1,7 +1,10 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{DiscoveredProvider, DiscoveryContext, ProviderDiscoverySpec, ProviderError}; +use crate::{ + DiscoveredProvider, DiscoveryContext, ProviderDiscoverySpec, ProviderError, ProviderTypeProfile, +}; +use std::collections::HashSet; pub fn discover_with_spec( spec: &ProviderDiscoverySpec, @@ -26,3 +29,132 @@ pub fn discover_with_spec( Ok(Some(discovered)) } } + +pub fn discover_from_profile( + profile: &ProviderTypeProfile, + context: &dyn DiscoveryContext, +) -> Result, ProviderError> { + let mut discovered = DiscoveredProvider::default(); + let mut scanned_env_vars = HashSet::new(); + + for credential_name in &profile.discovery.credentials { + let credential_name = credential_name.trim(); + let Some(credential) = profile + .credentials + .iter() + .find(|credential| credential.name.trim() == credential_name) + else { + return Err(ProviderError::UnknownDiscoveryCredential { + profile_id: profile.id.clone(), + credential_name: credential_name.to_string(), + }); + }; + + for env_var in &credential.env_vars { + let env_var = env_var.trim(); + if env_var.is_empty() || !scanned_env_vars.insert(env_var.to_string()) { + continue; + } + if let Some(value) = context.env_var(env_var) + && !value.trim().is_empty() + { + discovered + .credentials + .entry(env_var.to_string()) + .or_insert(value); + } + } + } + + if discovered.is_empty() { + Ok(None) + } else { + Ok(Some(discovered)) + } +} + +#[cfg(test)] +mod tests { + use super::discover_from_profile; + use crate::profiles::{CredentialProfile, DiscoveryProfile}; + use crate::test_helpers::MockDiscoveryContext; + use crate::{ProviderError, ProviderTypeProfile}; + + fn profile() -> ProviderTypeProfile { + ProviderTypeProfile { + id: "custom".to_string(), + display_name: "Custom".to_string(), + description: String::new(), + category: openshell_core::proto::ProviderProfileCategory::Other, + credentials: vec![ + CredentialProfile { + name: "api_key".to_string(), + env_vars: vec!["CUSTOM_API_KEY".to_string(), "CUSTOM_API_TOKEN".to_string()], + required: true, + description: String::new(), + auth_style: String::new(), + header_name: String::new(), + query_param: String::new(), + refresh: None, + }, + CredentialProfile { + name: "secondary".to_string(), + env_vars: vec!["CUSTOM_API_KEY".to_string()], + required: false, + description: String::new(), + auth_style: String::new(), + header_name: String::new(), + query_param: String::new(), + refresh: None, + }, + ], + endpoints: Vec::new(), + binaries: Vec::new(), + inference_capable: false, + discovery: DiscoveryProfile { + credentials: vec!["api_key".to_string(), "secondary".to_string()], + }, + } + } + + #[test] + fn profile_discovery_scans_referenced_credential_env_vars() { + let ctx = MockDiscoveryContext::new().with_env("CUSTOM_API_TOKEN", "secret-token"); + + let discovered = discover_from_profile(&profile(), &ctx) + .expect("discovery should succeed") + .expect("provider should be discovered"); + + assert_eq!( + discovered.credentials.get("CUSTOM_API_TOKEN"), + Some(&"secret-token".to_string()) + ); + assert!(!discovered.credentials.contains_key("CUSTOM_API_KEY")); + } + + #[test] + fn profile_discovery_ignores_empty_values_and_returns_none_when_empty() { + let ctx = MockDiscoveryContext::new().with_env("CUSTOM_API_KEY", " "); + + let discovered = discover_from_profile(&profile(), &ctx).expect("discovery should succeed"); + + assert!(discovered.is_none()); + } + + #[test] + fn profile_discovery_rejects_unknown_credential_references() { + let mut profile = profile(); + profile.discovery.credentials = vec!["missing".to_string()]; + + let err = discover_from_profile(&profile, &MockDiscoveryContext::new()) + .expect_err("unknown discovery credential should fail"); + + assert!(matches!( + err, + ProviderError::UnknownDiscoveryCredential { + profile_id, + credential_name + } if profile_id == "custom" && credential_name == "missing" + )); + } +} diff --git a/crates/openshell-providers/src/lib.rs b/crates/openshell-providers/src/lib.rs index 3b28030ca..21a1750ab 100644 --- a/crates/openshell-providers/src/lib.rs +++ b/crates/openshell-providers/src/lib.rs @@ -16,17 +16,25 @@ use std::path::Path; pub use openshell_core::proto::Provider; pub use context::{DiscoveryContext, RealDiscoveryContext}; -pub use discovery::discover_with_spec; +pub use discovery::{discover_from_profile, discover_with_spec}; pub use profiles::{ - ProfileError, ProfileValidationDiagnostic, ProviderTypeProfile, default_profiles, - get_default_profile, normalize_profile_id, parse_profile_json, parse_profile_yaml, - profile_to_json, profile_to_yaml, profiles_to_json, profiles_to_yaml, validate_profile_set, + CredentialRefreshProfile, ProfileError, ProfileValidationDiagnostic, ProviderTypeProfile, + default_profiles, get_default_profile, normalize_profile_id, parse_profile_json, + parse_profile_yaml, profile_to_json, profile_to_yaml, profiles_to_json, profiles_to_yaml, + validate_profile_set, }; #[derive(Debug, thiserror::Error)] pub enum ProviderError { #[error("unsupported provider type: {0}")] UnsupportedProvider(String), + #[error( + "provider profile '{profile_id}' discovery references unknown credential '{credential_name}'" + )] + UnknownDiscoveryCredential { + profile_id: String, + credential_name: String, + }, } #[derive(Debug, Clone, Default, PartialEq, Eq)] @@ -72,6 +80,25 @@ pub trait ProviderPlugin: Send + Sync { } } +/// Blanket implementation of [`ProviderPlugin`] for [`ProviderDiscoverySpec`]. +/// +/// Providers that only need standard env-var discovery can register their +/// `SPEC` constant directly, instead of defining a dedicated struct and +/// repeating the same three-method delegation. +impl ProviderPlugin for ProviderDiscoverySpec { + fn id(&self) -> &'static str { + self.id + } + + fn discover_existing(&self) -> Result, ProviderError> { + discover_with_spec(self, &RealDiscoveryContext) + } + + fn credential_env_vars(&self) -> &'static [&'static str] { + self.credential_env_vars + } +} + #[derive(Default)] pub struct ProviderRegistry { plugins: HashMap<&'static str, Box>, @@ -81,16 +108,16 @@ impl ProviderRegistry { #[must_use] pub fn new() -> Self { let mut registry = Self::default(); - registry.register(providers::claude::ClaudeProvider); - registry.register(providers::codex::CodexProvider); - registry.register(providers::copilot::CopilotProvider); + registry.register(providers::claude::SPEC); + registry.register(providers::codex::SPEC); + registry.register(providers::copilot::SPEC); registry.register(providers::opencode::OpencodeProvider); registry.register(providers::generic::GenericProvider); - registry.register(providers::openai::OpenaiProvider); - registry.register(providers::anthropic::AnthropicProvider); - registry.register(providers::nvidia::NvidiaProvider); - registry.register(providers::gitlab::GitlabProvider); - registry.register(providers::github::GithubProvider); + registry.register(providers::openai::SPEC); + registry.register(providers::anthropic::SPEC); + registry.register(providers::nvidia::SPEC); + registry.register(providers::gitlab::SPEC); + registry.register(providers::github::SPEC); registry.register(providers::outlook::OutlookProvider); registry } @@ -143,7 +170,7 @@ impl ProviderRegistry { pub fn normalize_provider_type(input: &str) -> Option<&'static str> { let normalized = input.trim().to_ascii_lowercase(); match normalized.as_str() { - "claude" => Some("claude"), + "claude" | "claude-code" | "claude_code" => Some("claude-code"), "codex" => Some("codex"), "copilot" => Some("copilot"), "opencode" => Some("opencode"), @@ -177,7 +204,8 @@ mod tests { assert_eq!(normalize_provider_type("gitlab"), Some("gitlab")); assert_eq!(normalize_provider_type("glab"), Some("gitlab")); assert_eq!(normalize_provider_type("gh"), Some("github")); - assert_eq!(normalize_provider_type("CLAUDE"), Some("claude")); + assert_eq!(normalize_provider_type("CLAUDE"), Some("claude-code")); + assert_eq!(normalize_provider_type("claude-code"), Some("claude-code")); assert_eq!(normalize_provider_type("generic"), Some("generic")); assert_eq!(normalize_provider_type("openai"), Some("openai")); assert_eq!(normalize_provider_type("anthropic"), Some("anthropic")); @@ -190,7 +218,7 @@ mod tests { fn detects_provider_from_command_token() { assert_eq!( detect_provider_from_command(&["claude".to_string()]), - Some("claude") + Some("claude-code") ); assert_eq!( detect_provider_from_command(&["/usr/bin/glab".to_string()]), diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 8c3f247cf..25c750e63 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -7,7 +7,9 @@ use openshell_core::proto::{ GraphqlOperation, L7Allow, L7DenyRule, L7QueryMatcher, L7Rule, NetworkBinary, NetworkEndpoint, - NetworkPolicyRule, ProviderProfile, ProviderProfileCategory, ProviderProfileCredential, + NetworkPolicyRule, ProviderCredentialRefresh, ProviderCredentialRefreshMaterial, + ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileCategory, + ProviderProfileCredential, ProviderProfileDiscovery, }; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; @@ -15,16 +17,9 @@ use std::collections::{HashMap, HashSet}; use std::sync::OnceLock; const BUILT_IN_PROFILE_YAMLS: &[&str] = &[ - include_str!("../../../providers/anthropic.yaml"), - include_str!("../../../providers/claude.yaml"), - include_str!("../../../providers/codex.yaml"), - include_str!("../../../providers/copilot.yaml"), + include_str!("../../../providers/claude-code.yaml"), include_str!("../../../providers/github.yaml"), - include_str!("../../../providers/gitlab.yaml"), include_str!("../../../providers/nvidia.yaml"), - include_str!("../../../providers/openai.yaml"), - include_str!("../../../providers/opencode.yaml"), - include_str!("../../../providers/outlook.yaml"), ]; #[derive(Debug, thiserror::Error)] @@ -84,6 +79,45 @@ pub struct CredentialProfile { pub header_name: String, #[serde(default)] pub query_param: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub refresh: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct CredentialRefreshProfile { + #[serde( + default = "default_refresh_strategy", + deserialize_with = "deserialize_refresh_strategy", + serialize_with = "serialize_refresh_strategy" + )] + pub strategy: ProviderCredentialRefreshStrategy, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub token_url: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub scopes: Vec, + #[serde(default, skip_serializing_if = "is_zero_i64")] + pub refresh_before_seconds: i64, + #[serde(default, skip_serializing_if = "is_zero_i64")] + pub max_lifetime_seconds: i64, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub material: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct CredentialRefreshMaterialProfile { + pub name: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub description: String, + #[serde(default)] + pub required: bool, + #[serde(default)] + pub secret: bool, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] +pub struct DiscoveryProfile { + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub credentials: Vec, } // These YAML/JSON DTOs mirror the network policy protos intentionally. Keep @@ -114,6 +148,10 @@ pub struct EndpointProfile { pub deny_rules: Vec, #[serde(default, skip_serializing_if = "is_false")] pub allow_encoded_slash: bool, + #[serde(default, skip_serializing_if = "is_false")] + pub websocket_credential_rewrite: bool, + #[serde(default, skip_serializing_if = "is_false")] + pub request_body_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] pub persisted_queries: String, #[serde(default, skip_serializing_if = "HashMap::is_empty")] @@ -210,6 +248,8 @@ pub struct ProviderTypeProfile { pub binaries: Vec, #[serde(default)] pub inference_capable: bool, + #[serde(default, skip_serializing_if = "discovery_is_empty")] + pub discovery: DiscoveryProfile, } // Provider profile import/export is expected to be lossless for the network @@ -236,11 +276,20 @@ impl ProviderTypeProfile { auth_style: credential.auth_style.clone(), header_name: credential.header_name.clone(), query_param: credential.query_param.clone(), + refresh: credential + .refresh + .as_ref() + .map(credential_refresh_from_proto), }) .collect(), endpoints: profile.endpoints.iter().map(endpoint_from_proto).collect(), binaries: profile.binaries.iter().map(binary_from_proto).collect(), inference_capable: profile.inference_capable, + discovery: profile + .discovery + .as_ref() + .map(discovery_from_proto) + .unwrap_or_default(), } } @@ -275,11 +324,14 @@ impl ProviderTypeProfile { auth_style: credential.auth_style.clone(), header_name: credential.header_name.clone(), query_param: credential.query_param.clone(), + refresh: credential.refresh.as_ref().map(credential_refresh_to_proto), }) .collect(), endpoints: self.endpoints.iter().map(endpoint_to_proto).collect(), binaries: self.binaries.iter().map(binary_to_proto).collect(), inference_capable: self.inference_capable, + discovery: (!discovery_is_empty(&self.discovery)) + .then(|| discovery_to_proto(&self.discovery)), } } @@ -293,6 +345,10 @@ impl ProviderTypeProfile { } } +fn discovery_is_empty(discovery: &DiscoveryProfile) -> bool { + discovery.credentials.is_empty() +} + impl Serialize for BinaryProfile { fn serialize(&self, serializer: S) -> Result where @@ -354,6 +410,15 @@ fn is_zero(value: &u32) -> bool { *value == 0 } +#[allow(clippy::trivially_copy_pass_by_ref)] +fn is_zero_i64(value: &i64) -> bool { + *value == 0 +} + +fn default_refresh_strategy() -> ProviderCredentialRefreshStrategy { + ProviderCredentialRefreshStrategy::Unspecified +} + fn deserialize_category<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, @@ -374,6 +439,28 @@ where serializer.serialize_str(provider_profile_category_to_yaml(*category)) } +fn deserialize_refresh_strategy<'de, D>( + deserializer: D, +) -> Result +where + D: Deserializer<'de>, +{ + let raw = String::deserialize(deserializer)?; + provider_refresh_strategy_from_yaml(&raw) + .ok_or_else(|| de::Error::custom(format!("unsupported provider refresh strategy: {raw}"))) +} + +#[allow(clippy::trivially_copy_pass_by_ref)] +fn serialize_refresh_strategy( + strategy: &ProviderCredentialRefreshStrategy, + serializer: S, +) -> Result +where + S: Serializer, +{ + serializer.serialize_str(provider_refresh_strategy_to_yaml(*strategy)) +} + #[must_use] pub fn provider_profile_category_from_yaml(raw: &str) -> Option { match raw.trim().to_ascii_lowercase().replace('-', "_").as_str() { @@ -401,6 +488,90 @@ pub fn provider_profile_category_to_yaml(category: ProviderProfileCategory) -> & } } +#[must_use] +pub fn provider_refresh_strategy_from_yaml(raw: &str) -> Option { + match raw.trim().to_ascii_lowercase().replace('-', "_").as_str() { + "" => Some(ProviderCredentialRefreshStrategy::Unspecified), + "static" => Some(ProviderCredentialRefreshStrategy::Static), + "external" => Some(ProviderCredentialRefreshStrategy::External), + "oauth2_refresh_token" => Some(ProviderCredentialRefreshStrategy::Oauth2RefreshToken), + "oauth2_client_credentials" => { + Some(ProviderCredentialRefreshStrategy::Oauth2ClientCredentials) + } + "google_service_account_jwt" => { + Some(ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt) + } + _ => None, + } +} + +#[must_use] +pub fn provider_refresh_strategy_to_yaml( + strategy: ProviderCredentialRefreshStrategy, +) -> &'static str { + match strategy { + ProviderCredentialRefreshStrategy::Static => "static", + ProviderCredentialRefreshStrategy::External => "external", + ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token", + ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials", + ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt", + ProviderCredentialRefreshStrategy::Unspecified => "unspecified", + } +} + +fn credential_refresh_from_proto(refresh: &ProviderCredentialRefresh) -> CredentialRefreshProfile { + CredentialRefreshProfile { + strategy: ProviderCredentialRefreshStrategy::try_from(refresh.strategy) + .unwrap_or(ProviderCredentialRefreshStrategy::Unspecified), + token_url: refresh.token_url.clone(), + scopes: refresh.scopes.clone(), + refresh_before_seconds: refresh.refresh_before_seconds, + max_lifetime_seconds: refresh.max_lifetime_seconds, + material: refresh + .material + .iter() + .map(|material| CredentialRefreshMaterialProfile { + name: material.name.clone(), + description: material.description.clone(), + required: material.required, + secret: material.secret, + }) + .collect(), + } +} + +fn credential_refresh_to_proto(refresh: &CredentialRefreshProfile) -> ProviderCredentialRefresh { + ProviderCredentialRefresh { + strategy: refresh.strategy as i32, + token_url: refresh.token_url.clone(), + scopes: refresh.scopes.clone(), + refresh_before_seconds: refresh.refresh_before_seconds, + max_lifetime_seconds: refresh.max_lifetime_seconds, + material: refresh + .material + .iter() + .map(|material| ProviderCredentialRefreshMaterial { + name: material.name.clone(), + description: material.description.clone(), + required: material.required, + secret: material.secret, + }) + .collect(), + } +} + +fn discovery_from_proto(discovery: &ProviderProfileDiscovery) -> DiscoveryProfile { + DiscoveryProfile { + credentials: discovery.credentials.clone(), + } +} + +fn discovery_to_proto(discovery: &DiscoveryProfile) -> ProviderProfileDiscovery { + ProviderProfileDiscovery { + credentials: discovery.credentials.clone(), + } +} + fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { NetworkEndpoint { host: endpoint.host.clone(), @@ -414,6 +585,8 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { ports: endpoint.ports.clone(), deny_rules: endpoint.deny_rules.iter().map(deny_rule_to_proto).collect(), allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: endpoint.websocket_credential_rewrite, + request_body_credential_rewrite: endpoint.request_body_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries @@ -442,6 +615,8 @@ fn endpoint_from_proto(endpoint: &NetworkEndpoint) -> EndpointProfile { .map(deny_rule_from_proto) .collect(), allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: endpoint.websocket_credential_rewrite, + request_body_credential_rewrite: endpoint.request_body_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries @@ -731,6 +906,33 @@ pub fn validate_profile_set( } } + let mut discovery_credentials = HashSet::new(); + for (index, credential_name) in profile.discovery.credentials.iter().enumerate() { + let credential_name = credential_name.trim(); + if credential_name.is_empty() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + format!("discovery.credentials[{index}]"), + "discovery credential name must not be empty", + )); + } else if !discovery_credentials.insert(credential_name.to_string()) { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + format!("discovery.credentials[{index}]"), + format!("duplicate discovery credential: {credential_name}"), + )); + } else if !credential_names.contains(credential_name) { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + format!("discovery.credentials[{index}]"), + format!("unknown discovery credential: {credential_name}"), + )); + } + } + let mut env_vars = HashSet::new(); for credential in &profile.credentials { for env_var in &credential.env_vars { @@ -781,6 +983,52 @@ pub fn validate_profile_set( format!("unsupported auth_style: {}", credential.auth_style), )), } + + if let Some(refresh) = credential.refresh.as_ref() { + if refresh.strategy == ProviderCredentialRefreshStrategy::Unspecified { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.refresh.strategy", + "refresh strategy is required", + )); + } + if refresh.refresh_before_seconds < 0 { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.refresh.refresh_before_seconds", + "refresh_before_seconds must be greater than or equal to 0", + )); + } + if refresh.max_lifetime_seconds < 0 { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.refresh.max_lifetime_seconds", + "max_lifetime_seconds must be greater than or equal to 0", + )); + } + let mut material_names = HashSet::new(); + for material in &refresh.material { + let name = material.name.trim(); + if name.is_empty() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.refresh.material.name", + "refresh material name is required", + )); + } else if !material_names.insert(name.to_string()) { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.refresh.material.name", + format!("duplicate refresh material name: {name}"), + )); + } + } + } } for (index, endpoint) in profile.endpoints.iter().enumerate() { @@ -845,7 +1093,7 @@ mod tests { use openshell_core::proto::ProviderProfileCategory; use super::{ - ProfileError, ProviderTypeProfile, default_profiles, get_default_profile, + DiscoveryProfile, ProfileError, ProviderTypeProfile, default_profiles, get_default_profile, normalize_profile_id, parse_profile_catalog_yamls, parse_profile_json, parse_profile_yaml, profile_to_json, profile_to_yaml, validate_profile_set, }; @@ -871,16 +1119,32 @@ mod tests { proto.category, ProviderProfileCategory::SourceControl as i32 ); - assert_eq!(proto.endpoints.len(), 2); + assert_eq!(proto.endpoints.len(), 3); + assert!( + proto.endpoints.iter().any(|endpoint| { + endpoint.host == "api.github.com" + && endpoint.protocol == "graphql" + && endpoint.path == "/graphql" + && endpoint.access == "read-only" + }), + "github profile should include read-only GraphQL endpoint" + ); + assert!( + proto + .endpoints + .iter() + .all(|endpoint| endpoint.access == "read-only"), + "github profile endpoints should all be read-only" + ); assert_eq!(proto.binaries.len(), 4); } #[test] fn credential_env_vars_are_deduplicated_in_profile_order() { - let profile = get_default_profile("copilot").expect("copilot profile"); + let profile = get_default_profile("claude-code").expect("claude-code profile"); assert_eq!( profile.credential_env_vars(), - vec!["COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN"] + vec!["ANTHROPIC_API_KEY", "CLAUDE_API_KEY"] ); } @@ -902,6 +1166,71 @@ credentials: assert_eq!(profile.credential_env_vars(), vec!["EXAMPLE_API_KEY"]); } + #[test] + fn profile_discovery_metadata_round_trips_through_proto_and_yaml() { + let profile = parse_profile_yaml( + r" +id: example +display_name: Example +credentials: + - name: api_key + env_vars: [EXAMPLE_API_KEY] +discovery: + credentials: [api_key] +", + ) + .expect("profile should parse"); + + assert_eq!(profile.discovery.credentials, vec!["api_key"]); + let from_proto = ProviderTypeProfile::from_proto(&profile.to_proto()); + assert_eq!(from_proto.discovery.credentials, vec!["api_key"]); + let exported = profile_to_yaml(&from_proto).expect("yaml"); + assert!(exported.contains("discovery:")); + assert!(exported.contains("api_key")); + } + + #[test] + fn profile_refresh_metadata_round_trips_through_proto_and_yaml() { + let profile = parse_profile_yaml( + r" +id: ms-graph +display_name: Microsoft Graph +credentials: + - name: access_token + env_vars: [MS_GRAPH_ACCESS_TOKEN] + refresh: + strategy: oauth2_client_credentials + token_url: https://login.microsoftonline.com/common/oauth2/v2.0/token + scopes: [https://graph.microsoft.com/.default] + refresh_before_seconds: 300 + material: + - name: tenant_id + required: true + - name: client_secret + required: true + secret: true +", + ) + .expect("profile should parse"); + + let refresh = profile.credentials[0].refresh.as_ref().expect("refresh"); + assert_eq!( + refresh.token_url, + "https://login.microsoftonline.com/common/oauth2/v2.0/token" + ); + assert_eq!(refresh.material.len(), 2); + + let from_proto = ProviderTypeProfile::from_proto(&profile.to_proto()); + assert_eq!( + from_proto.credentials[0].refresh, + profile.credentials[0].refresh + ); + + let exported = profile_to_yaml(&from_proto).expect("yaml"); + assert!(exported.contains("oauth2_client_credentials")); + assert!(exported.contains("client_secret")); + } + #[test] fn profile_json_round_trip_preserves_compact_dto_shape() { let profile = get_default_profile("github").expect("github profile"); @@ -1009,6 +1338,8 @@ credentials: - name: api_key env_vars: [BROKEN_TOKEN, ""] auth_style: unknown +discovery: + credentials: [api_key, missing_key] endpoints: - host: "" port: 0 @@ -1028,6 +1359,7 @@ binaries: ["", /usr/bin/broken] assert!(messages.contains(&"credential env var must not be empty")); assert!(messages.contains(&"query_param is required for query auth")); assert!(messages.contains(&"unsupported auth_style: unknown")); + assert!(messages.contains(&"unknown discovery credential: missing_key")); assert!( messages .iter() @@ -1050,6 +1382,7 @@ binaries: ["", /usr/bin/broken] endpoints: Vec::new(), binaries: Vec::new(), inference_capable: false, + discovery: DiscoveryProfile::default(), }, ), ( @@ -1063,6 +1396,7 @@ binaries: ["", /usr/bin/broken] endpoints: Vec::new(), binaries: Vec::new(), inference_capable: false, + discovery: DiscoveryProfile::default(), }, ), ( @@ -1076,6 +1410,7 @@ binaries: ["", /usr/bin/broken] endpoints: Vec::new(), binaries: Vec::new(), inference_capable: false, + discovery: DiscoveryProfile::default(), }, ), ]; diff --git a/crates/openshell-providers/src/providers/anthropic.rs b/crates/openshell-providers/src/providers/anthropic.rs index f4851dad4..8e392d457 100644 --- a/crates/openshell-providers/src/providers/anthropic.rs +++ b/crates/openshell-providers/src/providers/anthropic.rs @@ -1,46 +1,15 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, -}; - -pub struct AnthropicProvider; +use crate::ProviderDiscoverySpec; pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { id: "anthropic", credential_env_vars: &["ANTHROPIC_API_KEY"], }; -impl ProviderPlugin for AnthropicProvider { - fn id(&self) -> &'static str { - SPEC.id - } - - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) - } - - fn credential_env_vars(&self) -> &'static [&'static str] { - SPEC.credential_env_vars - } -} - -#[cfg(test)] -mod tests { - use super::SPEC; - use crate::discover_with_spec; - use crate::test_helpers::MockDiscoveryContext; - - #[test] - fn discovers_anthropic_env_credentials() { - let ctx = MockDiscoveryContext::new().with_env("ANTHROPIC_API_KEY", "sk-ant-test"); - let discovered = discover_with_spec(&SPEC, &ctx) - .expect("discovery") - .expect("provider"); - assert_eq!( - discovered.credentials.get("ANTHROPIC_API_KEY"), - Some(&"sk-ant-test".to_string()) - ); - } -} +test_discovers_env_credential!( + discovers_anthropic_env_credentials, + "ANTHROPIC_API_KEY", + "sk-ant-test" +); diff --git a/crates/openshell-providers/src/providers/claude.rs b/crates/openshell-providers/src/providers/claude.rs index 576b30e38..ec8457a81 100644 --- a/crates/openshell-providers/src/providers/claude.rs +++ b/crates/openshell-providers/src/providers/claude.rs @@ -1,46 +1,15 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, -}; - -pub struct ClaudeProvider; +use crate::ProviderDiscoverySpec; pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { - id: "claude", + id: "claude-code", credential_env_vars: &["ANTHROPIC_API_KEY", "CLAUDE_API_KEY"], }; -impl ProviderPlugin for ClaudeProvider { - fn id(&self) -> &'static str { - SPEC.id - } - - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) - } - - fn credential_env_vars(&self) -> &'static [&'static str] { - SPEC.credential_env_vars - } -} - -#[cfg(test)] -mod tests { - use super::SPEC; - use crate::discover_with_spec; - use crate::test_helpers::MockDiscoveryContext; - - #[test] - fn discovers_claude_env_credentials() { - let ctx = MockDiscoveryContext::new().with_env("ANTHROPIC_API_KEY", "test-key"); - let discovered = discover_with_spec(&SPEC, &ctx) - .expect("discovery") - .expect("provider"); - assert_eq!( - discovered.credentials.get("ANTHROPIC_API_KEY"), - Some(&"test-key".to_string()) - ); - } -} +test_discovers_env_credential!( + discovers_claude_env_credentials, + "ANTHROPIC_API_KEY", + "test-key" +); diff --git a/crates/openshell-providers/src/providers/codex.rs b/crates/openshell-providers/src/providers/codex.rs index d9d43264f..a75f35a89 100644 --- a/crates/openshell-providers/src/providers/codex.rs +++ b/crates/openshell-providers/src/providers/codex.rs @@ -1,46 +1,15 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, -}; - -pub struct CodexProvider; +use crate::ProviderDiscoverySpec; pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { id: "codex", credential_env_vars: &["OPENAI_API_KEY"], }; -impl ProviderPlugin for CodexProvider { - fn id(&self) -> &'static str { - SPEC.id - } - - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) - } - - fn credential_env_vars(&self) -> &'static [&'static str] { - SPEC.credential_env_vars - } -} - -#[cfg(test)] -mod tests { - use super::SPEC; - use crate::discover_with_spec; - use crate::test_helpers::MockDiscoveryContext; - - #[test] - fn discovers_codex_env_credentials() { - let ctx = MockDiscoveryContext::new().with_env("OPENAI_API_KEY", "openai-key"); - let discovered = discover_with_spec(&SPEC, &ctx) - .expect("discovery") - .expect("provider"); - assert_eq!( - discovered.credentials.get("OPENAI_API_KEY"), - Some(&"openai-key".to_string()) - ); - } -} +test_discovers_env_credential!( + discovers_codex_env_credentials, + "OPENAI_API_KEY", + "openai-key" +); diff --git a/crates/openshell-providers/src/providers/copilot.rs b/crates/openshell-providers/src/providers/copilot.rs index fff74cc3b..0facde617 100644 --- a/crates/openshell-providers/src/providers/copilot.rs +++ b/crates/openshell-providers/src/providers/copilot.rs @@ -1,46 +1,15 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, -}; - -pub struct CopilotProvider; +use crate::ProviderDiscoverySpec; pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { id: "copilot", credential_env_vars: &["COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN"], }; -impl ProviderPlugin for CopilotProvider { - fn id(&self) -> &'static str { - SPEC.id - } - - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) - } - - fn credential_env_vars(&self) -> &'static [&'static str] { - SPEC.credential_env_vars - } -} - -#[cfg(test)] -mod tests { - use super::SPEC; - use crate::discover_with_spec; - use crate::test_helpers::MockDiscoveryContext; - - #[test] - fn discovers_copilot_env_credentials() { - let ctx = MockDiscoveryContext::new().with_env("COPILOT_GITHUB_TOKEN", "ghp-copilot-token"); - let discovered = discover_with_spec(&SPEC, &ctx) - .expect("discovery") - .expect("provider"); - assert_eq!( - discovered.credentials.get("COPILOT_GITHUB_TOKEN"), - Some(&"ghp-copilot-token".to_string()) - ); - } -} +test_discovers_env_credential!( + discovers_copilot_env_credentials, + "COPILOT_GITHUB_TOKEN", + "ghp-copilot-token" +); diff --git a/crates/openshell-providers/src/providers/github.rs b/crates/openshell-providers/src/providers/github.rs index 4ca25d6d2..8e14aa363 100644 --- a/crates/openshell-providers/src/providers/github.rs +++ b/crates/openshell-providers/src/providers/github.rs @@ -1,46 +1,11 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, -}; - -pub struct GithubProvider; +use crate::ProviderDiscoverySpec; pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { id: "github", credential_env_vars: &["GITHUB_TOKEN", "GH_TOKEN"], }; -impl ProviderPlugin for GithubProvider { - fn id(&self) -> &'static str { - SPEC.id - } - - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) - } - - fn credential_env_vars(&self) -> &'static [&'static str] { - SPEC.credential_env_vars - } -} - -#[cfg(test)] -mod tests { - use super::SPEC; - use crate::discover_with_spec; - use crate::test_helpers::MockDiscoveryContext; - - #[test] - fn discovers_github_env_credentials() { - let ctx = MockDiscoveryContext::new().with_env("GH_TOKEN", "gh-token"); - let discovered = discover_with_spec(&SPEC, &ctx) - .expect("discovery") - .expect("provider"); - assert_eq!( - discovered.credentials.get("GH_TOKEN"), - Some(&"gh-token".to_string()) - ); - } -} +test_discovers_env_credential!(discovers_github_env_credentials, "GH_TOKEN", "gh-token"); diff --git a/crates/openshell-providers/src/providers/gitlab.rs b/crates/openshell-providers/src/providers/gitlab.rs index 0f944e09f..8fc4973a4 100644 --- a/crates/openshell-providers/src/providers/gitlab.rs +++ b/crates/openshell-providers/src/providers/gitlab.rs @@ -1,46 +1,11 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, -}; - -pub struct GitlabProvider; +use crate::ProviderDiscoverySpec; pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { id: "gitlab", credential_env_vars: &["GITLAB_TOKEN", "GLAB_TOKEN", "CI_JOB_TOKEN"], }; -impl ProviderPlugin for GitlabProvider { - fn id(&self) -> &'static str { - SPEC.id - } - - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) - } - - fn credential_env_vars(&self) -> &'static [&'static str] { - SPEC.credential_env_vars - } -} - -#[cfg(test)] -mod tests { - use super::SPEC; - use crate::discover_with_spec; - use crate::test_helpers::MockDiscoveryContext; - - #[test] - fn discovers_gitlab_env_credentials() { - let ctx = MockDiscoveryContext::new().with_env("GLAB_TOKEN", "glab-token"); - let discovered = discover_with_spec(&SPEC, &ctx) - .expect("discovery") - .expect("provider"); - assert_eq!( - discovered.credentials.get("GLAB_TOKEN"), - Some(&"glab-token".to_string()) - ); - } -} +test_discovers_env_credential!(discovers_gitlab_env_credentials, "GLAB_TOKEN", "glab-token"); diff --git a/crates/openshell-providers/src/providers/mod.rs b/crates/openshell-providers/src/providers/mod.rs index 6fe395135..dfe5935a1 100644 --- a/crates/openshell-providers/src/providers/mod.rs +++ b/crates/openshell-providers/src/providers/mod.rs @@ -1,6 +1,35 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +/// Generate a standard discovery smoke-test for a provider whose only test is +/// checking that an env-var credential is picked up by `discover_with_spec`. +/// +/// # Usage +/// ```ignore +/// test_discovers_env_credential!(discovers_openai_env_credentials, "OPENAI_API_KEY", "sk-test"); +/// ``` +macro_rules! test_discovers_env_credential { + ($test_name:ident, $env_var:expr, $env_value:expr) => { + #[cfg(test)] + mod tests { + use super::SPEC; + use crate::discover_with_spec; + use crate::test_helpers::MockDiscoveryContext; + + #[test] + fn $test_name() { + let ctx = MockDiscoveryContext::new().with_env($env_var, $env_value); + let discovered = discover_with_spec(&SPEC, &ctx) + .expect("discovery") + .expect("provider"); + assert_eq!( + discovered.credentials.get($env_var), + Some(&$env_value.to_string()) + ); + } + } + }; +} pub mod anthropic; pub mod claude; pub mod codex; diff --git a/crates/openshell-providers/src/providers/nvidia.rs b/crates/openshell-providers/src/providers/nvidia.rs index 62985c1dd..463d5b2cf 100644 --- a/crates/openshell-providers/src/providers/nvidia.rs +++ b/crates/openshell-providers/src/providers/nvidia.rs @@ -1,46 +1,15 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, -}; - -pub struct NvidiaProvider; +use crate::ProviderDiscoverySpec; pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { id: "nvidia", credential_env_vars: &["NVIDIA_API_KEY"], }; -impl ProviderPlugin for NvidiaProvider { - fn id(&self) -> &'static str { - SPEC.id - } - - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) - } - - fn credential_env_vars(&self) -> &'static [&'static str] { - SPEC.credential_env_vars - } -} - -#[cfg(test)] -mod tests { - use super::SPEC; - use crate::discover_with_spec; - use crate::test_helpers::MockDiscoveryContext; - - #[test] - fn discovers_nvidia_env_credentials() { - let ctx = MockDiscoveryContext::new().with_env("NVIDIA_API_KEY", "nvapi-123"); - let discovered = discover_with_spec(&SPEC, &ctx) - .expect("discovery") - .expect("provider"); - assert_eq!( - discovered.credentials.get("NVIDIA_API_KEY"), - Some(&"nvapi-123".to_string()) - ); - } -} +test_discovers_env_credential!( + discovers_nvidia_env_credentials, + "NVIDIA_API_KEY", + "nvapi-123" +); diff --git a/crates/openshell-providers/src/providers/openai.rs b/crates/openshell-providers/src/providers/openai.rs index 0dbe39414..92d8817eb 100644 --- a/crates/openshell-providers/src/providers/openai.rs +++ b/crates/openshell-providers/src/providers/openai.rs @@ -1,46 +1,15 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, -}; - -pub struct OpenaiProvider; +use crate::ProviderDiscoverySpec; pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { id: "openai", credential_env_vars: &["OPENAI_API_KEY"], }; -impl ProviderPlugin for OpenaiProvider { - fn id(&self) -> &'static str { - SPEC.id - } - - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) - } - - fn credential_env_vars(&self) -> &'static [&'static str] { - SPEC.credential_env_vars - } -} - -#[cfg(test)] -mod tests { - use super::SPEC; - use crate::discover_with_spec; - use crate::test_helpers::MockDiscoveryContext; - - #[test] - fn discovers_openai_env_credentials() { - let ctx = MockDiscoveryContext::new().with_env("OPENAI_API_KEY", "sk-openai-test"); - let discovered = discover_with_spec(&SPEC, &ctx) - .expect("discovery") - .expect("provider"); - assert_eq!( - discovered.credentials.get("OPENAI_API_KEY"), - Some(&"sk-openai-test".to_string()) - ); - } -} +test_discovers_env_credential!( + discovers_openai_env_credentials, + "OPENAI_API_KEY", + "sk-openai-test" +); diff --git a/crates/openshell-providers/src/providers/opencode.rs b/crates/openshell-providers/src/providers/opencode.rs index 417bdb6c2..feb707fd0 100644 --- a/crates/openshell-providers/src/providers/opencode.rs +++ b/crates/openshell-providers/src/providers/opencode.rs @@ -1,8 +1,12 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, + DiscoveredProvider, ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, + discover_with_spec, }; pub struct OpencodeProvider; @@ -12,13 +16,80 @@ pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { credential_env_vars: &["OPENCODE_API_KEY", "OPENROUTER_API_KEY", "OPENAI_API_KEY"], }; +/// Return the path to the opencode config file, respecting `XDG_CONFIG_HOME`. +fn opencode_config_path() -> Option { + let config_home = std::env::var("XDG_CONFIG_HOME") + .ok() + .map(PathBuf::from) + .or_else(|| { + std::env::var("HOME") + .ok() + .map(|h| PathBuf::from(h).join(".config")) + })?; + Some(config_home.join("opencode").join("opencode.json")) +} + +/// Extract API key credentials from the contents of an opencode config file. +/// +/// opencode stores per-provider API keys at `provider..options.apiKey`. +/// Each key is surfaced as `_API_KEY` so that it can be injected +/// as an environment variable into the sandbox and picked up by opencode at runtime. +fn extract_credentials_from_opencode_config(content: &str) -> HashMap { + let Ok(json) = serde_json::from_str::(content) else { + return HashMap::new(); + }; + let Some(providers) = json.get("provider").and_then(|p| p.as_object()) else { + return HashMap::new(); + }; + + let mut creds = HashMap::new(); + for (provider_name, provider_cfg) in providers { + if let Some(api_key) = provider_cfg + .get("options") + .and_then(|o| o.get("apiKey")) + .and_then(|k| k.as_str()) + .filter(|k| !k.trim().is_empty()) + { + let env_var = format!("{}_API_KEY", provider_name.to_ascii_uppercase()); + creds.insert(env_var, api_key.to_string()); + } + } + creds +} + +/// Read opencode credentials from `path`, returning `None` if the file is absent or unreadable. +fn read_opencode_config_file(path: &Path) -> Option> { + let content = std::fs::read_to_string(path).ok()?; + let creds = extract_credentials_from_opencode_config(&content); + if creds.is_empty() { None } else { Some(creds) } +} + impl ProviderPlugin for OpencodeProvider { fn id(&self) -> &'static str { SPEC.id } - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) + fn discover_existing(&self) -> Result, ProviderError> { + let mut discovered = discover_with_spec(&SPEC, &RealDiscoveryContext)?.unwrap_or_default(); + + // Supplement env-var discovery with credentials stored in the opencode config file. + // opencode's native config lives at $XDG_CONFIG_HOME/opencode/opencode.json and stores + // API keys under `provider..options.apiKey`. If the user configured opencode + // normally (i.e. no env vars set), this is the only place the keys exist. + if let Some(path) = opencode_config_path() + && let Some(file_creds) = read_opencode_config_file(&path) + { + for (key, value) in file_creds { + // Env vars already set take priority; config file fills the gaps. + discovered.credentials.entry(key).or_insert(value); + } + } + + if discovered.is_empty() { + Ok(None) + } else { + Ok(Some(discovered)) + } } fn credential_env_vars(&self) -> &'static [&'static str] { @@ -28,7 +99,7 @@ impl ProviderPlugin for OpencodeProvider { #[cfg(test)] mod tests { - use super::SPEC; + use super::{SPEC, extract_credentials_from_opencode_config}; use crate::discover_with_spec; use crate::test_helpers::MockDiscoveryContext; @@ -43,4 +114,61 @@ mod tests { Some(&"op-key".to_string()) ); } + + #[test] + fn extracts_credentials_from_config_file() { + let config = r#"{ + "provider": { + "anthropic": { "options": { "apiKey": "sk-ant-key" } }, + "openai": { "options": { "apiKey": "sk-openai-key" } } + } + }"#; + let creds = extract_credentials_from_opencode_config(config); + assert_eq!( + creds.get("ANTHROPIC_API_KEY"), + Some(&"sk-ant-key".to_string()) + ); + assert_eq!( + creds.get("OPENAI_API_KEY"), + Some(&"sk-openai-key".to_string()) + ); + } + + #[test] + fn skips_providers_without_api_key() { + let config = r#"{ + "provider": { + "ollama": { "options": { "baseUrl": "http://localhost:11434" } } + } + }"#; + let creds = extract_credentials_from_opencode_config(config); + assert!( + creds.is_empty(), + "no credentials expected for keyless provider" + ); + } + + #[test] + fn skips_empty_api_keys() { + let config = r#"{ + "provider": { + "anthropic": { "options": { "apiKey": "" } } + } + }"#; + let creds = extract_credentials_from_opencode_config(config); + assert!(creds.is_empty()); + } + + #[test] + fn tolerates_malformed_json() { + let creds = extract_credentials_from_opencode_config("not json at all"); + assert!(creds.is_empty()); + } + + #[test] + fn tolerates_missing_provider_section() { + let config = r#"{ "theme": "dark" }"#; + let creds = extract_credentials_from_opencode_config(config); + assert!(creds.is_empty()); + } } diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 4e07521ce..6d527bc53 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -58,6 +58,8 @@ uuid = { workspace = true } # Encoding base64 = { workspace = true } +flate2 = "1" +sha1 = "0.10" # IP network / CIDR parsing ipnet = "2" @@ -79,10 +81,12 @@ nix = { workspace = true } [target.'cfg(unix)'.dependencies] libc = "0.2" +rustix = { workspace = true } [target.'cfg(target_os = "linux")'.dependencies] landlock = "0.4" seccompiler = "0.5" +tempfile = "3" uuid = { version = "1", features = ["v4"] } [dev-dependencies] diff --git a/crates/openshell-sandbox/data/sandbox-policy.rego b/crates/openshell-sandbox/data/sandbox-policy.rego index 9fa820627..0fa1e6be7 100644 --- a/crates/openshell-sandbox/data/sandbox-policy.rego +++ b/crates/openshell-sandbox/data/sandbox-policy.rego @@ -260,6 +260,16 @@ request_denied_for_endpoint(request, endpoint) if { not graphql_request_allowed(request, endpoint) } +# The same authority applies when a WebSocket endpoint opts into GraphQL +# operation policy. Once the relay classifies a client text message as a +# GraphQL-over-WebSocket operation, generic WEBSOCKET_TEXT rules must not bypass +# operation_type / operation_name / fields policy. +request_denied_for_endpoint(request, endpoint) if { + endpoint.protocol == "websocket" + is_object(request.graphql) + not graphql_request_allowed(request, endpoint) +} + # Deny query matching: fail-closed semantics. # If no query rules on the deny rule, match unconditionally (any query params). # If query rules present, trigger the deny if ANY value for a configured key @@ -516,6 +526,7 @@ graphql_field_matches_any(field, patterns) if { } # Wildcard "*" matches any method; otherwise case-insensitive exact match. +# RFC 9110 §9.3.2: HEAD is semantically identical to GET except no response body. method_matches(_, "*") if true method_matches(actual, expected) if { @@ -523,6 +534,11 @@ method_matches(actual, expected) if { upper(actual) == upper(expected) } +method_matches(actual, expected) if { + upper(actual) == "HEAD" + upper(expected) == "GET" +} + # Path matching: "**" matches everything; otherwise glob.match with "/" delimiter. # # INVARIANT: `input.request.path` is canonicalized by the sandbox before diff --git a/crates/openshell-sandbox/src/bypass_monitor.rs b/crates/openshell-sandbox/src/bypass_monitor.rs index 1a7ec5f99..9e37ef27c 100644 --- a/crates/openshell-sandbox/src/bypass_monitor.rs +++ b/crates/openshell-sandbox/src/bypass_monitor.rs @@ -5,15 +5,15 @@ //! detect and report direct connection attempts that bypass the HTTP CONNECT //! proxy. //! -//! When the sandbox network namespace has iptables LOG rules installed (see +//! When the sandbox network namespace has nftables log rules installed (see //! `NetworkNamespace::install_bypass_rules`), the kernel writes a log line for -//! each dropped packet. This module reads those messages, parses the iptables +//! each dropped packet. This module reads those messages, parses the nftables //! LOG format, and emits structured tracing events + denial aggregator entries. //! //! ## Graceful degradation //! //! If `/dev/kmsg` cannot be opened (e.g., restricted container environment), -//! the monitor logs a one-time warning and returns. The iptables REJECT rules +//! the monitor logs a one-time warning and returns. The nftables reject rules //! still provide fast-fail UX — the monitor only adds diagnostic visibility. use crate::denial_aggregator::DenialEvent; @@ -26,7 +26,7 @@ use std::sync::atomic::{AtomicU32, Ordering}; use tokio::sync::mpsc; use tracing::debug; -/// A parsed iptables LOG entry from `/dev/kmsg`. +/// A parsed nftables log entry from `/dev/kmsg`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct BypassEvent { /// Destination IP address. @@ -41,7 +41,7 @@ pub struct BypassEvent { pub uid: Option, } -/// Parse an iptables LOG line from `/dev/kmsg`. +/// Parse a nftables log line from `/dev/kmsg`. /// /// Expected format (from the kernel LOG target): /// ```text @@ -74,7 +74,7 @@ pub fn parse_kmsg_line(line: &str, namespace_prefix: &str) -> Option &'static str { /// Spawn the bypass monitor as a background tokio task. /// -/// Uses `dmesg --follow` to tail the kernel ring buffer for iptables LOG +/// Uses `dmesg --follow` to tail the kernel ring buffer for nftables log /// entries matching the given namespace. Falls back gracefully if `dmesg` /// is not available. /// @@ -221,7 +221,7 @@ pub fn spawn( .severity(SeverityId::Medium) .dst_endpoint(dst_ep.clone()) .actor_process(Process::from_bypass(&binary, &binary_pid, &ancestors)) - .firewall_rule("bypass-detect", "iptables") + .firewall_rule("bypass-detect", "nftables") .observation_point(3) .message(format!( "BYPASS_DETECT {}:{} proto={} binary={binary} action=reject reason={reason}", diff --git a/crates/openshell-sandbox/src/child_env.rs b/crates/openshell-sandbox/src/child_env.rs index e764afdfe..32eecbee3 100644 --- a/crates/openshell-sandbox/src/child_env.rs +++ b/crates/openshell-sandbox/src/child_env.rs @@ -24,11 +24,12 @@ pub fn proxy_env_vars(proxy_url: &str) -> [(&'static str, String); 9] { pub fn tls_env_vars( ca_cert_path: &Path, combined_bundle_path: &Path, -) -> [(&'static str, String); 5] { +) -> [(&'static str, String); 6] { let ca_cert_path = ca_cert_path.display().to_string(); let combined_bundle_path = combined_bundle_path.display().to_string(); [ - ("NODE_EXTRA_CA_CERTS", ca_cert_path), + ("NODE_EXTRA_CA_CERTS", ca_cert_path.clone()), + ("DENO_CERT", ca_cert_path), ("SSL_CERT_FILE", combined_bundle_path.clone()), ("REQUESTS_CA_BUNDLE", combined_bundle_path.clone()), ("CURL_CA_BUNDLE", combined_bundle_path.clone()), @@ -81,6 +82,7 @@ mod tests { let stdout = String::from_utf8(output.stdout).expect("utf8"); assert!(stdout.contains("NODE_EXTRA_CA_CERTS=/etc/openshell-tls/openshell-ca.pem")); + assert!(stdout.contains("DENO_CERT=/etc/openshell-tls/openshell-ca.pem")); assert!(stdout.contains("SSL_CERT_FILE=/etc/openshell-tls/ca-bundle.pem")); assert!(stdout.contains("REQUESTS_CA_BUNDLE=/etc/openshell-tls/ca-bundle.pem")); assert!(stdout.contains("CURL_CA_BUNDLE=/etc/openshell-tls/ca-bundle.pem")); diff --git a/crates/openshell-sandbox/src/debug_rpc.rs b/crates/openshell-sandbox/src/debug_rpc.rs new file mode 100644 index 000000000..af22b7450 --- /dev/null +++ b/crates/openshell-sandbox/src/debug_rpc.rs @@ -0,0 +1,271 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! One-shot debug RPCs exposed via `openshell-sandbox debug-rpc`. +//! +//! Designed for end-to-end verification of the per-sandbox identity +//! flow (issue #1354). A `docker exec` (or `kubectl exec`) into a +//! running sandbox can issue raw sandbox-class gRPC calls without +//! standing up a custom binary inside the sandbox image — useful for +//! confirming the cross-sandbox IDOR guard and renewal semantics. +//! +//! Subcommands: +//! - `get-sandbox-config --sandbox-id ` — call `GetSandboxConfig` +//! - `refresh` — call `RefreshSandboxToken` +//! - `show-token` — print a token fingerprint and expiry, never the bearer +//! - `show-principal` — pretty-print the decoded JWT claims +//! (no signature verification — the supervisor already trusts the +//! token's origin) + +use base64::Engine as _; +use miette::{IntoDiagnostic, Result, WrapErr}; +use openshell_core::proto::{ + GetSandboxConfigRequest, RefreshSandboxTokenRequest, open_shell_client::OpenShellClient, +}; +use sha2::{Digest, Sha256}; + +use crate::grpc_client::{AuthedChannel, connect_channel_pub}; + +/// Entry point for the `debug-rpc` subcommand. Returns the process exit +/// code; `main` propagates it. +pub async fn run(args: &[String]) -> Result { + let cmd = args + .first() + .map(String::as_str) + .ok_or_else(|| miette::miette!("{}", USAGE))?; + + match cmd { + "get-sandbox-config" => run_get_sandbox_config(&args[1..]).await, + "refresh" => run_refresh().await, + "show-token" => run_show_token(), + "show-principal" => run_show_principal(), + "--help" | "-h" => { + println!("{USAGE}"); + Ok(0) + } + other => Err(miette::miette!( + "unknown debug-rpc command '{other}'\n\n{USAGE}" + )), + } +} + +const USAGE: &str = "\ +usage: openshell-sandbox debug-rpc [options] + +commands: + get-sandbox-config --sandbox-id call GetSandboxConfig + refresh renew the gateway JWT + show-token print JWT fingerprint and expiry + show-principal print decoded JWT claims + +requires: OPENSHELL_ENDPOINT in env, plus one of OPENSHELL_SANDBOX_TOKEN, +OPENSHELL_SANDBOX_TOKEN_FILE, or OPENSHELL_K8S_SA_TOKEN_FILE so the +supervisor's normal token-acquisition path can resolve a JWT."; + +async fn open_client() -> Result> { + let endpoint = std::env::var(openshell_core::sandbox_env::ENDPOINT) + .into_diagnostic() + .wrap_err("OPENSHELL_ENDPOINT must be set")?; + let channel = connect_channel_pub(&endpoint).await?; + Ok(OpenShellClient::new(channel)) +} + +async fn run_get_sandbox_config(args: &[String]) -> Result { + let sandbox_id = parse_flag(args, "--sandbox-id") + .ok_or_else(|| miette::miette!("get-sandbox-config: --sandbox-id is required"))?; + let mut client = open_client().await?; + let resp = client + .get_sandbox_config(GetSandboxConfigRequest { + sandbox_id: sandbox_id.to_string(), + }) + .await; + match resp { + Ok(r) => { + let inner = r.into_inner(); + println!( + "version={} policy_hash={} config_revision={}", + inner.version, inner.policy_hash, inner.config_revision + ); + Ok(0) + } + Err(status) => { + eprintln!("{}: {}", code_name(status.code()), status.message()); + // Map gRPC status to a non-zero exit so callers can branch + // (e.g. expect-permission-denied in a shell test). + Ok(match status.code() { + tonic::Code::PermissionDenied => 7, + tonic::Code::Unauthenticated => 16, + tonic::Code::NotFound => 5, + _ => 1, + }) + } + } +} + +async fn run_refresh() -> Result { + let mut client = open_client().await?; + let resp = client + .refresh_sandbox_token(RefreshSandboxTokenRequest {}) + .await; + match resp { + Ok(r) => { + let inner = r.into_inner(); + print_token_summary(&inner.token, Some(inner.expires_at_ms)); + Ok(0) + } + Err(status) => { + eprintln!("{}: {}", code_name(status.code()), status.message()); + Ok(1) + } + } +} + +fn run_show_token() -> Result { + let token = read_local_token()?; + print_token_summary(&token, None); + Ok(0) +} + +fn run_show_principal() -> Result { + let token = read_local_token()?; + let claims = decode_token_claims(&token)?; + println!( + "{}", + serde_json::to_string_pretty(&claims).into_diagnostic()? + ); + Ok(0) +} + +fn decode_token_claims(token: &str) -> Result { + let payload_b64 = token + .split('.') + .nth(1) + .ok_or_else(|| miette::miette!("token has no payload segment"))?; + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(payload_b64) + .into_diagnostic() + .wrap_err("failed to base64-decode token payload")?; + serde_json::from_slice(&payload) + .into_diagnostic() + .wrap_err("failed to parse token payload as JSON") +} + +fn print_token_summary(token: &str, expires_at_ms: Option) { + let claims = decode_token_claims(token).unwrap_or(serde_json::Value::Null); + let fingerprint = token_fingerprint(token); + let expires_at_ms = expires_at_ms + .or_else(|| { + claims + .get("exp") + .and_then(serde_json::Value::as_i64) + .map(|s| s.saturating_mul(1000)) + }) + .unwrap_or_default(); + let sandbox_id = claims + .get("sandbox_id") + .and_then(serde_json::Value::as_str) + .unwrap_or(""); + let subject = claims + .get("sub") + .and_then(serde_json::Value::as_str) + .unwrap_or(""); + let issuer = claims + .get("iss") + .and_then(serde_json::Value::as_str) + .unwrap_or(""); + println!( + "fingerprint={fingerprint}\nexpires_at_ms={expires_at_ms}\nsandbox_id={sandbox_id}\nsubject={subject}\nissuer={issuer}" + ); +} + +fn token_fingerprint(token: &str) -> String { + let digest = Sha256::digest(token.as_bytes()); + format!("sha256:{}", &hex::encode(digest)[..16]) +} + +/// Read the token from the env/file/SA-bootstrap chain, but only the +/// "already a gateway JWT" paths — show-token / show-principal don't +/// want to actually exchange an SA token. +fn read_local_token() -> Result { + if let Ok(t) = std::env::var(openshell_core::sandbox_env::SANDBOX_TOKEN) + && !t.is_empty() + { + return Ok(t); + } + if let Ok(path) = std::env::var(openshell_core::sandbox_env::SANDBOX_TOKEN_FILE) + && !path.is_empty() + { + return Ok(std::fs::read_to_string(&path) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read sandbox token from {path}"))? + .trim() + .to_string()); + } + Err(miette::miette!( + "no in-process gateway JWT available — set OPENSHELL_SANDBOX_TOKEN or \ + OPENSHELL_SANDBOX_TOKEN_FILE. The K8s SA-bootstrap path is intentionally \ + excluded from `show-token` / `show-principal` to avoid issuing a fresh \ + token just for inspection." + )) +} + +fn parse_flag<'a>(args: &'a [String], name: &str) -> Option<&'a str> { + let mut iter = args.iter(); + while let Some(a) = iter.next() { + if a == name { + return iter.next().map(String::as_str); + } + if let Some(rest) = a.strip_prefix(&format!("{name}=")) { + return Some(rest); + } + } + None +} + +fn code_name(c: tonic::Code) -> &'static str { + match c { + tonic::Code::Ok => "OK", + tonic::Code::Cancelled => "Cancelled", + tonic::Code::Unknown => "Unknown", + tonic::Code::InvalidArgument => "InvalidArgument", + tonic::Code::DeadlineExceeded => "DeadlineExceeded", + tonic::Code::NotFound => "NotFound", + tonic::Code::AlreadyExists => "AlreadyExists", + tonic::Code::PermissionDenied => "PermissionDenied", + tonic::Code::ResourceExhausted => "ResourceExhausted", + tonic::Code::FailedPrecondition => "FailedPrecondition", + tonic::Code::Aborted => "Aborted", + tonic::Code::OutOfRange => "OutOfRange", + tonic::Code::Unimplemented => "Unimplemented", + tonic::Code::Internal => "Internal", + tonic::Code::Unavailable => "Unavailable", + tonic::Code::DataLoss => "DataLoss", + tonic::Code::Unauthenticated => "Unauthenticated", + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_flag_handles_space_separated() { + let args: Vec = ["--sandbox-id", "abc-123"] + .iter() + .map(ToString::to_string) + .collect(); + assert_eq!(parse_flag(&args, "--sandbox-id"), Some("abc-123")); + } + + #[test] + fn parse_flag_handles_equals_separated() { + let args: Vec = ["--sandbox-id=abc-123".to_string()].to_vec(); + assert_eq!(parse_flag(&args, "--sandbox-id"), Some("abc-123")); + } + + #[test] + fn parse_flag_returns_none_when_missing() { + let args: Vec = ["--other".to_string(), "x".to_string()].to_vec(); + assert!(parse_flag(&args, "--sandbox-id").is_none()); + } +} diff --git a/crates/openshell-sandbox/src/denial_aggregator.rs b/crates/openshell-sandbox/src/denial_aggregator.rs index d64be7f1d..5d41adffd 100644 --- a/crates/openshell-sandbox/src/denial_aggregator.rs +++ b/crates/openshell-sandbox/src/denial_aggregator.rs @@ -11,7 +11,6 @@ use std::collections::HashMap; use std::future::Future; -use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::mpsc; use tracing::debug; @@ -124,7 +123,7 @@ impl DenialAggregator { /// Ingest a single denial event, merging into existing summary or creating /// a new one. fn ingest(&mut self, event: DenialEvent) { - let now_ms = current_time_ms(); + let now_ms = openshell_core::time::now_ms(); let key = (event.host.clone(), event.port, event.binary.clone()); let entry = self @@ -217,9 +216,3 @@ pub struct FlushableL7Sample { pub path: String, pub count: u32, } - -fn current_time_ms() -> i64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0, |d| i64::try_from(d.as_millis()).unwrap_or(i64::MAX)) -} diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 44f372355..14a6808c1 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -3,22 +3,120 @@ //! gRPC client for fetching sandbox policy, provider environment, and inference //! route bundles from `OpenShell` server. +//! +//! Every request carries a gateway-minted JWT in the `Authorization` header. +//! The token is resolved at startup from one of three sources: +//! +//! 1. `OPENSHELL_SANDBOX_TOKEN` — raw JWT in the env (test harness path). +//! 2. `OPENSHELL_SANDBOX_TOKEN_FILE` — file containing the JWT (Docker / +//! Podman / VM drivers write this to a bundle file at sandbox-create +//! time). +//! 3. `OPENSHELL_K8S_SA_TOKEN_FILE` — projected `ServiceAccount` JWT; the +//! supervisor exchanges it for a gateway JWT via `IssueSandboxToken` +//! once at startup. +//! +//! The resolved gateway JWT is held in process memory thereafter and +//! injected on every outbound call by [`AuthInterceptor`]. use std::collections::HashMap; -use std::time::Duration; +use std::sync::{Arc, OnceLock, RwLock}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ - DenialSummary, GetInferenceBundleRequest, GetInferenceBundleResponse, GetSandboxConfigRequest, - GetSandboxProviderEnvironmentRequest, PolicySource, PolicyStatus, ReportPolicyStatusRequest, - SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, UpdateConfigRequest, - inference_client::InferenceClient, open_shell_client::OpenShellClient, + DenialSummary, GetDraftPolicyRequest, GetInferenceBundleRequest, GetInferenceBundleResponse, + GetSandboxConfigRequest, GetSandboxProviderEnvironmentRequest, IssueSandboxTokenRequest, + PolicyChunk, PolicySource, PolicyStatus, RefreshSandboxTokenRequest, ReportPolicyStatusRequest, + SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, + UpdateConfigRequest, inference_client::InferenceClient, open_shell_client::OpenShellClient, }; +use openshell_core::sandbox_env; +use tonic::Status; +use tonic::metadata::AsciiMetadataValue; use tonic::service::interceptor::InterceptedService; use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity}; -use tracing::debug; +use tracing::{debug, info, warn}; + +/// Channel type after the [`AuthInterceptor`] is applied. Aliased so the +/// generated client type signatures stay readable. +pub type AuthedChannel = InterceptedService; + +/// Shared, refreshable Bearer header. All [`AuthInterceptor`] clones read +/// the same slot, so the renewal task can replace the token in place without +/// rebuilding the channel. +type TokenSlot = Arc>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TokenSource { + Env, + File, + K8sServiceAccount, +} + +#[derive(Debug)] +struct AcquiredToken { + token: String, + source: TokenSource, +} + +/// Process-wide token slot. Initialized by the first [`connect_channel`] +/// call and shared with every subsequent client and the renewal loop. +static TOKEN_SLOT: OnceLock = OnceLock::new(); -/// Create a channel to the `OpenShell` server. +/// Source used to acquire the process-wide token slot. +static TOKEN_SOURCE: OnceLock = OnceLock::new(); + +/// Serializes the first token acquisition. Several supervisor subsystems +/// connect during startup; without this guard they can all observe an empty +/// [`TOKEN_SLOT`] and perform duplicate K8s bootstrap exchanges. +static TOKEN_INIT_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(()); + +/// One-shot guard so the renewal loop spawns at most once per process. +static REFRESH_SPAWNED: OnceLock<()> = OnceLock::new(); + +fn install_token_slot(token: &str) -> Result { + let bearer = AsciiMetadataValue::try_from(format!("Bearer {token}")) + .into_diagnostic() + .wrap_err("sandbox JWT contained characters not valid for a header value")?; + if let Some(existing) = TOKEN_SLOT.get() { + *existing.write().expect("token slot poisoned") = bearer; + return Ok(existing.clone()); + } + let slot: TokenSlot = Arc::new(RwLock::new(bearer)); + let _ = TOKEN_SLOT.set(slot.clone()); + Ok(TOKEN_SLOT.get().cloned().unwrap_or(slot)) +} + +/// gRPC interceptor that injects `authorization: Bearer ` on every +/// outbound request. The token lives in a shared [`TokenSlot`] so the renewal +/// task can replace it without rebuilding clients. +#[derive(Clone)] +pub struct AuthInterceptor { + bearer: TokenSlot, +} + +impl AuthInterceptor { + fn new(bearer: TokenSlot) -> Self { + Self { bearer } + } +} + +impl tonic::service::Interceptor for AuthInterceptor { + fn call( + &mut self, + mut req: tonic::Request<()>, + ) -> std::result::Result, Status> { + let bearer = self + .bearer + .read() + .expect("auth interceptor token slot poisoned") + .clone(); + req.metadata_mut().insert("authorization", bearer); + Ok(req) + } +} + +/// Build the plain (un-intercepted) gRPC channel. /// /// When the endpoint uses `https://`, mTLS is configured using these env vars: /// - `OPENSHELL_TLS_CA` -- path to the CA certificate @@ -27,7 +125,7 @@ use tracing::debug; /// /// When the endpoint uses `http://`, a plaintext connection is used (for /// deployments where TLS is disabled, e.g. behind a Cloudflare Tunnel). -async fn connect_channel(endpoint: &str) -> Result { +async fn build_plain_channel(endpoint: &str) -> Result { let mut ep = Endpoint::from_shared(endpoint.to_string()) .into_diagnostic() .wrap_err("invalid gRPC endpoint")? @@ -43,13 +141,13 @@ async fn connect_channel(endpoint: &str) -> Result { let tls_enabled = endpoint.starts_with("https://"); if tls_enabled { - let ca_path = std::env::var("OPENSHELL_TLS_CA") + let ca_path = std::env::var(sandbox_env::TLS_CA) .into_diagnostic() .wrap_err("OPENSHELL_TLS_CA is required")?; - let cert_path = std::env::var("OPENSHELL_TLS_CERT") + let cert_path = std::env::var(sandbox_env::TLS_CERT) .into_diagnostic() .wrap_err("OPENSHELL_TLS_CERT is required")?; - let key_path = std::env::var("OPENSHELL_TLS_KEY") + let key_path = std::env::var(sandbox_env::TLS_KEY) .into_diagnostic() .wrap_err("OPENSHELL_TLS_KEY is required")?; @@ -79,59 +177,353 @@ async fn connect_channel(endpoint: &str) -> Result { .wrap_err("failed to connect to OpenShell server") } -/// Create a channel to the `OpenShell` server (public for use by `supervisor_session`). -pub async fn connect_channel_pub(endpoint: &str) -> Result { - connect_channel(endpoint).await +/// Build a Bearer-authenticated channel to the gateway. +/// +/// First call per process resolves the sandbox JWT via the three-step +/// lookup (env → file → K8s SA bootstrap exchange) and installs it into +/// the process-wide [`TOKEN_SLOT`]. Subsequent calls reuse the cached +/// slot — the renewal loop keeps the value fresh, so re-running the +/// bootstrap is both unnecessary and (on the K8s SA path) expensive +/// (one apiserver round-trip per call). The renewal loop itself is +/// spawned once per process via [`REFRESH_SPAWNED`]. +async fn connect_channel(endpoint: &str) -> Result { + let channel = build_plain_channel(endpoint).await?; + let (slot, source) = token_slot(endpoint, &channel).await?; + let plain_channel = channel.clone(); + let intercepted = InterceptedService::new(channel, AuthInterceptor::new(slot.clone())); + if REFRESH_SPAWNED.set(()).is_ok() { + let refresh_channel = intercepted.clone(); + let endpoint = endpoint.to_string(); + tokio::spawn(async move { + refresh_token_loop(refresh_channel, slot, source, endpoint, plain_channel).await; + }); + } + Ok(intercepted) +} + +async fn token_slot(endpoint: &str, plain_channel: &Channel) -> Result<(TokenSlot, TokenSource)> { + if let Some(existing) = TOKEN_SLOT.get() { + let source = TOKEN_SOURCE.get().copied().unwrap_or(TokenSource::Env); + return Ok((existing.clone(), source)); + } + + let _guard = TOKEN_INIT_LOCK.lock().await; + + if let Some(existing) = TOKEN_SLOT.get() { + let source = TOKEN_SOURCE.get().copied().unwrap_or(TokenSource::Env); + return Ok((existing.clone(), source)); + } + + let acquired = acquire_sandbox_token(endpoint, plain_channel).await?; + let slot = install_token_slot(&acquired.token)?; + let _ = TOKEN_SOURCE.set(acquired.source); + Ok((slot, acquired.source)) } -/// Interceptor that injects the sandbox shared secret into every gRPC request. +/// Resolve the sandbox JWT used to authenticate every outbound RPC. /// -/// The server validates this header on sandbox-to-server RPCs (`GetSandboxConfig`, -/// `GetSandboxProviderEnvironment`, etc.) instead of requiring an OIDC Bearer token. -#[derive(Clone)] -pub struct SandboxSecretInterceptor { - secret: Option>, +/// `endpoint` is logged on errors but never used for transport here; the +/// actual network call lives inside this function only on the K8s +/// bootstrap path, which uses `plain_channel` to call `IssueSandboxToken` +/// once before the steady-state Bearer-authenticated channel is built. +async fn acquire_sandbox_token(endpoint: &str, plain_channel: &Channel) -> Result { + if let Ok(t) = std::env::var(sandbox_env::SANDBOX_TOKEN) + && !t.is_empty() + { + debug!(source = "env", "loaded sandbox token"); + return Ok(AcquiredToken { + token: t, + source: TokenSource::Env, + }); + } + + if let Ok(path) = std::env::var(sandbox_env::SANDBOX_TOKEN_FILE) + && !path.is_empty() + { + let contents = std::fs::read_to_string(&path) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read sandbox token from {path}"))?; + debug!(source = "file", path = %path, "loaded sandbox token"); + return Ok(AcquiredToken { + token: contents.trim().to_string(), + source: TokenSource::File, + }); + } + + if let Ok(sa_path) = std::env::var(sandbox_env::K8S_SA_TOKEN_FILE) + && !sa_path.is_empty() + { + return Ok(AcquiredToken { + token: acquire_k8s_sandbox_token(endpoint, plain_channel, &sa_path).await?, + source: TokenSource::K8sServiceAccount, + }); + } + + Err(miette::miette!( + "no sandbox token source available — set one of {}, {}, or {}", + sandbox_env::SANDBOX_TOKEN, + sandbox_env::SANDBOX_TOKEN_FILE, + sandbox_env::K8S_SA_TOKEN_FILE, + )) } -impl tonic::service::Interceptor for SandboxSecretInterceptor { - fn call( - &mut self, - mut req: tonic::Request<()>, - ) -> std::result::Result, tonic::Status> { - if let Some(ref val) = self.secret { - req.metadata_mut().insert("x-sandbox-secret", val.clone()); +async fn acquire_k8s_sandbox_token( + endpoint: &str, + plain_channel: &Channel, + sa_path: &str, +) -> Result { + let sa_token = std::fs::read_to_string(sa_path) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read K8s SA token from {sa_path}"))? + .trim() + .to_string(); + info!(endpoint = %endpoint, "exchanging K8s ServiceAccount token for sandbox JWT"); + // The bootstrap exchange uses a one-off interceptor pinned to the + // SA token; the resulting gateway JWT becomes the value in the + // shared `TOKEN_SLOT` once `connect_channel` returns. + let bootstrap_slot: TokenSlot = Arc::new(RwLock::new( + AsciiMetadataValue::try_from(format!("Bearer {sa_token}")) + .into_diagnostic() + .wrap_err("SA token contained characters not valid for a header value")?, + )); + let interceptor = AuthInterceptor::new(bootstrap_slot); + let bootstrap = InterceptedService::new(plain_channel.clone(), interceptor); + let mut client = OpenShellClient::new(bootstrap); + let resp = client + .issue_sandbox_token(IssueSandboxTokenRequest {}) + .await + .into_diagnostic() + .wrap_err("IssueSandboxToken bootstrap exchange failed")?; + Ok(resp.into_inner().token) +} + +/// Build an authenticated channel for direct external use (e.g. the +/// long-lived `supervisor_session` control stream). +pub async fn connect_channel_pub(endpoint: &str) -> Result { + connect_channel(endpoint).await +} + +/// Background task that renews the sandbox JWT at ~80% of its remaining +/// lifetime. The new token replaces the value in [`TOKEN_SLOT`], so all +/// in-flight and future clients pick it up on their next request. The +/// loop never panics: every failure is logged and re-attempted after a +/// bounded backoff. +async fn refresh_token_loop( + channel: AuthedChannel, + slot: TokenSlot, + source: TokenSource, + endpoint: String, + plain_channel: Channel, +) { + let mut client = OpenShellClient::new(channel); + loop { + let sleep = compute_refresh_delay(&slot); + tokio::time::sleep(sleep).await; + match client + .refresh_sandbox_token(RefreshSandboxTokenRequest {}) + .await + { + Ok(resp) => { + let new_token = resp.into_inner().token; + match AsciiMetadataValue::try_from(format!("Bearer {new_token}")) { + Ok(value) => { + if let Ok(mut guard) = slot.write() { + *guard = value; + info!("renewed gateway sandbox JWT in-place"); + } + } + Err(e) => warn!(error = %e, "refreshed JWT contained invalid header bytes"), + } + } + Err(status) => { + if status.code() == tonic::Code::Unauthenticated + && source == TokenSource::K8sServiceAccount + { + if let Some(sa_path) = std::env::var(sandbox_env::K8S_SA_TOKEN_FILE) + .ok() + .filter(|p| !p.is_empty()) + { + match acquire_k8s_sandbox_token(&endpoint, &plain_channel, &sa_path).await { + Ok(new_token) => { + match AsciiMetadataValue::try_from(format!("Bearer {new_token}")) { + Ok(value) => { + if let Ok(mut guard) = slot.write() { + *guard = value; + info!( + "rebootstrapped gateway sandbox JWT after refresh authentication failure" + ); + continue; + } + } + Err(e) => warn!( + error = %e, + "rebootstrapped JWT contained invalid header bytes" + ), + } + } + Err(e) => warn!( + error = %e, + "K8s ServiceAccount bootstrap retry failed after refresh authentication failure" + ), + } + } else { + warn!( + "RefreshSandboxToken returned Unauthenticated and K8s SA token file is unavailable" + ); + } + } else if status.code() == tonic::Code::Unauthenticated { + warn!( + source = ?source, + "RefreshSandboxToken returned Unauthenticated; static token sources cannot rebootstrap automatically" + ); + } + warn!(error = %status, "RefreshSandboxToken failed; will retry"); + // Backoff so we don't spin against a sustained failure. + tokio::time::sleep(Duration::from_secs(10)).await; + } } - Ok(req) } } -type AuthenticatedClient = OpenShellClient>; -type AuthenticatedInferenceClient = - InferenceClient>; - -fn sandbox_secret_interceptor() -> SandboxSecretInterceptor { - let secret = std::env::var("OPENSHELL_SSH_HANDSHAKE_SECRET") +/// Compute the next refresh delay: 80 % of the time remaining until the +/// current token's `exp`, plus up to 10 % jitter, with a small lower bound +/// for already-expired tokens and capped at 12 h. If the token can't be parsed +/// (legacy/non-JWT bearer) +/// default to 6 h. +fn compute_refresh_delay(slot: &TokenSlot) -> Duration { + let token = slot + .read() .ok() - .and_then(|s| s.parse().ok()); - SandboxSecretInterceptor { secret } + .and_then(|v| v.to_str().ok().map(str::to_string)) + .unwrap_or_default(); + let bearer = token.strip_prefix("Bearer ").unwrap_or(&token); + let now_ms = i64::try_from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |d| d.as_millis()), + ) + .unwrap_or(i64::MAX); + let remaining_ms = parse_jwt_exp_ms(bearer).map_or(21_600_000, |exp| exp - now_ms); // 6 h fallback + let mut delay_ms = if remaining_ms <= 0 { + 1_000 + } else { + (remaining_ms * 8 / 10).clamp(1_000, 43_200_000) + }; + // Up to 10 % jitter, derived deterministically from token bytes so + // unit tests are reproducible without injecting an RNG. + let jitter_pct = (token.len() % 10) as u64; + let jitter_ms = (u64::try_from(delay_ms).unwrap_or(0) * jitter_pct) / 100; + delay_ms = delay_ms.saturating_add(i64::try_from(jitter_ms).unwrap_or(0)); + Duration::from_millis(u64::try_from(delay_ms).unwrap_or(0)) +} + +/// Decode the `exp` claim from a JWT without verifying its signature. +/// Returns the expiry in milliseconds since the Unix epoch, or `None` if +/// the token is not a parseable JWT. +fn parse_jwt_exp_ms(jwt: &str) -> Option { + use base64::Engine; + let mut parts = jwt.splitn(3, '.'); + let _header = parts.next()?; + let payload_b64 = parts.next()?; + let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(payload_b64) + .ok()?; + let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?; + let exp_secs = value.get("exp")?.as_i64()?; + exp_secs.checked_mul(1000) } -/// Connect to the `OpenShell` server with sandbox secret authentication. -async fn connect(endpoint: &str) -> Result { +#[cfg(test)] +mod auth_tests { + use super::*; + + #[test] + fn parse_jwt_exp_reads_unsigned_payload() { + use base64::Engine as _; + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(br#"{"exp":1234567890,"sandbox_id":"sb-1"}"#); + let token = format!("h.{payload}.sig"); + assert_eq!(parse_jwt_exp_ms(&token), Some(1_234_567_890_000)); + } + + #[test] + fn parse_jwt_exp_returns_none_for_malformed_token() { + assert!(parse_jwt_exp_ms("not-a-jwt").is_none()); + assert!(parse_jwt_exp_ms("only.two").is_none()); + assert!(parse_jwt_exp_ms("a.!!!.c").is_none()); + } + + #[test] + fn compute_refresh_delay_uses_80_percent_when_token_present() { + // Build a JWT whose exp is 1000 seconds in the future. With 0-jitter + // the delay should be roughly 800 seconds. + use base64::Engine as _; + let now_s = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let exp = now_s + 1000; + let payload_json = format!(r#"{{"exp":{exp}}}"#); + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload_json); + let token = format!("h.{payload}.s"); + let bearer = AsciiMetadataValue::try_from(format!("Bearer {token}")).unwrap(); + let slot: TokenSlot = Arc::new(RwLock::new(bearer)); + let delay = compute_refresh_delay(&slot); + // 800 s baseline + up to 10 % jitter → 800..=880 s, with some slack + // for the 1-second resolution of the exp claim. + let secs = delay.as_secs(); + assert!( + (700..=900).contains(&secs), + "expected 80%-of-1000s delay, got {secs}s" + ); + } + + #[test] + fn compute_refresh_delay_uses_short_delay_for_expired_token() { + // Already-expired token still produces a small positive delay so the + // loop doesn't busy-spin. + use base64::Engine as _; + let exp = 1; // past + let payload = + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(format!(r#"{{"exp":{exp}}}"#)); + let token = format!("h.{payload}.s"); + let bearer = AsciiMetadataValue::try_from(format!("Bearer {token}")).unwrap(); + let slot: TokenSlot = Arc::new(RwLock::new(bearer)); + let delay = compute_refresh_delay(&slot); + assert!((1..60).contains(&delay.as_secs())); + } + + #[test] + fn compute_refresh_delay_supports_short_token_ttl() { + use base64::Engine as _; + let now_s = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let exp = now_s + 30; + let payload_json = format!(r#"{{"exp":{exp}}}"#); + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload_json); + let token = format!("h.{payload}.s"); + let bearer = AsciiMetadataValue::try_from(format!("Bearer {token}")).unwrap(); + let slot: TokenSlot = Arc::new(RwLock::new(bearer)); + let delay = compute_refresh_delay(&slot); + assert!( + delay.as_secs() < 30, + "expected refresh before 30s expiry, got {delay:?}", + ); + } +} + +/// Connect to the `OpenShell` server. +async fn connect(endpoint: &str) -> Result> { let channel = connect_channel(endpoint).await?; - Ok(OpenShellClient::with_interceptor( - channel, - sandbox_secret_interceptor(), - )) + Ok(OpenShellClient::new(channel)) } -/// Connect to the inference service with sandbox secret authentication. -async fn connect_inference(endpoint: &str) -> Result { +/// Connect to the inference service. +async fn connect_inference(endpoint: &str) -> Result> { let channel = connect_channel(endpoint).await?; - Ok(InferenceClient::with_interceptor( - channel, - sandbox_secret_interceptor(), - )) + Ok(InferenceClient::new(channel)) } /// Fetch sandbox policy from `OpenShell` server via gRPC. @@ -151,7 +543,7 @@ pub async fn fetch_policy(endpoint: &str, sandbox_id: &str) -> Result, sandbox_id: &str, ) -> Result> { let response = client @@ -175,7 +567,7 @@ async fn fetch_policy_with_client( /// Sync a locally-discovered policy using an existing client connection. async fn sync_policy_with_client( - client: &mut AuthenticatedClient, + client: &mut OpenShellClient, sandbox: &str, policy: &ProtoSandboxPolicy, ) -> Result<()> { @@ -188,6 +580,7 @@ async fn sync_policy_with_client( delete_setting: false, global: false, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic() @@ -244,7 +637,7 @@ pub async fn sync_policy(endpoint: &str, sandbox: &str, policy: &ProtoSandboxPol pub async fn fetch_provider_environment( endpoint: &str, sandbox_id: &str, -) -> Result> { +) -> Result { debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Fetching provider environment"); let mut client = connect(endpoint).await?; @@ -256,7 +649,12 @@ pub async fn fetch_provider_environment( .await .into_diagnostic()?; - Ok(response.into_inner().environment) + let inner = response.into_inner(); + Ok(ProviderEnvironmentResult { + environment: inner.environment, + provider_env_revision: inner.provider_env_revision, + credential_expires_at_ms: inner.credential_expires_at_ms, + }) } /// A reusable gRPC client for the `OpenShell` service. @@ -265,7 +663,7 @@ pub async fn fetch_provider_environment( /// and status reporting, avoiding per-request TLS handshake overhead. #[derive(Clone)] pub struct CachedOpenShellClient { - client: AuthenticatedClient, + client: OpenShellClient, } /// Settings poll result returned by [`CachedOpenShellClient::poll_settings`]. @@ -279,6 +677,13 @@ pub struct SettingsPollResult { pub settings: HashMap, /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, + pub provider_env_revision: u64, +} + +pub struct ProviderEnvironmentResult { + pub environment: HashMap, + pub provider_env_revision: u64, + pub credential_expires_at_ms: HashMap, } impl CachedOpenShellClient { @@ -289,7 +694,7 @@ impl CachedOpenShellClient { } /// Get a clone of the underlying tonic client for direct RPC calls. - pub fn raw_client(&self) -> AuthenticatedClient { + pub fn raw_client(&self) -> OpenShellClient { self.client.clone() } @@ -315,18 +720,25 @@ impl CachedOpenShellClient { .unwrap_or(PolicySource::Unspecified), settings: inner.settings, global_policy_version: inner.global_policy_version, + provider_env_revision: inner.provider_env_revision, }) } - /// Submit denial summaries for policy analysis. + /// Submit denial summaries and/or agent-authored proposals for policy analysis. + /// + /// Returns the gateway response so callers can surface accepted/rejected + /// counts, rejection reasons, and server-assigned `accepted_chunk_ids` + /// (e.g., the `policy.local` API forwards these to the in-sandbox agent + /// so it can watch proposal state via `GET /v1/proposals/{id}`). pub async fn submit_policy_analysis( &self, sandbox_name: &str, summaries: Vec, - proposed_chunks: Vec, + proposed_chunks: Vec, analysis_mode: &str, - ) -> Result<()> { - self.client + ) -> Result { + let response = self + .client .clone() .submit_policy_analysis(SubmitPolicyAnalysisRequest { name: sandbox_name.to_string(), @@ -337,7 +749,28 @@ impl CachedOpenShellClient { .await .into_diagnostic()?; - Ok(()) + Ok(response.into_inner()) + } + + /// Fetch the current draft chunks for a sandbox. `status_filter` may be + /// `"pending"`, `"approved"`, `"rejected"`, or empty for all. Used by + /// `policy.local`'s `GET /v1/proposals/{id}` and `/wait` routes to + /// inspect proposal state. + pub async fn get_draft_policy( + &self, + sandbox_name: &str, + status_filter: &str, + ) -> Result> { + let response = self + .client + .clone() + .get_draft_policy(GetDraftPolicyRequest { + name: sandbox_name.to_string(), + status_filter: status_filter.to_string(), + }) + .await + .into_diagnostic()?; + Ok(response.into_inner().chunks) } /// Report policy load status back to the server. @@ -382,32 +815,3 @@ pub async fn fetch_inference_bundle(endpoint: &str) -> Result GraphqlRequestInfo } } +pub fn classify_json_envelope_value(value: &Value) -> GraphqlRequestInfo { + match classify_json_envelope(value) { + Ok(operations) => GraphqlRequestInfo { + operations, + error: None, + }, + Err(err) => GraphqlRequestInfo { + operations: Vec::new(), + error: Some(err), + }, + } +} + fn classify_request_inner( request: &L7Request, body: &[u8], diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index 5301ac4d5..703aafae4 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -15,11 +15,13 @@ pub mod provider; pub mod relay; pub mod rest; pub mod tls; +pub(crate) mod websocket; /// Application-layer protocol for L7 inspection. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum L7Protocol { Rest, + Websocket, Graphql, Sql, } @@ -28,6 +30,7 @@ impl L7Protocol { pub fn parse(s: &str) -> Option { match s.to_ascii_lowercase().as_str() { "rest" => Some(Self::Rest), + "websocket" => Some(Self::Websocket), "graphql" => Some(Self::Graphql), "sql" => Some(Self::Sql), _ => None, @@ -58,6 +61,10 @@ pub enum EnforcementMode { } /// L7 configuration for an endpoint, extracted from policy data. +#[allow( + clippy::struct_excessive_bools, + reason = "Endpoint config mirrors independent policy schema toggles." +)] #[derive(Debug, Clone)] pub struct L7EndpointConfig { pub protocol: L7Protocol, @@ -72,6 +79,15 @@ pub struct L7EndpointConfig { /// rather than rejected at the parser. Needed by upstreams like GitLab /// that embed `%2F` in namespaced project paths. Defaults to false. pub allow_encoded_slash: bool, + /// Opt-in rewrite of credential placeholders in client-to-server + /// WebSocket text messages after an allowed HTTP 101 upgrade. + pub websocket_credential_rewrite: bool, + /// Opt-in rewrite of credential placeholders in supported textual REST + /// request bodies before forwarding upstream. + pub request_body_credential_rewrite: bool, + /// When true, client-to-server GraphQL-over-WebSocket operation messages + /// are classified with the same operation policy used by GraphQL-over-HTTP. + pub websocket_graphql_policy: bool, } /// Result of an L7 policy decision for a single request. @@ -138,6 +154,12 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { }; let allow_encoded_slash = get_object_bool(val, "allow_encoded_slash").unwrap_or(false); + let websocket_credential_rewrite = + get_object_bool(val, "websocket_credential_rewrite").unwrap_or(false); + let request_body_credential_rewrite = + get_object_bool(val, "request_body_credential_rewrite").unwrap_or(false); + let websocket_graphql_policy = + protocol == L7Protocol::Websocket && endpoint_has_graphql_policy(val); let graphql_max_body_bytes = get_object_u64(val, "graphql_max_body_bytes") .and_then(|v| usize::try_from(v).ok()) .filter(|v| *v > 0) @@ -150,6 +172,9 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { enforcement, graphql_max_body_bytes, allow_encoded_slash, + websocket_credential_rewrite, + request_body_credential_rewrite, + websocket_graphql_policy, }) } @@ -231,6 +256,60 @@ fn get_object_str(val: ®orus::Value, key: &str) -> Option { } } +fn endpoint_has_graphql_policy(val: ®orus::Value) -> bool { + has_non_empty_object_field(val, "graphql_persisted_queries") + || has_graphql_persisted_query_mode(val) + || rules_have_graphql_policy(val, "rules", true) + || rules_have_graphql_policy(val, "deny_rules", false) +} + +fn rules_have_graphql_policy(val: ®orus::Value, key: &str, allow_wrapped: bool) -> bool { + let Some(regorus::Value::Array(rules)) = get_object_value(val, key) else { + return false; + }; + rules.iter().any(|rule| { + let rule = if allow_wrapped { + get_object_value(rule, "allow").unwrap_or(rule) + } else { + rule + }; + has_graphql_rule_fields(rule) + }) +} + +fn has_graphql_rule_fields(val: ®orus::Value) -> bool { + has_non_empty_string_field(val, "operation_type") + || has_non_empty_string_field(val, "operation_name") + || has_non_empty_array_field(val, "fields") +} + +fn has_non_empty_string_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::String(s)) if !s.is_empty()) +} + +fn has_non_empty_array_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Array(values)) if !values.is_empty()) +} + +fn has_non_empty_object_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Object(values)) if !values.is_empty()) +} + +fn has_graphql_persisted_query_mode(val: ®orus::Value) -> bool { + matches!( + get_object_value(val, "persisted_queries"), + Some(regorus::Value::String(mode)) if !mode.is_empty() && mode.as_ref() != "deny" + ) +} + +fn get_object_value<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => map.get(&key_val), + _ => None, + } +} + /// Check a glob pattern for obvious syntax issues. /// /// Returns `Some(warning_message)` if the pattern looks malformed. @@ -274,6 +353,43 @@ fn check_glob_syntax(pattern: &str) -> Option { None } +fn validate_host_wildcard(errors: &mut Vec, loc: &str, host: &str) { + if !host.contains('*') { + return; + } + + if host == "*" || host == "**" { + errors.push(format!( + "{loc}: host wildcard '{host}' matches all hosts; use specific patterns like '*.example.com'" + )); + return; + } + + let labels: Vec<&str> = host.split('.').collect(); + let first_label = labels.first().copied().unwrap_or_default(); + if labels.iter().skip(1).any(|label| label.contains('*')) { + errors.push(format!( + "{loc}: host wildcard may only appear in the first DNS label, got '{host}'" + )); + return; + } + if first_label.contains("**") && first_label != "**" { + errors.push(format!( + "{loc}: recursive host wildcard '**' is only allowed as the entire first DNS label, got '{host}'" + )); + return; + } + + // Reject TLD or single-label wildcards. They are accepted by the policy + // engine but silently fail at the proxy layer (see #787). + if labels.len() <= 2 { + errors.push(format!( + "{loc}: TLD wildcard '{host}' is not allowed; \ + use subdomain wildcards like '*.example.com' instead" + )); + } +} + fn validate_graphql_operation_type( errors: &mut Vec, loc: &str, @@ -353,6 +469,45 @@ fn validate_graphql_rule( validate_graphql_fields(errors, warnings, loc, rule.get("fields")); } +fn json_rule_has_graphql_fields(rule: &serde_json::Value) -> bool { + rule.get("operation_type") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule + .get("operation_name") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule.get("fields").is_some() +} + +fn json_rule_has_transport_fields(rule: &serde_json::Value) -> bool { + rule.get("method").is_some() || rule.get("path").is_some() || rule.get("query").is_some() +} + +fn json_endpoint_has_graphql_policy(ep: &serde_json::Value) -> bool { + ep.get("graphql_persisted_queries") + .and_then(|v| v.as_object()) + .is_some_and(|v| !v.is_empty()) + || ep + .get("persisted_queries") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty() && v != "deny") + || ep + .get("rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| { + rules.iter().any(|rule| { + rule.get("allow") + .or(Some(rule)) + .is_some_and(json_rule_has_graphql_fields) + }) + }) + || ep + .get("deny_rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| rules.iter().any(json_rule_has_graphql_fields)) +} + /// Validate L7 policy configuration in the loaded OPA data. /// /// Returns a list of errors and warnings. Errors should prevent sandbox startup; @@ -382,6 +537,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< .get("rules") .and_then(|v| v.as_array()) .is_some_and(|a| !a.is_empty()); + let websocket_has_graphql_policy = + protocol == "websocket" && json_endpoint_has_graphql_policy(ep); let host = ep.get("host").and_then(|v| v.as_str()).unwrap_or(""); let endpoint_path = ep.get("path").and_then(|v| v.as_str()).unwrap_or(""); @@ -409,29 +566,7 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } } - // Validate host wildcard patterns. - if host.contains('*') { - if host == "*" || host == "**" { - errors.push(format!( - "{loc}: host wildcard '{host}' matches all hosts; use specific patterns like '*.example.com'" - )); - } else if !host.starts_with("*.") && !host.starts_with("**.") { - errors.push(format!( - "{loc}: host wildcard must start with '*.' or '**.' (e.g., '*.example.com'), got '{host}'" - )); - } else { - // Reject TLD wildcards like *.com (2 labels) — they are - // accepted by the policy engine but silently fail at the - // proxy layer (see #787). - let label_count = host.split('.').count(); - if label_count <= 2 { - errors.push(format!( - "{loc}: TLD wildcard '{host}' is not allowed; \ - use subdomain wildcards like '*.example.com' instead" - )); - } - } - } + validate_host_wildcard(&mut errors, &loc, host); // port + ports mutual exclusion let has_scalar_port = ep @@ -462,7 +597,7 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< if !protocol.is_empty() && L7Protocol::parse(protocol).is_none() { errors.push(format!( - "{loc}: unknown protocol '{protocol}' (expected rest, graphql, or sql)" + "{loc}: unknown protocol '{protocol}' (expected rest, websocket, graphql, or sql)" )); } @@ -489,12 +624,36 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } if protocol != "graphql" + && protocol != "websocket" && (ep.get("persisted_queries").is_some() || ep.get("graphql_persisted_queries").is_some() || ep.get("graphql_max_body_bytes").is_some()) { warnings.push(format!( - "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql" + "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql or websocket" + )); + } + + if ep + .get("websocket_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + && protocol != "websocket" + { + warnings.push(format!( + "{loc}: websocket_credential_rewrite is ignored unless protocol is rest or websocket" + )); + } + + if ep + .get("request_body_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + { + warnings.push(format!( + "{loc}: request_body_credential_rewrite is ignored unless protocol is rest" )); } @@ -574,14 +733,13 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< // Validate method if let Some(method) = deny_rule.get("method").and_then(|m| m.as_str()) && !method.is_empty() - && protocol == "rest" + && (protocol == "rest" || protocol == "websocket") { - let valid_methods = [ - "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", - ]; + let valid_methods = valid_methods_for_protocol(protocol); if !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { warnings.push(format!( - "{deny_loc}: Unknown HTTP method '{method}'. Standard methods: GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS." + "{deny_loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") )); } } @@ -701,7 +859,17 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< .push(format!("{deny_loc}: command is for SQL protocol, not REST")); } - if protocol == "graphql" { + let deny_has_graphql = json_rule_has_graphql_fields(deny_rule); + if protocol == "websocket" + && deny_has_graphql + && json_rule_has_transport_fields(deny_rule) + { + errors.push(format!( + "{deny_loc}: WebSocket GraphQL deny rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + + if protocol == "graphql" || (protocol == "websocket" && deny_has_graphql) { validate_graphql_rule( &mut errors, &mut warnings, @@ -709,12 +877,9 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< deny_rule, true, ); - } else if deny_rule.get("operation_type").is_some() - || deny_rule.get("operation_name").is_some() - || deny_rule.get("fields").is_some() - { + } else if deny_has_graphql { warnings.push(format!( - "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql" + "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" )); } } @@ -733,10 +898,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } // Validate HTTP methods in rules - if has_rules && protocol == "rest" { - let valid_methods = [ - "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", - ]; + if has_rules && (protocol == "rest" || protocol == "websocket") { + let valid_methods = valid_methods_for_protocol(protocol); if let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { for (rule_idx, rule) in rules.iter().enumerate() { if let Some(method) = rule @@ -747,7 +910,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< && !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { warnings.push(format!( - "{loc}: Unknown HTTP method '{method}'. Standard methods: GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS." + "{loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") )); } @@ -858,14 +1022,36 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } } - if has_rules - && protocol == "graphql" - && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) - { + if has_rules && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { for (rule_idx, rule) in rules.iter().enumerate() { let allow = rule.get("allow").unwrap_or(rule); let rule_loc = format!("{loc}.rules[{rule_idx}].allow"); - validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); + let allow_has_graphql = json_rule_has_graphql_fields(allow); + if websocket_has_graphql_policy + && allow + .get("method") + .and_then(|m| m.as_str()) + .is_some_and(|method| method.eq_ignore_ascii_case("WEBSOCKET_TEXT")) + { + errors.push(format!( + "{rule_loc}: WebSocket endpoints with GraphQL operation policy must use operation_type/operation_name/fields rules for client messages instead of WEBSOCKET_TEXT" + )); + } + if protocol == "websocket" + && allow_has_graphql + && json_rule_has_transport_fields(allow) + { + errors.push(format!( + "{rule_loc}: WebSocket GraphQL allow rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + if protocol == "graphql" || (protocol == "websocket" && allow_has_graphql) { + validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); + } else if allow_has_graphql { + warnings.push(format!( + "{rule_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" + )); + } } } } @@ -921,6 +1107,13 @@ pub fn expand_access_presets(data: &mut serde_json::Value) { "full" => vec![graphql_rule_json("*")], _ => continue, } + } else if protocol == "websocket" { + match access.as_str() { + "read-only" => vec![rule_json("GET", "**")], + "read-write" => vec![rule_json("GET", "**"), rule_json("WEBSOCKET_TEXT", "**")], + "full" => vec![rule_json("*", "**")], + _ => continue, + } } else { match access.as_str() { "read-only" => vec![ @@ -957,6 +1150,15 @@ fn rule_json(method: &str, path: &str) -> serde_json::Value { }) } +fn valid_methods_for_protocol(protocol: &str) -> &'static [&'static str] { + match protocol { + "websocket" => &["GET", "WEBSOCKET_TEXT", "*"], + _ => &[ + "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", + ], + } +} + fn graphql_rule_json(operation_type: &str) -> serde_json::Value { serde_json::json!({ "allow": { @@ -994,6 +1196,16 @@ mod tests { assert_eq!(config.enforcement, EnforcementMode::Audit); } + #[test] + fn parse_l7_config_websocket_protocol() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert_eq!(config.protocol, L7Protocol::Websocket); + } + #[test] fn parse_l7_config_skip() { let val = regorus::Value::from_json_str( @@ -1031,6 +1243,242 @@ mod tests { assert!(config.allow_encoded_slash); } + #[test] + fn parse_l7_config_websocket_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_credential_rewrite); + } + + #[test] + fn parse_l7_config_websocket_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443, "websocket_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_credential_rewrite); + } + + #[test] + fn parse_l7_config_request_body_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.request_body_credential_rewrite); + } + + #[test] + fn parse_l7_config_request_body_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443, "request_body_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.request_body_credential_rewrite); + } + + #[test] + fn parse_l7_config_websocket_graphql_policy_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_graphql_policy); + } + + #[test] + fn parse_l7_config_websocket_graphql_policy_detects_operation_rules() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_graphql_policy); + } + + #[test] + fn validate_websocket_credential_rewrite_warns_unless_rest_or_websocket() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "websocket_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("websocket_credential_rewrite is ignored")), + "expected websocket_credential_rewrite warning: {warnings:?}" + ); + } + + #[test] + fn validate_request_body_credential_rewrite_warns_unless_rest() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "request_body_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("request_body_credential_rewrite is ignored")), + "expected request_body_credential_rewrite warning: {warnings:?}" + ); + } + + #[test] + fn expand_websocket_read_write_access_includes_text_messages() { + let mut data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "access": "read-write" + }], + "binaries": [] + } + } + }); + + expand_access_presets(&mut data); + let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] + .as_array() + .unwrap(); + let methods: Vec<&str> = rules + .iter() + .map(|r| r["allow"]["method"].as_str().unwrap()) + .collect(); + assert!(methods.contains(&"GET")); + assert!(methods.contains(&"WEBSOCKET_TEXT")); + } + + #[test] + fn validate_websocket_accepts_graphql_operation_rules() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!(errors.is_empty(), "expected no errors: {errors:?}"); + assert!(warnings.is_empty(), "expected no warnings: {warnings:?}"); + } + + #[test] + fn validate_websocket_graphql_rule_requires_operation_type() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("operation_type")), + "expected missing operation_type error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_rule_rejects_mixed_transport_fields() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql", "operation_type": "subscription"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("must not combine")), + "expected mixed-field error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_policy_rejects_raw_text_message_rule() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}, + {"allow": {"operation_type": "query"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("instead of WEBSOCKET_TEXT")), + "expected raw WEBSOCKET_TEXT rejection: {errors:?}" + ); + } + #[test] fn validate_rules_and_access_mutual_exclusion() { let data = serde_json::json!({ @@ -1360,7 +1808,27 @@ mod tests { } #[test] - fn validate_wildcard_host_no_star_dot_error() { + fn validate_wildcard_host_mid_label_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "foo.*.example.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("first DNS label")), + "Mid-label wildcard should be rejected, got errors: {errors:?}" + ); + } + + #[test] + fn validate_wildcard_host_single_label_error() { let data = serde_json::json!({ "network_policies": { "test": { @@ -1374,8 +1842,28 @@ mod tests { }); let (errors, _warnings) = validate_l7_policies(&data); assert!( - errors.iter().any(|e| e.contains("must start with")), - "Malformed wildcard should be rejected, got errors: {errors:?}" + errors.iter().any(|e| e.contains("TLD wildcard")), + "Single-label wildcard should be rejected, got errors: {errors:?}" + ); + } + + #[test] + fn validate_wildcard_host_recursive_intra_label_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "foo**.example.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("recursive host wildcard")), + "Recursive intra-label wildcard should be rejected, got errors: {errors:?}" ); } @@ -1443,6 +1931,54 @@ mod tests { ); } + #[test] + fn validate_wildcard_host_double_star_valid_no_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "**.example.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "**.example.com should be valid, got errors: {errors:?}" + ); + assert!( + warnings.is_empty(), + "**.example.com should not warn, got warnings: {warnings:?}" + ); + } + + #[test] + fn validate_wildcard_host_intra_label_valid_no_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "*-aiplatform.googleapis.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "*-aiplatform.googleapis.com should be valid, got errors: {errors:?}" + ); + assert!( + warnings.is_empty(), + "*-aiplatform.googleapis.com should not warn, got warnings: {warnings:?}" + ); + } + #[test] fn validate_port_and_ports_mutually_exclusive() { let data = serde_json::json!({ diff --git a/crates/openshell-sandbox/src/l7/provider.rs b/crates/openshell-sandbox/src/l7/provider.rs index 7516aa85c..864d94ad2 100644 --- a/crates/openshell-sandbox/src/l7/provider.rs +++ b/crates/openshell-sandbox/src/l7/provider.rs @@ -27,7 +27,10 @@ pub enum RelayOutcome { /// Contains any overflow bytes read from upstream past the 101 response /// headers that belong to the upgraded protocol. The 101 headers /// themselves have already been forwarded to the client. - Upgraded { overflow: Vec }, + Upgraded { + overflow: Vec, + websocket_permessage_deflate: bool, + }, } /// Body framing for HTTP requests/responses. diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index d0599ea99..6d271af21 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -8,6 +8,7 @@ //! and either forwards or denies the request. use crate::l7::provider::{L7Provider, RelayOutcome}; +use crate::l7::rest::WebSocketExtensionMode; use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; use crate::opa::{PolicyGenerationGuard, TunnelPolicyEngine}; use crate::secrets::{self, SecretResolver}; @@ -38,6 +39,44 @@ pub struct L7EvalContext { pub(crate) secret_resolver: Option>, } +#[derive(Default)] +pub(crate) struct UpgradeRelayOptions<'a> { + pub(crate) websocket_request: bool, + pub(crate) websocket: WebSocketUpgradeBehavior, + pub(crate) secret_resolver: Option>, + pub(crate) engine: Option<&'a TunnelPolicyEngine>, + pub(crate) ctx: Option<&'a L7EvalContext>, + pub(crate) enforcement: EnforcementMode, + pub(crate) target: String, + pub(crate) query_params: std::collections::HashMap>, + pub(crate) policy_name: String, +} + +#[derive(Default)] +pub(crate) struct WebSocketUpgradeBehavior { + pub(crate) credential_rewrite: bool, + pub(crate) message_policy: WebSocketMessagePolicy, + pub(crate) permessage_deflate: bool, +} + +#[derive(Clone, Copy, Default, PartialEq, Eq)] +pub(crate) enum WebSocketMessagePolicy { + #[default] + None, + Transport, + Graphql, +} + +impl WebSocketMessagePolicy { + fn inspects_messages(self) -> bool { + self != Self::None + } + + fn is_graphql(self) -> bool { + self == Self::Graphql + } +} + #[derive(Debug, Clone, Copy)] enum ParseRejectionMode { L7Endpoint, @@ -101,7 +140,9 @@ where U: AsyncRead + AsyncWrite + Unpin + Send, { match config.protocol { - L7Protocol::Rest => relay_rest(config, &engine, client, upstream, ctx).await, + L7Protocol::Rest | L7Protocol::Websocket => { + relay_rest(config, &engine, client, upstream, ctx).await + } L7Protocol::Graphql => relay_graphql(config, &engine, client, upstream, ctx).await, L7Protocol::Sql => { if close_if_stale(engine.generation_guard(), ctx) { @@ -242,6 +283,24 @@ where query_params: req.query_params.clone(), graphql: graphql_info.clone(), }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } let parse_error_reason = graphql_info .as_ref() @@ -264,10 +323,10 @@ where (false, EnforcementMode::Audit) => "audit", (false, EnforcementMode::Enforce) => "deny", }; - let engine_type = if config.protocol == L7Protocol::Graphql { - "l7-graphql" - } else { - "l7" + let engine_type = match config.protocol { + L7Protocol::Graphql => "l7-graphql", + L7Protocol::Websocket => "l7-websocket", + L7Protocol::Rest | L7Protocol::Sql => "l7", }; emit_l7_request_log( ctx, @@ -282,19 +341,39 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, + }, ) .await?; match outcome { RelayOutcome::Reusable => {} RelayOutcome::Consumed => return Ok(()), - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req.query_params, + Some(&engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; } } } else { @@ -305,6 +384,11 @@ where &reason, client, Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), ) .await?; return Ok(()); @@ -369,20 +453,29 @@ fn emit_l7_request_log( /// Handle an upgraded connection (101 Switching Protocols). /// /// Forwards any overflow bytes from the upgrade response to the client, then -/// switches to raw bidirectional TCP copy for the upgraded protocol (WebSocket, -/// HTTP/2, etc.). L7 policy enforcement does not apply after the upgrade — -/// the initial HTTP request was already evaluated. +/// either switches to a parsed WebSocket relay for opted-in message policy / +/// credential rewriting or to raw bidirectional TCP copy for other upgrades. pub(crate) async fn handle_upgrade( client: &mut C, upstream: &mut U, overflow: Vec, host: &str, port: u16, + options: UpgradeRelayOptions<'_>, ) -> Result<()> where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send, { + let use_websocket_relay = options.websocket_request + && (options.websocket.message_policy.inspects_messages() + || options.websocket.permessage_deflate + || (options.websocket.credential_rewrite && options.secret_resolver.is_some())); + let relay_mode = if use_websocket_relay { + "websocket parsed relay" + } else { + "raw bidirectional relay (L7 enforcement no longer active)" + }; ocsf_emit!( NetworkActivityBuilder::new(crate::ocsf_ctx()) .activity(ActivityId::Other) @@ -390,12 +483,56 @@ where .severity(SeverityId::Informational) .dst_endpoint(Endpoint::from_domain(host, port)) .message(format!( - "101 Switching Protocols — raw bidirectional relay (L7 enforcement no longer active) \ - [host:{host} port:{port} overflow_bytes:{}]", + "101 Switching Protocols — {relay_mode} [host:{host} port:{port} overflow_bytes:{}]", overflow.len() )) .build() ); + if use_websocket_relay { + let resolver = if options.websocket.credential_rewrite { + options.secret_resolver.as_deref() + } else { + None + }; + let inspector = if options.websocket.message_policy.inspects_messages() { + match (options.engine, options.ctx) { + (Some(engine), Some(ctx)) => Some(crate::l7::websocket::InspectionOptions { + engine, + ctx, + enforcement: options.enforcement, + target: options.target.clone(), + query_params: options.query_params.clone(), + graphql_policy: options.websocket.message_policy.is_graphql(), + }), + _ => { + return Err(miette!( + "websocket message inspection missing policy context" + )); + } + } + } else { + None + }; + let compression = if options.websocket.permessage_deflate { + crate::l7::websocket::WebSocketCompression::PermessageDeflate + } else { + crate::l7::websocket::WebSocketCompression::None + }; + return crate::l7::websocket::relay_with_options( + client, + upstream, + overflow, + host, + port, + crate::l7::websocket::RelayOptions { + policy_name: &options.policy_name, + resolver, + inspector, + compression, + }, + ) + .await; + } if !overflow.is_empty() { client.write_all(&overflow).await.into_diagnostic()?; client.flush().await.into_diagnostic()?; @@ -406,6 +543,57 @@ where Ok(()) } +pub(crate) fn upgrade_options<'a>( + config: &L7EndpointConfig, + ctx: &'a L7EvalContext, + websocket_request: bool, + target: &str, + query_params: &std::collections::HashMap>, + engine: Option<&'a TunnelPolicyEngine>, +) -> UpgradeRelayOptions<'a> { + let websocket_credential_rewrite = + matches!(config.protocol, L7Protocol::Rest | L7Protocol::Websocket) + && config.websocket_credential_rewrite; + let websocket_message_policy = if config.protocol == L7Protocol::Websocket { + if config.websocket_graphql_policy { + WebSocketMessagePolicy::Graphql + } else { + WebSocketMessagePolicy::Transport + } + } else { + WebSocketMessagePolicy::None + }; + UpgradeRelayOptions { + websocket_request, + websocket: WebSocketUpgradeBehavior { + credential_rewrite: websocket_credential_rewrite, + message_policy: websocket_message_policy, + permessage_deflate: false, + }, + secret_resolver: if websocket_credential_rewrite { + ctx.secret_resolver.clone() + } else { + None + }, + engine, + ctx: engine.map(|_| ctx), + enforcement: config.enforcement, + target: target.to_string(), + query_params: query_params.clone(), + policy_name: ctx.policy_name.clone(), + } +} + +pub(crate) fn websocket_extension_mode(config: &L7EndpointConfig) -> WebSocketExtensionMode { + if config.protocol == L7Protocol::Websocket + || (config.protocol == L7Protocol::Rest && config.websocket_credential_rewrite) + { + WebSocketExtensionMode::PermessageDeflate + } else { + WebSocketExtensionMode::Preserve + } +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -485,6 +673,24 @@ where query_params: req.query_params.clone(), graphql: None, }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + provider + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } // Evaluate L7 policy via Rego (using redacted target) let (allowed, reason) = evaluate_l7_request(engine, ctx, &request_info)?; @@ -553,12 +759,17 @@ where if allowed || config.enforcement == EnforcementMode::Audit { // Forward request to upstream and relay response - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, + }, ) .await?; match outcome { @@ -571,8 +782,23 @@ where ); return Ok(()); } - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req.query_params, + Some(engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; } } } else { @@ -584,6 +810,11 @@ where &reason, client, Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), ) .await?; return Ok(()); @@ -777,8 +1008,21 @@ where ); return Ok(()); } - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let options = UpgradeRelayOptions { + websocket: WebSocketUpgradeBehavior { + permessage_deflate: websocket_permessage_deflate, + ..Default::default() + }, + ..Default::default() + }; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; } } } else { @@ -789,6 +1033,11 @@ where &reason, client, Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), ) .await?; return Ok(()); @@ -1001,20 +1250,31 @@ where // Forward request with credential rewriting and relay the response. // relay_http_request_with_resolver handles both directions: it sends // the request upstream and reads the response back to the client. - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - resolver, - Some(generation_guard), + crate::l7::rest::RelayRequestOptions { + resolver, + generation_guard: Some(generation_guard), + ..Default::default() + }, ) .await?; match outcome { RelayOutcome::Reusable => {} // continue loop RelayOutcome::Consumed => break, - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { overflow, .. } => { + return handle_upgrade( + client, + upstream, + overflow, + &ctx.host, + ctx.port, + UpgradeRelayOptions::default(), + ) + .await; } } } @@ -1034,7 +1294,7 @@ mod tests { use super::*; use crate::opa::{NetworkInput, OpaEngine}; use std::path::PathBuf; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); @@ -1071,6 +1331,436 @@ mod tests { ); } + #[test] + fn websocket_text_policy_requires_explicit_message_rule() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "gateway.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&input) + .unwrap() + .1; + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "ws_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let request = L7RequestInfo { + action: "WEBSOCKET_TEXT".into(), + target: "/ws".into(), + query_params: std::collections::HashMap::new(), + graphql: None, + }; + + let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); + + assert!(!allowed); + assert!(reason.contains("WEBSOCKET_TEXT /ws not permitted")); + } + + #[tokio::test] + async fn route_selected_websocket_upgrade_rejects_invalid_accept_without_forwarding_101() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Rest, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }]; + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + assert!(forwarded.contains("Connection: Upgrade\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", + ) + .await + .unwrap(); + + let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should fail closed on invalid accept") + .unwrap() + .expect_err("invalid accept must fail the route-selected relay"); + assert!(err.to_string().contains("Sec-WebSocket-Accept")); + + let mut response = [0u8; 1]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client side should close without 101") + .unwrap(); + assert_eq!(n, 0, "invalid response must not forward 101 headers"); + } + + #[tokio::test] + async fn route_selected_websocket_rewrites_text_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + - allow: + method: WEBSOCKET_TEXT + path: "/ws" + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten websocket text should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!(rewritten, r#"{"op":2,"d":{"token":"real-token"}}"#); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + + #[tokio::test] + async fn route_selected_graphql_websocket_rewrites_connection_init_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/graphql".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: true, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("T").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /graphql HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("GET /graphql HTTP/1.1")); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten GraphQL WebSocket control message should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + + fn masked_text_frame(payload: &[u8]) -> Vec { + let mask = [0x11, 0x22, 0x33, 0x44]; + assert!( + payload.len() <= 125, + "test helper only supports small frames" + ); + let payload_len = u8::try_from(payload.len()).expect("small frame length"); + let mut frame = vec![0x81, 0x80 | payload_len]; + frame.extend_from_slice(&mask); + frame.extend( + payload + .iter() + .enumerate() + .map(|(idx, byte)| byte ^ mask[idx % 4]), + ); + frame + } + + async fn read_text_frame( + reader: &mut R, + ) -> std::io::Result<(bool, String)> { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await?; + assert_eq!(header[0] & 0x0f, 0x1, "expected text frame"); + let masked = header[1] & 0x80 != 0; + let payload_len = usize::from(header[1] & 0x7f); + assert!(payload_len <= 125, "test helper only supports small frames"); + let mut mask = [0u8; 4]; + if masked { + reader.read_exact(&mut mask).await?; + } + let mut payload = vec![0u8; payload_len]; + reader.read_exact(&mut payload).await?; + if masked { + for (idx, byte) in payload.iter_mut().enumerate() { + *byte ^= mask[idx % 4]; + } + } + Ok((masked, String::from_utf8(payload).expect("text payload"))) + } + #[tokio::test] async fn l7_relay_closes_keep_alive_tunnel_after_policy_generation_change() { let initial_data = r#" diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 19acdbf32..c513499f4 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -9,13 +9,18 @@ use crate::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; use crate::opa::PolicyGenerationGuard; -use crate::secrets::rewrite_http_header_block; +use crate::secrets::{ + SecretResolver, contains_reserved_credential_marker, rewrite_http_header_block, +}; +use base64::Engine as _; use miette::{IntoDiagnostic, Result, miette}; -use std::collections::HashMap; +use sha1::{Digest, Sha1}; +use std::collections::{HashMap, HashSet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::debug; const MAX_HEADER_BYTES: usize = 16384; // 16 KiB for HTTP headers +const MAX_REWRITE_BODY_BYTES: usize = 256 * 1024; const RELAY_BUF_SIZE: usize = 8192; /// Idle timeout for `relay_until_eof`. If no data arrives within this window /// the body is considered complete. Prevents blocking on servers that keep @@ -72,10 +77,19 @@ impl L7Provider for RestProvider { reason: &str, client: &mut C, ) -> Result<()> { - send_deny_response(req, policy_name, reason, client, None).await + send_deny_response(req, policy_name, reason, client, None, None).await } } +/// Extra sandbox-side context included in agent-readable deny responses when +/// the relay has it available. +#[derive(Debug, Clone, Copy, Default)] +pub(crate) struct DenyResponseContext<'a> { + pub(crate) host: Option<&'a str>, + pub(crate) port: Option, + pub(crate) binary: Option<&'a str>, +} + impl RestProvider { /// Deny with a redacted target for the response body. pub(crate) async fn deny_with_redacted_target( @@ -85,8 +99,9 @@ impl RestProvider { reason: &str, client: &mut C, redacted_target: Option<&str>, + context: Option>, ) -> Result<()> { - send_deny_response(req, policy_name, reason, client, redacted_target).await + send_deny_response(req, policy_name, reason, client, redacted_target, context).await } } @@ -333,7 +348,7 @@ pub(crate) async fn relay_http_request_with_resolver( req: &L7Request, client: &mut C, upstream: &mut U, - resolver: Option<&crate::secrets::SecretResolver>, + resolver: Option<&SecretResolver>, ) -> Result where C: AsyncRead + AsyncWrite + Unpin, @@ -346,9 +361,48 @@ pub(crate) async fn relay_http_request_with_resolver_guarded( req: &L7Request, client: &mut C, upstream: &mut U, - resolver: Option<&crate::secrets::SecretResolver>, + resolver: Option<&SecretResolver>, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result +where + C: AsyncRead + AsyncWrite + Unpin, + U: AsyncRead + AsyncWrite + Unpin, +{ + relay_http_request_with_options_guarded( + req, + client, + upstream, + RelayRequestOptions { + resolver, + generation_guard, + websocket_extensions: WebSocketExtensionMode::Preserve, + request_body_credential_rewrite: false, + }, + ) + .await +} + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub(crate) enum WebSocketExtensionMode { + #[default] + Preserve, + PermessageDeflate, +} + +#[derive(Clone, Copy, Default)] +pub(crate) struct RelayRequestOptions<'a> { + pub(crate) resolver: Option<&'a SecretResolver>, + pub(crate) generation_guard: Option<&'a PolicyGenerationGuard>, + pub(crate) websocket_extensions: WebSocketExtensionMode, + pub(crate) request_body_credential_rewrite: bool, +} + +pub(crate) async fn relay_http_request_with_options_guarded( + req: &L7Request, + client: &mut C, + upstream: &mut U, + options: RelayRequestOptions<'_>, +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -358,252 +412,365 @@ where .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); + let header_str = std::str::from_utf8(&req.raw_header[..header_end]) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let client_requested_upgrade = client_requested_upgrade(header_str); + let websocket_request = if options.websocket_extensions == WebSocketExtensionMode::Preserve { + None + } else { + parse_websocket_upgrade_request(&req.raw_header[..header_end])? + }; - let rewrite_result = rewrite_http_header_block(&req.raw_header[..header_end], resolver) + let (header_bytes, expected_websocket_extension) = rewrite_websocket_extensions_for_mode( + &req.raw_header[..header_end], + options.websocket_extensions, + websocket_request.is_some(), + )?; + let websocket_response = + websocket_request + .as_ref() + .map(|request| WebSocketResponseValidation { + expected_accept: websocket_accept_for_key(&request.sec_key), + expected_extension: expected_websocket_extension.clone(), + offered_subprotocols: request.subprotocols.clone(), + }); + + let rewrite_result = rewrite_http_header_block(&header_bytes, options.resolver) .map_err(|e| miette!("credential injection failed: {e}"))?; - if let Some(guard) = generation_guard { + if let Some(guard) = options.generation_guard { guard.ensure_current()?; } - upstream - .write_all(&rewrite_result.rewritten) - .await - .into_diagnostic()?; - - let overflow = &req.raw_header[header_end..]; - if !overflow.is_empty() { - if let Some(guard) = generation_guard { - guard.ensure_current()?; + if options.request_body_credential_rewrite { + let body = collect_and_rewrite_request_body( + req, + client, + &rewrite_result.rewritten, + header_str, + &req.raw_header[header_end..], + options.resolver, + options.generation_guard, + ) + .await?; + upstream.write_all(&body.headers).await.into_diagnostic()?; + if !body.body.is_empty() { + upstream.write_all(&body.body).await.into_diagnostic()?; } - upstream.write_all(overflow).await.into_diagnostic()?; - } - let overflow_len = overflow.len() as u64; + } else { + upstream + .write_all(&rewrite_result.rewritten) + .await + .into_diagnostic()?; - match req.body_length { - BodyLength::ContentLength(len) => { - let remaining = len.saturating_sub(overflow_len); - if remaining > 0 { - relay_fixed(client, upstream, remaining, generation_guard).await?; + let overflow = &req.raw_header[header_end..]; + if !overflow.is_empty() { + if let Some(guard) = options.generation_guard { + guard.ensure_current()?; } + upstream.write_all(overflow).await.into_diagnostic()?; } - BodyLength::Chunked => { - relay_chunked( - client, - upstream, - &req.raw_header[header_end..], - generation_guard, - ) - .await?; + let overflow_len = overflow.len() as u64; + + match req.body_length { + BodyLength::ContentLength(len) => { + let remaining = len.saturating_sub(overflow_len); + if remaining > 0 { + relay_fixed(client, upstream, remaining, options.generation_guard).await?; + } + } + BodyLength::Chunked => { + relay_chunked( + client, + upstream, + &req.raw_header[header_end..], + options.generation_guard, + ) + .await?; + } + BodyLength::None => {} } - BodyLength::None => {} } upstream.flush().await.into_diagnostic()?; - let outcome = relay_response(&req.action, upstream, client).await?; - - // Validate that the client actually requested an upgrade before accepting - // a 101 from upstream. Per RFC 9110 Section 7.8, the server MUST NOT send - // 101 unless the client sent Upgrade + Connection: Upgrade headers. A - // non-compliant or malicious upstream could send an unsolicited 101 to - // bypass L7 inspection. - if matches!(outcome, RelayOutcome::Upgraded { .. }) { - let header_str = String::from_utf8_lossy(&req.raw_header[..header_end]); - if !client_requested_upgrade(&header_str) { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::DetectionFindingBuilder::new(crate::ocsf_ctx()) - .activity(openshell_ocsf::ActivityId::Open) - .action(openshell_ocsf::ActionId::Denied) - .disposition(openshell_ocsf::DispositionId::Blocked) - .severity(openshell_ocsf::SeverityId::High) - .confidence(openshell_ocsf::ConfidenceId::High) - .is_alert(true) - .finding_info( - openshell_ocsf::FindingInfo::new( - "unsolicited-101-upgrade", - "Unsolicited 101 Switching Protocols", - ) - .with_desc(&format!( - "Upstream sent 101 without client Upgrade request for {} {} — \ - possible L7 inspection bypass. Connection closed.", - req.action, req.target, - )), - ) - .message(format!( - "Unsolicited 101 upgrade blocked: {} {}", - req.action, req.target, - )) - .build() - ); - return Ok(RelayOutcome::Consumed); - } - } + let outcome = relay_response( + &req.action, + upstream, + client, + RelayResponseOptions { + websocket_extensions: options.websocket_extensions, + websocket: websocket_response, + client_requested_upgrade, + }, + ) + .await?; Ok(outcome) } -/// Send a 403 Forbidden JSON deny response. -/// -/// When `redacted_target` is provided, it is used instead of `req.target` -/// in the response body to avoid leaking resolved credential values. -async fn send_deny_response( - req: &L7Request, - policy_name: &str, - reason: &str, - client: &mut C, - redacted_target: Option<&str>, -) -> Result<()> { - let target = redacted_target.unwrap_or(&req.target); - let body = serde_json::json!({ - "error": "policy_denied", - "policy": policy_name, - "rule": format!("{} {}", req.action, target), - "detail": reason - }); - let body_bytes = body.to_string(); - let response = format!( - "HTTP/1.1 403 Forbidden\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - X-OpenShell-Policy: {}\r\n\ - Connection: close\r\n\ - \r\n\ - {}", - body_bytes.len(), - policy_name, - body_bytes, - ); - client - .write_all(response.as_bytes()) - .await - .into_diagnostic()?; - client.flush().await.into_diagnostic()?; - Ok(()) +struct PreparedRequestBody { + headers: Vec, + body: Vec, } -/// Parse Content-Length or Transfer-Encoding from HTTP headers. -/// -/// Per RFC 7230 Section 3.3.3, rejects requests containing both -/// `Content-Length` and `Transfer-Encoding` headers to prevent request -/// smuggling via CL/TE ambiguity. -pub(crate) fn parse_body_length(headers: &str) -> Result { - let mut has_te_chunked = false; - let mut cl_value: Option = None; - - for line in headers.lines().skip(1) { - let lower = line.to_ascii_lowercase(); - if lower.starts_with("transfer-encoding:") { - let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); - if val.split(',').any(|enc| enc.trim() == "chunked") { - has_te_chunked = true; +async fn collect_and_rewrite_request_body( + req: &L7Request, + client: &mut C, + rewritten_headers: &[u8], + original_header_str: &str, + already_read: &[u8], + resolver: Option<&SecretResolver>, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result { + match req.body_length { + BodyLength::None => { + if body_bytes_contain_reserved_marker(already_read) { + return Err(miette!( + "request body credential rewrite cannot resolve placeholders without explicit body framing" + )); } + Ok(PreparedRequestBody { + headers: rewritten_headers.to_vec(), + body: already_read.to_vec(), + }) } - if lower.starts_with("content-length:") { - let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); - let len: u64 = val - .parse() - .map_err(|_| miette!("Request contains invalid Content-Length value"))?; - if let Some(prev) = cl_value - && prev != len - { + BodyLength::ContentLength(len) => { + let len = usize::try_from(len) + .map_err(|_| miette!("request body is too large for credential rewrite"))?; + if len > MAX_REWRITE_BODY_BYTES { return Err(miette!( - "Request contains multiple Content-Length headers with differing values ({prev} vs {len})" + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" )); } - cl_value = Some(len); + let mut body = Vec::with_capacity(len); + let initial_len = already_read.len().min(len); + body.extend_from_slice(&already_read[..initial_len]); + let mut remaining = len.saturating_sub(initial_len); + let mut buf = [0u8; RELAY_BUF_SIZE]; + while remaining > 0 { + let to_read = remaining.min(buf.len()); + let n = client.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!( + "Connection closed with {remaining} body bytes remaining" + )); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + body.extend_from_slice(&buf[..n]); + remaining -= n; + } + let (headers, body) = + rewrite_buffered_body(rewritten_headers, original_header_str, body, resolver)?; + Ok(PreparedRequestBody { headers, body }) + } + BodyLength::Chunked => { + let body = collect_chunked_body(client, already_read, generation_guard).await?; + if body_bytes_contain_reserved_marker(&body) { + return Err(miette!( + "request body credential rewrite does not support chunked bodies containing credential placeholders" + )); + } + Ok(PreparedRequestBody { + headers: rewritten_headers.to_vec(), + body, + }) } } +} - if has_te_chunked && cl_value.is_some() { - return Err(miette!( - "Request contains both Transfer-Encoding and Content-Length headers" - )); +fn rewrite_buffered_body( + headers: &[u8], + original_header_str: &str, + body: Vec, + resolver: Option<&SecretResolver>, +) -> Result<(Vec, Vec)> { + if body.is_empty() { + return Ok((headers.to_vec(), body)); } - if has_te_chunked { - return Ok(BodyLength::Chunked); + let content_type = content_type(original_header_str); + if !is_rewritable_content_type(content_type.as_deref()) { + if body_bytes_contain_reserved_marker(&body) { + return Err(miette!( + "request body credential rewrite found placeholders in an unsupported content type" + )); + } + return Ok((headers.to_vec(), body)); } - if let Some(len) = cl_value { - return Ok(BodyLength::ContentLength(len)); + + let mut text = String::from_utf8(body) + .map_err(|_| miette!("request body credential rewrite requires UTF-8 text bodies"))?; + if !contains_reserved_credential_marker(&text) { + return Ok((headers.to_vec(), text.into_bytes())); } - Ok(BodyLength::None) + + let Some(resolver) = resolver else { + return Err(miette!( + "request body credential rewrite found placeholders but no resolver is available" + )); + }; + + let replacements = if content_type.as_deref() == Some("application/x-www-form-urlencoded") { + let (rewritten, replacements) = rewrite_form_urlencoded_body(&text, resolver)?; + text = rewritten; + replacements + } else { + resolver + .rewrite_text_placeholders(&mut text, "request_body") + .map_err(|e| miette!("credential injection failed: {e}"))? + }; + if replacements == 0 || contains_reserved_credential_marker(&text) { + return Err(miette!( + "request body credential rewrite left unresolved credential placeholders" + )); + } + + let body = text.into_bytes(); + let headers = set_content_length(headers, body.len())?; + Ok((headers, body)) } -/// Relay exactly `len` bytes from reader to writer. -async fn relay_fixed( - reader: &mut R, - writer: &mut W, - len: u64, - generation_guard: Option<&PolicyGenerationGuard>, -) -> Result<()> -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - let mut remaining = len; - let mut buf = [0u8; RELAY_BUF_SIZE]; - while remaining > 0 { - let to_read = usize::try_from(remaining) - .unwrap_or(buf.len()) - .min(buf.len()); - let n = reader.read(&mut buf[..to_read]).await.into_diagnostic()?; - if n == 0 { +fn rewrite_form_urlencoded_body(body: &str, resolver: &SecretResolver) -> Result<(String, usize)> { + let mut rewritten = String::with_capacity(body.len()); + let mut replacements = 0usize; + + for (idx, field) in body.split('&').enumerate() { + if idx > 0 { + rewritten.push('&'); + } + + let (name, value) = field + .split_once('=') + .map_or((field, None), |(name, value)| (name, Some(value))); + let decoded_name = form_url_decode(name)?; + if contains_reserved_credential_marker(&decoded_name) { return Err(miette!( - "Connection closed with {remaining} bytes remaining" + "request body credential rewrite does not support placeholders in form field names" )); } - if let Some(guard) = generation_guard { - guard.ensure_current()?; + + rewritten.push_str(name); + let Some(value) = value else { + continue; + }; + + rewritten.push('='); + let decoded_value = form_url_decode(value)?; + if !contains_reserved_credential_marker(&decoded_value) { + rewritten.push_str(value); + continue; } - writer.write_all(&buf[..n]).await.into_diagnostic()?; - remaining -= n as u64; + + let mut rewritten_value = decoded_value; + let field_replacements = resolver + .rewrite_text_placeholders(&mut rewritten_value, "request_body") + .map_err(|e| miette!("credential injection failed: {e}"))?; + if field_replacements == 0 || contains_reserved_credential_marker(&rewritten_value) { + return Err(miette!( + "request body credential rewrite left unresolved credential placeholders" + )); + } + replacements += field_replacements; + rewritten.push_str(&form_url_encode(&rewritten_value)); } - Ok(()) + + Ok((rewritten, replacements)) } -/// Relay chunked transfer encoding from reader to writer. -/// -/// Copies bytes verbatim (preserving chunk framing) while parsing the stream -/// boundaries so we can stop exactly at the end of the current message body. -/// Handles chunk extensions and trailers per RFC 7230. -/// -/// `already_forwarded` are overflow bytes that were already written to the -/// writer during header parsing. They are seeded into the parser buffer so -/// termination can still be detected when boundaries span reads. -async fn relay_chunked( - reader: &mut R, - writer: &mut W, - already_forwarded: &[u8], +fn form_url_decode(input: &str) -> Result { + let bytes = input.as_bytes(); + let mut decoded = Vec::with_capacity(bytes.len()); + let mut pos = 0usize; + + while pos < bytes.len() { + match bytes[pos] { + b'+' => { + decoded.push(b' '); + pos += 1; + } + b'%' if pos + 2 < bytes.len() => { + if let (Some(hi), Some(lo)) = (hex_value(bytes[pos + 1]), hex_value(bytes[pos + 2])) + { + decoded.push((hi << 4) | lo); + pos += 3; + } else { + decoded.push(bytes[pos]); + pos += 1; + } + } + byte => { + decoded.push(byte); + pos += 1; + } + } + } + + String::from_utf8(decoded).map_err(|_| { + miette!("request body credential rewrite requires UTF-8 form-url-encoded fields") + }) +} + +fn form_url_encode(input: &str) -> String { + let mut encoded = String::with_capacity(input.len()); + for byte in input.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'*' => { + encoded.push(byte as char); + } + b' ' => encoded.push('+'), + _ => { + use std::fmt::Write as _; + let _ = write!(encoded, "%{byte:02X}"); + } + } + } + encoded +} + +fn hex_value(byte: u8) -> Option { + match byte { + b'0'..=b'9' => Some(byte - b'0'), + b'a'..=b'f' => Some(byte - b'a' + 10), + b'A'..=b'F' => Some(byte - b'A' + 10), + _ => None, + } +} + +async fn collect_chunked_body( + client: &mut C, + already_read: &[u8], generation_guard: Option<&PolicyGenerationGuard>, -) -> Result<()> -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - let started_at = std::time::Instant::now(); +) -> Result> { let mut read_buf = [0u8; RELAY_BUF_SIZE]; - let mut parse_buf = Vec::from(already_forwarded); + let mut parse_buf = Vec::from(already_read); let mut pos = 0usize; - let mut chunk_count = 0usize; - let mut chunk_payload_bytes = 0usize; - // Parse chunk-size lines + chunk payloads until final 0-size chunk, then - // parse trailers until the terminating empty trailer line. loop { - // Parse one chunk size line: "[;extensions]\r\n" + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + let size_line_end = loop { if let Some(end) = find_crlf(&parse_buf, pos) { break end; } - let n = reader.read(&mut read_buf).await.into_diagnostic()?; + let n = client.read(&mut read_buf).await.into_diagnostic()?; if n == 0 { return Err(miette!("Chunked body ended before chunk-size line")); } if let Some(guard) = generation_guard { guard.ensure_current()?; } - writer.write_all(&read_buf[..n]).await.into_diagnostic()?; parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } }; let size_line = std::str::from_utf8(&parse_buf[pos..size_line_end]) @@ -620,288 +787,1147 @@ where pos = size_line_end + 2; if chunk_size == 0 { - // Parse trailers (if any). Terminates on empty trailer line. - let mut trailer_count = 0usize; loop { let trailer_end = loop { if let Some(end) = find_crlf(&parse_buf, pos) { break end; } - let n = reader.read(&mut read_buf).await.into_diagnostic()?; + let n = client.read(&mut read_buf).await.into_diagnostic()?; if n == 0 { return Err(miette!("Chunked body ended before trailer terminator")); } if let Some(guard) = generation_guard { guard.ensure_current()?; } - writer.write_all(&read_buf[..n]).await.into_diagnostic()?; parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } }; - let trailer_line = &parse_buf[pos..trailer_end]; pos = trailer_end + 2; if trailer_line.is_empty() { - debug!( - chunk_count, - chunk_payload_bytes, - trailer_count, - elapsed_ms = started_at.elapsed().as_millis(), - "relay_chunked complete" - ); - return Ok(()); + return Ok(parse_buf); } - trailer_count += 1; } } - // Ensure the full chunk payload + trailing CRLF is available. let chunk_end = pos .checked_add(chunk_size) .ok_or_else(|| miette!("Chunk size overflow"))?; let chunk_with_crlf_end = chunk_end .checked_add(2) .ok_or_else(|| miette!("Chunk size overflow"))?; - while parse_buf.len() < chunk_with_crlf_end { - let n = reader.read(&mut read_buf).await.into_diagnostic()?; + let n = client.read(&mut read_buf).await.into_diagnostic()?; if n == 0 { return Err(miette!("Chunked body ended mid-chunk")); } if let Some(guard) = generation_guard { guard.ensure_current()?; } - writer.write_all(&read_buf[..n]).await.into_diagnostic()?; parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } } if &parse_buf[chunk_end..chunk_with_crlf_end] != b"\r\n" { return Err(miette!("Chunk missing terminating CRLF")); } pos = chunk_with_crlf_end; - chunk_count += 1; - chunk_payload_bytes = chunk_payload_bytes.saturating_add(chunk_size); - - // Keep parser memory bounded for long streams. - if pos > RELAY_BUF_SIZE * 4 { - parse_buf.drain(..pos); - pos = 0; - } } } -fn find_crlf(buf: &[u8], start: usize) -> Option { - buf.get(start..)? - .windows(2) - .position(|w| w == b"\r\n") - .map(|offset| start + offset) +fn content_type(headers: &str) -> Option { + headers.lines().skip(1).find_map(|line| { + let (name, value) = line.split_once(':')?; + name.trim().eq_ignore_ascii_case("content-type").then(|| { + value + .split(';') + .next() + .unwrap_or("") + .trim() + .to_ascii_lowercase() + }) + }) } -async fn relay_response( - request_method: &str, - upstream: &mut U, - client: &mut C, -) -> Result -where - U: AsyncRead + Unpin, - C: AsyncWrite + Unpin, -{ - let started_at = std::time::Instant::now(); - let mut buf = Vec::with_capacity(4096); - let mut tmp = [0u8; 1024]; +fn is_rewritable_content_type(content_type: Option<&str>) -> bool { + let Some(content_type) = content_type else { + return false; + }; + content_type == "application/json" + || content_type == "application/x-www-form-urlencoded" + || content_type.starts_with("text/") +} - // Read response headers - loop { - if buf.len() > MAX_HEADER_BYTES { - return Err(miette!("HTTP response headers exceed limit")); - } +fn body_bytes_contain_reserved_marker(body: &[u8]) -> bool { + if body.is_empty() { + return false; + } + String::from_utf8_lossy(body) + .split('\0') + .any(contains_reserved_credential_marker) +} - let n = upstream.read(&mut tmp).await.into_diagnostic()?; - if n == 0 { - // Upstream closed — forward whatever we have - if !buf.is_empty() { - client.write_all(&buf).await.into_diagnostic()?; +fn set_content_length(headers: &[u8], len: usize) -> Result> { + use std::fmt::Write as _; + + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = String::with_capacity(header_str.len() + 32); + let mut inserted = false; + for line in header_str.split("\r\n") { + if line.is_empty() { + if !inserted { + let _ = write!(out, "Content-Length: {len}\r\n"); } - return Ok(RelayOutcome::Consumed); - } - buf.extend_from_slice(&tmp[..n]); - - if buf.windows(4).any(|w| w == b"\r\n\r\n") { + out.push_str("\r\n"); break; } + if line + .split_once(':') + .is_some_and(|(name, _)| name.trim().eq_ignore_ascii_case("content-length")) + { + if !inserted { + let _ = write!(out, "Content-Length: {len}\r\n"); + inserted = true; + } + continue; + } + out.push_str(line); + out.push_str("\r\n"); } + Ok(out.into_bytes()) +} - let header_end = buf.windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4; - - // Parse response framing - let header_str = String::from_utf8_lossy(&buf[..header_end]); - let status_code = parse_status_code(&header_str).unwrap_or(200); - let server_wants_close = parse_connection_close(&header_str); - let body_length = parse_body_length(&header_str)?; - - debug!( - status_code, - ?body_length, - server_wants_close, - request_method, - overflow_bytes = buf.len() - header_end, - "relay_response framing" - ); +pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { + let header_end = raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(raw_header.len(), |p| p + 4); + validate_websocket_upgrade_request(&raw_header[..header_end]).unwrap_or(false) +} - // 101 Switching Protocols: the connection has been upgraded (e.g. to - // WebSocket). Forward the 101 headers to the client and signal the - // caller to switch to raw bidirectional TCP relay. Any bytes read - // from upstream beyond the headers are overflow that belong to the - // upgraded protocol and must be forwarded before switching. - if status_code == 101 { - client - .write_all(&buf[..header_end]) - .await - .into_diagnostic()?; - client.flush().await.into_diagnostic()?; - let overflow = buf[header_end..].to_vec(); - debug!( - request_method, - overflow_bytes = overflow.len(), - "101 Switching Protocols — signaling protocol upgrade" - ); - return Ok(RelayOutcome::Upgraded { overflow }); +fn rewrite_websocket_extensions_for_mode( + raw_header: &[u8], + mode: WebSocketExtensionMode, + websocket_request: bool, +) -> Result<(Vec, Option)> { + if !websocket_request || mode == WebSocketExtensionMode::Preserve { + return Ok((raw_header.to_vec(), None)); } - - // Bodiless responses (HEAD, 1xx, 204, 304): forward headers only, skip body - if is_bodiless_response(request_method, status_code) { - client - .write_all(&buf[..header_end]) - .await - .into_diagnostic()?; - client.flush().await.into_diagnostic()?; - return if server_wants_close { - Ok(RelayOutcome::Consumed) - } else { - Ok(RelayOutcome::Reusable) - }; + match mode { + WebSocketExtensionMode::Preserve => Ok((raw_header.to_vec(), None)), + WebSocketExtensionMode::PermessageDeflate => { + rewrite_websocket_extensions_for_permessage_deflate(raw_header) + } } +} - // No explicit framing (no Content-Length, no Transfer-Encoding). - // Per RFC 7230 §3.3.3 the body is delimited by connection close. - if matches!(body_length, BodyLength::None) { - if server_wants_close { - // Server indicated it will close — read until EOF. - let before_end = &buf[..header_end - 2]; - client.write_all(before_end).await.into_diagnostic()?; - client - .write_all(b"Connection: close\r\n\r\n") - .await - .into_diagnostic()?; - let overflow = &buf[header_end..]; - if !overflow.is_empty() { - client.write_all(overflow).await.into_diagnostic()?; +fn rewrite_websocket_extensions_for_permessage_deflate( + raw_header: &[u8], +) -> Result<(Vec, Option)> { + let header_str = std::str::from_utf8(raw_header) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let safe_offer = supported_permessage_deflate_offer(header_str)?; + let mut out = Vec::with_capacity(raw_header.len()); + let mut inserted = false; + + for line in header_str.split_inclusive("\r\n") { + let bare = line.strip_suffix("\r\n").unwrap_or(line); + if bare + .to_ascii_lowercase() + .starts_with("sec-websocket-extensions:") + { + continue; + } + if bare.is_empty() && !inserted { + if let Some(offer) = safe_offer.as_deref() { + out.extend_from_slice(b"Sec-WebSocket-Extensions: "); + out.extend_from_slice(offer.as_bytes()); + out.extend_from_slice(b"\r\n"); } - relay_until_eof(upstream, client).await?; - client.flush().await.into_diagnostic()?; - return Ok(RelayOutcome::Consumed); + inserted = true; } - // No Connection: close — an HTTP/1.1 keep-alive server that omits - // framing headers has an empty body. Forward headers and continue - // the relay loop instead of blocking on relay_until_eof. - debug!("BodyLength::None without Connection: close — treating body as empty"); - client - .write_all(&buf[..header_end]) - .await - .into_diagnostic()?; - client.flush().await.into_diagnostic()?; - return Ok(RelayOutcome::Reusable); + out.extend_from_slice(line.as_bytes()); } + Ok((out, safe_offer)) +} - // Forward response headers + any overflow body bytes - client.write_all(&buf).await.into_diagnostic()?; - let overflow_len = (buf.len() - header_end) as u64; - - // Forward remaining response body - match body_length { - BodyLength::ContentLength(len) => { - let remaining = len.saturating_sub(overflow_len); - if remaining > 0 { - relay_fixed(upstream, client, remaining, None).await?; +fn supported_permessage_deflate_offer(header_str: &str) -> Result> { + for offer in websocket_extension_offers(header_str)? { + if !offer.name.eq_ignore_ascii_case("permessage-deflate") { + continue; + } + let mut client_no_context_takeover = false; + let mut server_no_context_takeover = false; + let mut unsupported = false; + let mut seen = HashSet::new(); + for param in &offer.params { + let name = param.name.to_ascii_lowercase(); + if param.value.is_some() || !seen.insert(name.clone()) { + unsupported = true; + break; + } + if name == "client_no_context_takeover" { + client_no_context_takeover = true; + } else if name == "server_no_context_takeover" { + server_no_context_takeover = true; + } else { + unsupported = true; + break; } } - BodyLength::Chunked => { - relay_chunked(upstream, client, &buf[header_end..], None).await?; + if client_no_context_takeover && !unsupported { + let mut offer = "permessage-deflate; client_no_context_takeover".to_string(); + if server_no_context_takeover { + offer.push_str("; server_no_context_takeover"); + } + return Ok(Some(offer)); } - BodyLength::None => unreachable!(), } - client.flush().await.into_diagnostic()?; - debug!( - request_method, - elapsed_ms = started_at.elapsed().as_millis(), - "relay_response complete (explicit framing)" - ); + Ok(None) +} - // When body framing is explicit (Content-Length / Chunked), always report - // the connection as reusable so the relay loop continues. If the server - // sent `Connection: close`, the *next* upstream write will fail and the - // loop will exit via the normal error path. Exiting early here would - // tear down the CONNECT tunnel before the client can detect the close, - // causing ~30 s retry delays in clients like `gh`. - Ok(RelayOutcome::Reusable) +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketExtensionOffer { + name: String, + params: Vec, } -/// Parse the HTTP status code from a response status line. -/// -/// Expects the first line to look like `HTTP/1.1 200 OK`. -fn parse_status_code(headers: &str) -> Option { - let status_line = headers.lines().next()?; - let code_str = status_line.split_whitespace().nth(1)?; - code_str.parse().ok() +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketExtensionParam { + name: String, + value: Option, } -/// Check if the response headers contain `Connection: close`. -fn parse_connection_close(headers: &str) -> bool { - for line in headers.lines().skip(1) { - let lower = line.to_ascii_lowercase(); - if lower.starts_with("connection:") { - let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); - return val.contains("close"); +fn websocket_extension_offers(header_str: &str) -> Result> { + let mut offers = Vec::new(); + for line in header_str.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + if !name.trim().eq_ignore_ascii_case("sec-websocket-extensions") { + continue; + } + for extension in value.split(',') { + let mut parts = extension.split(';').map(str::trim); + let Some(extension_name) = parts.next().filter(|name| !name.is_empty()) else { + return Err(miette!("invalid WebSocket extension offer")); + }; + if !is_http_token(extension_name) { + return Err(miette!("invalid WebSocket extension token")); + } + let mut params = Vec::new(); + for param in parts { + if param.is_empty() { + return Err(miette!("invalid WebSocket extension parameter")); + } + let (param_name, param_value) = match param.split_once('=') { + Some((name, value)) => { + let value = value.trim(); + if value.is_empty() || value.starts_with('"') || !is_http_token(value) { + return Err(miette!("unsupported WebSocket extension parameter value")); + } + (name.trim(), Some(value.to_string())) + } + None => (param, None), + }; + if param_name.is_empty() || !is_http_token(param_name) { + return Err(miette!("invalid WebSocket extension parameter")); + } + params.push(WebSocketExtensionParam { + name: param_name.to_string(), + value: param_value, + }); + } + offers.push(WebSocketExtensionOffer { + name: extension_name.to_string(), + params, + }); } } - false + Ok(offers) } -/// Check if the client request headers contain both `Upgrade` and -/// `Connection: Upgrade` headers, indicating the client requested a -/// protocol upgrade (e.g. WebSocket). -/// -/// Per RFC 9110 Section 7.8, a server MUST NOT send 101 Switching Protocols -/// unless the client sent these headers. -fn client_requested_upgrade(headers: &str) -> bool { - let mut has_upgrade_header = false; - let mut connection_contains_upgrade = false; +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketUpgradeRequest { + sec_key: String, + subprotocols: Vec, +} - for line in headers.lines().skip(1) { - let lower = line.to_ascii_lowercase(); - if lower.starts_with("upgrade:") { - has_upgrade_header = true; +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketResponseValidation { + expected_accept: String, + expected_extension: Option, + offered_subprotocols: Vec, +} + +fn validate_websocket_upgrade_request(raw_header: &[u8]) -> Result { + parse_websocket_upgrade_request(raw_header).map(|request| request.is_some()) +} + +fn parse_websocket_upgrade_request(raw_header: &[u8]) -> Result> { + let header_str = std::str::from_utf8(raw_header) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut lines = header_str.lines(); + let Some(request_line) = lines.next() else { + return Ok(None); + }; + let method = request_line.split_whitespace().next().unwrap_or_default(); + let mut headers = WebSocketUpgradeHeaders::default(); + + for line in lines { + if line.is_empty() { + break; } - if lower.starts_with("connection:") { - let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); - // Connection header can have comma-separated values - if val.split(',').any(|tok| tok.trim() == "upgrade") { - connection_contains_upgrade = true; + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + let value = value.trim(); + match name.as_str() { + "upgrade" if header_value_contains_token(value, "websocket") => { + headers.upgrade_websocket = true; + } + "connection" if header_value_contains_token(value, "upgrade") => { + headers.connection_upgrade = true; + } + "sec-websocket-key" => { + headers.sec_key_count += 1; + headers.sec_key = Some(value.to_string()); + } + "sec-websocket-version" => { + headers.version_count += 1; + headers.version = Some(value.to_string()); } + "sec-websocket-protocol" => { + headers.subprotocols.extend(parse_http_token_list(value)?); + } + _ => {} } } - has_upgrade_header && connection_contains_upgrade + if !headers.is_attempt() { + return Ok(None); + } + if !method.eq_ignore_ascii_case("GET") { + return Err(miette!("websocket upgrade request must use GET")); + } + if !headers.upgrade_websocket { + return Err(miette!( + "websocket upgrade request missing Upgrade: websocket" + )); + } + if !headers.connection_upgrade { + return Err(miette!( + "websocket upgrade request missing Connection: Upgrade" + )); + } + if headers.sec_key_count != 1 { + return Err(miette!( + "websocket upgrade request must include exactly one Sec-WebSocket-Key" + )); + } + let key = headers.sec_key.as_deref().unwrap_or_default(); + let decoded_key = base64::engine::general_purpose::STANDARD + .decode(key.as_bytes()) + .map_err(|_| miette!("websocket upgrade request has invalid Sec-WebSocket-Key"))?; + if decoded_key.len() != 16 { + return Err(miette!( + "websocket upgrade request has invalid Sec-WebSocket-Key length" + )); + } + if headers.version_count != 1 || headers.version.as_deref() != Some("13") { + return Err(miette!( + "websocket upgrade request must use Sec-WebSocket-Version: 13" + )); + } + Ok(Some(WebSocketUpgradeRequest { + sec_key: key.to_string(), + subprotocols: headers.subprotocols, + })) } -/// Returns true for responses that MUST NOT contain a message body per RFC 7230 §3.3.3: -/// HEAD responses, 1xx informational, 204 No Content, 304 Not Modified. -fn is_bodiless_response(request_method: &str, status_code: u16) -> bool { - request_method.eq_ignore_ascii_case("HEAD") - || (100..200).contains(&status_code) - || status_code == 204 - || status_code == 304 +fn websocket_accept_for_key(sec_key: &str) -> String { + const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut hasher = Sha1::new(); + hasher.update(sec_key.as_bytes()); + hasher.update(WEBSOCKET_GUID.as_bytes()); + base64::engine::general_purpose::STANDARD.encode(hasher.finalize()) } -/// Relay all bytes from reader to writer until EOF or idle timeout. +fn header_value_contains_token(value: &str, expected: &str) -> bool { + value + .split(',') + .any(|token| token.trim().eq_ignore_ascii_case(expected)) +} + +fn parse_http_token_list(value: &str) -> Result> { + let mut tokens = Vec::new(); + for token in value.split(',') { + let token = token.trim(); + if token.is_empty() || !is_http_token(token) { + return Err(miette!("invalid HTTP token list")); + } + tokens.push(token.to_string()); + } + Ok(tokens) +} + +fn is_http_token(value: &str) -> bool { + !value.is_empty() + && value.as_bytes().iter().all(|byte| { + matches!( + byte, + b'!' | b'#' + | b'$' + | b'%' + | b'&' + | b'\'' + | b'*' + | b'+' + | b'-' + | b'.' + | b'^' + | b'_' + | b'`' + | b'|' + | b'~' + | b'0'..=b'9' + | b'A'..=b'Z' + | b'a'..=b'z' + ) + }) +} + +#[derive(Default)] +struct WebSocketUpgradeHeaders { + upgrade_websocket: bool, + connection_upgrade: bool, + sec_key: Option, + sec_key_count: usize, + version: Option, + version_count: usize, + subprotocols: Vec, +} + +impl WebSocketUpgradeHeaders { + fn is_attempt(&self) -> bool { + self.upgrade_websocket || self.sec_key.is_some() || self.version.is_some() + } +} + +/// Send a 403 Forbidden JSON deny response. +/// +/// When `redacted_target` is provided, it is used instead of `req.target` +/// in the response body to avoid leaking resolved credential values. +async fn send_deny_response( + req: &L7Request, + policy_name: &str, + reason: &str, + client: &mut C, + redacted_target: Option<&str>, + context: Option>, +) -> Result<()> { + let body = deny_response_body(req, policy_name, reason, redacted_target, context); + let body_bytes = body.to_string(); + let response = format!( + "HTTP/1.1 403 Forbidden\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + X-OpenShell-Policy: {}\r\n\ + Connection: close\r\n\ + \r\n\ + {}", + body_bytes.len(), + policy_name, + body_bytes, + ); + client + .write_all(response.as_bytes()) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + Ok(()) +} + +fn deny_response_body( + req: &L7Request, + policy_name: &str, + reason: &str, + redacted_target: Option<&str>, + context: Option>, +) -> serde_json::Value { + let target = redacted_target.unwrap_or(&req.target); + let context = context.unwrap_or_default(); + let host = non_empty(context.host); + let binary = non_empty(context.binary); + + let mut rule_missing = serde_json::Map::new(); + rule_missing.insert("type".to_string(), serde_json::json!("rest_allow")); + rule_missing.insert("layer".to_string(), serde_json::json!("l7")); + rule_missing.insert("method".to_string(), serde_json::json!(req.action)); + rule_missing.insert("path".to_string(), serde_json::json!(target)); + if let Some(host) = host { + rule_missing.insert("host".to_string(), serde_json::json!(host)); + } + if let Some(port) = context.port { + rule_missing.insert("port".to_string(), serde_json::json!(port)); + } + if let Some(binary) = binary { + rule_missing.insert("binary".to_string(), serde_json::json!(binary)); + } + + let mut body = serde_json::Map::new(); + body.insert("error".to_string(), serde_json::json!("policy_denied")); + body.insert("policy".to_string(), serde_json::json!(policy_name)); + body.insert( + "rule".to_string(), + serde_json::json!(format!("{} {}", req.action, target)), + ); + body.insert("detail".to_string(), serde_json::json!(reason)); + body.insert("layer".to_string(), serde_json::json!("l7")); + body.insert("protocol".to_string(), serde_json::json!("rest")); + body.insert("method".to_string(), serde_json::json!(req.action)); + body.insert("path".to_string(), serde_json::json!(target)); + if let Some(host) = host { + body.insert("host".to_string(), serde_json::json!(host)); + } + if let Some(port) = context.port { + body.insert("port".to_string(), serde_json::json!(port)); + } + if let Some(binary) = binary { + body.insert("binary".to_string(), serde_json::json!(binary)); + } + body.insert( + "rule_missing".to_string(), + serde_json::Value::Object(rule_missing), + ); + // `next_steps` is generated by the policy_local module so the wire URLs + // and the on-disk skill path stay in sync with the route table. Adding + // or renaming a route only requires touching the constants in that + // module; this side picks up the change automatically. + body.insert( + "next_steps".to_string(), + crate::policy_local::agent_next_steps(), + ); + + serde_json::Value::Object(body) +} + +fn non_empty(value: Option<&str>) -> Option<&str> { + value.map(str::trim).filter(|value| !value.is_empty()) +} + +/// Parse Content-Length or Transfer-Encoding from HTTP headers. +/// +/// Per RFC 7230 Section 3.3.3, rejects requests containing both +/// `Content-Length` and `Transfer-Encoding` headers to prevent request +/// smuggling via CL/TE ambiguity. +pub(crate) fn parse_body_length(headers: &str) -> Result { + let mut has_te_chunked = false; + let mut cl_value: Option = None; + + for line in headers.lines().skip(1) { + let lower = line.to_ascii_lowercase(); + if lower.starts_with("transfer-encoding:") { + let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); + if val.split(',').any(|enc| enc.trim() == "chunked") { + has_te_chunked = true; + } + } + if lower.starts_with("content-length:") { + let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); + let len: u64 = val + .parse() + .map_err(|_| miette!("Request contains invalid Content-Length value"))?; + if let Some(prev) = cl_value + && prev != len + { + return Err(miette!( + "Request contains multiple Content-Length headers with differing values ({prev} vs {len})" + )); + } + cl_value = Some(len); + } + } + + if has_te_chunked && cl_value.is_some() { + return Err(miette!( + "Request contains both Transfer-Encoding and Content-Length headers" + )); + } + + if has_te_chunked { + return Ok(BodyLength::Chunked); + } + if let Some(len) = cl_value { + return Ok(BodyLength::ContentLength(len)); + } + Ok(BodyLength::None) +} + +/// Relay exactly `len` bytes from reader to writer. +async fn relay_fixed( + reader: &mut R, + writer: &mut W, + len: u64, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut remaining = len; + let mut buf = [0u8; RELAY_BUF_SIZE]; + while remaining > 0 { + let to_read = usize::try_from(remaining) + .unwrap_or(buf.len()) + .min(buf.len()); + let n = reader.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!( + "Connection closed with {remaining} bytes remaining" + )); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + writer.write_all(&buf[..n]).await.into_diagnostic()?; + remaining -= n as u64; + } + Ok(()) +} + +/// Relay chunked transfer encoding from reader to writer. +/// +/// Copies bytes verbatim (preserving chunk framing) while parsing the stream +/// boundaries so we can stop exactly at the end of the current message body. +/// Handles chunk extensions and trailers per RFC 7230. +/// +/// `already_forwarded` are overflow bytes that were already written to the +/// writer during header parsing. They are seeded into the parser buffer so +/// termination can still be detected when boundaries span reads. +async fn relay_chunked( + reader: &mut R, + writer: &mut W, + already_forwarded: &[u8], + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let started_at = std::time::Instant::now(); + let mut read_buf = [0u8; RELAY_BUF_SIZE]; + let mut parse_buf = Vec::from(already_forwarded); + let mut pos = 0usize; + let mut chunk_count = 0usize; + let mut chunk_payload_bytes = 0usize; + + // Parse chunk-size lines + chunk payloads until final 0-size chunk, then + // parse trailers until the terminating empty trailer line. + loop { + // Parse one chunk size line: "[;extensions]\r\n" + let size_line_end = loop { + if let Some(end) = find_crlf(&parse_buf, pos) { + break end; + } + let n = reader.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended before chunk-size line")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + writer.write_all(&read_buf[..n]).await.into_diagnostic()?; + parse_buf.extend_from_slice(&read_buf[..n]); + }; + + let size_line = std::str::from_utf8(&parse_buf[pos..size_line_end]) + .into_diagnostic() + .map_err(|_| miette!("Invalid UTF-8 in chunk-size line"))?; + let size_token = size_line + .split(';') + .next() + .map(str::trim) + .unwrap_or_default(); + let chunk_size = usize::from_str_radix(size_token, 16) + .into_diagnostic() + .map_err(|_| miette!("Invalid chunk size token: {size_token:?}"))?; + pos = size_line_end + 2; + + if chunk_size == 0 { + // Parse trailers (if any). Terminates on empty trailer line. + let mut trailer_count = 0usize; + loop { + let trailer_end = loop { + if let Some(end) = find_crlf(&parse_buf, pos) { + break end; + } + let n = reader.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended before trailer terminator")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + writer.write_all(&read_buf[..n]).await.into_diagnostic()?; + parse_buf.extend_from_slice(&read_buf[..n]); + }; + + let trailer_line = &parse_buf[pos..trailer_end]; + pos = trailer_end + 2; + if trailer_line.is_empty() { + debug!( + chunk_count, + chunk_payload_bytes, + trailer_count, + elapsed_ms = started_at.elapsed().as_millis(), + "relay_chunked complete" + ); + return Ok(()); + } + trailer_count += 1; + } + } + + // Ensure the full chunk payload + trailing CRLF is available. + let chunk_end = pos + .checked_add(chunk_size) + .ok_or_else(|| miette!("Chunk size overflow"))?; + let chunk_with_crlf_end = chunk_end + .checked_add(2) + .ok_or_else(|| miette!("Chunk size overflow"))?; + + while parse_buf.len() < chunk_with_crlf_end { + let n = reader.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended mid-chunk")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + writer.write_all(&read_buf[..n]).await.into_diagnostic()?; + parse_buf.extend_from_slice(&read_buf[..n]); + } + if &parse_buf[chunk_end..chunk_with_crlf_end] != b"\r\n" { + return Err(miette!("Chunk missing terminating CRLF")); + } + pos = chunk_with_crlf_end; + chunk_count += 1; + chunk_payload_bytes = chunk_payload_bytes.saturating_add(chunk_size); + + // Keep parser memory bounded for long streams. + if pos > RELAY_BUF_SIZE * 4 { + parse_buf.drain(..pos); + pos = 0; + } + } +} + +fn find_crlf(buf: &[u8], start: usize) -> Option { + buf.get(start..)? + .windows(2) + .position(|w| w == b"\r\n") + .map(|offset| start + offset) +} + +#[derive(Clone)] +struct RelayResponseOptions { + websocket_extensions: WebSocketExtensionMode, + client_requested_upgrade: bool, + websocket: Option, +} + +impl Default for RelayResponseOptions { + fn default() -> Self { + Self { + websocket_extensions: WebSocketExtensionMode::Preserve, + client_requested_upgrade: true, + websocket: None, + } + } +} + +async fn relay_response( + request_method: &str, + upstream: &mut U, + client: &mut C, + options: RelayResponseOptions, +) -> Result +where + U: AsyncRead + Unpin, + C: AsyncWrite + Unpin, +{ + let started_at = std::time::Instant::now(); + let mut buf = Vec::with_capacity(4096); + let mut tmp = [0u8; 1024]; + + // Read response headers + loop { + if buf.len() > MAX_HEADER_BYTES { + return Err(miette!("HTTP response headers exceed limit")); + } + + let n = upstream.read(&mut tmp).await.into_diagnostic()?; + if n == 0 { + // Upstream closed — forward whatever we have + if !buf.is_empty() { + client.write_all(&buf).await.into_diagnostic()?; + } + return Ok(RelayOutcome::Consumed); + } + buf.extend_from_slice(&tmp[..n]); + + if buf.windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + + let header_end = buf.windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4; + + // Parse response framing + let header_str = String::from_utf8_lossy(&buf[..header_end]); + let status_code = parse_status_code(&header_str).unwrap_or(200); + let server_wants_close = parse_connection_close(&header_str); + let body_length = parse_body_length(&header_str)?; + + debug!( + status_code, + ?body_length, + server_wants_close, + request_method, + overflow_bytes = buf.len() - header_end, + "relay_response framing" + ); + + // 101 Switching Protocols: the connection has been upgraded (e.g. to + // WebSocket). Forward the 101 headers to the client and signal the + // caller to switch to raw bidirectional TCP relay. Any bytes read + // from upstream beyond the headers are overflow that belong to the + // upgraded protocol and must be forwarded before switching. + if status_code == 101 { + if !options.client_requested_upgrade { + return Ok(RelayOutcome::Consumed); + } + let websocket_permessage_deflate = validate_websocket_response( + &header_str, + options.websocket_extensions, + options.websocket.as_ref(), + )?; + client + .write_all(&buf[..header_end]) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + let overflow = buf[header_end..].to_vec(); + debug!( + request_method, + overflow_bytes = overflow.len(), + "101 Switching Protocols — signaling protocol upgrade" + ); + return Ok(RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + }); + } + + // Bodiless responses (HEAD, 1xx, 204, 304): forward headers only, skip body + if is_bodiless_response(request_method, status_code) { + client + .write_all(&buf[..header_end]) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + return if server_wants_close { + Ok(RelayOutcome::Consumed) + } else { + Ok(RelayOutcome::Reusable) + }; + } + + // No explicit framing (no Content-Length, no Transfer-Encoding). + // Per RFC 7230 §3.3.3 the body is delimited by connection close. + if matches!(body_length, BodyLength::None) { + if server_wants_close { + // Server indicated it will close — read until EOF. + let before_end = &buf[..header_end - 2]; + client.write_all(before_end).await.into_diagnostic()?; + client + .write_all(b"Connection: close\r\n\r\n") + .await + .into_diagnostic()?; + let overflow = &buf[header_end..]; + if !overflow.is_empty() { + client.write_all(overflow).await.into_diagnostic()?; + } + relay_until_eof(upstream, client).await?; + client.flush().await.into_diagnostic()?; + return Ok(RelayOutcome::Consumed); + } + // No Connection: close — an HTTP/1.1 keep-alive server that omits + // framing headers has an empty body. Forward headers and continue + // the relay loop instead of blocking on relay_until_eof. + debug!("BodyLength::None without Connection: close — treating body as empty"); + client + .write_all(&buf[..header_end]) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + return Ok(RelayOutcome::Reusable); + } + + // Forward response headers + any overflow body bytes + client.write_all(&buf).await.into_diagnostic()?; + let overflow_len = (buf.len() - header_end) as u64; + + // Forward remaining response body + match body_length { + BodyLength::ContentLength(len) => { + let remaining = len.saturating_sub(overflow_len); + if remaining > 0 { + relay_fixed(upstream, client, remaining, None).await?; + } + } + BodyLength::Chunked => { + relay_chunked(upstream, client, &buf[header_end..], None).await?; + } + BodyLength::None => unreachable!(), + } + client.flush().await.into_diagnostic()?; + debug!( + request_method, + elapsed_ms = started_at.elapsed().as_millis(), + "relay_response complete (explicit framing)" + ); + + // When body framing is explicit (Content-Length / Chunked), always report + // the connection as reusable so the relay loop continues. If the server + // sent `Connection: close`, the *next* upstream write will fail and the + // loop will exit via the normal error path. Exiting early here would + // tear down the CONNECT tunnel before the client can detect the close, + // causing ~30 s retry delays in clients like `gh`. + Ok(RelayOutcome::Reusable) +} + +/// Parse the HTTP status code from a response status line. +/// +/// Expects the first line to look like `HTTP/1.1 200 OK`. +fn parse_status_code(headers: &str) -> Option { + let status_line = headers.lines().next()?; + let code_str = status_line.split_whitespace().nth(1)?; + code_str.parse().ok() +} + +/// Check if the response headers contain `Connection: close`. +fn parse_connection_close(headers: &str) -> bool { + for line in headers.lines().skip(1) { + let lower = line.to_ascii_lowercase(); + if lower.starts_with("connection:") { + let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); + return val.contains("close"); + } + } + false +} + +fn validate_websocket_response( + headers: &str, + mode: WebSocketExtensionMode, + websocket: Option<&WebSocketResponseValidation>, +) -> Result { + let Some(validation) = websocket else { + return validate_websocket_response_extensions_preserved(headers, mode); + }; + + let mut upgrade_websocket = false; + let mut connection_upgrade = false; + let mut accept_count = 0usize; + let mut accept_matches = false; + let mut subprotocol_count = 0usize; + let mut selected_subprotocol = None; + + for line in headers.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + let value = value.trim(); + match name.as_str() { + "upgrade" if header_value_contains_token(value, "websocket") => { + upgrade_websocket = true; + } + "connection" if header_value_contains_token(value, "upgrade") => { + connection_upgrade = true; + } + "sec-websocket-accept" => { + accept_count += 1; + accept_matches = value == validation.expected_accept; + } + "sec-websocket-protocol" => { + subprotocol_count += 1; + if !is_http_token(value) { + return Err(miette!( + "websocket upgrade response has invalid Sec-WebSocket-Protocol" + )); + } + selected_subprotocol = Some(value.to_string()); + } + _ => {} + } + } + + if !upgrade_websocket { + return Err(miette!( + "websocket upgrade response missing Upgrade: websocket" + )); + } + if !connection_upgrade { + return Err(miette!( + "websocket upgrade response missing Connection: Upgrade" + )); + } + if accept_count != 1 || !accept_matches { + return Err(miette!( + "websocket upgrade response has invalid Sec-WebSocket-Accept" + )); + } + if subprotocol_count > 1 { + return Err(miette!( + "websocket upgrade response has multiple Sec-WebSocket-Protocol headers" + )); + } + if let Some(protocol) = selected_subprotocol + && !validation + .offered_subprotocols + .iter() + .any(|offered| offered == &protocol) + { + return Err(miette!( + "upstream selected WebSocket subprotocol that was not offered" + )); + } + + let actual_extension = normalized_websocket_extension(headers)?; + match (&validation.expected_extension, actual_extension.as_deref()) { + (None, Some(_)) => Err(miette!( + "upstream negotiated WebSocket extension that was not offered" + )), + (None | Some(_), None) => Ok(false), + (Some(expected), Some(actual)) if expected.eq_ignore_ascii_case(actual) => Ok(true), + (Some(_), Some(_)) => Err(miette!( + "upstream negotiated WebSocket extension that does not match the safe offer" + )), + } +} + +fn validate_websocket_response_extensions_preserved( + headers: &str, + mode: WebSocketExtensionMode, +) -> Result { + match mode { + WebSocketExtensionMode::Preserve => Ok(false), + WebSocketExtensionMode::PermessageDeflate => { + let offers = websocket_extension_offers(headers)?; + if offers.is_empty() { + Ok(false) + } else { + Err(miette!( + "upstream negotiated WebSocket extension that was not offered" + )) + } + } + } +} + +fn normalized_websocket_extension(headers: &str) -> Result> { + let offers = websocket_extension_offers(headers)?; + if offers.is_empty() { + return Ok(None); + } + if offers.len() != 1 { + return Err(miette!("upstream negotiated multiple WebSocket extensions")); + } + let offer = &offers[0]; + if !offer.name.eq_ignore_ascii_case("permessage-deflate") { + return Err(miette!( + "upstream negotiated unsupported WebSocket extension" + )); + } + let mut client_no_context_takeover = false; + let mut server_no_context_takeover = false; + let mut seen = HashSet::new(); + for param in &offer.params { + let name = param.name.to_ascii_lowercase(); + if param.value.is_some() || !seen.insert(name.clone()) { + return Err(miette!( + "upstream negotiated unsupported permessage-deflate parameter" + )); + } + if name == "client_no_context_takeover" { + client_no_context_takeover = true; + } else if name == "server_no_context_takeover" { + server_no_context_takeover = true; + } else { + return Err(miette!( + "upstream negotiated unsupported permessage-deflate parameter" + )); + } + } + let mut normalized = String::from("permessage-deflate"); + if client_no_context_takeover { + normalized.push_str("; client_no_context_takeover"); + } + if server_no_context_takeover { + normalized.push_str("; server_no_context_takeover"); + } + Ok(Some(normalized)) +} + +/// Check if the client request headers contain both `Upgrade` and +/// `Connection: Upgrade` headers, indicating the client requested a +/// protocol upgrade (e.g. WebSocket). +/// +/// Per RFC 9110 Section 7.8, a server MUST NOT send 101 Switching Protocols +/// unless the client sent these headers. +fn client_requested_upgrade(headers: &str) -> bool { + let mut has_upgrade_header = false; + let mut connection_contains_upgrade = false; + + for line in headers.lines().skip(1) { + let lower = line.to_ascii_lowercase(); + if lower.starts_with("upgrade:") { + has_upgrade_header = true; + } + if lower.starts_with("connection:") { + let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); + // Connection header can have comma-separated values + if val.split(',').any(|tok| tok.trim() == "upgrade") { + connection_contains_upgrade = true; + } + } + } + + has_upgrade_header && connection_contains_upgrade +} + +/// Returns true for responses that MUST NOT contain a message body per RFC 7230 §3.3.3: +/// HEAD responses, 1xx informational, 204 No Content, 304 Not Modified. +fn is_bodiless_response(request_method: &str, status_code: u16) -> bool { + request_method.eq_ignore_ascii_case("HEAD") + || (100..200).contains(&status_code) + || status_code == 204 + || status_code == 304 +} + +/// Relay all bytes from reader to writer until EOF or idle timeout. /// /// Used for HTTP responses with no explicit framing (no Content-Length, /// no Transfer-Encoding) where the body is delimited by connection close. @@ -961,21 +1987,404 @@ fn is_benign_close(err: &std::io::Error) -> bool { ) } -#[cfg(test)] -#[allow( - clippy::iter_on_single_items, - clippy::manual_string_new, - clippy::collapsible_if, - clippy::cast_possible_truncation, - reason = "Test code: test fixtures and explicit value-shape assertions are idiomatic in tests." -)] -mod tests { - use super::*; - use crate::opa::OpaEngine; - use crate::secrets::SecretResolver; - use base64::Engine as _; +#[cfg(test)] +#[allow( + clippy::iter_on_single_items, + clippy::manual_string_new, + clippy::collapsible_if, + clippy::cast_possible_truncation, + reason = "Test code: test fixtures and explicit value-shape assertions are idiomatic in tests." +)] +mod tests { + use super::*; + use crate::opa::OpaEngine; + use crate::secrets::SecretResolver; + use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; + use std::sync::Arc; + + const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + const VALID_WS_KEY: &str = "dGhlIHNhbXBsZSBub25jZQ=="; + const VALID_WS_ACCEPT: &str = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="; + const TEXT_OPCODE: u8 = 0x1; + + #[derive(Debug)] + struct CapturedFrame { + fin_opcode: u8, + masked: bool, + payload: Vec, + } + + async fn read_http_header_block(reader: &mut R) -> Vec { + tokio::time::timeout(std::time::Duration::from_secs(2), async { + let mut header = Vec::new(); + let mut byte = [0u8; 1]; + loop { + reader.read_exact(&mut byte).await.unwrap(); + header.push(byte[0]); + if header.ends_with(b"\r\n\r\n") { + break; + } + } + header + }) + .await + .expect("HTTP header block should arrive") + } + + async fn read_websocket_frame(reader: &mut R) -> CapturedFrame { + tokio::time::timeout(std::time::Duration::from_secs(2), async { + let mut prefix = [0u8; 2]; + reader.read_exact(&mut prefix).await.unwrap(); + let masked = prefix[1] & 0x80 != 0; + let mut payload_len = u64::from(prefix[1] & 0x7f); + if payload_len == 126 { + let mut extended = [0u8; 2]; + reader.read_exact(&mut extended).await.unwrap(); + payload_len = u64::from(u16::from_be_bytes(extended)); + } else if payload_len == 127 { + let mut extended = [0u8; 8]; + reader.read_exact(&mut extended).await.unwrap(); + payload_len = u64::from_be_bytes(extended); + } + let mut mask_key = [0u8; 4]; + if masked { + reader.read_exact(&mut mask_key).await.unwrap(); + } + let payload_len = usize::try_from(payload_len).unwrap(); + let mut payload = vec![0u8; payload_len]; + reader.read_exact(&mut payload).await.unwrap(); + if masked { + apply_test_mask(&mut payload, mask_key); + } + CapturedFrame { + fin_opcode: prefix[0], + masked, + payload, + } + }) + .await + .expect("WebSocket frame should arrive") + } + + fn masked_frame_with_rsv(opcode: u8, rsv: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x80 | rsv | opcode); + write_test_payload_len(&mut frame, 0x80, payload.len()); + frame.extend_from_slice(&mask_key); + let mut masked = payload.to_vec(); + apply_test_mask(&mut masked, mask_key); + frame.extend_from_slice(&masked); + frame + } + + fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + write_test_payload_len(&mut frame, 0, payload.len()); + frame.extend_from_slice(payload); + frame + } + + fn write_test_payload_len(frame: &mut Vec, mask_bit: u8, payload_len: usize) { + if payload_len < 126 { + frame.push(mask_bit | payload_len as u8); + } else if u16::try_from(payload_len).is_ok() { + frame.push(mask_bit | 0x7e); + frame.extend_from_slice(&(payload_len as u16).to_be_bytes()); + } else { + frame.push(mask_bit | 0x7f); + frame.extend_from_slice(&(payload_len as u64).to_be_bytes()); + } + } + + fn apply_test_mask(payload: &mut [u8], mask_key: [u8; 4]) { + for (index, byte) in payload.iter_mut().enumerate() { + *byte ^= mask_key[index % 4]; + } + } + + fn compress_test_permessage_deflate(payload: &[u8]) -> Vec { + let mut compressor = Compress::new(Compression::fast(), false); + let mut out = Vec::with_capacity(payload.len().saturating_add(128)); + loop { + let consumed = usize::try_from(compressor.total_in()).unwrap(); + if consumed >= payload.len() { + break; + } + let before_in = compressor.total_in(); + let before_out = compressor.total_out(); + let status = compressor + .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) + .unwrap(); + if matches!(status, Status::BufError) + || (compressor.total_in() == before_in && compressor.total_out() == before_out) + { + out.reserve(out.capacity().max(1024)); + } + } + loop { + out.reserve(64); + let before_out = compressor.total_out(); + compressor + .compress_vec(&[], &mut out, FlushCompress::Sync) + .unwrap(); + if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + break; + } + if compressor.total_out() == before_out { + out.reserve(out.capacity().max(1024)); + } + } + out.truncate(out.len() - 4); + out + } + + fn decompress_test_permessage_deflate(payload: &[u8]) -> Vec { + let mut decoder = Decompress::new(false); + let mut input = Vec::with_capacity(payload.len() + 4); + input.extend_from_slice(payload); + input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); + let mut out = Vec::new(); + let mut input_pos = 0usize; + let mut scratch = [0u8; RELAY_BUF_SIZE]; + loop { + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let status = decoder + .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) + .unwrap(); + let read = usize::try_from(decoder.total_in() - before_in).unwrap(); + let written = usize::try_from(decoder.total_out() - before_out).unwrap(); + input_pos += read; + out.extend_from_slice(&scratch[..written]); + if matches!(status, Status::StreamEnd) { + break; + } + if input_pos >= input.len() && written < scratch.len() { + break; + } + assert!( + read != 0 || written != 0, + "test permessage-deflate decompression did not make progress" + ); + } + out + } + + fn websocket_request(extension: Option<&str>) -> L7Request { + let mut raw_header = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\n" + ); + if let Some(extension) = extension { + raw_header.push_str("Sec-WebSocket-Extensions: "); + raw_header.push_str(extension); + raw_header.push_str("\r\n"); + } + raw_header.push_str("Sec-WebSocket-Version: 13\r\n\r\n"); + L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: raw_header.into_bytes(), + body_length: BodyLength::None, + } + } + + async fn run_upgraded_websocket_case( + request_extension: Option<&'static str>, + response_extension: Option<&'static str>, + extension_mode: WebSocketExtensionMode, + resolver: Option>, + client_frame: Vec, + ) -> (String, CapturedFrame) { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(16384); + let (mut client_app, mut proxy_to_client) = tokio::io::duplex(16384); + let req = websocket_request(request_extension); + let resolver_for_header = resolver.clone(); + let resolver_for_upgrade = resolver.clone(); + + let upstream_task = tokio::spawn(async move { + let forwarded = read_http_header_block(&mut upstream_side).await; + let forwarded = String::from_utf8(forwarded).unwrap(); + let mut response = format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\n" + ); + if let Some(extension) = response_extension { + response.push_str("Sec-WebSocket-Extensions: "); + response.push_str(extension); + response.push_str("\r\n"); + } + response.push_str("\r\n"); + upstream_side.write_all(response.as_bytes()).await.unwrap(); + upstream_side.flush().await.unwrap(); + let frame = read_websocket_frame(&mut upstream_side).await; + (forwarded, frame) + }); + + let relay_task = tokio::spawn(async move { + let outcome = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: resolver_for_header.as_deref(), + websocket_extensions: extension_mode, + ..Default::default() + }, + ) + .await + .expect("handshake relay should succeed"); + let RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + else { + panic!("expected upgraded relay outcome"); + }; + let credential_rewrite = resolver_for_upgrade.is_some(); + crate::l7::relay::handle_upgrade( + &mut proxy_to_client, + &mut proxy_to_upstream, + overflow, + "example.com", + 443, + crate::l7::relay::UpgradeRelayOptions { + websocket_request: true, + websocket: crate::l7::relay::WebSocketUpgradeBehavior { + credential_rewrite, + permessage_deflate: websocket_permessage_deflate, + ..Default::default() + }, + secret_resolver: resolver_for_upgrade, + target: "/ws".to_string(), + policy_name: "test-policy".to_string(), + ..Default::default() + }, + ) + .await + }); + + let response = read_http_header_block(&mut client_app).await; + assert!( + String::from_utf8_lossy(&response).contains("101 Switching Protocols"), + "client must receive the upgrade before frame relay starts" + ); + client_app.write_all(&client_frame).await.unwrap(); + client_app.flush().await.unwrap(); + + let result = upstream_task.await.expect("upstream task should complete"); + drop(client_app); + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay_task).await; + result + } + + #[test] + fn deny_response_body_is_agent_readable_and_redacted() { + // Agent-readable next_steps is gated on the proposals feature flag. + let _proposals = crate::test_helpers::ProposalsFlagGuard::set_blocking(true); + let req = L7Request { + action: "PUT".to_string(), + target: "/repos/NVIDIA/OpenShell/contents/README.md?access_token=secret-token" + .to_string(), + query_params: HashMap::new(), + raw_header: Vec::new(), + body_length: BodyLength::ContentLength(128), + }; + + let body = deny_response_body( + &req, + "github-readonly", + "no matching L7 allow rule", + Some("/repos/NVIDIA/OpenShell/contents/README.md"), + Some(DenyResponseContext { + host: Some("api.github.com"), + port: Some(443), + binary: Some("/usr/bin/gh"), + }), + ); - const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + assert_eq!(body["error"], "policy_denied"); + assert_eq!(body["policy"], "github-readonly"); + assert_eq!(body["layer"], "l7"); + assert_eq!(body["protocol"], "rest"); + assert_eq!(body["method"], "PUT"); + assert_eq!(body["host"], "api.github.com"); + assert_eq!(body["port"], 443); + assert_eq!(body["binary"], "/usr/bin/gh"); + assert_eq!(body["path"], "/repos/NVIDIA/OpenShell/contents/README.md"); + assert_eq!( + body["rule"], + "PUT /repos/NVIDIA/OpenShell/contents/README.md" + ); + assert_eq!(body["rule_missing"]["type"], "rest_allow"); + assert_eq!(body["rule_missing"]["layer"], "l7"); + assert_eq!(body["rule_missing"]["method"], "PUT"); + assert_eq!( + body["rule_missing"]["path"], + "/repos/NVIDIA/OpenShell/contents/README.md" + ); + assert_eq!(body["rule_missing"]["host"], "api.github.com"); + assert_eq!(body["rule_missing"]["port"], 443); + assert_eq!(body["rule_missing"]["binary"], "/usr/bin/gh"); + assert_eq!(body["next_steps"][0]["action"], "read_skill"); + assert_eq!( + body["next_steps"][0]["path"], + "/etc/openshell/skills/policy_advisor.md" + ); + assert_eq!(body["next_steps"][3]["body_type"], "PolicyMergeOperation"); + assert!( + !body.to_string().contains("secret-token"), + "deny body must not leak query params or credential values" + ); + } + + #[tokio::test] + async fn send_deny_response_writes_structured_json_403() { + // Agent-readable next_steps is gated on the proposals feature flag. + let _proposals = crate::test_helpers::ProposalsFlagGuard::set(true).await; + let (mut client, mut server) = tokio::io::duplex(4096); + let send = tokio::spawn(async move { + let req = L7Request { + action: "POST".to_string(), + target: "/user/repos".to_string(), + query_params: HashMap::new(), + raw_header: Vec::new(), + body_length: BodyLength::ContentLength(64), + }; + send_deny_response( + &req, + "github-readonly", + "no matching L7 allow rule", + &mut server, + None, + Some(DenyResponseContext { + host: Some("api.github.com"), + port: Some(443), + binary: Some("/usr/bin/gh"), + }), + ) + .await + .unwrap(); + }); + + let mut received = Vec::new(); + client.read_to_end(&mut received).await.unwrap(); + send.await.unwrap(); + + let response = String::from_utf8(received).unwrap(); + assert!(response.starts_with("HTTP/1.1 403 Forbidden")); + assert!(response.contains("Content-Type: application/json")); + assert!(response.contains("X-OpenShell-Policy: github-readonly")); + + let (_, body) = response.split_once("\r\n\r\n").unwrap(); + let body: serde_json::Value = serde_json::from_str(body).unwrap(); + assert_eq!(body["error"], "policy_denied"); + assert_eq!(body["method"], "POST"); + assert_eq!(body["path"], "/user/repos"); + assert_eq!(body["rule_missing"]["host"], "api.github.com"); + assert_eq!(body["next_steps"][2]["action"], "inspect_recent_denials"); + } #[test] fn parse_content_length() { @@ -1531,7 +2940,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay_response should not deadlock"); @@ -1572,7 +2986,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("must not block when no Connection: close"); @@ -1608,7 +3027,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("HEAD", &mut upstream_read, &mut client_write), + relay_response( + "HEAD", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("HEAD relay must not deadlock waiting for body"); @@ -1623,281 +3047,706 @@ mod tests { let mut received = Vec::new(); client_read.read_to_end(&mut received).await.unwrap(); let received_str = String::from_utf8_lossy(&received); - assert!(received_str.contains("200 OK")); - // Should NOT contain body bytes - assert!(!received_str.contains('\0')); + assert!(received_str.contains("200 OK")); + // Should NOT contain body bytes + assert!(!received_str.contains('\0')); + } + + #[tokio::test] + async fn relay_response_204_no_body() { + let response = b"HTTP/1.1 204 No Content\r\nServer: test\r\n\r\n"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("204 relay must not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "204 response should be reusable" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + assert!(String::from_utf8_lossy(&received).contains("204 No Content")); + } + + #[tokio::test] + async fn relay_response_chunked_body_complete_in_overflow() { + // Entire chunked body (including terminal 0\r\n\r\n) arrives with + // headers in the same read. relay_chunked must NOT be called or it + // will block forever waiting for data that was already consumed. + let response = + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + // Do NOT close — if relay_chunked is called it will block forever + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("must not block when chunked body is complete in overflow"); + + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "connection should be reusable" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!( + received_str.contains("hello"), + "chunked body should be forwarded" + ); + } + + #[tokio::test] + async fn relay_response_chunked_with_trailers_does_not_wait_for_eof() { + // Last-chunk can be followed by trailers, so body terminator is not + // always literal "0\r\n\r\n". We must stop at final empty trailer + // line without waiting for upstream connection close. + let response = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\nx-checksum: abc123\r\n\r\n"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + // Keep stream open to ensure relay terminates by framing, not EOF. + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("must not block when chunked response has trailers"); + + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "chunked response should be reusable" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!( + received_str.contains("hello"), + "chunked body should be forwarded" + ); + assert!( + received_str.contains("x-checksum: abc123"), + "trailers should be forwarded" + ); + } + + #[tokio::test] + async fn relay_response_normal_content_length() { + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("normal relay must not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "Content-Length response should be reusable" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!(received_str.contains("hello")); + } + + #[tokio::test] + async fn relay_response_connection_close_with_content_length() { + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: close\r\n\r\nhello"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("relay must not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + // With explicit framing, Connection: close is still reported as reusable + // so the relay loop continues. The *next* upstream write will fail and + // exit the loop via the normal error path. + assert!( + matches!(outcome, RelayOutcome::Reusable), + "explicit framing keeps loop alive despite Connection: close" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + assert!(String::from_utf8_lossy(&received).contains("hello")); + } + + #[tokio::test] + async fn relay_response_101_switching_protocols_returns_upgraded_with_overflow() { + // Build a 101 response followed by WebSocket frame data (overflow). + let mut response = Vec::new(); + response.extend_from_slice(b"HTTP/1.1 101 Switching Protocols\r\n"); + response.extend_from_slice(b"Upgrade: websocket\r\n"); + response.extend_from_slice(b"Connection: Upgrade\r\n"); + response.extend_from_slice(b"\r\n"); + response.extend_from_slice(b"\x81\x05hello"); // WebSocket frame + + let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, client_write) = tokio::io::duplex(4096); + + upstream_write.write_all(&response).await.unwrap(); + drop(upstream_write); + + let mut upstream_read = upstream_read; + let mut client_write = client_write; + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("relay_response should not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + match outcome { + RelayOutcome::Upgraded { overflow, .. } => { + assert_eq!( + &overflow, b"\x81\x05hello", + "overflow should contain WebSocket frame data" + ); + } + other => panic!("Expected Upgraded, got {other:?}"), + } + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!( + received_str.contains("101 Switching Protocols"), + "client should receive the 101 response headers" + ); } #[tokio::test] - async fn relay_response_204_no_body() { - let response = b"HTTP/1.1 204 No Content\r\nServer: test\r\n\r\n"; + async fn relay_response_101_no_overflow() { + // 101 response with no trailing bytes — overflow should be empty. + let response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (_client_read, client_write) = tokio::io::duplex(4096); - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); - }); + upstream_write.write_all(response).await.unwrap(); + drop(upstream_write); + + let mut upstream_read = upstream_read; + let mut client_write = client_write; let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await - .expect("204 relay must not deadlock"); - - let outcome = result.expect("relay_response should succeed"); - assert!( - matches!(outcome, RelayOutcome::Reusable), - "204 response should be reusable" - ); + .expect("relay_response should not deadlock"); - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - assert!(String::from_utf8_lossy(&received).contains("204 No Content")); + match result.expect("should succeed") { + RelayOutcome::Upgraded { overflow, .. } => { + assert!(overflow.is_empty(), "no overflow expected"); + } + other => panic!("Expected Upgraded, got {other:?}"), + } } #[tokio::test] - async fn relay_response_chunked_body_complete_in_overflow() { - // Entire chunked body (including terminal 0\r\n\r\n) arrives with - // headers in the same read. relay_chunked must NOT be called or it - // will block forever waiting for data that was already consumed. - let response = - b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n"; + async fn relay_rejects_unsolicited_101_without_client_upgrade_header() { + // Client sends a normal GET without Upgrade headers. + // Upstream responds with 101 (non-compliant). The relay should + // reject the upgrade and return Consumed instead. + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let req = L7Request { + action: "GET".to_string(), + target: "/api".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); - // Do NOT close — if relay_chunked is called it will block forever + let upstream_task = tokio::spawn(async move { + // Read the request + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + // Send unsolicited 101 + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); }); let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + None, + ), ) .await - .expect("must not block when chunked body is complete in overflow"); + .expect("relay must not deadlock"); - let outcome = result.expect("relay_response should succeed"); + let outcome = result.expect("relay should succeed"); assert!( - matches!(outcome, RelayOutcome::Reusable), - "connection should be reusable" + matches!(outcome, RelayOutcome::Consumed), + "unsolicited 101 should be rejected as Consumed, got {outcome:?}" ); - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - let received_str = String::from_utf8_lossy(&received); - assert!( - received_str.contains("hello"), - "chunked body should be forwarded" - ); + upstream_task.await.expect("upstream task should complete"); } #[tokio::test] - async fn relay_response_chunked_with_trailers_does_not_wait_for_eof() { - // Last-chunk can be followed by trailers, so body terminator is not - // always literal "0\r\n\r\n". We must stop at final empty trailer - // line without waiting for upstream connection close. - let response = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\nx-checksum: abc123\r\n\r\n"; + async fn relay_accepts_101_with_client_upgrade_header() { + // Client sends a proper upgrade request with Upgrade + Connection headers. + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); - // Keep stream open to ensure relay terminates by framing, not EOF. - tokio::time::sleep(std::time::Duration::from_secs(10)).await; + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); }); let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + None, + ), ) .await - .expect("must not block when chunked response has trailers"); + .expect("relay must not deadlock"); - let outcome = result.expect("relay_response should succeed"); + let outcome = result.expect("relay should succeed"); assert!( - matches!(outcome, RelayOutcome::Reusable), - "chunked response should be reusable" + matches!(outcome, RelayOutcome::Upgraded { .. }), + "proper upgrade request should be accepted, got {outcome:?}" ); - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - let received_str = String::from_utf8_lossy(&received); + upstream_task.await.expect("upstream task should complete"); + } + + #[tokio::test] + async fn opted_in_websocket_relay_rejects_invalid_upgrade_before_upstream_write() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; + + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; + assert!( - received_str.contains("hello"), - "chunked body should be forwarded" + result.is_err(), + "missing Sec-WebSocket-Key must fail closed" ); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); assert!( - received_str.contains("x-checksum: abc123"), - "trailers should be forwarded" + forwarded.is_empty(), + "invalid opted-in upgrade must not reach upstream" ); } #[tokio::test] - async fn relay_response_normal_content_length() { - let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"; + async fn opted_in_websocket_relay_strips_request_extensions_and_rejects_response_extensions() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut app_side, mut proxy_to_client) = tokio::io::duplex(8192); - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), + body_length: BodyLength::None, + }; - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let forwarded = String::from_utf8_lossy(&buf[..total]); + assert!( + !forwarded + .to_ascii_lowercase() + .contains("sec-websocket-extensions"), + "opted-in request must strip extension negotiation" + ); + upstream_side + .write_all( + format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n" + ) + .as_bytes(), + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); }); - let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), - ) - .await - .expect("normal relay must not deadlock"); + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; - let outcome = result.expect("relay_response should succeed"); - assert!( - matches!(outcome, RelayOutcome::Reusable), - "Content-Length response should be reusable" - ); + let err = result.expect_err("upstream extension negotiation must fail closed"); + assert!(err.to_string().contains("not offered")); + upstream_task.await.expect("upstream task should complete"); - client_write.shutdown().await.unwrap(); + drop(proxy_to_client); let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - let received_str = String::from_utf8_lossy(&received); - assert!(received_str.contains("hello")); + app_side.read_to_end(&mut received).await.unwrap(); + assert!( + received.is_empty(), + "rejected extension negotiation must not forward 101 headers" + ); } #[tokio::test] - async fn relay_response_connection_close_with_content_length() { - let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: close\r\n\r\nhello"; + async fn permessage_deflate_mode_allows_supported_no_context_takeover() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), + body_length: BodyLength::None, + }; - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let forwarded = String::from_utf8_lossy(&buf[..total]).to_ascii_lowercase(); + assert!(forwarded.contains( + "sec-websocket-extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover" + )); + upstream_side + .write_all( + format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\n\r\n" + ) + .as_bytes(), + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); }); - let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + let outcome = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, ) .await - .expect("relay must not deadlock"); + .expect("safe permessage-deflate negotiation should pass"); - let outcome = result.expect("relay_response should succeed"); - // With explicit framing, Connection: close is still reported as reusable - // so the relay loop continues. The *next* upstream write will fail and - // exit the loop via the normal error path. assert!( - matches!(outcome, RelayOutcome::Reusable), - "explicit framing keeps loop alive despite Connection: close" + matches!( + outcome, + RelayOutcome::Upgraded { + websocket_permessage_deflate: true, + .. + } + ), + "safe permessage-deflate must be marked negotiated" ); - - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - assert!(String::from_utf8_lossy(&received).contains("hello")); + upstream_task.await.expect("upstream task should complete"); } #[tokio::test] - async fn relay_response_101_switching_protocols_returns_upgraded_with_overflow() { - // Build a 101 response followed by WebSocket frame data (overflow). - let mut response = Vec::new(); - response.extend_from_slice(b"HTTP/1.1 101 Switching Protocols\r\n"); - response.extend_from_slice(b"Upgrade: websocket\r\n"); - response.extend_from_slice(b"Connection: Upgrade\r\n"); - response.extend_from_slice(b"\r\n"); - response.extend_from_slice(b"\x81\x05hello"); // WebSocket frame - - let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, client_write) = tokio::io::duplex(4096); - - upstream_write.write_all(&response).await.unwrap(); - drop(upstream_write); - - let mut upstream_read = upstream_read; - let mut client_write = client_write; - - let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + async fn websocket_conformance_preserve_mode_relays_raw_frames_without_validation() { + let (forwarded, frame) = run_upgraded_websocket_case( + None, + None, + WebSocketExtensionMode::Preserve, + None, + unmasked_frame(TEXT_OPCODE, b"raw-unmasked"), ) - .await - .expect("relay_response should not deadlock"); - - let outcome = result.expect("relay_response should succeed"); - match outcome { - RelayOutcome::Upgraded { overflow } => { - assert_eq!( - &overflow, b"\x81\x05hello", - "overflow should contain WebSocket frame data" - ); - } - other => panic!("Expected Upgraded, got {other:?}"), - } + .await; - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - let received_str = String::from_utf8_lossy(&received); assert!( - received_str.contains("101 Switching Protocols"), - "client should receive the 101 response headers" + forwarded.contains("Upgrade: websocket"), + "raw preserve path should still forward the upgrade request" + ); + assert!( + !frame.masked, + "raw preserve path must not validate or rewrite client frame masking" ); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert_eq!(frame.payload, b"raw-unmasked"); } #[tokio::test] - async fn relay_response_101_no_overflow() { - // 101 response with no trailing bytes — overflow should be empty. - let response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; - - let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (_client_read, client_write) = tokio::io::duplex(4096); + async fn websocket_conformance_rewrite_mode_rewrites_text_after_upgrade() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); - upstream_write.write_all(response).await.unwrap(); - drop(upstream_write); + let (forwarded, frame) = run_upgraded_websocket_case( + None, + None, + WebSocketExtensionMode::PermessageDeflate, + resolver.map(Arc::new), + masked_frame_with_rsv(TEXT_OPCODE, 0, payload.as_bytes()), + ) + .await; - let mut upstream_read = upstream_read; - let mut client_write = client_write; + assert!( + !forwarded + .to_ascii_lowercase() + .contains("sec-websocket-extensions"), + "plain rewrite path should not offer compression when the client did not offer a safe subset" + ); + assert!(frame.masked, "parsed relay must preserve client masking"); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert_eq!( + String::from_utf8(frame.payload).unwrap(), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } - let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + #[tokio::test] + async fn websocket_conformance_deflate_rewrites_compressed_text_after_upgrade() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + let compressed = compress_test_permessage_deflate(payload.as_bytes()); + + let (forwarded, frame) = run_upgraded_websocket_case( + Some("permessage-deflate; server_no_context_takeover; client_no_context_takeover"), + Some("permessage-deflate; server_no_context_takeover; client_no_context_takeover"), + WebSocketExtensionMode::PermessageDeflate, + resolver.map(Arc::new), + masked_frame_with_rsv(TEXT_OPCODE, 0x40, &compressed), ) - .await - .expect("relay_response should not deadlock"); + .await; - match result.expect("should succeed") { - RelayOutcome::Upgraded { overflow } => { - assert!(overflow.is_empty(), "no overflow expected"); - } - other => panic!("Expected Upgraded, got {other:?}"), - } + assert!( + forwarded.to_ascii_lowercase().contains( + "sec-websocket-extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover" + ), + "safe extension offer should be canonicalized before forwarding" + ); + assert!(frame.masked, "parsed relay must preserve client masking"); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert!( + frame.fin_opcode & 0x40 != 0, + "rewritten compressed text must retain RSV1" + ); + assert_eq!( + String::from_utf8(decompress_test_permessage_deflate(&frame.payload)).unwrap(), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); } #[tokio::test] - async fn relay_rejects_unsolicited_101_without_client_upgrade_header() { - // Client sends a normal GET without Upgrade headers. - // Upstream responds with 101 (non-compliant). The relay should - // reject the upgrade and return Consumed instead. + async fn opted_in_websocket_relay_rejects_invalid_accept_before_forwarding_101() { let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); - let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + let (mut app_side, mut proxy_to_client) = tokio::io::duplex(8192); let req = L7Request { action: "GET".to_string(), - target: "/api".to_string(), + target: "/ws".to_string(), query_params: HashMap::new(), - raw_header: b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), body_length: BodyLength::None, }; let upstream_task = tokio::spawn(async move { - // Read the request let mut buf = vec![0u8; 4096]; let mut total = 0; loop { @@ -1910,92 +3759,249 @@ mod tests { break; } } - // Send unsolicited 101 upstream_side .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", ) .await .unwrap(); upstream_side.flush().await.unwrap(); }); - let result = tokio::time::timeout( - std::time::Duration::from_secs(5), - relay_http_request_with_resolver( - &req, - &mut proxy_to_client, - &mut proxy_to_upstream, - None, + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; + + let err = result.expect_err("invalid Sec-WebSocket-Accept must fail closed"); + assert!(err.to_string().contains("Sec-WebSocket-Accept")); + upstream_task.await.expect("upstream task should complete"); + + drop(proxy_to_client); + let mut received = Vec::new(); + app_side.read_to_end(&mut received).await.unwrap(); + assert!( + received.is_empty(), + "invalid websocket response must not forward 101 headers" + ); + } + + #[test] + fn websocket_accept_matches_rfc_6455_sample() { + assert_eq!(websocket_accept_for_key(VALID_WS_KEY), VALID_WS_ACCEPT); + } + + #[test] + fn strict_response_validation_rejects_missing_upgrade_headers() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("missing Upgrade/Connection must fail"); + + assert!(err.to_string().contains("Upgrade: websocket")); + } + + #[test] + fn permessage_deflate_response_must_match_exact_safe_offer() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some( + "permessage-deflate; client_no_context_takeover; server_no_context_takeover" + .to_string(), + ), + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("extension response must exactly match the safe offer"); + + assert!(err.to_string().contains("safe offer")); + } + + #[test] + fn permessage_deflate_offer_requires_client_no_context_takeover() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!( + supported_permessage_deflate_offer(&raw) + .expect("valid unsupported extension offer should parse") + .is_none() + ); + } + + #[test] + fn permessage_deflate_offer_canonicalizes_safe_params() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + assert_eq!( + supported_permessage_deflate_offer(&raw) + .expect("safe extension offer should parse") + .as_deref(), + Some("permessage-deflate; client_no_context_takeover; server_no_context_takeover") + ); + } + + #[test] + fn permessage_deflate_offer_rejects_duplicate_safe_params() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; client_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + assert!( + supported_permessage_deflate_offer(&raw) + .expect("duplicate safe param should parse but not be supported") + .is_none() + ); + } + + #[test] + fn permessage_deflate_offer_rejects_quoted_values() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover=\"true\"\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + let err = supported_permessage_deflate_offer(&raw) + .expect_err("quoted permessage-deflate parameter values should fail closed"); + assert!(err.to_string().contains("parameter value")); + } + + #[test] + fn permessage_deflate_response_accepts_reordered_safe_params() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some( + "permessage-deflate; client_no_context_takeover; server_no_context_takeover" + .to_string(), ), + offered_subprotocols: Vec::new(), + }; + + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect("reordered safe extension params should canonicalize"); + + assert!(negotiated); + } + + #[test] + fn permessage_deflate_response_rejects_duplicate_safe_params() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some("permessage-deflate; client_no_context_takeover".to_string()), + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("duplicate extension params should fail closed"); + + assert!(err.to_string().contains("unsupported permessage-deflate")); + } + + #[test] + fn preserve_mode_leaves_malformed_extension_response_raw() { + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover=\"true\"\r\n\r\n", + WebSocketExtensionMode::Preserve, + None, ) - .await - .expect("relay must not deadlock"); + .expect("preserve mode should not parse or reject raw extension negotiation"); - let outcome = result.expect("relay should succeed"); - assert!( - matches!(outcome, RelayOutcome::Consumed), - "unsolicited 101 should be rejected as Consumed, got {outcome:?}" + assert!(!negotiated); + } + + #[test] + fn parse_websocket_upgrade_request_tracks_subprotocols() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Protocol: chat, superchat\r\nSec-WebSocket-Version: 13\r\n\r\n" ); - upstream_task.await.expect("upstream task should complete"); - } + let request = parse_websocket_upgrade_request(raw.as_bytes()) + .expect("request should parse") + .expect("request should be websocket"); - #[tokio::test] - async fn relay_accepts_101_with_client_upgrade_header() { - // Client sends a proper upgrade request with Upgrade + Connection headers. - let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); - let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + assert_eq!(request.subprotocols, ["chat", "superchat"]); + } - let req = L7Request { - action: "GET".to_string(), - target: "/ws".to_string(), - query_params: HashMap::new(), - raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n".to_vec(), - body_length: BodyLength::None, + #[test] + fn strict_response_validation_allows_offered_subprotocol() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string(), "superchat".to_string()], }; - let upstream_task = tokio::spawn(async move { - let mut buf = vec![0u8; 4096]; - let mut total = 0; - loop { - let n = upstream_side.read(&mut buf[total..]).await.unwrap(); - if n == 0 { - break; - } - total += n; - if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { - break; - } - } - upstream_side - .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", - ) - .await - .unwrap(); - upstream_side.flush().await.unwrap(); - }); + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: superchat\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect("offered subprotocol should validate"); - let result = tokio::time::timeout( - std::time::Duration::from_secs(5), - relay_http_request_with_resolver( - &req, - &mut proxy_to_client, - &mut proxy_to_upstream, - None, - ), + assert!(!negotiated); + } + + #[test] + fn strict_response_validation_rejects_unoffered_subprotocol() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string()], + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: admin\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), ) - .await - .expect("relay must not deadlock"); + .expect_err("unoffered subprotocol should fail closed"); - let outcome = result.expect("relay should succeed"); - assert!( - matches!(outcome, RelayOutcome::Upgraded { .. }), - "proper upgrade request should be accepted, got {outcome:?}" - ); + assert!(err.to_string().contains("subprotocol")); + } - upstream_task.await.expect("upstream task should complete"); + #[test] + fn strict_response_validation_rejects_multiple_subprotocol_headers() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string(), "superchat".to_string()], + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Protocol: superchat\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("multiple selected subprotocols should fail closed"); + + assert!(err.to_string().contains("Sec-WebSocket-Protocol")); } #[tokio::test] @@ -2063,6 +4069,94 @@ mod tests { assert!(client_requested_upgrade(headers)); } + #[test] + fn request_is_websocket_upgrade_detects_websocket_upgrade() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: keep-alive, Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!(request_is_websocket_upgrade(raw.as_bytes())); + } + + #[test] + fn request_is_websocket_upgrade_rejects_missing_key() { + let raw = b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(validate_websocket_upgrade_request(raw).is_err()); + } + + #[test] + fn request_is_websocket_upgrade_rejects_wrong_method() { + let raw = format!( + "POST /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!(!request_is_websocket_upgrade(raw.as_bytes())); + assert!(validate_websocket_upgrade_request(raw.as_bytes()).is_err()); + } + + #[test] + fn request_is_websocket_upgrade_rejects_wrong_version() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 12\r\n\r\n" + ); + assert!(!request_is_websocket_upgrade(raw.as_bytes())); + assert!(validate_websocket_upgrade_request(raw.as_bytes()).is_err()); + } + + #[test] + fn validate_websocket_upgrade_ignores_plain_rest_request() { + let raw = b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(!validate_websocket_upgrade_request(raw).expect("plain request should parse")); + } + + #[test] + fn validate_websocket_upgrade_ignores_non_websocket_upgrade() { + let raw = b"GET /h2c HTTP/1.1\r\nHost: example.com\r\nUpgrade: h2c\r\nConnection: Upgrade\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(!validate_websocket_upgrade_request(raw).expect("h2c request should parse")); + } + + #[test] + fn strip_websocket_extensions_removes_extension_negotiation() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + let (stripped, offered) = rewrite_websocket_extensions_for_mode( + raw.as_bytes(), + WebSocketExtensionMode::PermessageDeflate, + true, + ) + .expect("strip should succeed"); + assert!(offered.is_none()); + let stripped = String::from_utf8(stripped).unwrap(); + + assert!(stripped.contains("Upgrade: websocket\r\n")); + assert!(stripped.contains("Sec-WebSocket-Key: ")); + assert!(stripped.contains("Sec-WebSocket-Version: 13\r\n")); + assert!( + !stripped + .to_ascii_lowercase() + .contains("sec-websocket-extensions") + ); + assert!(stripped.ends_with("\r\n\r\n")); + } + + #[test] + fn strip_websocket_extensions_leaves_non_websocket_request_unchanged() { + let raw = b"GET /api HTTP/1.1\r\nHost: example.com\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n"; + + let (stripped, offered) = rewrite_websocket_extensions_for_mode( + raw, + WebSocketExtensionMode::PermessageDeflate, + false, + ) + .expect("strip should succeed"); + + assert!(offered.is_none()); + assert_eq!(stripped, raw); + } + #[test] fn rewrite_header_block_resolves_placeholder_auth_headers() { let (_, resolver) = SecretResolver::from_provider_env( @@ -2334,6 +4428,257 @@ mod tests { Ok(forwarded) } + async fn relay_and_capture_with_options( + raw_header: Vec, + body_length: BodyLength, + resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, + ) -> Result { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let header_str = String::from_utf8_lossy(&raw_header); + let first_line = header_str.lines().next().unwrap_or(""); + let parts: Vec<&str> = first_line.splitn(3, ' ').collect(); + let action = parts.first().unwrap_or(&"GET").to_string(); + let target = parts.get(1).unwrap_or(&"/").to_string(); + + let req = L7Request { + action, + target, + query_params: HashMap::new(), + raw_header, + body_length, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + let mut total = 0usize; + let mut header_end = None; + let mut expected_total = None; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if header_end.is_none() + && let Some(end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") + { + let end = end + 4; + let headers = String::from_utf8_lossy(&buf[..end]); + let len = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + header_end = Some(end); + expected_total = Some(end + len); + } + if expected_total.is_some_and(|expected| total >= expected) { + break; + } + } + upstream_side + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + String::from_utf8_lossy(&buf[..total]).to_string() + }); + + relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver, + request_body_credential_rewrite, + ..Default::default() + }, + ) + .await?; + + upstream_task + .await + .map_err(|e| miette!("upstream task failed: {e}")) + } + + #[tokio::test] + async fn relay_request_body_rewrites_provider_alias_header_and_urlencoded_token() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider.v1-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = format!("token={alias}&channel=C123"); + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::ContentLength(body.len() as u64), + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); + } + + #[tokio::test] + async fn relay_request_body_rewrites_percent_encoded_canonical_urlencoded_token() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=openshell%3Aresolve%3Aenv%3AAPI_TOKEN¬e=hello+world"; + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::ContentLength(body.len() as u64), + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + let expected_body = "token=provider-real-token¬e=hello+world"; + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("openshell%3Aresolve%3Aenv%3AAPI_TOKEN")); + assert!(!forwarded.contains("openshell:resolve:env:API_TOKEN")); + } + + #[tokio::test] + async fn relay_request_body_unresolved_alias_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=provider-OPENSHELL-RESOLVE-ENV-APP_TOKEN"; + let raw = format!( + "POST /api/connections.open HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let req = L7Request { + action: "POST".to_string(), + target: "/api/connections.open".to_string(), + query_params: HashMap::new(), + raw_header: raw.into_bytes(), + body_length: BodyLength::ContentLength(body.len() as u64), + }; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: Some(&resolver), + request_body_credential_rewrite: true, + ..Default::default() + }, + ) + .await + .expect_err("unknown body alias should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed body rewrite must not reach upstream" + ); + } + + #[tokio::test] + async fn relay_request_body_unresolved_encoded_canonical_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=openshell%3Aresolve%3Aenv%3AMISSING_TOKEN"; + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let req = L7Request { + action: "POST".to_string(), + target: "/api/messages".to_string(), + query_params: HashMap::new(), + raw_header: raw.into_bytes(), + body_length: BodyLength::ContentLength(body.len() as u64), + }; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: Some(&resolver), + request_body_credential_rewrite: true, + ..Default::default() + }, + ) + .await + .expect_err("unknown encoded body placeholder should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + assert!(!err.to_string().contains("MISSING_TOKEN")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed body rewrite must not reach upstream" + ); + } + #[tokio::test] async fn relay_injects_bearer_header_credential() { let (child_env, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs new file mode 100644 index 000000000..2dc1b25c3 --- /dev/null +++ b/crates/openshell-sandbox/src/l7/websocket.rs @@ -0,0 +1,1937 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! WebSocket relay for opt-in credential placeholder rewriting and message policy. +//! +//! The relay parses only client-to-server frames. Server-to-client bytes stay +//! raw passthrough so inspection and rewriting cannot expose response payloads. + +use crate::l7::relay::{L7EvalContext, evaluate_l7_request}; +use crate::l7::{EnforcementMode, L7RequestInfo}; +use crate::opa::TunnelPolicyEngine; +use crate::secrets::SecretResolver; +use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; +use miette::{IntoDiagnostic, Result, miette}; +use openshell_ocsf::{ + ActionId, ActivityId, DispositionId, Endpoint, NetworkActivityBuilder, SeverityId, StatusId, + ocsf_emit, +}; +use std::collections::HashMap; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +const MAX_TEXT_MESSAGE_BYTES: usize = 1024 * 1024; +const MAX_RAW_FRAME_PAYLOAD_BYTES: u64 = 16 * 1024 * 1024; +const COPY_BUF_SIZE: usize = 8192; +const OPCODE_CONTINUATION: u8 = 0x0; +const OPCODE_TEXT: u8 = 0x1; +const OPCODE_BINARY: u8 = 0x2; +const OPCODE_CLOSE: u8 = 0x8; +const OPCODE_PING: u8 = 0x9; +const OPCODE_PONG: u8 = 0xA; + +#[derive(Debug)] +struct FrameHeader { + fin: bool, + rsv: u8, + opcode: u8, + masked: bool, + payload_len: u64, + mask_key: Option<[u8; 4]>, + raw_header: Vec, +} + +#[derive(Debug)] +enum FragmentState { + None, + Text { payload: Vec, compressed: bool }, + Binary, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum WebSocketCompression { + None, + PermessageDeflate, +} + +pub(super) struct InspectionOptions<'a> { + pub(super) engine: &'a TunnelPolicyEngine, + pub(super) ctx: &'a L7EvalContext, + pub(super) enforcement: EnforcementMode, + pub(super) target: String, + pub(super) query_params: HashMap>, + pub(super) graphql_policy: bool, +} + +pub(super) struct RelayOptions<'a> { + pub(super) policy_name: &'a str, + pub(super) resolver: Option<&'a SecretResolver>, + pub(super) inspector: Option>, + pub(super) compression: WebSocketCompression, +} + +/// Relay an upgraded WebSocket connection with optional client text inspection, +/// credential rewriting, and strict permessage-deflate handling. +pub(super) async fn relay_with_options( + client: &mut C, + upstream: &mut U, + overflow: Vec, + host: &str, + port: u16, + options: RelayOptions<'_>, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + let (mut client_read, mut client_write) = tokio::io::split(client); + let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream); + + if !overflow.is_empty() { + client_write.write_all(&overflow).await.into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + } + + let client_to_server = + relay_client_to_server(&mut client_read, &mut upstream_write, host, port, &options); + let server_to_client = async { + tokio::io::copy(&mut upstream_read, &mut client_write) + .await + .into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + Ok::<(), miette::Report>(()) + }; + + let result = tokio::select! { + result = client_to_server => result, + result = server_to_client => result, + }; + let _ = upstream_write.shutdown().await; + let _ = client_write.shutdown().await; + result +} + +async fn relay_client_to_server( + reader: &mut R, + writer: &mut W, + host: &str, + port: u16, + options: &RelayOptions<'_>, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut fragments = FragmentState::None; + let mut close_seen = false; + + loop { + let Some(frame) = read_frame_header(reader).await.inspect_err(|e| { + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(e)); + })? + else { + writer.shutdown().await.into_diagnostic()?; + return Ok(()); + }; + + if close_seen { + let e = miette!("websocket frame received after close frame"); + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); + return Err(e); + } + + if let Err(e) = validate_frame_header(&frame, &fragments, options.compression) { + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); + return Err(e); + } + + match frame.opcode { + OPCODE_TEXT => { + let payload = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + let compressed = frame.rsv == 0x40; + if frame.fin { + relay_text_payload( + writer, &frame, payload, false, compressed, host, port, options, + ) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } else { + fragments = FragmentState::Text { + payload, + compressed, + }; + } + } + OPCODE_CONTINUATION => match &mut fragments { + FragmentState::Text { + payload, + compressed, + } => { + let next = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if let Err(e) = append_text_fragment(payload, next) { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); + return Err(e); + } + if frame.fin { + let complete = std::mem::take(payload); + let was_compressed = *compressed; + fragments = FragmentState::None; + relay_text_payload( + writer, + &frame, + complete, + true, + was_compressed, + host, + port, + options, + ) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } + } + FragmentState::Binary => { + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if frame.fin { + fragments = FragmentState::None; + } + } + FragmentState::None => { + let e = + miette!("websocket continuation frame without active fragmented message"); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); + return Err(e); + } + }, + OPCODE_BINARY => { + if !frame.fin { + fragments = FragmentState::Binary; + } + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } + OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG => { + relay_control_frame(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if frame.opcode == OPCODE_CLOSE { + close_seen = true; + } + } + _ => unreachable!("validated opcode"), + } + } +} + +async fn read_frame_header(reader: &mut R) -> Result> { + let first = match reader.read_u8().await { + Ok(byte) => byte, + Err(e) + if matches!( + e.kind(), + std::io::ErrorKind::UnexpectedEof + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::BrokenPipe + ) => + { + return Ok(None); + } + Err(e) => return Err(miette!("{e}")), + }; + let second = reader + .read_u8() + .await + .map_err(|e| miette!("malformed websocket frame header: {e}"))?; + + let mut raw_header = vec![first, second]; + let len_code = second & 0x7F; + let payload_len = match len_code { + 0..=125 => u64::from(len_code), + 126 => { + let mut bytes = [0u8; 2]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + raw_header.extend_from_slice(&bytes); + let len = u64::from(u16::from_be_bytes(bytes)); + if len < 126 { + return Err(miette!( + "websocket frame uses non-minimal 16-bit extended length" + )); + } + len + } + 127 => { + let mut bytes = [0u8; 8]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + if bytes[0] & 0x80 != 0 { + return Err(miette!("websocket frame uses non-canonical 64-bit length")); + } + raw_header.extend_from_slice(&bytes); + let len = u64::from_be_bytes(bytes); + if u16::try_from(len).is_ok() { + return Err(miette!( + "websocket frame uses non-minimal 64-bit extended length" + )); + } + len + } + _ => unreachable!("7-bit length code"), + }; + + let masked = second & 0x80 != 0; + let mask_key = if masked { + let mut key = [0u8; 4]; + reader + .read_exact(&mut key) + .await + .map_err(|e| miette!("malformed websocket mask key: {e}"))?; + raw_header.extend_from_slice(&key); + Some(key) + } else { + None + }; + + Ok(Some(FrameHeader { + fin: first & 0x80 != 0, + rsv: first & 0x70, + opcode: first & 0x0F, + masked, + payload_len, + mask_key, + raw_header, + })) +} + +fn validate_frame_header( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> Result<()> { + if !valid_rsv_bits(frame, fragments, compression) { + return Err(miette!( + "websocket frame has unsupported RSV bits or extension state" + )); + } + if !frame.masked { + return Err(miette!("websocket client frame is not masked")); + } + if !matches!( + frame.opcode, + OPCODE_CONTINUATION + | OPCODE_TEXT + | OPCODE_BINARY + | OPCODE_CLOSE + | OPCODE_PING + | OPCODE_PONG + ) { + return Err(miette!("websocket frame uses reserved opcode")); + } + if matches!(frame.opcode, OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG) { + if !frame.fin { + return Err(miette!("websocket control frame is fragmented")); + } + if frame.payload_len > 125 { + return Err(miette!("websocket control frame exceeds 125 bytes")); + } + } + if matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) + && !matches!(fragments, FragmentState::None) + { + return Err(miette!( + "websocket data frame started before previous fragmented message completed" + )); + } + if matches!(frame.opcode, OPCODE_CONTINUATION) && matches!(fragments, FragmentState::None) { + return Err(miette!( + "websocket continuation frame without active fragmented message" + )); + } + if (frame.opcode == OPCODE_BINARY + || (frame.opcode == OPCODE_CONTINUATION && matches!(fragments, FragmentState::Binary))) + && frame.payload_len > MAX_RAW_FRAME_PAYLOAD_BYTES + { + return Err(miette!( + "websocket binary frame exceeds {MAX_RAW_FRAME_PAYLOAD_BYTES} byte relay limit" + )); + } + Ok(()) +} + +fn valid_rsv_bits( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> bool { + if frame.rsv == 0 { + return true; + } + if compression != WebSocketCompression::PermessageDeflate || frame.rsv != 0x40 { + return false; + } + matches!(fragments, FragmentState::None) && matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) +} + +async fn read_masked_payload( + reader: &mut R, + frame: &FrameHeader, +) -> Result> { + let payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket text frame is too large to buffer"))?; + if payload_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + let mut payload = vec![0u8; payload_len]; + reader + .read_exact(&mut payload) + .await + .map_err(|e| miette!("malformed websocket payload: {e}"))?; + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + Ok(payload) +} + +fn append_text_fragment(buffer: &mut Vec, next: Vec) -> Result<()> { + let new_len = buffer + .len() + .checked_add(next.len()) + .ok_or_else(|| miette!("websocket text message length overflow"))?; + if new_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + buffer.extend_from_slice(&next); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +async fn relay_text_payload( + writer: &mut W, + frame: &FrameHeader, + payload: Vec, + force_reframe: bool, + compressed: bool, + host: &str, + port: u16, + options: &RelayOptions<'_>, +) -> Result<()> { + let message_payload = if compressed { + decompress_permessage_deflate(&payload)? + } else { + payload + }; + let mut text = String::from_utf8(message_payload) + .map_err(|_| miette!("websocket text message is not valid UTF-8"))?; + let replacements = if let Some(resolver) = options.resolver { + resolver + .rewrite_websocket_text_placeholders(&mut text) + .map_err(|_| miette!("websocket credential placeholder resolution failed"))? + } else { + 0 + }; + + if let Some(inspector) = options.inspector.as_ref() { + inspect_websocket_text_message(host, port, options.policy_name, inspector, &text)?; + } + + if replacements == 0 && !force_reframe && !compressed { + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut payload = text.into_bytes(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + writer.write_all(&payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + return Ok(()); + } + + if replacements > 0 { + emit_rewrite_event(host, port, options.policy_name, replacements); + } + if compressed { + let compressed_payload = compress_permessage_deflate(text.as_bytes())?; + return write_masked_frame_with_rsv(writer, OPCODE_TEXT, 0x40, &compressed_payload).await; + } + write_masked_frame(writer, OPCODE_TEXT, text.as_bytes()).await +} + +fn inspect_websocket_text_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, + text: &str, +) -> Result<()> { + if inspector.graphql_policy { + return inspect_graphql_websocket_message(host, port, policy_name, inspector, text); + } + + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + }; + let (allowed, reason) = evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)?; + let decision = match (allowed, inspector.enforcement) { + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + None, + ); + if !allowed && inspector.enforcement == EnforcementMode::Enforce { + return Err(miette!("websocket text message denied by policy")); + } + Ok(()) +} + +fn inspect_graphql_websocket_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, + text: &str, +) -> Result<()> { + match classify_graphql_websocket_message(text) { + GraphqlWebSocketMessage::Control { message_type } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_CONTROL".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + }; + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + "allow", + &format!("GraphQL WebSocket control message {message_type}"), + None, + ); + Ok(()) + } + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: Some(graphql.clone()), + }; + let parse_error_reason = graphql + .error + .as_deref() + .map(|error| format!("GraphQL WebSocket message rejected: {error}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason) = if let Some(reason) = parse_error_reason { + (false, reason) + } else { + evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)? + }; + let decision = match (allowed, inspector.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + let reason = format!("graphql_ws_type={message_type} {reason}"); + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + Some(&graphql), + ); + if (!allowed && inspector.enforcement == EnforcementMode::Enforce) || force_deny { + return Err(miette!("websocket GraphQL message denied by policy")); + } + Ok(()) + } + } +} + +#[derive(Debug)] +enum GraphqlWebSocketMessage { + Control { + message_type: String, + }, + Operation { + message_type: String, + graphql: crate::l7::graphql::GraphqlRequestInfo, + }, +} + +fn classify_graphql_websocket_message(text: &str) -> GraphqlWebSocketMessage { + let value = match serde_json::from_str::(text) { + Ok(value) => value, + Err(err) => { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error(format!( + "GraphQL WebSocket message is not valid JSON: {err}" + )), + }; + } + }; + let Some(obj) = value.as_object() else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message must be a JSON object"), + }; + }; + let Some(message_type) = obj.get("type").and_then(serde_json::Value::as_str) else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message missing string type"), + }; + }; + + match message_type { + "subscribe" | "start" => { + if obj + .get("id") + .and_then(serde_json::Value::as_str) + .is_none_or(str::is_empty) + { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing non-empty id", + ), + }; + } + let Some(payload) = obj.get("payload").filter(|value| value.is_object()) else { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing object payload", + ), + }; + }; + GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: crate::l7::graphql::classify_json_envelope_value(payload), + } + } + "connection_init" | "connection_terminate" | "ping" | "pong" | "complete" | "stop" => { + GraphqlWebSocketMessage::Control { + message_type: message_type.to_string(), + } + } + _ => GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error(format!( + "unsupported GraphQL WebSocket client message type {message_type:?}" + )), + }, + } +} + +fn graphql_error(message: impl Into) -> crate::l7::graphql::GraphqlRequestInfo { + crate::l7::graphql::GraphqlRequestInfo { + operations: Vec::new(), + error: Some(message.into()), + } +} + +async fn relay_control_frame( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let raw_payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket control frame payload length overflow"))?; + let mut raw_payload = vec![0u8; raw_payload_len]; + reader + .read_exact(&mut raw_payload) + .await + .map_err(|e| miette!("malformed websocket control payload: {e}"))?; + + if frame.opcode == OPCODE_CLOSE { + let mut payload = raw_payload.clone(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + validate_close_payload(&payload)?; + } + + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + writer.write_all(&raw_payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn validate_close_payload(payload: &[u8]) -> Result<()> { + if payload.len() == 1 { + return Err(miette!( + "websocket close frame payload cannot be exactly one byte" + )); + } + if payload.len() < 2 { + return Ok(()); + } + + let code = u16::from_be_bytes([payload[0], payload[1]]); + if !valid_close_code(code) { + return Err(miette!("websocket close frame uses invalid close code")); + } + if std::str::from_utf8(&payload[2..]).is_err() { + return Err(miette!("websocket close frame reason is not valid UTF-8")); + } + Ok(()) +} + +fn valid_close_code(code: u16) -> bool { + (matches!(code, 1000..=1014) && !matches!(code, 1004..=1006)) || (3000..=4999).contains(&code) +} + +async fn copy_raw_frame_payload( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut remaining = frame.payload_len; + let mut buf = [0u8; COPY_BUF_SIZE]; + while remaining > 0 { + let to_read = usize::try_from(remaining) + .unwrap_or(buf.len()) + .min(buf.len()); + let n = reader.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("websocket payload ended before declared length")); + } + writer.write_all(&buf[..n]).await.into_diagnostic()?; + remaining -= n as u64; + } + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +async fn write_masked_frame( + writer: &mut W, + opcode: u8, + payload: &[u8], +) -> Result<()> { + write_masked_frame_with_rsv(writer, opcode, 0, payload).await +} + +async fn write_masked_frame_with_rsv( + writer: &mut W, + opcode: u8, + rsv: u8, + payload: &[u8], +) -> Result<()> { + let mut header = Vec::with_capacity(14); + header.push(0x80 | rsv | opcode); + match payload.len() { + 0..=125 => header.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + header.push(0x80 | 0x7e); + header.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + header.push(0x80 | 127); + header.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + let mask_key = new_mask_key(); + header.extend_from_slice(&mask_key); + + let mut masked = payload.to_vec(); + apply_mask(&mut masked, mask_key); + writer.write_all(&header).await.into_diagnostic()?; + writer.write_all(&masked).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn decompress_permessage_deflate(payload: &[u8]) -> Result> { + let mut decoder = Decompress::new(false); + let mut input = Vec::with_capacity(payload.len() + 4); + input.extend_from_slice(payload); + input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); + let mut out = Vec::with_capacity(payload.len().saturating_mul(2).min(MAX_TEXT_MESSAGE_BYTES)); + let mut input_pos = 0usize; + let mut scratch = [0u8; COPY_BUF_SIZE]; + loop { + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let status = decoder + .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate decompression failed: {e}"))?; + let read = usize::try_from(decoder.total_in() - before_in) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + let written = usize::try_from(decoder.total_out() - before_out) + .map_err(|_| miette!("websocket permessage-deflate output length overflow"))?; + input_pos = input_pos + .checked_add(read) + .ok_or_else(|| miette!("websocket permessage-deflate input length overflow"))?; + if out.len().saturating_add(written) > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + out.extend_from_slice(&scratch[..written]); + if matches!(status, Status::StreamEnd) { + break; + } + if input_pos >= input.len() && written < scratch.len() { + break; + } + if read == 0 && written == 0 { + return Err(miette!( + "websocket permessage-deflate decompression did not make progress" + )); + } + } + Ok(out) +} + +fn compress_permessage_deflate(payload: &[u8]) -> Result> { + let mut compressor = Compress::new(Compression::fast(), false); + let expansion = payload.len() / 16; + let mut out = Vec::with_capacity(payload.len().saturating_add(expansion).saturating_add(128)); + loop { + let consumed = usize::try_from(compressor.total_in()) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + if consumed >= payload.len() { + break; + } + let before_in = compressor.total_in(); + let before_out = compressor.total_out(); + let status = compressor + .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if matches!(status, Status::BufError) + || (compressor.total_in() == before_in && compressor.total_out() == before_out) + { + out.reserve(out.capacity().max(1024)); + } + } + loop { + out.reserve(64); + let before_out = compressor.total_out(); + compressor + .compress_vec(&[], &mut out, FlushCompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + break; + } + if compressor.total_out() == before_out { + out.reserve(out.capacity().max(1024)); + } + } + if !out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + return Err(miette!( + "websocket permessage-deflate compression missing sync marker" + )); + } + out.truncate(out.len() - 4); + Ok(out) +} + +fn new_mask_key() -> [u8; 4] { + let bytes = uuid::Uuid::new_v4().into_bytes(); + [bytes[0], bytes[1], bytes[2], bytes[3]] +} + +fn apply_mask(payload: &mut [u8], mask_key: [u8; 4]) { + for (i, byte) in payload.iter_mut().enumerate() { + *byte ^= mask_key[i % 4]; + } +} + +fn emit_rewrite_event(host: &str, port: u16, policy_name: &str, replacements: usize) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(rewrite_event_message(host, port, replacements)) + .build(); + ocsf_emit!(event); +} + +fn rewrite_event_message(host: &str, port: u16, replacements: usize) -> String { + format!( + "WEBSOCKET_CREDENTIAL_REWRITE rewrote client text message [host:{host} port:{port} replacements:{replacements}]" + ) +} + +fn emit_websocket_l7_event( + host: &str, + port: u16, + policy_name: &str, + request_info: &L7RequestInfo, + decision: &str, + reason: &str, + graphql: Option<&crate::l7::graphql::GraphqlRequestInfo>, +) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let (action_id, disposition_id, severity) = match decision { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "allow" | "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let summary = graphql.map(graphql_log_summary).unwrap_or_default(); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(format!( + "WEBSOCKET_L7_REQUEST {decision} {} {host}:{port}{}{} reason={reason}", + request_info.action, request_info.target, summary + )) + .build(); + ocsf_emit!(event); +} + +fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String { + if let Some(error) = info.error.as_deref() { + return format!(" graphql_error={error:?}"); + } + let ops: Vec = info + .operations + .iter() + .map(|op| { + let name = op.operation_name.as_deref().unwrap_or("-"); + let fields = if op.fields.is_empty() { + "-".to_string() + } else { + op.fields.join(",") + }; + let persisted = op + .persisted_query_hash + .as_deref() + .or(op.persisted_query_id.as_deref()) + .unwrap_or("-"); + format!( + "type={} name={} fields={} persisted={}", + op.operation_type, name, fields, persisted + ) + }) + .collect(); + format!(" graphql_ops={}", ops.join(";")) +} + +fn protocol_failure_class(error: &miette::Report) -> &'static str { + let msg = error.to_string().to_ascii_lowercase(); + if msg.contains("credential") { + "credential_resolution_failed" + } else if msg.contains("utf-8") { + "invalid_utf8" + } else if msg.contains("close frame") || msg.contains("after close") { + "invalid_close_frame" + } else if msg.contains("control frame") { + "invalid_control_frame" + } else if msg.contains("length") + || msg.contains("too large") + || msg.contains("exceeds") + || msg.contains("overflow") + { + "invalid_length" + } else if msg.contains("continuation") || msg.contains("fragmented") { + "invalid_fragmentation" + } else if msg.contains("reserved opcode") { + "reserved_opcode" + } else if msg.contains("not masked") { + "unmasked_client_frame" + } else if msg.contains("rsv") { + "rsv_bits" + } else if msg.contains("malformed") { + "malformed_frame" + } else { + "protocol_error" + } +} + +fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, failure_class: &str) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(protocol_failure_message(host, port)) + .status_detail(failure_class) + .build(); + ocsf_emit!(event); +} + +fn protocol_failure_message(host: &str, port: u16) -> String { + format!("WEBSOCKET_CREDENTIAL_REWRITE closed ambiguous client frame [host:{host} port:{port}]") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::l7::relay::L7EvalContext; + use crate::opa::{NetworkInput, OpaEngine}; + use crate::secrets::SecretResolver; + use std::path::PathBuf; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + const GRAPHQL_WS_POLICY: &str = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + binaries: + - { path: /usr/bin/node } +"#; + + fn resolver() -> (HashMap, SecretResolver) { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + (child_env, resolver.expect("resolver")) + } + + fn masked_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec { + masked_frame_with_rsv(fin, opcode, 0, payload) + } + + fn masked_frame_with_rsv(fin: bool, opcode: u8, rsv: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push((if fin { 0x80 } else { 0 }) | rsv | opcode); + match payload.len() { + 0..=125 => frame.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + frame.push(0x80 | 127); + frame.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(u8::try_from(payload.len()).expect("test payload fits in one byte")); + frame.extend_from_slice(payload); + frame + } + + fn masked_frame_with_declared_len(opcode: u8, declared_len: u64) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 127); + frame.extend_from_slice(&declared_len.to_be_bytes()); + frame.extend_from_slice(&[0x37, 0xfa, 0x21, 0x3d]); + frame + } + + fn masked_frame_with_non_minimal_16_bit_len(opcode: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("test payload fits u16") + .to_be_bytes(), + ); + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn close_payload(code: u16, reason: &[u8]) -> Vec { + let mut payload = Vec::with_capacity(2 + reason.len()); + payload.extend_from_slice(&code.to_be_bytes()); + payload.extend_from_slice(reason); + payload + } + + async fn run_client_to_server(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + async fn run_client_to_server_with_graphql_policy( + input: Vec, + resolver: Option<&SecretResolver>, + ) -> Result> { + let engine = OpaEngine::from_strings(TEST_POLICY, GRAPHQL_WS_POLICY) + .expect("GraphQL WebSocket policy should load"); + let network_input = NetworkInput { + host: "realtime.graphql.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&network_input) + .expect("network action should evaluate") + .1; + let tunnel_engine = engine + .clone_engine_for_tunnel(generation) + .expect("tunnel engine"); + let ctx = L7EvalContext { + host: "realtime.graphql.test".into(), + port: 443, + policy_name: "graphql_ws".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "graphql_ws", + resolver, + inspector: Some(InspectionOptions { + engine: &tunnel_engine, + ctx: &ctx, + enforcement: EnforcementMode::Enforce, + target: "/graphql".to_string(), + query_params: HashMap::new(), + graphql_policy: true, + }), + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "realtime.graphql.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + async fn run_client_to_server_compressed(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::PermessageDeflate, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + fn decode_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_ne!(frame[1] & 0x80, 0); + String::from_utf8(decode_masked_payload(frame)).unwrap() + } + + fn decode_masked_payload(frame: &[u8]) -> Vec { + assert_ne!(frame[1] & 0x80, 0); + let len_code = frame[1] & 0x7F; + let (payload_len, mask_offset) = match len_code { + 0..=125 => (usize::from(len_code), 2), + 126 => (usize::from(u16::from_be_bytes([frame[2], frame[3]])), 4), + 127 => { + let len = u64::from_be_bytes(frame[2..10].try_into().unwrap()); + (usize::try_from(len).unwrap(), 10) + } + _ => unreachable!(), + }; + let mask_key: [u8; 4] = frame[mask_offset..mask_offset + 4].try_into().unwrap(); + let mut payload = frame[mask_offset + 4..mask_offset + 4 + payload_len].to_vec(); + apply_mask(&mut payload, mask_key); + payload + } + + fn decode_compressed_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_eq!(frame[0] & 0x40, 0x40); + let payload = decode_masked_payload(frame); + String::from_utf8(decompress_permessage_deflate(&payload).unwrap()).unwrap() + } + + async fn read_one_frame(reader: &mut R) -> Vec { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await.unwrap(); + let len_code = header[1] & 0x7F; + let extended_len = match len_code { + 0..=125 => Vec::new(), + 126 => { + let mut bytes = vec![0u8; 2]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + 127 => { + let mut bytes = vec![0u8; 8]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + _ => unreachable!(), + }; + let payload_len = match len_code { + 0..=125 => usize::from(len_code), + 126 => usize::from(u16::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )), + 127 => usize::try_from(u64::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )) + .unwrap(), + _ => unreachable!(), + }; + let mask_len = if header[1] & 0x80 != 0 { 4 } else { 0 }; + let mut rest = vec![0u8; extended_len.len() + mask_len + payload_len]; + rest[..extended_len.len()].copy_from_slice(&extended_len); + reader + .read_exact(&mut rest[extended_len.len()..]) + .await + .unwrap(); + + let mut frame = header.to_vec(); + frame.extend_from_slice(&rest); + frame + } + + #[test] + fn classifies_graphql_transport_ws_subscribe_operation() { + let message = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "subscribe"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations.len(), 1); + assert_eq!(graphql.operations[0].operation_type, "subscription"); + assert_eq!( + graphql.operations[0].operation_name.as_deref(), + Some("NewMessages") + ); + assert_eq!(graphql.operations[0].fields, vec!["messageAdded"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_legacy_graphql_ws_start_operation() { + let message = r#"{"type":"start","id":"1","payload":{"query":"query Viewer { viewer }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "start"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations[0].operation_type, "query"); + assert_eq!(graphql.operations[0].fields, vec!["viewer"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_graphql_websocket_control_message_without_payload_logging() { + match classify_graphql_websocket_message( + r#"{"type":"connection_init","payload":{"authorization":"secret"}}"#, + ) { + GraphqlWebSocketMessage::Control { message_type } => { + assert_eq!(message_type, "connection_init"); + } + other @ GraphqlWebSocketMessage::Operation { .. } => { + panic!("expected control message, got {other:?}") + } + } + } + + #[test] + fn unsupported_graphql_websocket_message_type_fails_closed() { + match classify_graphql_websocket_message(r#"{"type":"next","id":"1"}"#) { + GraphqlWebSocketMessage::Operation { graphql, .. } => { + assert!( + graphql + .error + .as_deref() + .is_some_and(|error| error.contains("unsupported")) + ); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation error, got {other:?}") + } + } + } + + #[test] + fn graphql_websocket_log_summary_excludes_payload_variables_and_secrets() { + let placeholder = "openshell:resolve:env:T"; + let message = format!( + r#"{{"type":"subscribe","id":"1","payload":{{"query":"query Viewer {{ viewer }}","variables":{{"token":"{placeholder}"}}}}}}"# + ); + let graphql = match classify_graphql_websocket_message(&message) { + GraphqlWebSocketMessage::Operation { graphql, .. } => graphql, + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + }; + let summary = graphql_log_summary(&graphql); + + assert!(summary.contains("type=query")); + assert!(summary.contains("fields=viewer")); + assert!(!summary.contains(placeholder)); + assert!(!summary.contains("real-token")); + assert!(!summary.contains("variables")); + assert!(!summary.contains("token")); + assert!(!summary.contains("secret_len")); + } + + #[tokio::test] + async fn rewrites_discord_like_identify_text_payload() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let output = run_client_to_server(masked_frame(true, OPCODE_TEXT, payload.as_bytes())) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } + + #[tokio::test] + async fn upgraded_relay_rewrites_client_text_before_upstream_receives_it() { + let (child_env, resolver) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + let client_frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + assert!( + !String::from_utf8_lossy(&client_frame).contains("real-token"), + "client-side fixture must not contain the real token" + ); + + let (mut client_app, mut relay_client) = tokio::io::duplex(4096); + let (mut relay_upstream, mut upstream_app) = tokio::io::duplex(4096); + let relay = tokio::spawn(async move { + relay_with_options( + &mut relay_client, + &mut relay_upstream, + Vec::new(), + "gateway.example.test", + 443, + RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }, + ) + .await + }); + + client_app.write_all(&client_frame).await.unwrap(); + client_app.flush().await.unwrap(); + + let upstream_frame = tokio::time::timeout( + std::time::Duration::from_secs(2), + read_one_frame(&mut upstream_app), + ) + .await + .expect("upstream should receive rewritten frame"); + assert_eq!( + decode_masked_text_frame(&upstream_frame), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + + drop(client_app); + drop(upstream_app); + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay).await; + } + + #[tokio::test] + async fn graphql_websocket_policy_allows_subscription_operation() { + let payload = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame.clone(), None) + .await + .expect("allowed subscription should relay"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), payload); + } + + #[tokio::test] + async fn graphql_websocket_policy_denies_unlisted_operation_field() { + let payload = + r#"{"type":"subscribe","id":"1","payload":{"query":"query Admin { adminAuditLog }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let err = run_client_to_server_with_graphql_policy(frame, None) + .await + .expect_err("unlisted field should be denied"); + + assert!(err.to_string().contains("websocket GraphQL message denied")); + } + + #[tokio::test] + async fn graphql_websocket_control_message_rewrites_credentials_before_relay() { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("T").expect("placeholder env"); + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame, Some(&resolver)) + .await + .expect("control message should relay after credential rewrite"); + + let rewritten = decode_masked_text_frame(&output); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + } + + #[tokio::test] + async fn text_without_placeholder_passes_semantically_unchanged() { + let frame = masked_frame(true, OPCODE_TEXT, br#"{"op":1,"d":42}"#); + let output = run_client_to_server(frame.clone()) + .await + .expect("relay should succeed"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), r#"{"op":1,"d":42}"#); + } + + #[tokio::test] + async fn unknown_placeholder_fails_closed() { + let frame = masked_frame( + true, + OPCODE_TEXT, + br#"{"token":"openshell:resolve:env:UNKNOWN"}"#, + ); + + let err = run_client_to_server(frame) + .await + .expect_err("unknown placeholder should fail"); + + assert!( + err.to_string() + .contains("credential placeholder resolution") + ); + } + + #[tokio::test] + async fn fragmented_text_rewrites_after_final_continuation() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let second = r#""}"#; + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend(masked_frame(true, OPCODE_CONTINUATION, second.as_bytes())); + + let output = run_client_to_server(input) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn rejects_rsv_bits() { + let mut frame = masked_frame(true, OPCODE_TEXT, b"hello"); + frame[0] |= 0x40; + + let err = run_client_to_server(frame) + .await + .expect_err("RSV frame should fail"); + + assert!(err.to_string().contains("RSV bits")); + } + + #[tokio::test] + async fn rejects_unmasked_client_frame() { + let err = run_client_to_server(unmasked_frame(OPCODE_TEXT, b"hello")) + .await + .expect_err("unmasked frame should fail"); + + assert!(err.to_string().contains("not masked")); + } + + #[tokio::test] + async fn rejects_invalid_utf8_text() { + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &[0xff])) + .await + .expect_err("invalid UTF-8 should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } + + #[tokio::test] + async fn rejects_oversize_text_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &payload)) + .await + .expect_err("oversize text should fail"); + + assert!(err.to_string().contains("exceeds")); + } + + #[tokio::test] + async fn fragmented_text_allows_interleaved_ping_pong_and_rewrites_at_completion() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let first_control_frame = masked_frame(true, OPCODE_PING, b"p"); + let second_control_frame = masked_frame(true, OPCODE_PONG, b"q"); + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend_from_slice(&first_control_frame); + input.extend_from_slice(&second_control_frame); + input.extend(masked_frame(true, OPCODE_CONTINUATION, br#""}"#)); + + let output = run_client_to_server(input) + .await + .expect("relay should allow interleaved control frames"); + + assert!(output.starts_with(&first_control_frame)); + assert_eq!( + &output + [first_control_frame.len()..first_control_frame.len() + second_control_frame.len()], + second_control_frame.as_slice() + ); + assert_eq!( + decode_masked_text_frame( + &output[first_control_frame.len() + second_control_frame.len()..] + ), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn compressed_text_rewrites_with_permessage_deflate() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"token":"{placeholder}"}}"#); + let compressed = compress_permessage_deflate(payload.as_bytes()).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let output = run_client_to_server_compressed(input) + .await + .expect("compressed text should relay"); + + assert_eq!( + decode_compressed_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn compressed_text_rejects_decompressed_oversize_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let compressed = compress_permessage_deflate(&payload).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let err = run_client_to_server_compressed(input) + .await + .expect_err("oversize decompressed text should fail"); + + assert!(err.to_string().contains("exceeds")); + } + + #[tokio::test] + async fn binary_frame_passes_through_unchanged() { + let frame = masked_frame(true, OPCODE_BINARY, &[0, 1, 2, 3, 255]); + + let output = run_client_to_server(frame.clone()) + .await + .expect("binary frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_reserved_opcode() { + let err = run_client_to_server(masked_frame(true, 0x3, b"reserved")) + .await + .expect_err("reserved opcode should fail"); + + assert!(err.to_string().contains("reserved opcode")); + } + + #[tokio::test] + async fn rejects_continuation_without_active_message() { + let err = run_client_to_server(masked_frame(true, OPCODE_CONTINUATION, b"orphan")) + .await + .expect_err("orphan continuation should fail"); + + assert!(err.to_string().contains("continuation")); + } + + #[tokio::test] + async fn rejects_new_data_frame_before_fragment_completion() { + let mut input = masked_frame(false, OPCODE_TEXT, b"partial"); + input.extend(masked_frame(true, OPCODE_TEXT, b"second")); + + let err = run_client_to_server(input) + .await + .expect_err("new data frame during fragmentation should fail"); + + assert!(err.to_string().contains("previous fragmented message")); + } + + #[tokio::test] + async fn rejects_fragmented_control_frame() { + let err = run_client_to_server(masked_frame(false, OPCODE_PING, b"ping")) + .await + .expect_err("fragmented control frame should fail"); + + assert!(err.to_string().contains("control frame is fragmented")); + } + + #[tokio::test] + async fn rejects_control_frame_over_125_bytes() { + let payload = vec![b'a'; 126]; + let err = run_client_to_server(masked_frame(true, OPCODE_PING, &payload)) + .await + .expect_err("oversize control frame should fail"); + + assert!(err.to_string().contains("control frame exceeds")); + } + + #[tokio::test] + async fn rejects_non_minimal_extended_length() { + let err = run_client_to_server(masked_frame_with_non_minimal_16_bit_len( + OPCODE_TEXT, + b"hello", + )) + .await + .expect_err("non-minimal length should fail"); + + assert!(err.to_string().contains("non-minimal")); + } + + #[tokio::test] + async fn rejects_oversize_binary_frame_before_payload_buffering() { + let err = run_client_to_server(masked_frame_with_declared_len( + OPCODE_BINARY, + MAX_RAW_FRAME_PAYLOAD_BYTES + 1, + )) + .await + .expect_err("oversize binary frame should fail"); + + assert!(err.to_string().contains("binary frame exceeds")); + } + + #[tokio::test] + async fn validates_close_frame_payloads() { + let frame = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + + let output = run_client_to_server(frame.clone()) + .await + .expect("valid close frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_close_frame_with_one_byte_payload() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &[0x03])) + .await + .expect_err("one-byte close frame should fail"); + + assert!(err.to_string().contains("exactly one byte")); + } + + #[tokio::test] + async fn rejects_reserved_close_code() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &close_payload(1005, b""))) + .await + .expect_err("reserved close code should fail"); + + assert!(err.to_string().contains("invalid close code")); + } + + #[tokio::test] + async fn rejects_close_reason_with_invalid_utf8() { + let err = run_client_to_server(masked_frame( + true, + OPCODE_CLOSE, + &close_payload(1000, &[0xff]), + )) + .await + .expect_err("invalid close reason should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } + + #[tokio::test] + async fn rejects_frames_after_client_close_frame() { + let mut input = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + input.extend(masked_frame(true, OPCODE_TEXT, b"late")); + + let err = run_client_to_server(input) + .await + .expect_err("frames after close should fail"); + + assert!(err.to_string().contains("after close")); + } + + #[test] + fn websocket_ocsf_messages_do_not_include_payload_or_secret_material() { + let placeholder = "openshell:resolve:env:DISCORD_BOT_TOKEN"; + let secret = "real-token"; + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let rewrite = rewrite_event_message("gateway.example.test", 443, 1); + let failure = protocol_failure_message("gateway.example.test", 443); + let messages = [rewrite, failure]; + + for message in messages { + assert!(!message.contains(placeholder)); + assert!(!message.contains(secret)); + assert!(!message.contains(&payload)); + assert!(!message.contains("secret_len")); + assert!(!message.contains("payload_len")); + } + } +} diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 19424bd2b..a21c0c130 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -7,6 +7,7 @@ pub mod bypass_monitor; mod child_env; +pub mod debug_rpc; pub mod denial_aggregator; mod grpc_client; mod identity; @@ -15,11 +16,14 @@ pub mod log_push; pub mod mechanistic_mapper; pub mod opa; mod policy; +mod policy_local; mod process; pub mod procfs; +mod provider_credentials; pub mod proxy; mod sandbox; mod secrets; +mod skills; mod ssh; mod supervisor_session; @@ -87,6 +91,83 @@ pub(crate) fn ocsf_ctx() -> &'static SandboxContext { OCSF_CTX.get().unwrap_or(&OCSF_CTX_FALLBACK) } +/// Process-wide flag for the agent-driven policy proposal surface. +/// Set once during `run_sandbox()` startup and updated by the settings poll +/// loop when `agent_policy_proposals_enabled` changes. Read by the +/// `policy.local` route handler and the L7 deny body's `next_steps` builder +/// to gate the agent-controlled mutation surface. Exposed `pub(crate)` so +/// unit tests in sibling modules can flip the flag through a serialized +/// guard (see `policy_local::tests::ProposalsFlagGuard`). +pub(crate) static AGENT_PROPOSALS_ENABLED: OnceLock> = + OnceLock::new(); + +/// Read the current value of the agent proposals feature flag. +/// +/// Returns `false` if `run_sandbox()` has not initialized the flag (e.g. +/// during unit tests), matching the documented default for the setting. +pub(crate) fn agent_proposals_enabled() -> bool { + AGENT_PROPOSALS_ENABLED + .get() + .is_some_and(|flag| flag.load(Ordering::Relaxed)) +} + +/// Test-only helpers shared across sibling test modules. +#[cfg(test)] +pub(crate) mod test_helpers { + #![allow( + clippy::redundant_pub_crate, + reason = "intentional crate-private module" + )] + use std::sync::Arc; + use std::sync::LazyLock; + use std::sync::atomic::{AtomicBool, Ordering}; + use tokio::sync::MutexGuard; + + static PROPOSALS_FLAG_LOCK: LazyLock> = + LazyLock::new(|| tokio::sync::Mutex::new(())); + + /// Guard for tests that toggle the process-wide + /// `AGENT_PROPOSALS_ENABLED` flag. Acquires a process-wide async mutex, + /// swaps in the requested value, and restores the previous value on drop. + /// Hold the guard for the duration of any code that reads + /// `agent_proposals_enabled()`. + pub(crate) struct ProposalsFlagGuard { + prev: bool, + flag: Arc, + _lock: MutexGuard<'static, ()>, + } + + impl ProposalsFlagGuard { + pub(crate) async fn set(enabled: bool) -> Self { + let lock = PROPOSALS_FLAG_LOCK.lock().await; + Self::with_lock(enabled, lock) + } + + pub(crate) fn set_blocking(enabled: bool) -> Self { + let lock = PROPOSALS_FLAG_LOCK.blocking_lock(); + Self::with_lock(enabled, lock) + } + + fn with_lock(enabled: bool, lock: MutexGuard<'static, ()>) -> Self { + let flag = super::AGENT_PROPOSALS_ENABLED + .get_or_init(|| Arc::new(AtomicBool::new(false))) + .clone(); + let prev = flag.swap(enabled, Ordering::Relaxed); + Self { + prev, + flag, + _lock: lock, + } + } + } + + impl Drop for ProposalsFlagGuard { + fn drop(&mut self) { + self.flag.store(self.prev, Ordering::Relaxed); + } + } +} + use crate::identity::BinaryIdentityCache; use crate::l7::tls::{ CertCache, ProxyTlsState, SandboxCa, build_upstream_client_config, read_system_ca_bundle, @@ -97,7 +178,6 @@ use crate::policy::{NetworkMode, NetworkPolicy, ProxyPolicy, SandboxPolicy}; use crate::proxy::ProxyHandle; #[cfg(target_os = "linux")] use crate::sandbox::linux::netns::NetworkNamespace; -use crate::secrets::SecretResolver; pub use process::{ProcessHandle, ProcessStatus}; pub use sandbox::apply_supervisor_startup_hardening; @@ -212,8 +292,6 @@ pub async fn run_sandbox( policy_rules: Option, policy_data: Option, ssh_socket_path: Option, - ssh_handshake_secret: Option, - ssh_handshake_skew_secs: u64, _health_check: bool, _health_port: u16, inference_routes: Option, @@ -260,6 +338,11 @@ pub async fn run_sandbox( policy_data, ) .await?; + let policy_local_ctx = Arc::new(policy_local::PolicyLocalContext::new( + retained_proto.clone(), + openshell_endpoint.clone(), + sandbox_name_for_agg.clone().or_else(|| sandbox_id.clone()), + )); // Validate that the required "sandbox" user exists in this image. // All sandbox images must include this user for privilege dropping. @@ -269,42 +352,59 @@ pub async fn run_sandbox( // Fetch provider environment variables from the server. // This is done after loading the policy so the sandbox can still start // even if provider env fetch fails (graceful degradation). - let provider_env = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { - match grpc_client::fetch_provider_environment(endpoint, id).await { - Ok(env) => { - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .state(StateId::Enabled, "loaded") - .message(format!( - "Fetched provider environment [env_count:{}]", - env.len() - )) - .build() - ); - env - } - Err(e) => { - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .state(StateId::Other, "degraded") - .message(format!( - "Failed to fetch provider environment, continuing without: {e}" - )) - .build() - ); - std::collections::HashMap::new() + let (provider_env_revision, provider_env, provider_credential_expires_at_ms) = + if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { + match grpc_client::fetch_provider_environment(endpoint, id).await { + Ok(result) => { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .message(format!( + "Fetched provider environment [env_count:{}]", + result.environment.len() + )) + .build() + ); + ( + result.provider_env_revision, + result.environment, + result.credential_expires_at_ms, + ) + } + Err(e) => { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "degraded") + .message(format!( + "Failed to fetch provider environment, continuing without: {e}" + )) + .build() + ); + ( + 0, + std::collections::HashMap::new(), + std::collections::HashMap::new(), + ) + } } - } - } else { - std::collections::HashMap::new() - }; + } else { + ( + 0, + std::collections::HashMap::new(), + std::collections::HashMap::new(), + ) + }; - let (provider_env, secret_resolver) = SecretResolver::from_provider_env(provider_env); - let secret_resolver = secret_resolver.map(Arc::new); + let provider_credentials = provider_credentials::ProviderCredentialState::from_environment( + provider_env_revision, + provider_env, + provider_credential_expires_at_ms, + ); + let provider_env = provider_credentials.snapshot().child_env.clone(); // Create identity cache for SHA256 TOFU when OPA is active let identity_cache = opa_engine @@ -314,6 +414,49 @@ pub async fn run_sandbox( // Prepare filesystem: create and chown read_write directories prepare_filesystem(&policy)?; + // Initialize the agent-proposals feature flag. Default false until the + // initial settings fetch (or the poll loop) tells us otherwise. The flag + // gates the skill install, the policy.local route handler, and the L7 + // deny body's `next_steps` field — see `agent_proposals_enabled()`. + let proposals_enabled = Arc::new(std::sync::atomic::AtomicBool::new(false)); + if AGENT_PROPOSALS_ENABLED + .set(proposals_enabled.clone()) + .is_err() + { + debug!("agent proposals flag already initialized, keeping existing"); + } + + // Eagerly fetch the initial settings so skill install can honor the flag + // at startup rather than waiting for the poll loop's first tick. In + // offline/file-mode there is no gateway, so the flag stays false. + if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) + && let Ok(client) = grpc_client::CachedOpenShellClient::connect(endpoint).await + && let Ok(result) = client.poll_settings(id).await + { + let initial = extract_bool_setting( + &result.settings, + openshell_core::settings::AGENT_POLICY_PROPOSALS_ENABLED_KEY, + ) + .unwrap_or(false); + proposals_enabled.store(initial, Ordering::Relaxed); + } + + if agent_proposals_enabled() { + match skills::install_static_skills() { + Ok(installed) => { + info!( + path = %installed.policy_advisor.display(), + "Installed sandbox agent skill" + ); + } + Err(error) => { + warn!(error = %error, "Failed to install sandbox agent skill"); + } + } + } else { + debug!("agent_policy_proposals_enabled is false at startup; skipping skill install"); + } + // Generate ephemeral CA and TLS state for HTTPS L7 inspection. // The CA cert is written to disk so sandbox processes can trust it. let (tls_state, ca_file_paths) = if matches!(policy.network.mode, NetworkMode::Proxy) { @@ -380,7 +523,7 @@ pub async fn run_sandbox( let netns = if matches!(policy.network.mode, NetworkMode::Proxy) { match NetworkNamespace::create() { Ok(ns) => { - // Install bypass detection rules (iptables LOG + REJECT). + // Install bypass detection rules (nftables log + reject). // This provides fast-fail UX and diagnostic logging for direct // connection attempts that bypass the HTTP CONNECT proxy. let proxy_port = policy @@ -421,7 +564,7 @@ pub async fn run_sandbox( let _netns: Option<()> = None; // Install the supervisor seccomp prelude after privileged startup helpers - // (network namespace setup, iptables probes) complete, but before the SSH + // (network namespace setup, nftables probes) complete, but before the SSH // listener and workload process are exposed. apply_supervisor_startup_hardening()?; @@ -480,7 +623,8 @@ pub async fn run_sandbox( entrypoint_pid.clone(), tls_state, inference_ctx, - secret_resolver.clone(), + Some(provider_credentials.clone()), + Some(policy_local_ctx.clone()), denial_tx, ) .await?; @@ -490,7 +634,7 @@ pub async fn run_sandbox( }; // Spawn bypass detection monitor (Linux only, proxy mode only). - // Reads /dev/kmsg for iptables LOG entries and emits structured + // Reads /dev/kmsg for nftables log entries and emits structured // tracing events for direct connection attempts that bypass the proxy. #[cfg(target_os = "linux")] let _bypass_monitor = netns.as_ref().and_then(|ns| { @@ -614,12 +758,10 @@ pub async fn run_sandbox( if let Some(listen_path) = ssh_socket_path.clone() { let policy_clone = policy.clone(); let workdir_clone = workdir.clone(); - let _ = ssh_handshake_secret; // retained in the signature for compat; unused - let _ = ssh_handshake_skew_secs; let proxy_url = ssh_proxy_url; let netns_fd = ssh_netns_fd; let ca_paths = ca_file_paths.clone(); - let provider_env_clone = provider_env.clone(); + let provider_credentials_clone = provider_credentials.clone(); let (ssh_ready_tx, ssh_ready_rx) = tokio::sync::oneshot::channel(); @@ -632,7 +774,7 @@ pub async fn run_sandbox( netns_fd, proxy_url, ca_paths, - provider_env_clone, + provider_credentials_clone, ) .await { @@ -685,7 +827,7 @@ pub async fn run_sandbox( sandbox_id.as_ref(), ssh_socket_path.as_ref(), ) { - supervisor_session::spawn(endpoint.clone(), id.clone(), socket.clone()); + supervisor_session::spawn(endpoint.clone(), id.clone(), socket.clone(), ssh_netns_fd); info!("supervisor session task spawned"); } @@ -796,22 +938,25 @@ pub async fn run_sandbox( let poll_engine = engine.clone(); let poll_ocsf_enabled = ocsf_enabled.clone(); let poll_pid = entrypoint_pid.clone(); + let poll_provider_credentials = provider_credentials.clone(); + let poll_policy_local = policy_local_ctx.clone(); let poll_interval_secs: u64 = std::env::var("OPENSHELL_POLICY_POLL_INTERVAL_SECS") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(10); + let poll_ctx = PolicyPollLoopContext { + endpoint: poll_endpoint, + sandbox_id: poll_id, + opa_engine: poll_engine, + entrypoint_pid: poll_pid, + interval_secs: poll_interval_secs, + ocsf_enabled: poll_ocsf_enabled, + provider_credentials: poll_provider_credentials, + policy_local_ctx: Some(poll_policy_local), + }; tokio::spawn(async move { - if let Err(e) = run_policy_poll_loop( - &poll_endpoint, - &poll_id, - &poll_engine, - &poll_pid, - poll_interval_secs, - &poll_ocsf_enabled, - ) - .await - { + if let Err(e) = run_policy_poll_loop(poll_ctx).await { ocsf_emit!( AppLifecycleBuilder::new(ocsf_ctx()) .activity(ActivityId::Fail) @@ -1306,22 +1451,39 @@ fn enumerate_gpu_device_nodes() -> Vec { paths } -/// Collect all baseline paths for enrichment: proxy defaults + GPU (if present). -/// Returns `(read_only, read_write)` as owned `String` vecs. -fn baseline_enrichment_paths() -> (Vec, Vec) { - let mut ro: Vec = PROXY_BASELINE_READ_ONLY - .iter() - .map(|&s| s.to_string()) - .collect(); - let mut rw: Vec = PROXY_BASELINE_READ_WRITE - .iter() - .map(|&s| s.to_string()) - .collect(); +fn push_unique(paths: &mut Vec, path: String) { + if !paths.iter().any(|p| p == &path) { + paths.push(path); + } +} - if has_gpu_devices() { - ro.extend(GPU_BASELINE_READ_ONLY.iter().map(|&s| s.to_string())); - rw.extend(GPU_BASELINE_READ_WRITE.iter().map(|&s| s.to_string())); - rw.extend(enumerate_gpu_device_nodes()); +fn collect_baseline_enrichment_paths( + include_proxy: bool, + include_gpu: bool, + gpu_device_nodes: Vec, +) -> (Vec, Vec) { + let mut ro = Vec::new(); + let mut rw = Vec::new(); + + if include_proxy { + for &path in PROXY_BASELINE_READ_ONLY { + push_unique(&mut ro, path.to_string()); + } + for &path in PROXY_BASELINE_READ_WRITE { + push_unique(&mut rw, path.to_string()); + } + } + + if include_gpu { + for &path in GPU_BASELINE_READ_ONLY { + push_unique(&mut ro, path.to_string()); + } + for &path in GPU_BASELINE_READ_WRITE { + push_unique(&mut rw, path.to_string()); + } + for path in gpu_device_nodes { + push_unique(&mut rw, path); + } } // A path promoted to read_write (e.g. /proc for GPU) should not also @@ -1332,14 +1494,33 @@ fn baseline_enrichment_paths() -> (Vec, Vec) { (ro, rw) } -/// Ensure a proto `SandboxPolicy` includes the baseline filesystem paths -/// required for proxy-mode sandboxes. Paths are only added if missing; -/// user-specified paths are never removed. -/// -/// Returns `true` if the policy was modified (caller may want to sync back). -fn enrich_proto_baseline_paths(proto: &mut openshell_core::proto::SandboxPolicy) -> bool { - // Only enrich if network_policies are present (proxy mode indicator). - if proto.network_policies.is_empty() { +fn active_baseline_enrichment_paths(include_proxy: bool) -> (Vec, Vec) { + let include_gpu = has_gpu_devices(); + let gpu_device_nodes = if include_gpu { + enumerate_gpu_device_nodes() + } else { + Vec::new() + }; + collect_baseline_enrichment_paths(include_proxy, include_gpu, gpu_device_nodes) +} + +/// Collect all active baseline paths for tests and diagnostics. +/// Returns `(read_only, read_write)` as owned `String` vecs. +#[cfg(test)] +fn baseline_enrichment_paths() -> (Vec, Vec) { + active_baseline_enrichment_paths(true) +} + +fn enrich_proto_baseline_paths_with( + proto: &mut openshell_core::proto::SandboxPolicy, + ro: &[String], + rw: &[String], + path_exists: F, +) -> bool +where + F: Fn(&str) -> bool, +{ + if ro.is_empty() && rw.is_empty() { return false; } @@ -1350,17 +1531,10 @@ fn enrich_proto_baseline_paths(proto: &mut openshell_core::proto::SandboxPolicy) ..Default::default() }); - let (ro, rw) = baseline_enrichment_paths(); - - // Baseline paths are system-injected, not user-specified. Skip paths - // that do not exist in this container image to avoid noisy warnings from - // Landlock and, more critically, to prevent a single missing baseline - // path from abandoning the entire Landlock ruleset under best-effort - // mode (see issue #664). let mut modified = false; - for path in &ro { + for path in ro { if !fs.read_only.iter().any(|p| p == path) && !fs.read_write.iter().any(|p| p == path) { - if !std::path::Path::new(path).exists() { + if !path_exists(path) { debug!( path, "Baseline read-only path does not exist, skipping enrichment" @@ -1371,11 +1545,11 @@ fn enrich_proto_baseline_paths(proto: &mut openshell_core::proto::SandboxPolicy) modified = true; } } - for path in &rw { + for path in rw { if fs.read_only.iter().any(|p| p == path) || fs.read_write.iter().any(|p| p == path) { continue; } - if !std::path::Path::new(path).exists() { + if !path_exists(path) { debug!( path, "Baseline read-write path does not exist, skipping enrichment" @@ -1386,6 +1560,26 @@ fn enrich_proto_baseline_paths(proto: &mut openshell_core::proto::SandboxPolicy) modified = true; } + modified +} + +/// Ensure a proto `SandboxPolicy` includes the baseline filesystem paths +/// required by proxy-mode sandboxes and GPU runtimes. Paths are only added if +/// missing; user-specified paths are never removed. +/// +/// Returns `true` if the policy was modified (caller may want to sync back). +fn enrich_proto_baseline_paths(proto: &mut openshell_core::proto::SandboxPolicy) -> bool { + let (ro, rw) = active_baseline_enrichment_paths(!proto.network_policies.is_empty()); + + // Baseline paths are system-injected, not user-specified. Skip paths + // that do not exist in this container image to avoid noisy warnings from + // Landlock and, more critically, to prevent a single missing baseline + // path from abandoning the entire Landlock ruleset under best-effort + // mode (see issue #664). + let modified = enrich_proto_baseline_paths_with(proto, &ro, &rw, |path| { + std::path::Path::new(path).exists() + }); + if modified { ocsf_emit!( ConfigStateChangeBuilder::new(ocsf_ctx()) @@ -1401,15 +1595,15 @@ fn enrich_proto_baseline_paths(proto: &mut openshell_core::proto::SandboxPolicy) } /// Ensure a `SandboxPolicy` (Rust type) includes the baseline filesystem -/// paths required for proxy-mode sandboxes. Used for the local-file code -/// path where no proto is available. +/// paths required by proxy-mode sandboxes and GPU runtimes. Used for the +/// local-file code path where no proto is available. fn enrich_sandbox_baseline_paths(policy: &mut SandboxPolicy) { - if !matches!(policy.network.mode, NetworkMode::Proxy) { + let (ro, rw) = + active_baseline_enrichment_paths(matches!(policy.network.mode, NetworkMode::Proxy)); + if ro.is_empty() && rw.is_empty() { return; } - let (ro, rw) = baseline_enrichment_paths(); - let mut modified = false; for path in &ro { let p = std::path::PathBuf::from(path); @@ -1564,6 +1758,31 @@ mod baseline_tests { ); } + #[test] + fn proto_gpu_enrichment_adds_devices_without_network_policy() { + let mut policy = openshell_policy::restrictive_default_policy(); + assert!( + policy.network_policies.is_empty(), + "regression setup must exercise the no-network default path" + ); + let (ro, rw) = + collect_baseline_enrichment_paths(false, true, vec!["/dev/nvidia0".to_string()]); + + let enriched = enrich_proto_baseline_paths_with(&mut policy, &ro, &rw, |path| { + matches!(path, "/proc" | "/dev/nvidia0") + }); + + let filesystem = policy.filesystem.expect("filesystem policy"); + assert!( + enriched, + "GPU enrichment should not require network policies" + ); + assert!( + filesystem.read_write.contains(&"/dev/nvidia0".to_string()), + "GPU enrichment should add enumerated device nodes without network policies" + ); + } + #[test] fn gpu_baseline_read_write_contains_dxg() { // /dev/dxg must be present so WSL2 sandboxes get the Landlock @@ -2124,7 +2343,7 @@ async fn flush_proposals_to_gateway( // Run the mechanistic mapper sandbox-side to generate proposals. // The gateway is a thin persistence + validation layer — it never // generates proposals itself. - let proposals = mechanistic_mapper::generate_proposals(&proto_summaries).await; + let proposals = mechanistic_mapper::generate_proposals(&proto_summaries); info!( sandbox_name = %sandbox_name, @@ -2145,20 +2364,25 @@ async fn flush_proposals_to_gateway( /// /// When the entrypoint PID is available, policy reloads include symlink /// resolution for binary paths via the container filesystem. -async fn run_policy_poll_loop( - endpoint: &str, - sandbox_id: &str, - opa_engine: &Arc, - entrypoint_pid: &Arc, +struct PolicyPollLoopContext { + endpoint: String, + sandbox_id: String, + opa_engine: Arc, + entrypoint_pid: Arc, interval_secs: u64, - ocsf_enabled: &std::sync::atomic::AtomicBool, -) -> Result<()> { + ocsf_enabled: Arc, + provider_credentials: provider_credentials::ProviderCredentialState, + policy_local_ctx: Option>, +} + +async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { use crate::grpc_client::CachedOpenShellClient; use openshell_core::proto::PolicySource; use std::sync::atomic::Ordering; - let client = CachedOpenShellClient::connect(endpoint).await?; + let client = CachedOpenShellClient::connect(&ctx.endpoint).await?; let mut current_config_revision: u64 = 0; + let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); let mut current_settings: std::collections::HashMap< String, @@ -2166,7 +2390,7 @@ async fn run_policy_poll_loop( > = std::collections::HashMap::new(); // Initialize revision from the first poll. - match client.poll_settings(sandbox_id).await { + match client.poll_settings(&ctx.sandbox_id).await { Ok(result) => { current_config_revision = result.config_revision; current_policy_hash = result.policy_hash.clone(); @@ -2181,11 +2405,11 @@ async fn run_policy_poll_loop( } } - let interval = Duration::from_secs(interval_secs); + let interval = Duration::from_secs(ctx.interval_secs); loop { tokio::time::sleep(interval).await; - let result = match client.poll_settings(sandbox_id).await { + let result = match client.poll_settings(&ctx.sandbox_id).await { Ok(r) => r, Err(e) => { debug!(error = %e, "Settings poll: server unreachable, will retry"); @@ -2193,7 +2417,8 @@ async fn run_policy_poll_loop( } }; - if result.config_revision == current_config_revision { + let provider_env_changed = result.provider_env_revision != current_provider_env_revision; + if result.config_revision == current_config_revision && !provider_env_changed { continue; } @@ -2209,12 +2434,47 @@ async fn run_policy_poll_loop( .unmapped("old_config_revision", serde_json::json!(current_config_revision)) .unmapped("new_config_revision", serde_json::json!(result.config_revision)) .unmapped("policy_changed", serde_json::json!(policy_changed)) + .unmapped("provider_env_changed", serde_json::json!(provider_env_changed)) .message(format!( - "Settings poll: config change detected [old_revision:{current_config_revision} new_revision:{} policy_changed:{policy_changed}]", + "Settings poll: config change detected [old_revision:{current_config_revision} new_revision:{} policy_changed:{policy_changed} provider_env_changed:{provider_env_changed}]", result.config_revision )) .build()); + if provider_env_changed { + match grpc_client::fetch_provider_environment(&ctx.endpoint, &ctx.sandbox_id).await { + Ok(env_result) => { + let env_count = ctx.provider_credentials.install_environment( + env_result.provider_env_revision, + env_result.environment, + env_result.credential_expires_at_ms, + ); + current_provider_env_revision = env_result.provider_env_revision; + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped( + "provider_env_revision", + serde_json::json!(current_provider_env_revision) + ) + .message(format!( + "Provider environment refreshed [revision:{current_provider_env_revision} env_count:{env_count}]" + )) + .build() + ); + } + Err(e) => { + warn!( + error = %e, + provider_env_revision = result.provider_env_revision, + "Settings poll: failed to refresh provider environment" + ); + } + } + } + // Only reload OPA when the policy payload actually changed. if policy_changed { let Some(policy) = result.policy.as_ref() else { @@ -2230,9 +2490,12 @@ async fn run_policy_poll_loop( continue; }; - let pid = entrypoint_pid.load(Ordering::Acquire); - match opa_engine.reload_from_proto_with_pid(policy, pid) { + let pid = ctx.entrypoint_pid.load(Ordering::Acquire); + match ctx.opa_engine.reload_from_proto_with_pid(policy, pid) { Ok(()) => { + if let Some(policy_local_ctx) = ctx.policy_local_ctx.as_ref() { + policy_local_ctx.set_current_policy(policy.clone()).await; + } if result.global_policy_version > 0 { ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) .severity(SeverityId::Informational) @@ -2263,7 +2526,7 @@ async fn run_policy_poll_loop( if result.version > 0 && result.policy_source == PolicySource::Sandbox && let Err(e) = client - .report_policy_status(sandbox_id, result.version, true, "") + .report_policy_status(&ctx.sandbox_id, result.version, true, "") .await { warn!(error = %e, "Failed to report policy load success"); @@ -2284,7 +2547,12 @@ async fn run_policy_poll_loop( if result.version > 0 && result.policy_source == PolicySource::Sandbox && let Err(report_err) = client - .report_policy_status(sandbox_id, result.version, false, &e.to_string()) + .report_policy_status( + &ctx.sandbox_id, + result.version, + false, + &e.to_string(), + ) .await { warn!(error = %report_err, "Failed to report policy load failure"); @@ -2295,11 +2563,45 @@ async fn run_policy_poll_loop( // Apply OCSF JSON toggle from the `ocsf_json_enabled` setting. let new_ocsf = extract_bool_setting(&result.settings, "ocsf_json_enabled").unwrap_or(false); - let prev_ocsf = ocsf_enabled.swap(new_ocsf, Ordering::Relaxed); + let prev_ocsf = ctx.ocsf_enabled.swap(new_ocsf, Ordering::Relaxed); if new_ocsf != prev_ocsf { info!(ocsf_json_enabled = new_ocsf, "OCSF JSONL logging toggled"); } + // Apply the agent-proposals feature toggle. On a false→true transition + // we lazily install the skill so a sandbox that started with the flag + // off picks up the surface without a recreate. We never uninstall on + // a true→false transition: stale skill content on disk is harmless + // because route_request and agent_next_steps both gate on the live + // atomic, so the agent that reads the skill will see 404s and an + // empty `next_steps` array regardless. + if let Some(flag) = AGENT_PROPOSALS_ENABLED.get() { + let new_proposals = extract_bool_setting( + &result.settings, + openshell_core::settings::AGENT_POLICY_PROPOSALS_ENABLED_KEY, + ) + .unwrap_or(false); + let prev_proposals = flag.swap(new_proposals, Ordering::Relaxed); + if new_proposals != prev_proposals { + info!( + agent_policy_proposals_enabled = new_proposals, + "agent-driven policy proposals toggled" + ); + if new_proposals && !prev_proposals { + match skills::install_static_skills() { + Ok(installed) => info!( + path = %installed.policy_advisor.display(), + "Installed sandbox agent skill on toggle-on" + ), + Err(error) => warn!( + error = %error, + "Failed to install sandbox agent skill on toggle-on" + ), + } + } + } + } + current_config_revision = result.config_revision; current_policy_hash = result.policy_hash; current_settings = result.settings; diff --git a/crates/openshell-sandbox/src/log_push.rs b/crates/openshell-sandbox/src/log_push.rs index 8e053f79f..fd33d1e07 100644 --- a/crates/openshell-sandbox/src/log_push.rs +++ b/crates/openshell-sandbox/src/log_push.rs @@ -9,7 +9,6 @@ use crate::grpc_client::CachedOpenShellClient; use openshell_core::proto::{PushSandboxLogsRequest, SandboxLogLine}; -use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::mpsc; use tracing::{Event, Subscriber}; use tracing_subscriber::Layer; @@ -67,7 +66,7 @@ impl Layer for LogPushLayer { visitor.into_parts(meta.name()) }; - let ts = current_time_ms().unwrap_or(0); + let ts = openshell_core::time::now_ms(); let is_ocsf = meta.target() == openshell_ocsf::OCSF_TARGET; @@ -298,8 +297,3 @@ impl tracing::field::Visit for LogVisitor { } } } - -fn current_time_ms() -> Option { - let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?; - i64::try_from(now.as_millis()).ok() -} diff --git a/crates/openshell-sandbox/src/main.rs b/crates/openshell-sandbox/src/main.rs index 20d455663..3c9e21578 100644 --- a/crates/openshell-sandbox/src/main.rs +++ b/crates/openshell-sandbox/src/main.rs @@ -19,11 +19,20 @@ use openshell_sandbox::run_sandbox; /// Subcommand name used to self-copy the supervisor binary into a shared volume. /// -/// The supervisor image only ships the binary itself, so init containers -/// cannot rely on `sh`/`cp` to copy the binary out. Invoking the binary itself -/// with this argument performs the copy in pure Rust. +/// Init containers invoke the binary directly instead of relying on `sh`/`cp` +/// to copy the binary out. Invoking the binary itself with this argument +/// performs the copy in pure Rust. const COPY_SELF_SUBCOMMAND: &str = "copy-self"; +/// Subcommand for one-shot debug RPCs from inside a sandbox container. +/// +/// Reads the same token sources as the supervisor (env, file, K8s SA +/// bootstrap) and issues a single gRPC call against the gateway. Useful +/// for end-to-end verification: e.g. `docker exec` into a sandbox, then +/// run `openshell-sandbox debug-rpc get-sandbox-config --sandbox-id ` +/// to confirm the cross-sandbox IDOR guard fires. +const DEBUG_RPC_SUBCOMMAND: &str = "debug-rpc"; + /// `OpenShell` Sandbox - process isolation and monitoring. #[derive(Parser, Debug)] #[command(name = "openshell-sandbox")] @@ -50,17 +59,17 @@ struct Args { /// Sandbox ID for fetching policy via gRPC from `OpenShell` server. /// Requires --openshell-endpoint to be set. - #[arg(long, env = "OPENSHELL_SANDBOX_ID")] + #[arg(long, env = openshell_core::sandbox_env::SANDBOX_ID)] sandbox_id: Option, /// Sandbox (used for policy sync when the sandbox discovers policy /// from disk or falls back to the restrictive default). - #[arg(long, env = "OPENSHELL_SANDBOX")] + #[arg(long, env = openshell_core::sandbox_env::SANDBOX)] sandbox: Option, /// `OpenShell` server gRPC endpoint for fetching policy. /// Required when using --sandbox-id. - #[arg(long, env = "OPENSHELL_ENDPOINT")] + #[arg(long, env = openshell_core::sandbox_env::ENDPOINT)] openshell_endpoint: Option, /// Path to Rego policy file for OPA-based network access control. @@ -74,23 +83,15 @@ struct Args { policy_data: Option, /// Log level (trace, debug, info, warn, error). - #[arg(long, default_value = "warn", env = "OPENSHELL_LOG_LEVEL")] + #[arg(long, default_value = "warn", env = openshell_core::sandbox_env::LOG_LEVEL)] log_level: String, /// Filesystem path to the Unix socket the embedded SSH daemon binds. /// The supervisor bridges `RelayStream` traffic from the gateway onto /// this socket; nothing else should connect to it. - #[arg(long, env = "OPENSHELL_SSH_SOCKET_PATH")] + #[arg(long, env = openshell_core::sandbox_env::SSH_SOCKET_PATH)] ssh_socket_path: Option, - /// Shared secret for gateway-to-sandbox SSH handshake. - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SECRET")] - ssh_handshake_secret: Option, - - /// Allowed clock skew for SSH handshake validation. - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SKEW_SECS", default_value = "300")] - ssh_handshake_skew_secs: u64, - /// Path to YAML inference routes for standalone routing. /// When set, inference routes are loaded from this file instead of /// fetching a bundle from the gateway. @@ -148,9 +149,8 @@ fn copy_self(dest: &str) -> Result<()> { fn main() -> Result<()> { // Handle `copy-self ` before clap so it works without any of the - // sandbox flags. The supervisor image only ships the binary itself, and - // Kubernetes init containers invoke this path to seed an emptyDir volume - // that the agent container then executes from. + // sandbox flags. Kubernetes init containers invoke this path to seed an + // emptyDir volume that the agent container then executes from. let raw_args: Vec = std::env::args().collect(); if raw_args.get(1).map(String::as_str) == Some(COPY_SELF_SUBCOMMAND) { let dest = raw_args.get(2).ok_or_else(|| { @@ -159,6 +159,20 @@ fn main() -> Result<()> { return copy_self(dest); } + // Handle `debug-rpc [args]` before clap. Uses a small + // dedicated runtime so we don't pay the supervisor's full startup cost. + if raw_args.get(1).map(String::as_str) == Some(DEBUG_RPC_SUBCOMMAND) { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .into_diagnostic()?; + return runtime.block_on(async move { + let _ = rustls::crypto::ring::default_provider().install_default(); + let exit = openshell_sandbox::debug_rpc::run(&raw_args[2..]).await?; + std::process::exit(exit); + }); + } + let args = Args::parse(); // Try to open a rolling log file; fall back to stderr-only logging if it fails @@ -266,7 +280,7 @@ fn main() -> Result<()> { // Get command - either from CLI args, environment variable, or default to /bin/bash let command = if !args.command.is_empty() { args.command - } else if let Ok(c) = std::env::var("OPENSHELL_SANDBOX_COMMAND") { + } else if let Ok(c) = std::env::var(openshell_core::sandbox_env::SANDBOX_COMMAND) { // Simple shell-like splitting on whitespace c.split_whitespace().map(String::from).collect() } else { @@ -289,8 +303,6 @@ fn main() -> Result<()> { args.policy_rules, args.policy_data, args.ssh_socket_path, - args.ssh_handshake_secret, - args.ssh_handshake_skew_secs, args.health_check, args.health_port, args.inference_routes, diff --git a/crates/openshell-sandbox/src/mechanistic_mapper.rs b/crates/openshell-sandbox/src/mechanistic_mapper.rs index cb6daa550..521c882a0 100644 --- a/crates/openshell-sandbox/src/mechanistic_mapper.rs +++ b/crates/openshell-sandbox/src/mechanistic_mapper.rs @@ -12,7 +12,7 @@ //! The LLM-powered `PolicyAdvisor` (issue #205) wraps and enriches these //! mechanistic proposals with context-aware rationale and smarter grouping. -use openshell_core::net::{is_always_blocked_ip, is_internal_ip}; +use openshell_core::net::is_always_blocked_ip; use openshell_core::proto::{ DenialSummary, L7Allow, L7Rule, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, PolicyChunk, }; @@ -48,12 +48,15 @@ const WELL_KNOWN_PORTS: &[(u16, &str)] = &[ /// single binary. This produces one proposal per binary so each /// `(sandbox_id, host, port, binary)` maps to exactly one DB row. /// -/// When a host resolves to a private IP (RFC 1918, loopback, link-local), -/// the proposed endpoint includes `allowed_ips` so the proxy's SSRF override -/// accepts the connection. Public IPs do not need `allowed_ips`. +/// Proposals never include `allowed_ips`. If the user applies a proposed rule +/// and the host resolves to a private IP, the proxy's SSRF defense will deny +/// the connection. That SSRF denial flows back through the aggregator, and the +/// user can then explicitly add `allowed_ips` to their policy. This two-step +/// flow avoids DNS resolution in the mapper, which would leak the denied +/// hostname via DNS even though the connection was blocked. See #1169. /// /// Returns an empty vec if there are no actionable denials. -pub async fn generate_proposals(summaries: &[DenialSummary]) -> Vec { +pub fn generate_proposals(summaries: &[DenialSummary]) -> Vec { // Group denials by (host, port, binary). let mut groups: HashMap<(String, u32, String), Vec<&DenialSummary>> = HashMap::new(); @@ -116,11 +119,6 @@ pub async fn generate_proposals(summaries: &[DenialSummary]) -> Vec continue; } - // Resolve the host and check if any IP is private. When a host - // resolves to private IP space, the proxy requires `allowed_ips` as an - // explicit SSRF override. Public IPs don't need this. - let allowed_ips = resolve_allowed_ips_if_private(host, *port).await; - // Build proposed NetworkPolicyRule. let l7_rules = build_l7_rules(&l7_methods); let endpoint = if has_l7 && !l7_rules.is_empty() { @@ -131,7 +129,6 @@ pub async fn generate_proposals(summaries: &[DenialSummary]) -> Vec protocol: "rest".to_string(), enforcement: "enforce".to_string(), rules: l7_rules, - allowed_ips: allowed_ips.clone(), ..Default::default() } } else { @@ -139,7 +136,6 @@ pub async fn generate_proposals(summaries: &[DenialSummary]) -> Vec host: host.clone(), port: *port, ports: vec![*port], - allowed_ips: allowed_ips.clone(), ..Default::default() } }; @@ -178,15 +174,6 @@ pub async fn generate_proposals(summaries: &[DenialSummary]) -> Vec .map(|(_, name)| format!(" ({name})")) .unwrap_or_default(); - let private_ip_note = if allowed_ips.is_empty() { - String::new() - } else { - format!( - " Host resolves to private IP ({}); allowed_ips included for SSRF override.", - allowed_ips.join(", ") - ) - }; - // Note: hit_count in the DB accumulates across flush cycles, so we // don't bake a denial count into the rationale text (it would go stale). let rationale = if has_l7 && !l7_methods.is_empty() { @@ -194,13 +181,13 @@ pub async fn generate_proposals(summaries: &[DenialSummary]) -> Vec format!( "Allow {binary_list} to connect to {host}:{port}{port_name} \ with L7 inspection. \ - Allowed paths: {}.{private_ip_note}", + Allowed paths: {}.", paths.join(", ") ) } else { format!( "Allow {binary_list} to connect to \ - {host}:{port}{port_name}.{private_ip_note}" + {host}:{port}{port_name}." ) }; @@ -230,6 +217,8 @@ pub async fn generate_proposals(summaries: &[DenialSummary]) -> Vec first_seen_ms, last_seen_ms, binary: binary.clone(), + validation_result: String::new(), + rejection_reason: String::new(), }); } @@ -290,17 +279,18 @@ fn generate_security_notes(host: &str, port: u16, is_ssrf: bool) -> String { if is_ssrf { notes.push( "This connection was blocked by SSRF protection. \ - Allowing it bypasses internal-IP safety checks." + Private IP access requires an explicit `allowed_ips` policy entry." .to_string(), ); } - // Check for private IP patterns in the host. + // Check for private/reserved IP patterns in the host. if host.starts_with("10.") || host.starts_with("172.") || host.starts_with("192.168.") || host == "localhost" || host.starts_with("127.") + || host.starts_with("169.254.") { notes.push(format!( "Destination '{host}' appears to be an internal/private address." @@ -438,82 +428,6 @@ fn is_always_blocked_destination(host: &str) -> bool { host_lc == "localhost" || host_lc == "localhost." } -/// Resolve a hostname and return the IPs as `allowed_ips` strings only if any -/// resolved address is in private IP space. -/// -/// When a host resolves entirely to public IPs, the proxy doesn't need -/// `allowed_ips` — it passes public traffic through after the OPA check. -/// When any resolved IP is private/internal, the proxy requires `allowed_ips` -/// as an explicit SSRF override, so the mapper includes them. -/// -/// Returns an empty vec for public hosts or on DNS resolution failure. -async fn resolve_allowed_ips_if_private(host: &str, port: u32) -> Vec { - let addr = format!("{host}:{port}"); - let addrs = match tokio::net::lookup_host(&addr).await { - Ok(addrs) => addrs.collect::>(), - Err(e) => { - let port_u16 = u16::try_from(port).unwrap_or(u16::MAX); - let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(openshell_ocsf::ActivityId::Fail) - .severity(openshell_ocsf::SeverityId::Low) - .dst_endpoint(openshell_ocsf::Endpoint::from_domain(host, port_u16)) - .message(format!("DNS resolution failed for allowed_ips check: {e}")) - .build(); - openshell_ocsf::ocsf_emit!(event); - return Vec::new(); - } - }; - - if addrs.is_empty() { - let port_u16 = u16::try_from(port).unwrap_or(u16::MAX); - let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(openshell_ocsf::ActivityId::Fail) - .severity(openshell_ocsf::SeverityId::Low) - .dst_endpoint(openshell_ocsf::Endpoint::from_domain(host, port_u16)) - .message(format!( - "DNS resolution returned no addresses for {host}:{port}" - )) - .build(); - openshell_ocsf::ocsf_emit!(event); - return Vec::new(); - } - - let has_private = addrs.iter().any(|a| is_internal_ip(a.ip())); - if !has_private { - return Vec::new(); - } - - // Host has private IPs — include non-always-blocked resolved IPs in - // allowed_ips. Always-blocked addresses (loopback, link-local, - // unspecified) are filtered out since the proxy will reject them - // regardless of policy. - let mut ips: Vec = addrs - .iter() - .filter(|a| !is_always_blocked_ip(a.ip())) - .map(|a| a.ip().to_string()) - .collect(); - ips.sort(); - ips.dedup(); - - if ips.is_empty() { - // All resolved IPs were always-blocked — no viable allowed_ips. - tracing::debug!( - host, - port, - "All resolved IPs are always-blocked; skipping allowed_ips" - ); - return Vec::new(); - } - - tracing::debug!( - host, - port, - ?ips, - "Host resolves to private IP; adding allowed_ips" - ); - ips -} - #[cfg(test)] mod tests { use super::*; @@ -547,14 +461,14 @@ mod tests { assert!(notes.contains("SSRF")); } - #[tokio::test] - async fn test_generate_proposals_empty() { - let proposals = generate_proposals(&[]).await; + #[test] + fn test_generate_proposals_empty() { + let proposals = generate_proposals(&[]); assert!(proposals.is_empty()); } - #[tokio::test] - async fn test_generate_proposals_basic() { + #[test] + fn test_generate_proposals_basic() { let summaries = vec![DenialSummary { sandbox_id: "test".to_string(), host: "api.example.com".to_string(), @@ -575,7 +489,7 @@ mod tests { l7_inspection_active: false, }]; - let proposals = generate_proposals(&summaries).await; + let proposals = generate_proposals(&summaries); assert_eq!(proposals.len(), 1); assert_eq!(proposals[0].rule_name, "allow_api_example_com_443"); assert!(proposals[0].proposed_rule.is_some()); @@ -591,15 +505,12 @@ mod tests { assert!(rule.endpoints[0].protocol.is_empty()); assert!(rule.endpoints[0].rules.is_empty()); - // Public host should NOT have allowed_ips. - assert!( - rule.endpoints[0].allowed_ips.is_empty(), - "Public host should not get allowed_ips" - ); + // Proposals never include allowed_ips (two-step approval flow). + assert!(rule.endpoints[0].allowed_ips.is_empty()); } - #[tokio::test] - async fn test_generate_proposals_with_l7_samples() { + #[test] + fn test_generate_proposals_with_l7_samples() { use openshell_core::proto::L7RequestSample; let summaries = vec![DenialSummary { @@ -635,7 +546,7 @@ mod tests { l7_inspection_active: true, }]; - let proposals = generate_proposals(&summaries).await; + let proposals = generate_proposals(&summaries); assert_eq!(proposals.len(), 1); let rule = proposals[0].proposed_rule.as_ref().unwrap(); @@ -664,76 +575,6 @@ mod tests { assert!(proposals[0].rationale.contains("L7")); } - // -- is_internal_ip tests ------------------------------------------------- - - #[test] - fn test_is_internal_ip_private_v4() { - use std::net::Ipv4Addr; - // RFC 1918 ranges - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(10, 110, 50, 3)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)))); - } - - #[test] - fn test_is_internal_ip_loopback_and_link_local() { - use std::net::Ipv4Addr; - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::LOCALHOST))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new( - 169, 254, 169, 254 - )))); - } - - #[test] - fn test_is_internal_ip_public_v4() { - use std::net::Ipv4Addr; - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(208, 95, 112, 1)))); - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)))); - } - - #[test] - fn test_is_internal_ip_cgnat() { - use std::net::Ipv4Addr; - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(100, 100, 50, 3)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new( - 100, 127, 255, 255 - )))); - // Just outside the /10 boundary - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(100, 128, 0, 1)))); - } - - #[test] - fn test_is_internal_ip_special_use() { - use std::net::Ipv4Addr; - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(192, 0, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(198, 18, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 1)))); - } - - #[test] - fn test_is_internal_ip_v6() { - use std::net::Ipv6Addr; - // Loopback - assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))); - // Link-local fe80::1 - assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::new( - 0xfe80, 0, 0, 0, 0, 0, 0, 1 - )))); - // ULA fd00::1 - assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::new( - 0xfd00, 0, 0, 0, 0, 0, 0, 1 - )))); - // Public 2001:db8::1 - assert!(!is_internal_ip(IpAddr::V6(Ipv6Addr::new( - 0x2001, 0xdb8, 0, 0, 0, 0, 0, 1 - )))); - } - // -- is_always_blocked_destination tests ------------------------------------ #[test] @@ -770,8 +611,8 @@ mod tests { // -- generate_proposals: always-blocked filtering tests -------------------- - #[tokio::test] - async fn test_generate_proposals_skips_loopback_destination() { + #[test] + fn test_generate_proposals_skips_loopback_destination() { let summaries = vec![DenialSummary { host: "127.0.0.1".to_string(), port: 80, @@ -783,15 +624,15 @@ mod tests { ..Default::default() }]; - let proposals = generate_proposals(&summaries).await; + let proposals = generate_proposals(&summaries); assert!( proposals.is_empty(), "should skip proposals for loopback: {proposals:?}" ); } - #[tokio::test] - async fn test_generate_proposals_skips_link_local_destination() { + #[test] + fn test_generate_proposals_skips_link_local_destination() { let summaries = vec![DenialSummary { host: "169.254.169.254".to_string(), port: 80, @@ -803,15 +644,15 @@ mod tests { ..Default::default() }]; - let proposals = generate_proposals(&summaries).await; + let proposals = generate_proposals(&summaries); assert!( proposals.is_empty(), "should skip proposals for link-local: {proposals:?}" ); } - #[tokio::test] - async fn test_generate_proposals_skips_localhost_hostname() { + #[test] + fn test_generate_proposals_skips_localhost_hostname() { let summaries = vec![DenialSummary { host: "localhost".to_string(), port: 8080, @@ -823,15 +664,15 @@ mod tests { ..Default::default() }]; - let proposals = generate_proposals(&summaries).await; + let proposals = generate_proposals(&summaries); assert!( proposals.is_empty(), "should skip proposals for localhost: {proposals:?}" ); } - #[tokio::test] - async fn test_generate_proposals_keeps_public_destination() { + #[test] + fn test_generate_proposals_keeps_public_destination() { let summaries = vec![DenialSummary { host: "api.github.com".to_string(), port: 443, @@ -843,7 +684,7 @@ mod tests { ..Default::default() }]; - let proposals = generate_proposals(&summaries).await; + let proposals = generate_proposals(&summaries); assert_eq!(proposals.len(), 1, "should keep proposals for public host"); } diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index 5897679a0..b49875b78 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -1061,6 +1061,12 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St if e.allow_encoded_slash { ep["allow_encoded_slash"] = true.into(); } + if e.websocket_credential_rewrite { + ep["websocket_credential_rewrite"] = true.into(); + } + if e.request_body_credential_rewrite { + ep["request_body_credential_rewrite"] = true.into(); + } if !e.persisted_queries.is_empty() { ep["persisted_queries"] = e.persisted_queries.clone().into(); } @@ -1811,6 +1817,28 @@ network_policies: access: read-only binaries: - { path: /usr/bin/curl } + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + deny_rules: + - operation_type: mutation + binaries: + - { path: /usr/bin/curl } l4_only: name: l4_only endpoints: @@ -1897,6 +1925,25 @@ process: }) } + fn l7_websocket_graphql_input(host: &str, operations: serde_json::Value) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": 443 }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "WEBSOCKET_TEXT", + "path": "/graphql", + "query_params": {}, + "graphql": { + "operations": operations + } + } + }) + } + fn eval_l7(engine: &OpaEngine, input: &serde_json::Value) -> bool { let mut eng = engine.engine.lock().unwrap(); eng.set_input_json(&input.to_string()).unwrap(); @@ -2134,6 +2181,97 @@ process: assert!(!eval_l7(&engine, &mutation)); } + #[test] + fn l7_websocket_graphql_subscription_allowed_by_field_rule() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "subscription", + "operation_name": "NewMessages", + "fields": ["messageAdded"], + "persisted_query": false + }]), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_unlisted_field_denied() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_deny_rule_takes_precedence() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "mutation", + "operation_name": "DeleteRepo", + "fields": ["deleteRepository"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_not_bypassed_by_generic_text_rule() { + let data = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + method: WEBSOCKET_TEXT + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + binaries: + - { path: /usr/bin/curl } +"#; + let data_json: serde_json::Value = + serde_yml::from_str(data).expect("fixture should parse as YAML"); + let mut rego = regorus::Engine::new(); + rego.add_policy("policy.rego".into(), TEST_POLICY.into()) + .expect("policy should load"); + rego.add_data_json(&data_json.to_string()) + .expect("data should load"); + let engine = OpaEngine { + engine: Mutex::new(rego), + generation: Arc::new(AtomicU64::new(0)), + }; + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + #[test] fn l7_endpoint_path_scopes_rest_and_graphql_on_same_host() { let data = r#" @@ -2463,6 +2601,120 @@ network_policies: assert!(l7.allow_encoded_slash); } + #[test] + fn l7_endpoint_config_preserves_proto_websocket_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "gateway".to_string(), + NetworkPolicyRule { + name: "gateway".to_string(), + endpoints: vec![NetworkEndpoint { + host: "gateway.example.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "full".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "gateway.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.websocket_credential_rewrite); + } + + #[test] + fn l7_endpoint_config_preserves_proto_request_body_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "slack".to_string(), + NetworkPolicyRule { + name: "slack".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "read-write".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "slack.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.request_body_credential_rewrite); + } + #[test] fn l7_endpoint_config_none_for_l4_only() { let engine = l7_engine(); @@ -3739,6 +3991,69 @@ network_policies: assert!(!decision.allowed, "Wildcard host on wrong port should deny"); } + #[test] + fn wildcard_host_intra_label_matches() { + // First-label intra-label wildcard: `*` matches the variable prefix + // within a single DNS label. Locks validator/runtime alignment for + // the pattern accepted by `validate_host_wildcard`. + let data = r#" +network_policies: + intra_label: + name: intra_label + endpoints: + - { host: "*-aiplatform.googleapis.com", port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "us-central1-aiplatform.googleapis.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "*-aiplatform.googleapis.com should match us-central1-aiplatform.googleapis.com: {}", + decision.reason + ); + } + + #[test] + fn wildcard_host_intra_label_does_not_cross_dot() { + // `glob.match(..., ["."])` treats `.` as a label boundary that `*` + // cannot cross. `*-aiplatform.googleapis.com` must not match a host + // whose first label is `us-central1` and where `aiplatform` is a + // separate label. + let data = r#" +network_policies: + intra_label: + name: intra_label + endpoints: + - { host: "*-aiplatform.googleapis.com", port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "us-central1.aiplatform.googleapis.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + !decision.allowed, + "*-aiplatform.googleapis.com must NOT match us-central1.aiplatform.googleapis.com \ + (would cross a `.` boundary)" + ); + } + #[test] fn wildcard_host_multi_port() { let data = r#" @@ -4553,4 +4868,41 @@ network_policies: decision.reason ); } + + #[test] + fn l7_head_allowed_where_get_is_allowed() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "HEAD", "/repos/myorg/foo"); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_head_denied_when_only_post_allowed() { + let engine = OpaEngine::from_strings( + TEST_POLICY, + "network_policies:\n p:\n name: p\n endpoints:\n - host: h.test\n port: 80\n protocol: rest\n enforcement: enforce\n rules:\n - allow: {method: POST, path: \"/\"}\n binaries:\n - {path: /usr/bin/curl}\n", + ) + .unwrap(); + let input = l7_input("h.test", 80, "HEAD", "/"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_options_not_implicitly_allowed_by_get() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "OPTIONS", "/repos/myorg/foo"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_head_blocked_by_deny_rule_targeting_get() { + // deny_rules use method_matches() too; a deny on GET must also block HEAD. + let engine = OpaEngine::from_strings( + TEST_POLICY, + "network_policies:\n p:\n name: p\n endpoints:\n - host: h.test\n port: 80\n protocol: rest\n enforcement: enforce\n access: full\n deny_rules:\n - method: GET\n path: \"/protected\"\n binaries:\n - {path: /usr/bin/curl}\n", + ) + .unwrap(); + let input = l7_input("h.test", 80, "HEAD", "/protected"); + assert!(!eval_l7(&engine, &input)); + } } diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs new file mode 100644 index 000000000..657fd760f --- /dev/null +++ b/crates/openshell-sandbox/src/policy_local.rs @@ -0,0 +1,1981 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Sandbox-local policy advisor HTTP API. + +use miette::{IntoDiagnostic, Result}; +use openshell_core::proto::{ + L7Allow, L7DenyRule, L7Rule, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, PolicyChunk, + SandboxPolicy as ProtoSandboxPolicy, +}; +use openshell_ocsf::{ConfigStateChangeBuilder, SeverityId, StateId, StatusId, ocsf_emit}; +use serde::Deserialize; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::sync::RwLock; + +pub const POLICY_LOCAL_HOST: &str = "policy.local"; + +/// Filesystem path of the static agent guidance bundle inside the sandbox. +/// Single source of truth: the skill installer writes here, the L7 deny body +/// references this path in `next_steps`, and the skill's own documentation +/// renders the same path. Changing the location is a one-line update here. +pub const SKILL_PATH: &str = "/etc/openshell/skills/policy_advisor.md"; + +/// Routes served by the in-sandbox policy advisor API. Held in one place so +/// the L7 deny `next_steps` array, the route dispatcher, the skill content, +/// and tests all stay in sync — change the wire path here and every caller +/// follows. See `agent_next_steps()` for the consumer that surfaces these +/// to the agent on a 403. +pub const ROUTE_POLICY_CURRENT: &str = "/v1/policy/current"; +pub const ROUTE_DENIALS: &str = "/v1/denials"; +pub const ROUTE_PROPOSALS: &str = "/v1/proposals"; +/// Per-proposal status and long-poll routes live below this prefix: +/// `GET /v1/proposals/{chunk_id}` — immediate status +/// `GET /v1/proposals/{chunk_id}/wait?timeout` — long-poll until terminal +/// Trailing slash differentiates from the bare `POST /v1/proposals` submit. +const ROUTE_PROPOSALS_PREFIX: &str = "/v1/proposals/"; + +/// Long-poll bounds for `GET /v1/proposals/{id}/wait?timeout=`. The agent +/// re-issues on timeout, so the cap is a hold ceiling, not a hard limit on +/// how long the agent can wait overall. +const PROPOSAL_WAIT_DEFAULT_SECS: u64 = 60; +const PROPOSAL_WAIT_MIN_SECS: u64 = 1; +const PROPOSAL_WAIT_MAX_SECS: u64 = 300; +const PROPOSAL_WAIT_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1); +/// Minimum window the reload-readiness phase gets after a chunk +/// terminalizes, even if the caller's deadline is shorter. Without this, +/// approvals that arrive at T-50ms always return `policy_reloaded=false` +/// and force a re-issue. 500ms is well below typical supervisor poll +/// latency but enough to cover the in-memory coverage check. +const RELOAD_WAIT_MIN_FLOOR: std::time::Duration = std::time::Duration::from_millis(500); + +const MAX_POLICY_LOCAL_BODY_BYTES: usize = 64 * 1024; +/// Hard ceiling on how long a single request body read can stall. Bounds a +/// slowloris-style upload from an in-sandbox process; the proxy listener only +/// accepts loopback connections, so practical impact is limited, but this is +/// cheap defense-in-depth. +const POLICY_LOCAL_BODY_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15); +const DEFAULT_DENIALS_LIMIT: usize = 10; +const MAX_DENIALS_LIMIT: usize = 100; +/// The shorthand rolling appender keeps three files (daily rotation); read the +/// most recent two so a request just past midnight still has yesterday's +/// denials. +const DENIAL_LOG_FILES_TO_SCAN: usize = 2; +const LOG_DIR: &str = "/var/log"; +/// Shorthand log filenames are `openshell.YYYY-MM-DD.log`. The trailing dot in +/// the prefix is intentional: it disambiguates from the OCSF JSONL appender's +/// `openshell-ocsf.YYYY-MM-DD.log`, which we never want to surface here (the +/// JSONL is opt-in via `ocsf_json_enabled` and not the source of truth for +/// `/v1/denials`). +const SHORTHAND_LOG_PREFIX: &str = "openshell."; +/// Defensive cap on per-line length returned to the agent so a pathological +/// log entry (very long URL path, etc.) cannot blow up the response. +const MAX_DENIAL_LINE_BYTES: usize = 4096; + +#[derive(Debug)] +pub struct PolicyLocalContext { + current_policy: Arc>>, + gateway_endpoint: Option, + sandbox_name: Option, + shorthand_log_dir: PathBuf, +} + +impl PolicyLocalContext { + pub fn new( + current_policy: Option, + gateway_endpoint: Option, + sandbox_name: Option, + ) -> Self { + Self::with_log_dir( + current_policy, + gateway_endpoint, + sandbox_name, + PathBuf::from(LOG_DIR), + ) + } + + fn with_log_dir( + current_policy: Option, + gateway_endpoint: Option, + sandbox_name: Option, + shorthand_log_dir: PathBuf, + ) -> Self { + Self { + current_policy: Arc::new(RwLock::new(current_policy)), + gateway_endpoint, + sandbox_name, + shorthand_log_dir, + } + } + + pub async fn set_current_policy(&self, policy: ProtoSandboxPolicy) { + *self.current_policy.write().await = Some(policy); + } +} + +pub async fn handle_forward_request( + ctx: &PolicyLocalContext, + method: &str, + path: &str, + initial_request: &[u8], + client: &mut S, +) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let body = read_request_body(initial_request, client).await?; + let (status, payload) = route_request(ctx, method, path, &body).await; + write_json_response(client, status, payload).await +} + +async fn route_request( + ctx: &PolicyLocalContext, + method: &str, + path: &str, + body: &[u8], +) -> (u16, serde_json::Value) { + let (route, query) = path.split_once('?').map_or((path, ""), |(r, q)| (r, q)); + // Gate every route on the feature flag so the agent surface is fully off + // when the flag is off — including the diagnostic `current_policy` and + // `denials` routes. The skill is also not installed in that mode, so a + // disabled sandbox has no entry point into this API at all. + if !crate::agent_proposals_enabled() { + return ( + 404, + serde_json::json!({ + "error": "feature_disabled", + "detail": "agent-driven policy proposals are not enabled in this sandbox; set the `agent_policy_proposals_enabled` setting to true to enable" + }), + ); + } + match (method, route) { + ("GET", ROUTE_POLICY_CURRENT) => current_policy_response(ctx).await, + ("GET", ROUTE_DENIALS) => recent_denials_response(ctx, query).await, + ("POST", ROUTE_PROPOSALS) => submit_proposal(ctx, body).await, + ("GET", path) if path.starts_with(ROUTE_PROPOSALS_PREFIX) => { + proposal_state_route(ctx, path, query).await + } + _ => ( + 404, + serde_json::json!({ + "error": "not_found", + "detail": format!("policy.local route not found: {method} {route}") + }), + ), + } +} + +/// Parse `{chunk_id}` (status) or `{chunk_id}/wait` (long-poll) from the path +/// suffix and dispatch. Empty `chunk_id` or extra segments return 404 so a +/// malformed path cannot trigger a gateway call. +async fn proposal_state_route( + ctx: &PolicyLocalContext, + path: &str, + query: &str, +) -> (u16, serde_json::Value) { + let suffix = path + .strip_prefix(ROUTE_PROPOSALS_PREFIX) + .unwrap_or_default(); + let (chunk_id, wait) = match suffix.split_once('/') { + Some((id, "wait")) => (id, true), + Some(_) => return not_found_payload(path), + None => (suffix, false), + }; + if chunk_id.is_empty() { + return not_found_payload(path); + } + if wait { + proposal_wait_response(ctx, chunk_id, query).await + } else { + proposal_status_response(ctx, chunk_id).await + } +} + +fn not_found_payload(path: &str) -> (u16, serde_json::Value) { + ( + 404, + serde_json::json!({ + "error": "not_found", + "detail": format!("policy.local proposal sub-route not found: {path}") + }), + ) +} + +/// Build the `next_steps` array embedded in the L7 deny body so the agent has +/// machine-readable pointers to this API. Centralizes the shape here to keep +/// the deny body and the actual route table from drifting — adding or +/// renaming a route only requires touching the route constants above. +/// +/// Returns an empty array when `agent_proposals_enabled()` is false so a +/// disabled sandbox doesn't advertise a surface that 404s. The deny body +/// caller still emits the field (with `[]`) so the wire shape is stable. +#[must_use] +pub fn agent_next_steps() -> serde_json::Value { + if !crate::agent_proposals_enabled() { + return serde_json::json!([]); + } + let host = POLICY_LOCAL_HOST; + serde_json::json!([ + { + "action": "read_skill", + "path": SKILL_PATH, + }, + { + "action": "inspect_policy", + "method": "GET", + "url": format!("http://{host}{ROUTE_POLICY_CURRENT}"), + }, + { + "action": "inspect_recent_denials", + "method": "GET", + "url": format!("http://{host}{ROUTE_DENIALS}?last=5"), + }, + { + "action": "submit_proposal", + "method": "POST", + "url": format!("http://{host}{ROUTE_PROPOSALS}"), + "body_type": "PolicyMergeOperation", + }, + ]) +} + +async fn current_policy_response(ctx: &PolicyLocalContext) -> (u16, serde_json::Value) { + let Some(policy) = ctx.current_policy.read().await.clone() else { + return ( + 404, + serde_json::json!({ + "error": "policy_unavailable", + "detail": "no current sandbox policy is loaded" + }), + ); + }; + + match openshell_policy::serialize_sandbox_policy(&policy) { + Ok(policy_yaml) => ( + 200, + serde_json::json!({ + "format": "yaml", + "policy_yaml": policy_yaml + }), + ), + Err(error) => ( + 500, + serde_json::json!({ + "error": "policy_serialize_failed", + "detail": error.to_string() + }), + ), + } +} + +async fn recent_denials_response( + ctx: &PolicyLocalContext, + query: &str, +) -> (u16, serde_json::Value) { + let limit = parse_last_query(query).unwrap_or(DEFAULT_DENIALS_LIMIT); + let log_dir = ctx.shorthand_log_dir.clone(); + + // Distinguish "shorthand log exists and no denials happened" from "no log + // file yet, so we have nothing to read." Without this flag the agent sees + // `[]` in both cases and cannot tell the difference. The shorthand log is + // always-on (no setting gates it), so the only way `log_available=false` + // happens in practice is if the supervisor has not flushed any events to + // disk yet, or `/var/log` is not writable in this image. + let log_available = matches!( + collect_shorthand_log_files(&log_dir, 1), + Ok(files) if !files.is_empty() + ); + + let denials = tokio::task::spawn_blocking(move || read_recent_denial_lines(&log_dir, limit)) + .await + .unwrap_or_default(); + + let mut payload = serde_json::json!({ + "denials": denials, + "log_available": log_available, + }); + if !log_available { + payload["note"] = serde_json::json!( + "no shorthand log file is present yet at /var/log/openshell.YYYY-MM-DD.log; the supervisor may not have emitted any events to disk yet" + ); + } + + (200, payload) +} + +fn parse_last_query(query: &str) -> Option { + if query.is_empty() { + return None; + } + for pair in query.split('&') { + let Some((key, value)) = pair.split_once('=') else { + continue; + }; + if key == "last" { + return value + .parse::() + .ok() + .map(|n| n.clamp(1, MAX_DENIALS_LIMIT)); + } + } + None +} + +/// Walk the shorthand log files (most-recent first) and return up to `limit` +/// raw denial lines in newest-first order. The agent receives the same +/// human-readable text that `openshell logs` displays — no parsing back into +/// structured form. Updating the shorthand format adds fields automatically; +/// no schema rev required. +/// +/// Reads files synchronously and is intended to run inside `spawn_blocking`. +fn read_recent_denial_lines(log_dir: &Path, limit: usize) -> Vec { + let Ok(files) = collect_shorthand_log_files(log_dir, DENIAL_LOG_FILES_TO_SCAN) else { + return Vec::new(); + }; + + let mut lines: Vec = Vec::with_capacity(limit); + for path in files { + let Ok(contents) = std::fs::read_to_string(&path) else { + continue; + }; + // Walk lines newest-first. Within a single file, the last line written + // is the freshest event. + for line in contents.lines().rev() { + if !is_ocsf_denial_line(line) { + continue; + } + // Defense-in-depth: redact query strings before truncation. The + // FORWARD deny path in `proxy.rs` populates the OCSF `message` + // and URL with the raw request path including `?query=...`, which + // the shorthand layer then renders verbatim. Stripping queries + // here means the agent never sees the secret even if an upstream + // emit site forgets to redact (TODO: harden the emit sites in + // proxy.rs FORWARD path so the on-disk shorthand log itself is + // clean — tracked separately). Redact first so truncation cannot + // slice mid-secret. + let redacted = redact_query_strings(line); + let surfaced = truncate_at_char_boundary(&redacted, MAX_DENIAL_LINE_BYTES); + lines.push(surfaced); + if lines.len() >= limit { + return lines; + } + } + } + lines +} + +/// Replace any `?` substring with `?[redacted]` to keep query-string +/// secrets out of the agent's view. Walks per Unicode scalar value so multi-byte +/// content is safe. A query is everything from `?` until the next whitespace or +/// `]` (the shorthand format uses `[...]` for context tags). +fn redact_query_strings(line: &str) -> String { + let mut out = String::with_capacity(line.len()); + let mut chars = line.chars(); + while let Some(c) = chars.next() { + if c == '?' { + out.push('?'); + out.push_str("[redacted]"); + // Consume until whitespace or `]` (preserved as the next token's + // boundary by writing it back out). + for next in chars.by_ref() { + if next.is_whitespace() || next == ']' { + out.push(next); + break; + } + } + } else { + out.push(c); + } + } + out +} + +/// Truncate `s` at the largest UTF-8 char boundary <= `max_bytes`, appending a +/// `...[truncated]` suffix. Returning a `String` (not `&str`) avoids surprising +/// callers about lifetime relationships with `s`. +fn truncate_at_char_boundary(s: &str, max_bytes: usize) -> String { + if s.len() <= max_bytes { + return s.to_string(); + } + let mut end = max_bytes; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + let mut out = String::with_capacity(end + "...[truncated]".len()); + out.push_str(&s[..end]); + out.push_str("...[truncated]"); + out +} + +/// True for OCSF denial events as rendered by the shorthand layer. The format +/// is ` OCSF <[SEV]> ...`. The literal +/// ` OCSF ` substring identifies an OCSF event (vs. a non-OCSF tracing line); +/// ` DENIED ` is the OCSF action label uppercased and surrounded by spaces, so +/// matching it is safe against substring collisions in URLs or hostnames. +fn is_ocsf_denial_line(line: &str) -> bool { + line.contains(" OCSF ") && line.contains(" DENIED ") +} + +fn collect_shorthand_log_files(log_dir: &Path, max_files: usize) -> std::io::Result> { + let mut entries: Vec<(std::time::SystemTime, PathBuf)> = std::fs::read_dir(log_dir)? + .filter_map(std::result::Result::ok) + .filter_map(|entry| { + let path = entry.path(); + let name = entry.file_name(); + let name = name.to_string_lossy(); + // `openshell.YYYY-MM-DD.log` only — the trailing dot in the prefix + // disambiguates from `openshell-ocsf.YYYY-MM-DD.log`. + if !name.starts_with(SHORTHAND_LOG_PREFIX) || !name.ends_with(".log") { + return None; + } + let modified = entry.metadata().and_then(|m| m.modified()).ok()?; + Some((modified, path)) + }) + .collect(); + + entries.sort_by_key(|entry| std::cmp::Reverse(entry.0)); + Ok(entries + .into_iter() + .take(max_files) + .map(|(_, p)| p) + .collect()) +} + +async fn submit_proposal(ctx: &PolicyLocalContext, body: &[u8]) -> (u16, serde_json::Value) { + let Some(endpoint) = ctx.gateway_endpoint.as_deref() else { + return ( + 503, + serde_json::json!({ + "error": "gateway_unavailable", + "detail": "policy proposal submission requires a gateway-connected sandbox" + }), + ); + }; + let Some(sandbox_name) = ctx + .sandbox_name + .as_deref() + .map(str::trim) + .filter(|name| !name.is_empty()) + else { + return ( + 503, + serde_json::json!({ + "error": "sandbox_name_unavailable", + "detail": "policy proposal submission requires a sandbox name" + }), + ); + }; + + let chunks = match proposal_chunks_from_body(body) { + Ok(chunks) => chunks, + Err(error) => return (400, error_payload("invalid_proposal", error)), + }; + + let client = match crate::grpc_client::CachedOpenShellClient::connect(endpoint).await { + Ok(client) => client, + Err(error) => { + return ( + 502, + serde_json::json!({ + "error": "gateway_connect_failed", + "detail": error.to_string() + }), + ); + } + }; + + // Pre-compute the audit summaries before handing `chunks` to the + // gateway client (which consumes the vec). The summaries pair up with + // the gateway's `accepted_chunk_ids` by index for the propose events + // emitted after submit returns. + let audit_summaries: Vec = chunks.iter().map(summarize_chunk_for_audit).collect(); + + let response = match client + .submit_policy_analysis(sandbox_name, vec![], chunks, "agent_authored") + .await + { + Ok(response) => response, + Err(error) => { + return ( + 502, + serde_json::json!({ + "error": "proposal_submit_failed", + "detail": error.to_string() + }), + ); + } + }; + + // One OCSF event per accepted chunk so the audit trace in + // `openshell logs ` carries the propose beat alongside the + // proxy deny and policy reload that bracket it. + // + // The gateway compresses its `accepted_chunk_ids` by skipping rejected + // chunks (`grpc/policy.rs:1357-1436`); the proto does not promise 1:1 + // ordering against the request. Today client-side validation catches + // both rejection causes (missing rule_name, missing proposed_rule) + // before submit, so the lengths match in practice. If they don't, we + // can't safely pair audit_summaries by index — fall back to a generic + // event per accepted chunk_id rather than mis-attribute a summary. + let pairing_is_safe = response.accepted_chunk_ids.len() == audit_summaries.len(); + for (idx, chunk_id) in response.accepted_chunk_ids.iter().enumerate() { + let summary = if pairing_is_safe { + audit_summaries[idx].as_str() + } else { + "(summary unavailable: gateway partially accepted)" + }; + emit_policy_propose_event(chunk_id, summary); + } + + ( + 202, + serde_json::json!({ + "status": "submitted", + "accepted_chunks": response.accepted_chunks, + "rejected_chunks": response.rejected_chunks, + "rejection_reasons": response.rejection_reasons, + "accepted_chunk_ids": response.accepted_chunk_ids, + }), + ) +} + +/// Emit one CONFIG:PROPOSED audit event for an agent-authored proposal that +/// the gateway just accepted. The message names the `chunk_id`, the binary, +/// and the endpoint the agent is asking to reach — what a developer needs +/// to see in the audit trace to correlate against the inbox card. +fn emit_policy_propose_event(chunk_id: &str, summary: &str) { + ocsf_emit!( + ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Other, "PROPOSED") + .unmapped("chunk_id", serde_json::json!(chunk_id)) + .message(format!( + "agent_authored proposal chunk:{chunk_id} {summary}" + )) + .build() + ); +} + +/// Emit one CONFIG:APPROVED or CONFIG:REJECTED audit event observed by the +/// `/wait` poll loop. The reviewer's free-form `rejection_reason` (if any) +/// is included verbatim so the audit trace shows what guidance the agent +/// received. +fn emit_policy_decision_event(chunk: &PolicyChunk) { + let summary = summarize_chunk_for_audit(chunk); + match chunk.status.as_str() { + "approved" => ocsf_emit!( + ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "APPROVED") + .unmapped("chunk_id", serde_json::json!(chunk.id)) + .message(format!("chunk:{} approved {summary}", chunk.id)) + .build() + ), + "rejected" => { + // The reviewer's free-form rejection_reason is opaque user + // input. The agent reads the raw text via `GET /v1/proposals/ + // {id}` to redraft; the OCSF surface (which can be shipped to + // external SIEMs per AGENTS.md) gets a sanitized copy — caps + // length and strips control characters so a stray credential + // or escape sequence cannot leak into the audit log. + let sanitized = sanitize_reason_for_audit(&chunk.rejection_reason); + let reason_display = if sanitized.is_empty() { + "(no guidance)".to_string() + } else { + format!("\"{sanitized}\"") + }; + ocsf_emit!( + ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(SeverityId::Low) + .status(StatusId::Success) + .state(StateId::Disabled, "REJECTED") + .unmapped("chunk_id", serde_json::json!(chunk.id)) + .unmapped("rejection_reason", serde_json::json!(sanitized)) + .message(format!( + "chunk:{} rejected {summary} reason:{reason_display}", + chunk.id + )) + .build() + ); + } + // Caller is gated on `is_terminal_status`, so a non-terminal status + // here is a code change that broke the invariant. Warn loudly so + // the audit gap doesn't go silent. + other => tracing::warn!( + chunk_id = %chunk.id, + status = %other, + "emit_policy_decision_event called on non-terminal status; no audit event emitted" + ), + } +} + +/// Sanitize a free-form reviewer-typed string before it lands in the OCSF +/// audit surface. The agent still reads the raw text via the API — this is +/// audit-side defense only. +fn sanitize_reason_for_audit(raw: &str) -> String { + const MAX_CHARS: usize = 200; + let cleaned: String = raw + .chars() + .filter(|c| !c.is_control() || *c == ' ') + .take(MAX_CHARS) + .collect(); + if raw.chars().count() > MAX_CHARS { + format!("{cleaned}…") + } else { + cleaned + } +} + +/// One-line audit description of a chunk's target: binary, host, port, and +/// L7 method/path if present. Used by both the propose and approve/reject +/// audit events so the trace can be grepped by endpoint without parsing +/// JSON. +fn summarize_chunk_for_audit(chunk: &PolicyChunk) -> String { + let Some(rule) = chunk.proposed_rule.as_ref() else { + return format!("rule_name:{}", chunk.rule_name); + }; + let endpoint = rule.endpoints.first().map_or_else( + || "unknown".to_string(), + |ep| format!("{}:{}", ep.host, ep.port), + ); + let l7 = rule + .endpoints + .first() + .and_then(|ep| ep.rules.first()) + .and_then(|r| r.allow.as_ref()) + .map(|a| format!(" {} {}", a.method, a.path)) + .unwrap_or_default(); + let binary = if chunk.binary.is_empty() { + String::new() + } else { + format!(" by {}", chunk.binary) + }; + format!("on {endpoint}{l7}{binary}") +} + +/// `GET /v1/proposals/{chunk_id}` — immediate state. One gateway call, no loop. +async fn proposal_status_response( + ctx: &PolicyLocalContext, + chunk_id: &str, +) -> (u16, serde_json::Value) { + let session = match open_lookup_session(ctx).await { + Ok(session) => session, + Err(err) => return err, + }; + fetch_chunk_or_404(&session, chunk_id, false).await +} + +/// `GET /v1/proposals/{chunk_id}/wait?timeout=` — block until terminal or +/// timeout. Returns the chunk's current state on a status transition; on +/// timeout, returns the still-pending state with `timed_out: true` so the +/// agent can re-issue without ambiguity. The agent's wait costs zero LLM +/// tokens — the tool call sits in a socket recv until we return. +async fn proposal_wait_response( + ctx: &PolicyLocalContext, + chunk_id: &str, + query: &str, +) -> (u16, serde_json::Value) { + let session = match open_lookup_session(ctx).await { + Ok(session) => session, + Err(err) => return err, + }; + let timeout_secs = parse_timeout_query(query); + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(timeout_secs); + loop { + match fetch_chunk(&session, chunk_id).await { + Ok(Some(chunk)) if is_terminal_status(&chunk.status) => { + // Audit beat: emit at the moment this sandbox observes the + // decision so the trace correlates with the proxy events + // bracketing the loop. Multiple waiters on the same chunk + // each fire one event — acceptable for a wakeup audit. + emit_policy_decision_event(&chunk); + let policy_reloaded = if chunk.status == "approved" { + // Hold the wait until the local supervisor has loaded a + // policy that semantically contains this chunk's + // proposed rule. Reloads triggered by *other* chunks or + // settings changes do not wake us; a missing + // proposed_rule (defensive) skips the check and + // returns reloaded=false so the agent can decide. + // + // Floor the reload-wait window to RELOAD_WAIT_MIN_FLOOR + // so an approval that arrives at T-50ms still gets a + // realistic shot at seeing the reload. Worst case we + // overshoot the caller's deadline by this floor — + // preferable to returning reloaded=false on every + // short-budget call and forcing the agent to re-issue. + let reload_deadline = std::cmp::max( + deadline, + tokio::time::Instant::now() + RELOAD_WAIT_MIN_FLOOR, + ); + match chunk.proposed_rule.as_ref() { + Some(rule) => { + wait_for_local_policy_to_cover(ctx, rule, reload_deadline).await + } + None => false, + } + } else { + // Rejected: no reload semantics — the agent reads + // rejection_reason and redrafts. + false + }; + return (200, chunk_state_payload(&chunk, false, policy_reloaded)); + } + Ok(Some(chunk)) => { + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + return (200, chunk_state_payload(&chunk, true, false)); + } + let sleep_for = std::cmp::min(remaining, PROPOSAL_WAIT_POLL_INTERVAL); + tokio::time::sleep(sleep_for).await; + } + Ok(None) => return chunk_not_found_payload(chunk_id), + Err(err) => return err, + } + } +} + +fn chunk_not_found_payload(chunk_id: &str) -> (u16, serde_json::Value) { + ( + 404, + error_payload( + "chunk_not_found", + format!("chunk '{chunk_id}' is not present in this sandbox's draft policy"), + ), + ) +} + +async fn fetch_chunk_or_404( + session: &LookupSession<'_>, + chunk_id: &str, + timed_out: bool, +) -> (u16, serde_json::Value) { + match fetch_chunk(session, chunk_id).await { + Ok(Some(chunk)) => (200, chunk_state_payload(&chunk, timed_out, false)), + Ok(None) => chunk_not_found_payload(chunk_id), + Err(err) => err, + } +} + +/// Build the agent-facing response for a chunk. +/// +/// Selection rule: include the fields the agent needs to decide what to do +/// next on the redraft loop — identity (`chunk_id`, `status`), the proposal +/// it submitted (`rule_name`, `binary`), the two feedback signals +/// (`rejection_reason` from the reviewer, `validation_result` from the +/// gateway prover), and (on /wait) `policy_reloaded` so the agent can tell +/// "approved AND the new rule is loaded — safe to retry" from "approved +/// but the supervisor hasn't reloaded yet — re-issue /wait or surface to +/// user". Display-only proto fields (`hit_count`, `confidence`, `stage`, +/// timing) are left off until a concrete agent need surfaces them. +fn chunk_state_payload( + chunk: &PolicyChunk, + timed_out: bool, + policy_reloaded: bool, +) -> serde_json::Value { + let mut payload = serde_json::json!({ + "chunk_id": chunk.id, + "status": chunk.status, + "rule_name": chunk.rule_name, + "binary": chunk.binary, + "rejection_reason": chunk.rejection_reason, + "validation_result": chunk.validation_result, + }); + if timed_out { + payload["timed_out"] = serde_json::json!(true); + } + if chunk.status == "approved" { + payload["policy_reloaded"] = serde_json::json!(policy_reloaded); + } + payload +} + +fn is_terminal_status(status: &str) -> bool { + matches!(status, "approved" | "rejected") +} + +/// After a chunk is approved upstream, wait until the local supervisor has +/// loaded a policy that semantically contains the chunk's proposed rule. +/// Returns `true` if coverage was observed before the deadline, `false` +/// otherwise — the caller reports that bool back to the agent as +/// `policy_reloaded` so it can decide whether to retry immediately or +/// re-issue `/wait`. +/// +/// Why rule-coverage instead of whole-policy diff (as we used to do): +/// +/// 1. **False sleep.** If the agent re-issues `/wait` after a `timed_out` +/// response, the chunk may have approved AND the supervisor may have +/// reloaded between the two `/wait` calls. A diff-based check snapshots +/// the already-updated policy as baseline and then waits forever for +/// another change. The skill tells the agent to re-issue on +/// `timed_out`, so the diff approach is broken on the happy path. +/// 2. **False wakeup.** Any unrelated reload (another agent's approval, +/// settings change) flips a whole-policy diff, but the chunk's actual +/// rule may not be loaded yet. The agent retries, hits another +/// `policy_denied`, and the revise-loop fires with no real signal to +/// revise on. +/// +/// The polling cadence here is faster than `PROPOSAL_WAIT_POLL_INTERVAL` +/// (which paces upstream gateway calls). This loop only reads in-memory +/// state, so 200ms gives a responsive handoff to the agent's retry once +/// the supervisor's own policy poll catches up. +async fn wait_for_local_policy_to_cover( + ctx: &PolicyLocalContext, + proposed_rule: &NetworkPolicyRule, + deadline: tokio::time::Instant, +) -> bool { + const TICK: std::time::Duration = std::time::Duration::from_millis(200); + loop { + // Clone the snapshot out of the RwLock before running coverage — + // otherwise the read guard is held across `policy_covers_rule`'s + // iteration of `network_policies`, serializing a writer (supervisor + // reload) on the very thing we're waiting for. Clone-per-tick on + // a few-KB struct is cheap for the bounded wait window here. + let snapshot = ctx.current_policy.read().await.clone(); + if let Some(policy) = snapshot.as_ref() + && openshell_policy::policy_covers_rule(policy, proposed_rule) + { + return true; + } + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + return false; + } + tokio::time::sleep(std::cmp::min(remaining, TICK)).await; + } +} + +/// Parse `?timeout=` from the query string. Default applies for missing +/// or unparseable values; bounds clamp to keep the agent's hold ceiling +/// sane. Re-issue is the right pattern for longer waits. +fn parse_timeout_query(query: &str) -> u64 { + let raw = query + .split('&') + .filter_map(|kv| kv.split_once('=')) + .find(|(k, _)| *k == "timeout") + .map_or("", |(_, v)| v); + raw.parse::() + .unwrap_or(PROPOSAL_WAIT_DEFAULT_SECS) + .clamp(PROPOSAL_WAIT_MIN_SECS, PROPOSAL_WAIT_MAX_SECS) +} + +/// One connected gateway client + the validated sandbox name. Built once +/// per request and reused for every `fetch_chunk` call in a wait loop so a +/// 60-second wait does one TLS handshake, not sixty. +struct LookupSession<'a> { + client: crate::grpc_client::CachedOpenShellClient, + sandbox_name: &'a str, +} + +/// Validate ctx and open one gateway channel. Failures map to the canonical +/// error payload shape used by both `/proposals/{id}` and `/wait`. +async fn open_lookup_session( + ctx: &PolicyLocalContext, +) -> std::result::Result, (u16, serde_json::Value)> { + let endpoint = ctx.gateway_endpoint.as_deref().ok_or_else(|| { + ( + 503, + error_payload( + "gateway_unavailable", + "proposal state lookup requires a gateway-connected sandbox".to_string(), + ), + ) + })?; + let sandbox_name = ctx + .sandbox_name + .as_deref() + .map(str::trim) + .filter(|name| !name.is_empty()) + .ok_or_else(|| { + ( + 503, + error_payload( + "sandbox_name_unavailable", + "proposal state lookup requires a sandbox name".to_string(), + ), + ) + })?; + let client = crate::grpc_client::CachedOpenShellClient::connect(endpoint) + .await + .map_err(|e| (502, error_payload("gateway_connect_failed", e.to_string())))?; + Ok(LookupSession { + client, + sandbox_name, + }) +} + +/// One gateway call: list the sandbox's draft chunks and find the matching +/// id. Returns `Ok(None)` only when the gateway responded successfully but +/// no chunk in this sandbox matches. +async fn fetch_chunk( + session: &LookupSession<'_>, + chunk_id: &str, +) -> std::result::Result, (u16, serde_json::Value)> { + let chunks = session + .client + .get_draft_policy(session.sandbox_name, "") + .await + .map_err(|e| (502, error_payload("gateway_lookup_failed", e.to_string())))?; + Ok(chunks.into_iter().find(|c| c.id == chunk_id)) +} + +fn proposal_chunks_from_body(body: &[u8]) -> std::result::Result, String> { + let request: ProposalRequest = serde_json::from_slice(body).map_err(|e| e.to_string())?; + if request.operations.is_empty() { + return Err("proposal requires at least one operation".to_string()); + } + + let mut chunks = Vec::new(); + for operation in request.operations { + let Some(add_rule) = operation.get("addRule").cloned() else { + return Err( + "this MVP accepts `addRule` operations; submit a full narrow NetworkPolicyRule" + .to_string(), + ); + }; + let add_rule: AddNetworkRuleJson = + serde_json::from_value(add_rule).map_err(|e| e.to_string())?; + chunks.push(policy_chunk_from_add_rule( + add_rule, + request.intent_summary.as_deref().unwrap_or_default(), + )?); + } + + Ok(chunks) +} + +fn policy_chunk_from_add_rule( + add_rule: AddNetworkRuleJson, + intent_summary: &str, +) -> std::result::Result { + let mut rule = network_rule_from_json(add_rule.rule)?; + let rule_name = add_rule + .rule_name + .as_deref() + .map(str::trim) + .filter(|name| !name.is_empty()) + .map_or_else(|| rule.name.clone(), ToString::to_string); + if rule_name.trim().is_empty() { + return Err("addRule.ruleName or rule.name is required".to_string()); + } + if rule.name.trim().is_empty() { + rule.name.clone_from(&rule_name); + } + + let binary = rule + .binaries + .first() + .map(|binary| binary.path.clone()) + .unwrap_or_default(); + + Ok(PolicyChunk { + id: String::new(), + status: "pending".to_string(), + rule_name, + proposed_rule: Some(rule), + rationale: intent_summary.to_string(), + security_notes: String::new(), + confidence: 0.75, + denial_summary_ids: vec![], + created_at_ms: 0, + decided_at_ms: 0, + stage: "agent".to_string(), + supersedes_chunk_id: String::new(), + hit_count: 1, + first_seen_ms: 0, + last_seen_ms: 0, + binary, + validation_result: String::new(), + rejection_reason: String::new(), + }) +} + +fn network_rule_from_json( + rule: NetworkPolicyRuleJson, +) -> std::result::Result { + if rule.endpoints.is_empty() { + return Err("rule.endpoints must contain at least one endpoint".to_string()); + } + + let endpoints = rule + .endpoints + .into_iter() + .map(network_endpoint_from_json) + .collect::, _>>()?; + let binaries = rule + .binaries + .into_iter() + .map(|binary| NetworkBinary { + path: binary.path, + ..Default::default() + }) + .collect(); + + Ok(NetworkPolicyRule { + name: rule.name.unwrap_or_default(), + endpoints, + binaries, + }) +} + +fn network_endpoint_from_json( + endpoint: NetworkEndpointJson, +) -> std::result::Result { + if endpoint.host.trim().is_empty() { + return Err("endpoint.host is required".to_string()); + } + + let mut ports = endpoint.ports; + if ports.is_empty() && endpoint.port > 0 { + ports.push(endpoint.port); + } + if ports.is_empty() { + return Err("endpoint.port or endpoint.ports is required".to_string()); + } + if endpoint + .rules + .iter() + .any(|rule| rule.allow.path.contains('?')) + { + return Err("L7 allow paths must not include query strings".to_string()); + } + + let port = ports.first().copied().unwrap_or_default(); + let rules = endpoint + .rules + .into_iter() + .map(|rule| L7Rule { + allow: Some(L7Allow { + method: rule.allow.method, + path: rule.allow.path, + command: rule.allow.command, + query: HashMap::new(), + // GraphQL fields default empty — agent-authored proposals from + // policy.local target REST/SQL/L4 endpoints; GraphQL operation + // matching is set on the policy server side or via direct YAML. + operation_type: String::new(), + operation_name: String::new(), + fields: Vec::new(), + }), + }) + .collect(); + let deny_rules = endpoint + .deny_rules + .into_iter() + .map(|rule| L7DenyRule { + method: rule.method, + path: rule.path, + command: rule.command, + query: HashMap::new(), + operation_type: String::new(), + operation_name: String::new(), + fields: Vec::new(), + }) + .collect(); + + Ok(NetworkEndpoint { + host: endpoint.host, + port, + protocol: endpoint.protocol, + tls: endpoint.tls, + enforcement: endpoint.enforcement, + access: endpoint.access, + rules, + allowed_ips: endpoint.allowed_ips, + ports, + deny_rules, + allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + // GraphQL persisted-query knobs and path scoping default empty — + // agent proposals don't author them today. + persisted_queries: String::new(), + graphql_persisted_queries: HashMap::new(), + graphql_max_body_bytes: 0, + path: String::new(), + }) +} + +async fn read_request_body(initial_request: &[u8], client: &mut S) -> Result> +where + S: AsyncRead + Unpin, +{ + let Some(header_end) = find_header_end(initial_request) else { + return Ok(Vec::new()); + }; + let content_length = parse_content_length(&initial_request[..header_end])?; + if content_length > MAX_POLICY_LOCAL_BODY_BYTES { + return Err(miette::miette!( + "policy.local request body exceeds {MAX_POLICY_LOCAL_BODY_BYTES} bytes" + )); + } + + let mut body = initial_request[header_end..].to_vec(); + if body.len() > content_length { + body.truncate(content_length); + } + let read_loop = async { + while body.len() < content_length { + let remaining = content_length - body.len(); + let mut chunk = vec![0u8; remaining.min(8192)]; + let n = client.read(&mut chunk).await.into_diagnostic()?; + if n == 0 { + return Err(miette::miette!("policy.local request body ended early")); + } + body.extend_from_slice(&chunk[..n]); + } + Ok::<(), miette::Report>(()) + }; + tokio::time::timeout(POLICY_LOCAL_BODY_READ_TIMEOUT, read_loop) + .await + .map_err(|_| miette::miette!("policy.local request body read timed out"))??; + + Ok(body) +} + +fn parse_content_length(headers: &[u8]) -> Result { + let headers = String::from_utf8_lossy(headers); + for line in headers.lines().skip(1) { + if let Some((name, value)) = line.split_once(':') + && name.eq_ignore_ascii_case("content-length") + { + return value + .trim() + .parse::() + .into_diagnostic() + .map_err(|_| miette::miette!("invalid policy.local Content-Length")); + } + } + Ok(0) +} + +fn find_header_end(buf: &[u8]) -> Option { + buf.windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|idx| idx + 4) +} + +async fn write_json_response( + client: &mut S, + status: u16, + payload: serde_json::Value, +) -> Result<()> +where + S: AsyncWrite + Unpin, +{ + let body = payload.to_string(); + let response = format!( + "HTTP/1.1 {status} {}\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + Connection: close\r\n\ + \r\n\ + {}", + status_text(status), + body.len(), + body + ); + client + .write_all(response.as_bytes()) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + Ok(()) +} + +fn status_text(status: u16) -> &'static str { + match status { + 202 => "Accepted", + 400 => "Bad Request", + 404 => "Not Found", + 500 => "Internal Server Error", + 502 => "Bad Gateway", + 503 => "Service Unavailable", + _ => "OK", + } +} + +fn error_payload(error: &str, detail: String) -> serde_json::Value { + serde_json::json!({ + "error": error, + "detail": detail + }) +} + +#[derive(Debug, Deserialize)] +struct ProposalRequest { + #[serde(default)] + intent_summary: Option, + #[serde(default)] + operations: Vec, +} + +#[derive(Debug, Deserialize)] +struct AddNetworkRuleJson { + #[serde(default, rename = "ruleName")] + rule_name: Option, + rule: NetworkPolicyRuleJson, +} + +#[derive(Debug, Deserialize)] +struct NetworkPolicyRuleJson { + #[serde(default)] + name: Option, + #[serde(default)] + endpoints: Vec, + #[serde(default)] + binaries: Vec, +} + +#[derive(Debug, Deserialize)] +struct NetworkEndpointJson { + host: String, + #[serde(default)] + port: u32, + #[serde(default)] + ports: Vec, + #[serde(default)] + protocol: String, + #[serde(default)] + tls: String, + #[serde(default)] + enforcement: String, + #[serde(default)] + access: String, + #[serde(default)] + rules: Vec, + #[serde(default)] + allowed_ips: Vec, + #[serde(default)] + deny_rules: Vec, + #[serde(default)] + allow_encoded_slash: bool, +} + +#[derive(Debug, Deserialize)] +struct NetworkBinaryJson { + path: String, +} + +#[derive(Debug, Deserialize)] +struct L7RuleJson { + allow: L7AllowJson, +} + +#[derive(Debug, Deserialize)] +struct L7AllowJson { + #[serde(default)] + method: String, + #[serde(default)] + path: String, + #[serde(default)] + command: String, +} + +#[derive(Debug, Deserialize)] +struct L7DenyRuleJson { + #[serde(default)] + method: String, + #[serde(default)] + path: String, + #[serde(default)] + command: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn proposal_chunks_from_body_accepts_add_rule_operation() { + let body = br#"{ + "intent_summary": "Allow gh to create one repo.", + "operations": [ + { + "addRule": { + "ruleName": "github_api_repo_create", + "rule": { + "endpoints": [ + { + "host": "api.github.com", + "port": 443, + "protocol": "rest", + "tls": "terminate", + "enforcement": "enforce", + "rules": [ + { + "allow": { + "method": "POST", + "path": "/user/repos" + } + } + ] + } + ], + "binaries": [ + { + "path": "/usr/bin/gh" + } + ] + } + } + } + ] + }"#; + + let chunks = proposal_chunks_from_body(body).unwrap(); + + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0].rule_name, "github_api_repo_create"); + assert_eq!(chunks[0].rationale, "Allow gh to create one repo."); + assert_eq!(chunks[0].binary, "/usr/bin/gh"); + let rule = chunks[0].proposed_rule.as_ref().unwrap(); + assert_eq!(rule.name, "github_api_repo_create"); + assert_eq!(rule.endpoints[0].host, "api.github.com"); + assert_eq!(rule.endpoints[0].port, 443); + assert_eq!(rule.endpoints[0].ports, vec![443]); + assert_eq!(rule.endpoints[0].protocol, "rest"); + assert_eq!( + rule.endpoints[0].rules[0].allow.as_ref().unwrap().path, + "/user/repos" + ); + } + + #[test] + fn proposal_chunks_from_body_rejects_query_in_l7_path() { + let body = br#"{ + "operations": [ + { + "addRule": { + "ruleName": "bad", + "rule": { + "endpoints": [ + { + "host": "api.github.com", + "port": 443, + "rules": [ + { + "allow": { + "method": "GET", + "path": "/repos?token=secret" + } + } + ] + } + ] + } + } + } + ] + }"#; + + let error = proposal_chunks_from_body(body).unwrap_err(); + assert!(error.contains("query strings")); + assert!(!error.contains("secret")); + } + + #[test] + fn parse_last_query_clamps_to_max() { + assert_eq!(parse_last_query("last=5"), Some(5)); + assert_eq!(parse_last_query("foo=bar&last=20"), Some(20)); + assert_eq!(parse_last_query("last=999"), Some(MAX_DENIALS_LIMIT)); + assert_eq!(parse_last_query("last=0"), Some(1)); + assert_eq!(parse_last_query(""), None); + assert_eq!(parse_last_query("other=1"), None); + } + + #[test] + fn is_ocsf_denial_line_filters_correctly() { + // OCSF denial — match. + assert!(is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com:443/x [policy:p engine:l7]" + )); + assert!(is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z OCSF NET:OPEN [MED] DENIED curl(42) -> blocked.com:443 [policy:- engine:opa]" + )); + + // OCSF allowed — must not match. + assert!(!is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z OCSF NET:OPEN [INFO] ALLOWED curl(42) -> api.example.com:443" + )); + + // Non-OCSF tracing line — must not match even if it contains the word DENIED. + assert!(!is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z INFO some::module: request DENIED in upstream" + )); + + // Empty line — must not match. + assert!(!is_ocsf_denial_line("")); + } + + #[tokio::test] + async fn recent_denials_returns_newest_first_from_shorthand_lines() { + let dir = tempfile::tempdir().unwrap(); + let log_path = dir.path().join("openshell.2026-05-06.log"); + // Mixed file: allowed events, non-OCSF info lines, two denials. + // Lines are written in chronological order; reader walks newest-first. + let body = "\ +2026-05-06T17:02:00.000Z OCSF NET:OPEN [INFO] ALLOWED curl(10) -> api.example.com:443 [policy:default engine:opa] +2026-05-06T17:02:01.000Z INFO some::module: routine status check +2026-05-06T17:02:02.000Z OCSF HTTP:GET [MED] DENIED GET http://blocked.example/v1/data [policy:default-deny engine:l7] +2026-05-06T17:02:03.000Z OCSF NET:OPEN [INFO] ALLOWED curl(11) -> api.example.com:443 +2026-05-06T17:02:04.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com:443/repos/x/y/contents/z [policy:gh_readonly engine:l7] +"; + std::fs::write(&log_path, body).unwrap(); + + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (status, payload) = recent_denials_response(&ctx, "last=10").await; + assert_eq!(status, 200); + assert_eq!(payload["log_available"], true); + let denials = payload["denials"].as_array().unwrap(); + assert_eq!(denials.len(), 2); + // Newest first. + assert!(denials[0].as_str().unwrap().contains("HTTP:PUT")); + assert!( + denials[0] + .as_str() + .unwrap() + .contains("/repos/x/y/contents/z") + ); + assert!(denials[1].as_str().unwrap().contains("HTTP:GET")); + assert!(denials[1].as_str().unwrap().contains("blocked.example")); + } + + #[tokio::test] + async fn recent_denials_skips_jsonl_log_files() { + // The shorthand reader must not surface `openshell-ocsf.*.log` content + // even if a deny-looking line is present, so the response stays + // independent of the JSONL appender's enabled state. + let dir = tempfile::tempdir().unwrap(); + let jsonl = dir.path().join("openshell-ocsf.2026-05-06.log"); + std::fs::write( + &jsonl, + r#"{"class_uid":4002,"action_id":2,"message":"DENIED","time":1}"#, + ) + .unwrap(); + + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (status, payload) = recent_denials_response(&ctx, "").await; + assert_eq!(status, 200); + assert_eq!(payload["log_available"], false); + assert_eq!(payload["denials"].as_array().unwrap().len(), 0); + } + + #[tokio::test] + async fn recent_denials_signals_when_log_is_missing() { + let dir = tempfile::tempdir().unwrap(); + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (status, payload) = recent_denials_response(&ctx, "").await; + assert_eq!(status, 200); + assert_eq!(payload["log_available"], false); + assert_eq!(payload["denials"].as_array().unwrap().len(), 0); + assert!( + payload["note"] + .as_str() + .unwrap() + .contains("/var/log/openshell.") + ); + } + + #[test] + fn redact_query_strings_removes_query_from_url_token() { + let line = "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com/x?access_token=secret-token-1234 [policy:p engine:l7]"; + let redacted = redact_query_strings(line); + assert!(!redacted.contains("secret-token-1234")); + assert!(!redacted.contains("access_token")); + assert!(redacted.contains("?[redacted]")); + // Bracketed tag after the URL preserved. + assert!(redacted.contains("[policy:p engine:l7]")); + } + + #[test] + fn redact_query_strings_removes_query_in_reason_tag() { + // The FORWARD deny path's `message` becomes `[reason:...]` and may + // include a path with query string lacking a `://` prefix. + let line = "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com/x [policy:p engine:opa] [reason:FORWARD denied PUT api.github.com:443/x?token=secret-456]"; + let redacted = redact_query_strings(line); + assert!(!redacted.contains("secret-456")); + assert!(!redacted.contains("token=secret")); + assert!(redacted.contains("?[redacted]]")); + } + + #[test] + fn redact_query_strings_handles_multibyte_chars() { + let line = "ÜLÅUTF8 ? secret-x [policy:p]"; + // No `?` here, so no redaction — but must not panic. + let _ = redact_query_strings(line); + } + + #[test] + fn truncate_at_char_boundary_does_not_panic_on_multibyte() { + // 4-byte emoji sequence so byte-naive slicing would panic. + let s = "🚀".repeat(2000); // 8000 bytes + let truncated = truncate_at_char_boundary(&s, 4096); + assert!(truncated.len() <= 4096 + "...[truncated]".len()); + assert!(truncated.ends_with("...[truncated]")); + // Result must be valid UTF-8 — implicit if we return without panic. + } + + #[tokio::test] + async fn recent_denials_truncates_pathological_lines() { + let dir = tempfile::tempdir().unwrap(); + let log_path = dir.path().join("openshell.2026-05-06.log"); + // A single OCSF denial line exceeding MAX_DENIAL_LINE_BYTES. + let huge_path = "/".to_string() + &"a".repeat(MAX_DENIAL_LINE_BYTES + 100); + let line = format!( + "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://x{huge_path} [policy:p engine:l7]\n" + ); + std::fs::write(&log_path, line).unwrap(); + + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (_, payload) = recent_denials_response(&ctx, "last=1").await; + let denials = payload["denials"].as_array().unwrap(); + assert_eq!(denials.len(), 1); + let surfaced = denials[0].as_str().unwrap(); + assert!(surfaced.len() <= MAX_DENIAL_LINE_BYTES + "...[truncated]".len()); + assert!(surfaced.ends_with("...[truncated]")); + } + + use crate::test_helpers::ProposalsFlagGuard; + + #[test] + fn agent_next_steps_returns_empty_when_flag_off() { + let _guard = ProposalsFlagGuard::set_blocking(false); + let steps = agent_next_steps(); + let arr = steps.as_array().expect("agent_next_steps is an array"); + assert!( + arr.is_empty(), + "expected empty next_steps when feature is off, got {steps}" + ); + } + + #[test] + fn agent_next_steps_returns_full_array_when_flag_on() { + let _guard = ProposalsFlagGuard::set_blocking(true); + let steps = agent_next_steps(); + let arr = steps.as_array().expect("agent_next_steps is an array"); + assert_eq!(arr.len(), 4, "expected 4 next_steps when feature is on"); + let actions: Vec<&str> = arr + .iter() + .filter_map(|v| v.get("action").and_then(serde_json::Value::as_str)) + .collect(); + assert!(actions.contains(&"read_skill")); + assert!(actions.contains(&"submit_proposal")); + } + + #[tokio::test] + async fn route_request_returns_feature_disabled_when_flag_off() { + let _guard = ProposalsFlagGuard::set(false).await; + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + + // Even the otherwise-public `current_policy` route returns 404 with + // a feature_disabled error: when the surface is off it's off + // entirely, not selectively. + let (status, payload) = route_request(&ctx, "GET", ROUTE_POLICY_CURRENT, &[]).await; + assert_eq!(status, 404); + assert_eq!(payload["error"], "feature_disabled"); + assert!( + payload["detail"] + .as_str() + .unwrap() + .contains("agent_policy_proposals_enabled"), + "feature_disabled detail must name the setting key for actionability" + ); + } + + #[tokio::test] + async fn current_policy_route_returns_yaml_envelope() { + let _guard = ProposalsFlagGuard::set(true).await; + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + + let (mut client, mut server) = tokio::io::duplex(4096); + let request = + b"GET http://policy.local/v1/policy/current HTTP/1.1\r\nHost: policy.local\r\n\r\n"; + let task = tokio::spawn(async move { + handle_forward_request(&ctx, "GET", "/v1/policy/current", request, &mut server) + .await + .unwrap(); + }); + + let mut received = Vec::new(); + client.read_to_end(&mut received).await.unwrap(); + task.await.unwrap(); + + let response = String::from_utf8(received).unwrap(); + assert!(response.starts_with("HTTP/1.1 200 OK")); + let (_, body) = response.split_once("\r\n\r\n").unwrap(); + let body: serde_json::Value = serde_json::from_str(body).unwrap(); + assert_eq!(body["format"], "yaml"); + assert!(body["policy_yaml"].as_str().unwrap().contains("version: 1")); + } + + #[test] + fn parse_timeout_query_defaults_and_clamps() { + assert_eq!(parse_timeout_query(""), PROPOSAL_WAIT_DEFAULT_SECS); + assert_eq!(parse_timeout_query("timeout="), PROPOSAL_WAIT_DEFAULT_SECS); + assert_eq!( + parse_timeout_query("timeout=abc"), + PROPOSAL_WAIT_DEFAULT_SECS + ); + assert_eq!(parse_timeout_query("timeout=30"), 30); + assert_eq!(parse_timeout_query("foo=1&timeout=45"), 45); + // Below floor clamps up; above ceiling clamps down. + assert_eq!(parse_timeout_query("timeout=0"), PROPOSAL_WAIT_MIN_SECS); + assert_eq!(parse_timeout_query("timeout=9999"), PROPOSAL_WAIT_MAX_SECS); + } + + #[test] + fn is_terminal_status_matches_only_approved_and_rejected() { + assert!(!is_terminal_status("pending")); + assert!(is_terminal_status("approved")); + assert!(is_terminal_status("rejected")); + assert!(!is_terminal_status("")); + } + + #[test] + fn chunk_state_payload_surfaces_loop_fields() { + let chunk = PolicyChunk { + id: "chunk-x".to_string(), + status: "rejected".to_string(), + rule_name: "allow_example".to_string(), + binary: "/usr/bin/curl".to_string(), + rejection_reason: "scope too broad".to_string(), + validation_result: "no exfil paths".to_string(), + ..Default::default() + }; + let pending = chunk_state_payload(&chunk, false, false); + assert_eq!(pending["chunk_id"], "chunk-x"); + assert_eq!(pending["status"], "rejected"); + assert_eq!(pending["rejection_reason"], "scope too broad"); + assert_eq!(pending["validation_result"], "no exfil paths"); + // timed_out and policy_reloaded only appear when relevant. + assert!(pending.get("timed_out").is_none()); + assert!( + pending.get("policy_reloaded").is_none(), + "policy_reloaded is only meaningful for approved chunks" + ); + + let timed = chunk_state_payload(&chunk, true, false); + assert_eq!(timed["timed_out"], true); + } + + #[test] + fn chunk_state_payload_includes_policy_reloaded_when_approved() { + let chunk = PolicyChunk { + id: "chunk-y".to_string(), + status: "approved".to_string(), + rule_name: "allow_github".to_string(), + binary: "/usr/bin/curl".to_string(), + ..Default::default() + }; + let reloaded = chunk_state_payload(&chunk, false, true); + assert_eq!(reloaded["status"], "approved"); + assert_eq!(reloaded["policy_reloaded"], true); + + let not_reloaded = chunk_state_payload(&chunk, false, false); + assert_eq!(not_reloaded["policy_reloaded"], false); + } + + #[tokio::test] + async fn proposal_routes_reject_malformed_paths() { + let _guard = ProposalsFlagGuard::set(true).await; + let ctx = PolicyLocalContext::new(None, None, None); + + // Empty chunk_id after the prefix is 404, not a wildcard list. + let (status, _) = route_request(&ctx, "GET", "/v1/proposals/", &[]).await; + assert_eq!(status, 404); + + // More than one segment after the id (not "/wait") is 404, not a + // partial match. Prevents `/v1/proposals/abc/extra` from silently + // dispatching as a status lookup for "abc/extra". + let (status, _) = route_request(&ctx, "GET", "/v1/proposals/abc/extra", &[]).await; + assert_eq!(status, 404); + + // Trailing path after `/wait` also 404 — must not match the wait + // arm as a wildcard. + let (status, _) = route_request(&ctx, "GET", "/v1/proposals/abc/wait/extra", &[]).await; + assert_eq!(status, 404); + } + + #[tokio::test] + async fn proposal_status_route_returns_503_when_no_gateway() { + let _guard = ProposalsFlagGuard::set(true).await; + let ctx = PolicyLocalContext::new(None, None, Some("test-sandbox".to_string())); + + let (status, body) = route_request(&ctx, "GET", "/v1/proposals/chunk-id", &[]).await; + assert_eq!(status, 503); + assert_eq!(body["error"], "gateway_unavailable"); + } + + #[tokio::test] + async fn proposal_wait_route_returns_503_when_no_gateway() { + let _guard = ProposalsFlagGuard::set(true).await; + let ctx = PolicyLocalContext::new(None, None, Some("test-sandbox".to_string())); + + let (status, body) = + route_request(&ctx, "GET", "/v1/proposals/chunk-id/wait?timeout=1", &[]).await; + assert_eq!(status, 503); + assert_eq!(body["error"], "gateway_unavailable"); + } + + #[tokio::test] + async fn proposal_routes_return_feature_disabled_when_flag_off() { + let _guard = ProposalsFlagGuard::set(false).await; + let ctx = PolicyLocalContext::new(None, None, Some("test-sandbox".to_string())); + + let (status, body) = route_request(&ctx, "GET", "/v1/proposals/abc", &[]).await; + assert_eq!(status, 404); + assert_eq!(body["error"], "feature_disabled"); + + let (status, _) = route_request(&ctx, "GET", "/v1/proposals/abc/wait", &[]).await; + assert_eq!(status, 404); + } + + #[test] + fn summarize_chunk_for_audit_includes_endpoint_l7_path_and_binary() { + let chunk = PolicyChunk { + id: "ignored".to_string(), + rule_name: "github_write".to_string(), + binary: "/usr/bin/curl".to_string(), + proposed_rule: Some(NetworkPolicyRule { + name: "github_write".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + rules: vec![L7Rule { + allow: Some(L7Allow { + method: "PUT".to_string(), + path: "/repos/foo/bar/contents/x.md".to_string(), + ..Default::default() + }), + }], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }), + ..Default::default() + }; + let summary = summarize_chunk_for_audit(&chunk); + assert!(summary.contains("api.github.com:443")); + assert!(summary.contains("PUT /repos/foo/bar/contents/x.md")); + assert!(summary.contains("/usr/bin/curl")); + } + + // Helpers — synthetic proposed rule + policy with that rule already + // merged. Both reused across reload-readiness tests. + fn proposed_curl_rule_for_github() -> NetworkPolicyRule { + NetworkPolicyRule { + name: "agent_proposed".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + ports: vec![443], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + } + } + + fn policy_with_rule(rule: NetworkPolicyRule) -> ProtoSandboxPolicy { + ProtoSandboxPolicy { + version: 1, + network_policies: HashMap::from([(rule.name.clone(), rule)]), + ..Default::default() + } + } + + #[tokio::test] + async fn wait_returns_reloaded_true_when_rule_already_loaded() { + // John's false-sleep case: the supervisor has already reloaded a + // policy containing the proposed rule before /wait starts. A + // whole-policy diff would never see another change and burn the + // full timeout. Rule-coverage must return immediately. + let proposed = proposed_curl_rule_for_github(); + let ctx = PolicyLocalContext::new(Some(policy_with_rule(proposed.clone())), None, None); + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2); + + let start = tokio::time::Instant::now(); + let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; + let elapsed = start.elapsed(); + + assert!(reloaded, "should report reloaded=true on coverage"); + assert!( + elapsed < std::time::Duration::from_millis(200), + "should return immediately, not poll-and-wait; took {elapsed:?}" + ); + } + + #[tokio::test] + async fn wait_does_not_wake_on_unrelated_policy_change() { + // John's false-wakeup case: a *different* rule gets added to the + // local policy (other agent's approval, settings change, etc.). + // The agent's specific rule is still not loaded. A diff-based + // check would wake here; coverage must not. + let proposed = proposed_curl_rule_for_github(); + // Start with a policy that does NOT contain the proposed rule. + let initial = ProtoSandboxPolicy { + version: 1, + ..Default::default() + }; + let ctx = PolicyLocalContext::new(Some(initial), None, None); + + // Concurrently, an unrelated rule lands. We must not return. + let unrelated_load = { + let policy = ctx.current_policy.clone(); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + *policy.write().await = Some(policy_with_rule(NetworkPolicyRule { + name: "unrelated".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + ports: vec![443], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + })); + }) + }; + + let deadline = tokio::time::Instant::now() + std::time::Duration::from_millis(400); + let start = tokio::time::Instant::now(); + let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; + unrelated_load.await.unwrap(); + let elapsed = start.elapsed(); + + assert!( + !reloaded, + "must not wake on an unrelated reload; coverage was never satisfied" + ); + assert!( + elapsed >= std::time::Duration::from_millis(350), + "should have held until the deadline; only waited {elapsed:?}" + ); + } + + #[tokio::test] + async fn wait_wakes_when_matching_rule_arrives_mid_flight() { + // Sandbox starts without the rule, then a reload lands containing + // it. /wait should observe coverage and return reloaded=true. + let proposed = proposed_curl_rule_for_github(); + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + + let matching_load = { + let policy = ctx.current_policy.clone(); + let target = proposed.clone(); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + *policy.write().await = Some(policy_with_rule(target)); + }) + }; + + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2); + let start = tokio::time::Instant::now(); + let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; + matching_load.await.unwrap(); + let elapsed = start.elapsed(); + + assert!(reloaded, "should report reloaded=true after coverage lands"); + assert!( + elapsed < std::time::Duration::from_millis(800), + "should return shortly after coverage; took {elapsed:?}" + ); + } + + #[tokio::test] + async fn wait_returns_reloaded_false_at_deadline_when_no_coverage() { + // Deadline budget exhausted, the proposed rule never showed up. + // Coverage check returns false — the agent gets policy_reloaded= + // false and decides whether to retry blind or re-issue /wait. + let proposed = proposed_curl_rule_for_github(); + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + let deadline = tokio::time::Instant::now() + std::time::Duration::from_millis(300); + let start = tokio::time::Instant::now(); + let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; + let elapsed = start.elapsed(); + + assert!(!reloaded); + assert!( + elapsed >= std::time::Duration::from_millis(250), + "should wait until ~deadline; only waited {elapsed:?}" + ); + assert!( + elapsed < std::time::Duration::from_millis(800), + "should not extend past deadline by much; took {elapsed:?}" + ); + } + + #[test] + fn sanitize_reason_for_audit_strips_control_chars_and_caps_length() { + // Tabs and newlines are stripped; ordinary printable chars survive; + // multi-byte characters count as one char in the cap. + let raw = "line one\nline\ttwo\u{0001}\u{0007}"; + let cleaned = sanitize_reason_for_audit(raw); + assert!(!cleaned.contains('\n')); + assert!(!cleaned.contains('\t')); + assert!(!cleaned.contains('\u{0001}')); + assert!(cleaned.contains("line one")); + assert!(cleaned.contains("linetwo")); + + // Length cap with ellipsis marker so a downstream reader can tell + // the audit string is truncated. + let long: String = "x".repeat(500); + let capped = sanitize_reason_for_audit(&long); + assert!(capped.chars().count() <= 201); + assert!(capped.ends_with('…')); + + // Empty input maps to empty output (caller renders "(no guidance)"). + assert_eq!(sanitize_reason_for_audit(""), ""); + } + + #[test] + fn summarize_chunk_for_audit_falls_back_to_rule_name_without_rule() { + let chunk = PolicyChunk { + rule_name: "fallback".to_string(), + proposed_rule: None, + ..Default::default() + }; + assert_eq!(summarize_chunk_for_audit(&chunk), "rule_name:fallback"); + } +} diff --git a/crates/openshell-sandbox/src/process.rs b/crates/openshell-sandbox/src/process.rs index 0dc513836..9bbcfe66c 100644 --- a/crates/openshell-sandbox/src/process.rs +++ b/crates/openshell-sandbox/src/process.rs @@ -22,58 +22,43 @@ use std::process::Stdio; use tokio::process::{Child, Command}; use tracing::debug; -const SSH_HANDSHAKE_SECRET_ENV: &str = "OPENSHELL_SSH_HANDSHAKE_SECRET"; - fn inject_provider_env(cmd: &mut Command, provider_env: &HashMap) { for (key, value) in provider_env { cmd.env(key, value); } } -fn scrub_sensitive_env(cmd: &mut Command) { - cmd.env_remove(SSH_HANDSHAKE_SECRET_ENV); -} - #[cfg(unix)] -#[allow(unsafe_code, clippy::borrow_as_ptr)] pub fn harden_child_process() -> Result<()> { - let core_limit = libc::rlimit { - rlim_cur: 0, - rlim_max: 0, - }; - let rc = unsafe { libc::setrlimit(libc::RLIMIT_CORE, &raw const core_limit) }; - if rc != 0 { - return Err(miette::miette!( - "Failed to disable core dumps: {}", - std::io::Error::last_os_error() - )); - } + use rustix::process::{Resource, Rlimit, setrlimit}; + + setrlimit( + Resource::Core, + Rlimit { + current: Some(0), + maximum: Some(0), + }, + ) + .map_err(|e| miette::miette!("Failed to disable core dumps: {e}"))?; // Limit process creation to prevent fork bombs. 512 processes per UID is // sufficient for typical agent workloads (shell, compilers, language servers) // while preventing runaway forking. Set as a hard limit so the sandbox user // cannot raise it after privilege drop. - let nproc_limit = libc::rlimit { - rlim_cur: 512, - rlim_max: 512, - }; - let rc = unsafe { libc::setrlimit(libc::RLIMIT_NPROC, &raw const nproc_limit) }; - if rc != 0 { - return Err(miette::miette!( - "Failed to set RLIMIT_NPROC: {}", - std::io::Error::last_os_error() - )); - } + setrlimit( + Resource::Nproc, + Rlimit { + current: Some(512), + maximum: Some(512), + }, + ) + .map_err(|e| miette::miette!("Failed to set RLIMIT_NPROC: {e}"))?; #[cfg(target_os = "linux")] { - let rc = unsafe { libc::prctl(libc::PR_SET_DUMPABLE, 0, 0, 0, 0) }; - if rc != 0 { - return Err(miette::miette!( - "Failed to set PR_SET_DUMPABLE=0: {}", - std::io::Error::last_os_error() - )); - } + use rustix::process::{DumpableBehavior, set_dumpable_behavior}; + set_dumpable_behavior(DumpableBehavior::NotDumpable) + .map_err(|e| miette::miette!("Failed to set PR_SET_DUMPABLE=0: {e}"))?; } Ok(()) @@ -159,9 +144,17 @@ impl ProcessHandle { .stdout(Stdio::inherit()) .stderr(Stdio::inherit()) .kill_on_drop(true) - .env("OPENSHELL_SANDBOX", "1"); + .env(openshell_core::sandbox_env::SANDBOX, "1"); + + // Strip supervisor-only credentials from the entrypoint's inherited + // environment. The entrypoint drops to the sandbox user before + // `exec`; without this strip, anything running as the sandbox user + // (e.g. an SSH-spawned shell) could read /proc//environ + // and recover the gateway-minted JWT. Issue #1354. + cmd.env_remove(openshell_core::sandbox_env::SANDBOX_TOKEN) + .env_remove(openshell_core::sandbox_env::SANDBOX_TOKEN_FILE) + .env_remove(openshell_core::sandbox_env::K8S_SA_TOKEN_FILE); - scrub_sensitive_env(&mut cmd); inject_provider_env(&mut cmd, provider_env); if let Some(dir) = workdir { @@ -286,9 +279,17 @@ impl ProcessHandle { .stdout(Stdio::inherit()) .stderr(Stdio::inherit()) .kill_on_drop(true) - .env("OPENSHELL_SANDBOX", "1"); + .env(openshell_core::sandbox_env::SANDBOX, "1"); + + // Strip supervisor-only credentials from the entrypoint's inherited + // environment. The entrypoint drops to the sandbox user before + // `exec`; without this strip, anything running as the sandbox user + // (e.g. an SSH-spawned shell) could read /proc//environ + // and recover the gateway-minted JWT. Issue #1354. + cmd.env_remove(openshell_core::sandbox_env::SANDBOX_TOKEN) + .env_remove(openshell_core::sandbox_env::SANDBOX_TOKEN_FILE) + .env_remove(openshell_core::sandbox_env::K8S_SA_TOKEN_FILE); - scrub_sensitive_env(&mut cmd); inject_provider_env(&mut cmd, provider_env); if let Some(dir) = workdir { @@ -804,21 +805,6 @@ mod tests { assert_eq!(probe_hardened_child(dumpable_flag_probe), 0); } - #[tokio::test] - async fn scrub_sensitive_env_removes_ssh_handshake_secret() { - let mut cmd = Command::new("/usr/bin/env"); - cmd.stdin(StdStdio::null()) - .stdout(StdStdio::piped()) - .stderr(StdStdio::null()) - .env(SSH_HANDSHAKE_SECRET_ENV, "super-secret"); - - scrub_sensitive_env(&mut cmd); - - let output = cmd.output().await.expect("spawn env"); - let stdout = String::from_utf8(output.stdout).expect("utf8"); - assert!(!stdout.contains(SSH_HANDSHAKE_SECRET_ENV)); - } - #[tokio::test] async fn inject_provider_env_sets_placeholder_values() { let mut cmd = Command::new("/usr/bin/env"); diff --git a/crates/openshell-sandbox/src/procfs.rs b/crates/openshell-sandbox/src/procfs.rs index e02e850b5..3ac8dbe14 100644 --- a/crates/openshell-sandbox/src/procfs.rs +++ b/crates/openshell-sandbox/src/procfs.rs @@ -762,6 +762,17 @@ mod tests { use std::net::{TcpListener, TcpStream}; use std::time::{Duration, Instant}; + struct ChildGuard(libc::pid_t); + impl Drop for ChildGuard { + fn drop(&mut self) { + #[allow(unsafe_code)] + unsafe { + libc::kill(self.0, libc::SIGKILL); + libc::waitpid(self.0, std::ptr::null_mut(), 0); + } + } + } + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); let listener_port = listener.local_addr().unwrap().port(); let stream = TcpStream::connect(("127.0.0.1", listener_port)).expect("connect"); @@ -781,9 +792,10 @@ mod tests { } } + let _guard = ChildGuard(child_pid); let child_pid_u32 = child_pid.cast_unsigned(); let entrypoint_pid = std::process::id(); - let deadline = Instant::now() + Duration::from_secs(2); + let deadline = Instant::now() + Duration::from_secs(5); let owners = loop { let owners = resolve_tcp_peer_socket_owners(entrypoint_pid, peer_port) .expect("resolve socket owners"); @@ -802,13 +814,6 @@ mod tests { std::thread::sleep(Duration::from_millis(20)); }; - // libc/syscall FFI requires unsafe - #[allow(unsafe_code)] - unsafe { - libc::kill(child_pid, libc::SIGKILL); - libc::waitpid(child_pid, std::ptr::null_mut(), 0); - } - let owner_pids = owners .owners .iter() diff --git a/crates/openshell-sandbox/src/provider_credentials.rs b/crates/openshell-sandbox/src/provider_credentials.rs new file mode 100644 index 000000000..ae91e8d6e --- /dev/null +++ b/crates/openshell-sandbox/src/provider_credentials.rs @@ -0,0 +1,229 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Runtime provider credential snapshots. + +use crate::secrets::SecretResolver; +use std::collections::{HashMap, VecDeque}; +use std::sync::{Arc, RwLock}; + +const MAX_RETAINED_CREDENTIAL_GENERATIONS: usize = 8; + +#[derive(Debug, Clone, Default)] +pub struct ProviderCredentialSnapshot { + pub revision: u64, + pub child_env: HashMap, +} + +#[derive(Debug)] +struct ProviderCredentialStateInner { + current: Arc, + generations: VecDeque>, + current_resolver: Option>, + combined_resolver: Option>, +} + +#[derive(Debug, Clone)] +pub struct ProviderCredentialState { + inner: Arc>, +} + +impl ProviderCredentialState { + pub fn from_environment( + revision: u64, + env: HashMap, + credential_expires_at_ms: HashMap, + ) -> Self { + let (child_env, generation_resolver, current_resolver) = + SecretResolver::from_provider_env_for_current_revision( + env, + credential_expires_at_ms, + revision, + ); + let snapshot = Arc::new(ProviderCredentialSnapshot { + revision, + child_env, + }); + let generations: VecDeque<_> = generation_resolver.map(Arc::new).into_iter().collect(); + let current_resolver = current_resolver.map(Arc::new); + let combined_resolver = merge_resolvers(&generations, current_resolver.as_ref()); + + Self { + inner: Arc::new(RwLock::new(ProviderCredentialStateInner { + current: snapshot, + generations, + current_resolver, + combined_resolver, + })), + } + } + + pub fn snapshot(&self) -> Arc { + self.inner + .read() + .expect("provider credential state poisoned") + .current + .clone() + } + + pub fn resolver(&self) -> Option> { + self.inner + .read() + .expect("provider credential state poisoned") + .combined_resolver + .clone() + } + + pub fn install_environment( + &self, + revision: u64, + env: HashMap, + credential_expires_at_ms: HashMap, + ) -> usize { + let (child_env, generation_resolver, current_resolver) = + SecretResolver::from_provider_env_for_current_revision( + env, + credential_expires_at_ms, + revision, + ); + let mut inner = self + .inner + .write() + .expect("provider credential state poisoned"); + + inner.current = Arc::new(ProviderCredentialSnapshot { + revision, + child_env, + }); + inner.current_resolver = current_resolver.map(Arc::new); + + if let Some(resolver) = generation_resolver { + inner.generations.push_back(Arc::new(resolver)); + while inner.generations.len() > MAX_RETAINED_CREDENTIAL_GENERATIONS { + inner.generations.pop_front(); + } + } + inner.combined_resolver = + merge_resolvers(&inner.generations, inner.current_resolver.as_ref()); + inner.current.child_env.len() + } +} + +fn merge_resolvers( + generations: &VecDeque>, + current_resolver: Option<&Arc>, +) -> Option> { + SecretResolver::merge( + generations + .iter() + .map(Arc::as_ref) + .chain(current_resolver.into_iter().map(Arc::as_ref)), + ) + .map(Arc::new) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn snapshots_use_revision_scoped_placeholders() { + let state = ProviderCredentialState::from_environment( + 10, + HashMap::from([("GITHUB_TOKEN".to_string(), "old".to_string())]), + HashMap::new(), + ); + let first = state.snapshot(); + assert_eq!( + first.child_env.get("GITHUB_TOKEN").map(String::as_str), + Some("openshell:resolve:env:v10_GITHUB_TOKEN") + ); + + state.install_environment( + 11, + HashMap::from([("GITHUB_TOKEN".to_string(), "new".to_string())]), + HashMap::new(), + ); + let second = state.snapshot(); + assert_eq!( + second.child_env.get("GITHUB_TOKEN").map(String::as_str), + Some("openshell:resolve:env:v11_GITHUB_TOKEN") + ); + + let resolver = state.resolver().expect("resolver"); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:v10_GITHUB_TOKEN"), + Some("old") + ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:v11_GITHUB_TOKEN"), + Some("new") + ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:GITHUB_TOKEN"), + Some("new") + ); + assert_eq!( + resolver.resolve_placeholder("provider-OPENSHELL-RESOLVE-ENV-GITHUB_TOKEN"), + Some("new") + ); + } + + #[test] + fn empty_refresh_removes_current_aliases_but_retains_revisioned_resolver() { + let state = ProviderCredentialState::from_environment( + 10, + HashMap::from([("GITHUB_TOKEN".to_string(), "old".to_string())]), + HashMap::new(), + ); + + state.install_environment(11, HashMap::new(), HashMap::new()); + + assert!(state.snapshot().child_env.is_empty()); + let resolver = state.resolver().expect("old resolver retained"); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:v10_GITHUB_TOKEN"), + Some("old") + ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:GITHUB_TOKEN"), + None + ); + assert_eq!( + resolver.resolve_placeholder("provider-OPENSHELL-RESOLVE-ENV-GITHUB_TOKEN"), + None + ); + } + + #[test] + fn expired_retained_generation_does_not_resolve() { + let now_ms = i64::try_from( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis(), + ) + .unwrap(); + let state = ProviderCredentialState::from_environment( + 10, + HashMap::from([("GITHUB_TOKEN".to_string(), "old".to_string())]), + HashMap::from([("GITHUB_TOKEN".to_string(), now_ms - 1_000)]), + ); + + state.install_environment( + 11, + HashMap::from([("GITHUB_TOKEN".to_string(), "new".to_string())]), + HashMap::from([("GITHUB_TOKEN".to_string(), now_ms + 60_000)]), + ); + + let resolver = state.resolver().expect("resolver"); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:v10_GITHUB_TOKEN"), + None + ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:v11_GITHUB_TOKEN"), + Some("new") + ); + } +} diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 5344374ac..88deb1596 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -8,9 +8,11 @@ use crate::identity::BinaryIdentityCache; use crate::l7::tls::ProxyTlsState; use crate::opa::{NetworkAction, OpaEngine, PolicyGenerationGuard}; use crate::policy::ProxyPolicy; -use crate::secrets::{SecretResolver, rewrite_header_line}; +use crate::policy_local::{POLICY_LOCAL_HOST, PolicyLocalContext}; +use crate::provider_credentials::ProviderCredentialState; +use crate::secrets::{SecretResolver, rewrite_header_line_checked}; use miette::{IntoDiagnostic, Result}; -use openshell_core::net::{is_always_blocked_ip, is_internal_ip}; +use openshell_core::net::{is_always_blocked_ip, is_internal_ip, is_link_local_ip}; use openshell_ocsf::{ ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, NetworkActivityBuilder, Process, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, @@ -31,15 +33,41 @@ const MAX_HEADER_BYTES: usize = 8192; const INFERENCE_LOCAL_HOST: &str = "inference.local"; const INFERENCE_LOCAL_PORT: u16 = 443; +/// Hostnames injected by compute drivers as `/etc/hosts` aliases for the host +/// machine. Traffic to these names is eligible for the trusted-gateway SSRF +/// exemption when the resolved IP matches the driver-injected value read from +/// `/etc/hosts` at proxy startup. +const HOST_GATEWAY_ALIASES: &[&str] = &[ + "host.openshell.internal", + "host.containers.internal", + "host.docker.internal", +]; + +/// Cloud instance metadata IPs that are NEVER exempted from SSRF blocking, +/// even when they coincidentally match a host-gateway alias resolution. +/// This list covers the well-known IMDS endpoints across major cloud providers. +const CLOUD_METADATA_IPS: &[IpAddr] = &[ + // AWS / GCP / Azure instance metadata service + IpAddr::V4(std::net::Ipv4Addr::new(169, 254, 169, 254)), +]; + /// Maximum total bytes for a streaming inference response body (32 MiB). +#[cfg(not(test))] const MAX_STREAMING_BODY: usize = 32 * 1024 * 1024; +// Keep unit tests deterministic without pushing tens of MiB through loopback. +#[cfg(test)] +const MAX_STREAMING_BODY: usize = 1024; /// Idle timeout per chunk when relaying streaming inference responses. /// /// Reasoning models (e.g. nemotron-3-super, o1, o3) can pause for 60+ seconds /// between "thinking" and output phases. 120s provides headroom while still /// catching genuinely stuck streams. +#[cfg(not(test))] const CHUNK_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); +// Exercise idle-timeout truncation without slowing the full package test suite. +#[cfg(test)] +const CHUNK_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(100); /// Result of a proxy CONNECT policy decision. struct ConnectDecision { @@ -147,7 +175,7 @@ impl ProxyHandle { /// The proxy uses OPA for network decisions with process-identity binding /// via `/proc/net/tcp`. All connections are evaluated through OPA policy. #[allow(clippy::too_many_arguments)] - pub async fn start_with_bind_addr( + pub(crate) async fn start_with_bind_addr( policy: &ProxyPolicy, bind_addr: Option, opa_engine: Arc, @@ -155,7 +183,8 @@ impl ProxyHandle { entrypoint_pid: Arc, tls_state: Option>, inference_ctx: Option>, - secret_resolver: Option>, + provider_credentials: Option, + policy_local_ctx: Option>, denial_tx: Option>, ) -> Result { // Use override bind_addr, fall back to policy http_addr, then default @@ -185,6 +214,18 @@ impl ProxyHandle { ocsf_emit!(event); } + // Detect the trusted host gateway IP from /etc/hosts before user code + // runs. This is read once at startup so later /etc/hosts modifications + // by sandbox workloads cannot influence the stored value. + let trusted_host_gateway: Arc> = Arc::new(detect_trusted_host_gateway()); + if let Some(ref ip) = *trusted_host_gateway { + tracing::info!( + %ip, + "Trusted host gateway detected from /etc/hosts; \ + host-gateway aliases exempt from SSRF always-blocked check" + ); + } + let join = tokio::spawn(async move { loop { match listener.accept().await { @@ -194,11 +235,24 @@ impl ProxyHandle { let spid = entrypoint_pid.clone(); let tls = tls_state.clone(); let inf = inference_ctx.clone(); - let resolver = secret_resolver.clone(); + let policy_local = policy_local_ctx.clone(); + let gw = trusted_host_gateway.clone(); + let resolver = provider_credentials + .as_ref() + .and_then(ProviderCredentialState::resolver); let dtx = denial_tx.clone(); tokio::spawn(async move { if let Err(err) = handle_tcp_connection( - stream, opa, cache, spid, tls, inf, resolver, dtx, + stream, + opa, + cache, + spid, + tls, + inf, + policy_local, + gw, + resolver, + dtx, ) .await { @@ -313,6 +367,8 @@ async fn handle_tcp_connection( entrypoint_pid: Arc, tls_state: Option>, inference_ctx: Option>, + policy_local_ctx: Option>, + trusted_host_gateway: Arc>, secret_resolver: Option>, denial_tx: Option>, ) -> Result<()> { @@ -357,6 +413,8 @@ async fn handle_tcp_connection( opa_engine, identity_cache, entrypoint_pid, + policy_local_ctx, + trusted_host_gateway, secret_resolver, denial_tx.as_ref(), ) @@ -511,7 +569,63 @@ async fn handle_tcp_connection( // The "non-empty" branch is the explicit-allowlist path; reading it first // matches the policy decision narrative. #[allow(clippy::if_not_else)] - let mut upstream = if !raw_allowed_ips.is_empty() { + let mut upstream = if is_host_gateway_alias(&host_lc) + && let Some(gw) = *trusted_host_gateway + { + // Trusted host-gateway path. The compute driver injected this hostname + // into /etc/hosts pointing at a known IP (read at proxy startup before + // user code runs). Bypass the normal SSRF tiers so link-local gateway + // addresses (used by rootless Podman with pasta) are not hard-blocked. + // Cloud metadata IPs and control-plane ports are still rejected. + match resolve_and_check_trusted_gateway(&host, port, gw, sandbox_entrypoint_pid).await { + Ok(addrs) => TcpStream::connect(addrs.as_slice()) + .await + .into_diagnostic()?, + Err(reason) => { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "ssrf") + .message(format!( + "CONNECT blocked: trusted-gateway check failed for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial( + &denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + respond( + &mut client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!("CONNECT {host_lc}:{port} blocked: trusted-gateway check failed"), + ), + ) + .await?; + return Ok(()); + } + } + } else if !raw_allowed_ips.is_empty() { // allowed_ips mode: validate resolved IPs against CIDR allowlist. // Loopback and link-local are still always blocked. match parse_allowed_ips(&raw_allowed_ips) { @@ -1798,6 +1912,149 @@ fn normalize_host_lookup_key(host: &str) -> &str { .unwrap_or(host) } +/// Returns `true` if `host` is one of the well-known driver-injected aliases +/// for the host machine (e.g. `host.openshell.internal`). +fn is_host_gateway_alias(host: &str) -> bool { + let h = normalize_host_lookup_key(host); + HOST_GATEWAY_ALIASES + .iter() + .any(|alias| alias.eq_ignore_ascii_case(h)) +} + +/// Returns `true` if `ip` is a known cloud instance metadata endpoint that +/// must never be exempted from SSRF blocking. +/// +/// IPv4-mapped IPv6 addresses (e.g. `::ffff:169.254.169.254`) are normalized +/// to their embedded IPv4 representation before comparison, so the invariant +/// holds regardless of how the address is represented. +fn is_cloud_metadata_ip(ip: IpAddr) -> bool { + match ip { + IpAddr::V4(_) => CLOUD_METADATA_IPS.contains(&ip), + IpAddr::V6(v6) => v6 + .to_ipv4_mapped() + .is_some_and(|v4| CLOUD_METADATA_IPS.contains(&IpAddr::V4(v4))), + } +} + +/// Read the proxy's own `/etc/hosts` at startup and return the IP mapped to +/// `host.openshell.internal`, if present and safe. +/// +/// This is called once before user code runs, so the returned value is immune +/// to later `/etc/hosts` tampering by sandbox workloads. Returns `None` if no +/// entry exists, the entry cannot be parsed, or the mapped IP is a cloud +/// metadata address. +#[cfg(any(target_os = "linux", test))] +fn detect_trusted_host_gateway() -> Option { + let contents = std::fs::read_to_string("/etc/hosts").ok()?; + let ips = parse_hosts_file_for_host(&contents, "host.openshell.internal"); + + // Multiple distinct IPs for the alias is unexpected — compute drivers + // always inject exactly one. Warn loudly so operators can diagnose the + // inconsistency; we still proceed with the first entry rather than + // disabling the exemption entirely, because the mismatch guard in + // resolve_and_check_trusted_gateway() will reject any runtime resolution + // that returns a different IP. + if ips.len() > 1 { + warn!( + ips = ?ips, + "host.openshell.internal has {} distinct IPs in /etc/hosts; \ + expected exactly one. Using first entry. \ + Connections resolving to any other IP will be rejected.", + ips.len() + ); + } + + let ip = ips.into_iter().next()?; + + if is_cloud_metadata_ip(ip) { + warn!( + %ip, + "host.openshell.internal resolves to a cloud metadata IP; \ + trusted-gateway SSRF exemption disabled" + ); + return None; + } + // The exemption exists solely for link-local IPs used by rootless Podman + // with pasta. Private RFC 1918 addresses (e.g. Docker bridge 172.17.0.1, + // Kubernetes node 192.168.x.x), loopback, unspecified, and all other + // non-link-local addresses are never legitimate candidates for the + // link-local SSRF exemption — they must fall through to the normal + // allowed_ips / resolve_and_reject_internal() enforcement path. + if !is_link_local_ip(ip) { + warn!( + %ip, + "host.openshell.internal maps to a non-link-local IP; \ + trusted-gateway SSRF exemption disabled" + ); + return None; + } + Some(ip) +} + +#[cfg(not(any(target_os = "linux", test)))] +fn detect_trusted_host_gateway() -> Option { + None +} + +/// Resolve `host:port` and validate that every resolved address matches the +/// trusted host gateway IP. +/// +/// This bypasses the normal SSRF tiers (always-blocked and internal-IP) for +/// driver-injected host-gateway aliases, allowing link-local addresses used +/// by rootless Podman with pasta without opening up arbitrary link-local or +/// cloud metadata access. +/// +/// Rejects: +/// - Any resolved IP that is a cloud metadata address (defense-in-depth) +/// - Any resolved IP that does not match `trusted_gw` (prevents /etc/hosts tampering) +/// - Control-plane ports (etcd, K8s API, kubelet) regardless of IP +async fn resolve_and_check_trusted_gateway( + host: &str, + port: u16, + trusted_gw: IpAddr, + entrypoint_pid: u32, +) -> std::result::Result, String> { + if BLOCKED_CONTROL_PLANE_PORTS.contains(&port) { + return Err(format!( + "port {port} is a blocked control-plane port, connection rejected" + )); + } + let addrs = resolve_socket_addrs(host, port, entrypoint_pid).await?; + if addrs.is_empty() { + return Err(format!( + "DNS resolution returned no addresses for {}", + normalize_host_lookup_key(host) + )); + } + for addr in &addrs { + if is_cloud_metadata_ip(addr.ip()) { + return Err(format!( + "{host} resolves to cloud metadata address {}, connection rejected", + addr.ip() + )); + } + if addr.ip() != trusted_gw { + return Err(format!( + "{host} resolves to {} which does not match trusted host gateway \ + {trusted_gw}, connection rejected", + addr.ip() + )); + } + // Defense-in-depth: even if the resolved IP matches trusted_gw, reject + // any non-link-local address. detect_trusted_host_gateway() already + // enforces this at startup, but we re-check here to guard against any + // unanticipated code path that might admit a private or loopback IP. + if !is_link_local_ip(addr.ip()) { + return Err(format!( + "{host} resolves to non-link-local address {}, \ + connection rejected", + addr.ip() + )); + } + } + Ok(addrs) +} + fn resolve_ip_literal(host: &str, port: u16) -> Option> { normalize_host_lookup_key(host) .parse::() @@ -2261,11 +2518,17 @@ fn rewrite_forward_request( used: usize, path: &str, secret_resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, ) -> Result, crate::secrets::UnresolvedPlaceholderError> { let header_end = raw[..used] .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(used, |p| p + 4); + let websocket_upgrade = crate::l7::rest::request_is_websocket_upgrade(&raw[..header_end]); + let upstream_path = match secret_resolver { + Some(resolver) => crate::secrets::rewrite_target_for_eval(path, resolver)?.resolved, + None => path.to_string(), + }; let header_str = String::from_utf8_lossy(&raw[..header_end]); let lines = header_str.split("\r\n").collect::>(); @@ -2282,7 +2545,7 @@ fn rewrite_forward_request( if parts.len() == 3 { output.extend_from_slice(parts[0].as_bytes()); output.push(b' '); - output.extend_from_slice(path.as_bytes()); + output.extend_from_slice(upstream_path.as_bytes()); output.push(b' '); output.extend_from_slice(parts[2].as_bytes()); } else { @@ -2309,14 +2572,19 @@ fn rewrite_forward_request( // Replace Connection header if lower.starts_with("connection:") { has_connection = true; + if websocket_upgrade { + output.extend_from_slice(line.as_bytes()); + output.extend_from_slice(b"\r\n"); + continue; + } output.extend_from_slice(b"Connection: close\r\n"); continue; } - let rewritten_line = secret_resolver.map_or_else( - || line.to_string(), - |resolver| rewrite_header_line(line, resolver), - ); + let rewritten_line = match secret_resolver { + Some(resolver) => rewrite_header_line_checked(line, resolver)?, + None => line.to_string(), + }; output.extend_from_slice(rewritten_line.as_bytes()); output.extend_from_slice(b"\r\n"); @@ -2327,7 +2595,7 @@ fn rewrite_forward_request( } // Inject missing headers - if !has_connection { + if !has_connection && !websocket_upgrade { output.extend_from_slice(b"Connection: close\r\n"); } if !has_via { @@ -2336,6 +2604,7 @@ fn rewrite_forward_request( // End of headers output.extend_from_slice(b"\r\n"); + let rewritten_header_end = output.len(); // Append any overflow body bytes from the original buffer if header_end < used { @@ -2344,8 +2613,15 @@ fn rewrite_forward_request( // Fail-closed: scan for any remaining unresolved placeholders if secret_resolver.is_some() { - let output_str = String::from_utf8_lossy(&output); - if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) { + let scan_end = if request_body_credential_rewrite { + rewritten_header_end + } else { + output.len() + }; + let output_str = String::from_utf8_lossy(&output[..scan_end]); + if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) + || output_str.contains(crate::secrets::PROVIDER_ALIAS_MARKER_PUBLIC) + { return Err(crate::secrets::UnresolvedPlaceholderError { location: "header" }); } } @@ -2353,13 +2629,20 @@ fn rewrite_forward_request( Ok(output) } +struct ForwardRelayOptions<'a> { + generation_guard: &'a PolicyGenerationGuard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode, + secret_resolver: Option<&'a SecretResolver>, + request_body_credential_rewrite: bool, +} + async fn relay_rewritten_forward_request( method: &str, path: &str, rewritten: Vec, client: &mut C, upstream: &mut U, - generation_guard: &PolicyGenerationGuard, + options: ForwardRelayOptions<'_>, ) -> Result where C: TokioAsyncRead + TokioAsyncWrite + Unpin, @@ -2380,12 +2663,16 @@ where body_length, }; - crate::l7::rest::relay_http_request_with_resolver_guarded( + crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - None, - Some(generation_guard), + crate::l7::rest::RelayRequestOptions { + resolver: options.secret_resolver, + generation_guard: Some(options.generation_guard), + websocket_extensions: options.websocket_extensions, + request_body_credential_rewrite: options.request_body_credential_rewrite, + }, ) .await } @@ -2408,6 +2695,8 @@ async fn handle_forward_proxy( opa_engine: Arc, identity_cache: Arc, entrypoint_pid: Arc, + policy_local_ctx: Option>, + trusted_host_gateway: Arc>, secret_resolver: Option>, denial_tx: Option<&mpsc::UnboundedSender>, ) -> Result<()> { @@ -2431,6 +2720,38 @@ async fn handle_forward_proxy( }; let host_lc = host.to_ascii_lowercase(); + if host_lc == POLICY_LOCAL_HOST { + if scheme != "http" || port != 80 { + respond( + client, + &build_json_error_response( + 400, + "Bad Request", + "invalid_policy_local_scheme", + "Use http://policy.local only", + ), + ) + .await?; + return Ok(()); + } + if let Some(ctx) = policy_local_ctx { + return crate::policy_local::handle_forward_request( + &ctx, + method, + &path, + &buf[..used], + client, + ) + .await; + } + respond( + client, + b"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 31\r\n\r\npolicy.local is not configured", + ) + .await?; + return Ok(()); + } + // 2. Reject HTTPS — must use CONNECT for TLS if scheme == "https" { { @@ -2574,6 +2895,35 @@ async fn handle_forward_proxy( }; let mut forward_request_bytes = buf[..used].to_vec(); let mut upstream_target = path.clone(); + let mut websocket_extensions = crate::l7::rest::WebSocketExtensionMode::Preserve; + let mut forward_tunnel_engine: Option = None; + let mut forward_upgrade_config: Option = None; + let mut forward_upgrade_target = String::new(); + let mut forward_upgrade_query_params = std::collections::HashMap::new(); + let mut forward_websocket_request = + crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); + let mut request_body_credential_rewrite = false; + let l7_ctx = crate::l7::relay::L7EvalContext { + host: host_lc.clone(), + port, + policy_name: matched_policy.clone().unwrap_or_default(), + binary_path: decision + .binary + .as_ref() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_default(), + ancestors: decision + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + cmdline_paths: decision + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + secret_resolver: secret_resolver.clone(), + }; // 4b. If the endpoint has L7 config, evaluate the request against // L7 policy. The forward proxy handles exactly one request per @@ -2621,28 +2971,6 @@ async fn handle_forward_proxy( } }; - let l7_ctx = crate::l7::relay::L7EvalContext { - host: host_lc.clone(), - port, - policy_name: matched_policy.clone().unwrap_or_default(), - binary_path: decision - .binary - .as_ref() - .map(|p| p.to_string_lossy().into_owned()) - .unwrap_or_default(), - ancestors: decision - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - cmdline_paths: decision - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - secret_resolver: secret_resolver.clone(), - }; - // Canonicalize the request-target. The canonical form is fed to OPA // AND reassigned to the outer `path` variable so the later call to // `rewrite_forward_request` writes canonical bytes to the upstream. @@ -2711,6 +3039,14 @@ async fn handle_forward_proxy( .await?; return Ok(()); }; + forward_websocket_request = + crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); + websocket_extensions = crate::l7::relay::websocket_extension_mode(&l7_config.config); + request_body_credential_rewrite = l7_config.config.protocol == crate::l7::L7Protocol::Rest + && l7_config.config.request_body_credential_rewrite; + forward_upgrade_config = Some(l7_config.config.clone()); + forward_upgrade_target = path.clone(); + forward_upgrade_query_params = query_params.clone(); let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { let header_end = forward_request_bytes .windows(4) @@ -2871,9 +3207,12 @@ async fn handle_forward_proxy( .await?; return Ok(()); } + forward_tunnel_engine = Some(tunnel_engine); } // 5. DNS resolution + SSRF defence (mirrors the CONNECT path logic). + // - If the host is a driver-injected host-gateway alias: bypass SSRF + // tiers and validate only against the trusted gateway IP. // - If allowed_ips is set: validate resolved IPs against the allowlist // (this is the SSRF override for private IP destinations). // - If allowed_ips is empty: reject internal IPs, allow public IPs through. @@ -2884,70 +3223,18 @@ async fn handle_forward_proxy( raw_allowed_ips = implicit_allowed_ips_for_ip_host(&host); } - // The "non-empty" branch is the explicit-allowlist path; reading it first - // matches the policy decision narrative. + // The trusted-gateway branch is the first path; reading it before the + // allowed_ips and default branches matches the policy decision narrative. #[allow(clippy::if_not_else)] - let addrs = - if !raw_allowed_ips.is_empty() { - // allowed_ips mode: validate resolved IPs against CIDR allowlist. - match parse_allowed_ips(&raw_allowed_ips) { - Ok(nets) => { - match resolve_and_check_allowed_ips(&host, port, &nets, sandbox_entrypoint_pid) - .await - { - Ok(addrs) => addrs, - Err(reason) => { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, "ssrf") - .message(format!( - "FORWARD blocked: allowed_ips check failed for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!("{method} {host_lc}:{port} blocked: allowed_ips check failed"), - ), - ) - .await?; - return Ok(()); - } - } - } - Err(reason) => { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + let addrs = if is_host_gateway_alias(&host_lc) + && let Some(gw) = *trusted_host_gateway + { + // Trusted host-gateway path. Mirrors the CONNECT path logic. + match resolve_and_check_trusted_gateway(&host, port, gw, sandbox_entrypoint_pid).await { + Ok(addrs) => addrs, + Err(reason) => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) .activity(ActivityId::Other) .action(ActionId::Denied) .disposition(DispositionId::Blocked) @@ -2965,89 +3252,196 @@ async fn handle_forward_proxy( ) .firewall_rule(policy_str, "ssrf") .message(format!( - "FORWARD blocked: invalid allowed_ips in policy for {host_lc}:{port}" + "FORWARD blocked: trusted-gateway check failed for {host_lc}:{port}" )) .status_detail(&reason) .build(); - ocsf_emit!(event); - } - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!( - "{method} {host_lc}:{port} blocked: invalid allowed_ips in policy" - ), - ), - ) - .await?; - return Ok(()); + ocsf_emit!(event); } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!("{method} {host_lc}:{port} blocked: trusted-gateway check failed"), + ), + ) + .await?; + return Ok(()); } - } else { - // No allowed_ips: reject internal IPs, allow public IPs through. - match resolve_and_reject_internal(&host, port, sandbox_entrypoint_pid).await { - Ok(addrs) => addrs, - Err(reason) => { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, "ssrf") - .message(format!( - "FORWARD blocked: internal IP without allowed_ips for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!("{method} {host_lc}:{port} blocked: internal address"), - ), - ) - .await?; - return Ok(()); + } + } else if !raw_allowed_ips.is_empty() { + // allowed_ips mode: validate resolved IPs against CIDR allowlist. + match parse_allowed_ips(&raw_allowed_ips) { + Ok(nets) => { + match resolve_and_check_allowed_ips(&host, port, &nets, sandbox_entrypoint_pid) + .await + { + Ok(addrs) => addrs, + Err(reason) => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: allowed_ips check failed for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!( + "{method} {host_lc}:{port} blocked: allowed_ips check failed" + ), + ), + ) + .await?; + return Ok(()); + } } } - }; + Err(reason) => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: invalid allowed_ips in policy for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!( + "{method} {host_lc}:{port} blocked: invalid allowed_ips in policy" + ), + ), + ) + .await?; + return Ok(()); + } + } + } else { + // No allowed_ips: reject internal IPs, allow public IPs through. + match resolve_and_reject_internal(&host, port, sandbox_entrypoint_pid).await { + Ok(addrs) => addrs, + Err(reason) => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: internal IP without allowed_ips for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!("{method} {host_lc}:{port} blocked: internal address"), + ), + ) + .await?; + return Ok(()); + } + } + }; if let Err(e) = forward_generation_guard.ensure_current() { emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); @@ -3131,6 +3525,7 @@ async fn handle_forward_proxy( forward_request_bytes.len(), &upstream_target, secret_resolver.as_deref(), + request_body_credential_rewrite, ) { Ok(bytes) => bytes, Err(e) => { @@ -3173,11 +3568,47 @@ async fn handle_forward_proxy( rewritten, client, &mut upstream, - &forward_generation_guard, + ForwardRelayOptions { + generation_guard: &forward_generation_guard, + websocket_extensions, + secret_resolver: secret_resolver.as_deref(), + request_body_credential_rewrite, + }, ) .await?; - if let crate::l7::provider::RelayOutcome::Upgraded { overflow } = outcome { - crate::l7::relay::handle_upgrade(client, &mut upstream, overflow, &host_lc, port).await?; + if let crate::l7::provider::RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + { + let mut upgrade_options = if let (Some(config), Some(engine)) = ( + forward_upgrade_config.as_ref(), + forward_tunnel_engine.as_ref(), + ) { + crate::l7::relay::upgrade_options( + config, + &l7_ctx, + forward_websocket_request, + &forward_upgrade_target, + &forward_upgrade_query_params, + Some(engine), + ) + } else { + crate::l7::relay::UpgradeRelayOptions { + websocket_request: forward_websocket_request, + ..Default::default() + } + }; + upgrade_options.websocket.permessage_deflate = websocket_permessage_deflate; + crate::l7::relay::handle_upgrade( + client, + &mut upstream, + overflow, + &host_lc, + port, + upgrade_options, + ) + .await?; } Ok(()) @@ -3248,7 +3679,477 @@ fn is_benign_relay_error(err: &miette::Report) -> bool { )] mod tests { use super::*; + use std::future::Future; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use std::sync::Arc; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use tokio::net::{TcpListener, TcpStream}; + + fn websocket_l7_config( + protocol: crate::l7::L7Protocol, + websocket_credential_rewrite: bool, + ) -> crate::l7::L7EndpointConfig { + crate::l7::L7EndpointConfig { + protocol, + path: "/**".to_string(), + tls: crate::l7::TlsMode::Auto, + enforcement: crate::l7::EnforcementMode::Enforce, + graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + allow_encoded_slash: false, + websocket_credential_rewrite, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + } + } + + fn forward_test_guard() -> PolicyGenerationGuard { + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + engine + .generation_guard(engine.current_generation()) + .unwrap() + } + + async fn relay_forward_request_and_capture( + method: &str, + path: &str, + raw: &[u8], + resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, + ) -> Result { + let guard = forward_test_guard(); + let rewritten = rewrite_forward_request( + raw, + raw.len(), + path, + resolver, + request_body_credential_rewrite, + ) + .map_err(|e| miette::miette!("{e}"))?; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + let mut total = 0usize; + let mut expected_total = None; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if expected_total.is_none() + && let Some(end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") + { + let header_end = end + 4; + let headers = String::from_utf8_lossy(&buf[..header_end]); + let len = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + expected_total = Some(header_end + len); + } + if expected_total.is_some_and(|expected| total >= expected) { + break; + } + } + upstream_side + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + String::from_utf8_lossy(&buf[..total]).to_string() + }); + + relay_rewritten_forward_request( + method, + path, + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: resolver, + request_body_credential_rewrite, + }, + ) + .await?; + + upstream_task + .await + .map_err(|e| miette::miette!("upstream task failed: {e}")) + } + + fn forward_websocket_policy_parts( + data: &str, + host: &str, + port: u16, + path: &str, + policy_name: &str, + ) -> ( + crate::l7::L7EndpointConfig, + crate::opa::TunnelPolicyEngine, + crate::l7::relay::L7EvalContext, + ) { + let policy = include_str!("../data/sandbox-policy.rego"); + let engine = OpaEngine::from_strings(policy, data).unwrap(); + let decision = ConnectDecision { + action: NetworkAction::Allow { + matched_policy: Some(policy_name.to_string()), + }, + generation: engine.current_generation(), + binary: Some(PathBuf::from("/usr/bin/node")), + binary_pid: None, + ancestors: vec![], + cmdline_paths: vec![], + }; + let route = + query_l7_route_snapshot(&engine, &decision, host, port).expect("L7 route should match"); + let config = select_l7_config_for_path(&route.configs, path) + .expect("path-specific L7 config should match") + .config + .clone(); + let tunnel_engine = engine + .clone_engine_for_tunnel(route.generation) + .expect("tunnel engine"); + let ctx = crate::l7::relay::L7EvalContext { + host: host.to_string(), + port, + policy_name: policy_name.to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + (config, tunnel_engine, ctx) + } + + async fn read_http_headers(reader: &mut R) -> Vec { + let mut bytes = Vec::new(); + let mut chunk = [0u8; 256]; + loop { + let n = + tokio::time::timeout(std::time::Duration::from_secs(1), reader.read(&mut chunk)) + .await + .expect("HTTP headers should arrive") + .expect("header read should succeed"); + assert!(n > 0, "stream closed before HTTP headers"); + bytes.extend_from_slice(&chunk[..n]); + if bytes.windows(4).any(|w| w == b"\r\n\r\n") { + return bytes; + } + } + } + + fn masked_text_frame(payload: &[u8]) -> Vec { + let mask = [0x11, 0x22, 0x33, 0x44]; + assert!( + payload.len() <= 125, + "test helper only supports small frames" + ); + let payload_len = u8::try_from(payload.len()).expect("small frame length"); + let mut frame = vec![0x81, 0x80 | payload_len]; + frame.extend_from_slice(&mask); + frame.extend( + payload + .iter() + .enumerate() + .map(|(idx, byte)| byte ^ mask[idx % 4]), + ); + frame + } + + async fn forward_websocket_denied_after_upgrade( + config: crate::l7::L7EndpointConfig, + tunnel_engine: crate::opa::TunnelPolicyEngine, + ctx: crate::l7::relay::L7EvalContext, + path: &str, + payload: &str, + ) -> (miette::Report, Vec) { + let host = ctx.host.clone(); + let port = ctx.port; + let raw = format!( + "GET http://{host}{path} HTTP/1.1\r\n\ + Host: {host}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n\r\n" + ); + let rewritten = rewrite_forward_request(raw.as_bytes(), raw.len(), path, None, false) + .expect("forward websocket request should rewrite to origin form"); + let websocket_extensions = crate::l7::relay::websocket_extension_mode(&config); + let target = path.to_string(); + let query_params = std::collections::HashMap::new(); + let (mut proxy_to_upstream, mut upstream) = tokio::io::duplex(8192); + let (mut app, mut proxy_to_client) = tokio::io::duplex(8192); + + let relay = tokio::spawn(async move { + let guard = tunnel_engine.generation_guard(); + let outcome = relay_rewritten_forward_request( + "GET", + &target, + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: guard, + websocket_extensions, + secret_resolver: None, + request_body_credential_rewrite: false, + }, + ) + .await?; + if let crate::l7::provider::RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + { + let mut options = crate::l7::relay::upgrade_options( + &config, + &ctx, + true, + &target, + &query_params, + Some(&tunnel_engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + crate::l7::relay::handle_upgrade( + &mut proxy_to_client, + &mut proxy_to_upstream, + overflow, + &host, + port, + options, + ) + .await?; + } + Ok::<(), miette::Report>(()) + }); + + let forwarded_headers = read_http_headers(&mut upstream).await; + let forwarded_headers = String::from_utf8_lossy(&forwarded_headers); + assert!(forwarded_headers.starts_with(&format!("GET {path} HTTP/1.1\r\n"))); + assert!(forwarded_headers.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let response = read_http_headers(&mut app).await; + assert!(String::from_utf8_lossy(&response).contains("101 Switching Protocols")); + + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("websocket relay should fail closed after denied frame") + .expect("relay task should not panic") + .expect_err("denied websocket frame should fail the forward relay"); + + let mut leaked = Vec::new(); + tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read_to_end(&mut leaked), + ) + .await + .expect("upstream side should close") + .expect("upstream read should succeed"); + (err, leaked) + } + + #[test] + fn forward_websocket_upgrade_options_enable_native_policy_context() { + let (_, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "discord-real".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.map(Arc::new); + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let ctx = crate::l7::relay::L7EvalContext { + host: "gateway.example.test".to_string(), + port: 80, + policy_name: "ws_api".to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver, + }; + let query_params = std::collections::HashMap::new(); + + let extensions = crate::l7::relay::websocket_extension_mode(&websocket_l7_config( + crate::l7::L7Protocol::Websocket, + true, + )); + let options = crate::l7::relay::upgrade_options( + &websocket_l7_config(crate::l7::L7Protocol::Websocket, true), + &ctx, + true, + "/ws", + &query_params, + Some(&tunnel_engine), + ); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::PermessageDeflate + ); + assert!(options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_some()); + assert!(options.engine.is_some()); + assert!(options.ctx.is_some()); + assert!(matches!( + options.websocket.message_policy, + crate::l7::relay::WebSocketMessagePolicy::Transport + )); + } + + #[test] + fn forward_websocket_upgrade_options_preserve_rest_without_rewrite() { + let ctx = crate::l7::relay::L7EvalContext { + host: "gateway.example.test".to_string(), + port: 80, + policy_name: "rest_api".to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let query_params = std::collections::HashMap::new(); + let config = websocket_l7_config(crate::l7::L7Protocol::Rest, false); + let extensions = crate::l7::relay::websocket_extension_mode(&config); + let options = + crate::l7::relay::upgrade_options(&config, &ctx, true, "/ws", &query_params, None); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::Preserve + ); + assert!(!options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_none()); + assert!(options.engine.is_none()); + assert!(options.ctx.is_none()); + assert!(matches!( + options.websocket.message_policy, + crate::l7::relay::WebSocketMessagePolicy::None + )); + } + + #[tokio::test] + async fn forward_websocket_upgrade_blocks_text_frame_by_policy() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 80 + path: "/ws" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + - allow: + method: WEBSOCKET_TEXT + path: "/ws" + deny_rules: + - method: WEBSOCKET_TEXT + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let (config, tunnel_engine, ctx) = + forward_websocket_policy_parts(data, "gateway.example.test", 80, "/ws", "ws_api"); + + let (err, leaked) = forward_websocket_denied_after_upgrade( + config, + tunnel_engine, + ctx, + "/ws", + r#"{"type":"unsafe"}"#, + ) + .await; + + assert!(err.to_string().contains("websocket text message denied")); + assert!( + leaked.is_empty(), + "denied forward-proxy WebSocket text frames must not reach upstream" + ); + } + + #[tokio::test] + async fn forward_graphql_websocket_upgrade_blocks_unallowed_operation() { + let data = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: gateway.example.test + port: 80 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + deny_rules: + - operation_type: query + fields: [admin] + binaries: + - { path: /usr/bin/node } +"#; + let (config, tunnel_engine, ctx) = forward_websocket_policy_parts( + data, + "gateway.example.test", + 80, + "/graphql", + "graphql_ws", + ); + assert!( + config.websocket_graphql_policy, + "operation rules should enable GraphQL-over-WebSocket inspection" + ); + + let (err, leaked) = forward_websocket_denied_after_upgrade( + config, + tunnel_engine, + ctx, + "/graphql", + r#"{"id":"1","type":"subscribe","payload":{"query":"query { admin }"}}"#, + ) + .await; + + assert!(err.to_string().contains("websocket GraphQL message denied")); + assert!( + leaked.is_empty(), + "denied forward-proxy GraphQL WebSocket operations must not reach upstream" + ); + } #[test] fn l7_route_selection_prefers_path_specific_graphql_endpoint() { @@ -3261,6 +4162,9 @@ mod tests { enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, }, }, L7ConfigSnapshot { @@ -3271,6 +4175,9 @@ mod tests { enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, }, }, ]; @@ -3467,61 +4374,499 @@ mod tests { "partial alias match should not resolve" ); - let result = resolve_from_hosts_file_contents(contents, "searxng.local", 8080); - assert_eq!( - result, - vec![SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(192, 168, 1, 105)), - 8080 - )] + let result = resolve_from_hosts_file_contents(contents, "searxng.local", 8080); + assert_eq!( + result, + vec![SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 105)), + 8080 + )] + ); + } + + #[test] + fn test_resolve_from_hosts_file_contents_public_ip_passes_default_ssrf_check() { + let addrs = + resolve_from_hosts_file_contents("93.184.216.34 example.local\n", "example.local", 80); + assert!(reject_internal_resolved_addrs("example.local", &addrs).is_ok()); + } + + #[test] + fn test_resolve_from_hosts_file_contents_private_ip_requires_allowed_ips() { + let addrs = resolve_from_hosts_file_contents( + "192.168.1.105 searxng.local\n", + "searxng.local", + 8080, + ); + + let err = reject_internal_resolved_addrs("searxng.local", &addrs).unwrap_err(); + assert!( + err.contains("internal address"), + "expected private hosts-file resolution to remain blocked: {err}" + ); + + let nets = parse_allowed_ips(&["192.168.1.105/32".to_string()]).unwrap(); + assert!( + validate_allowed_ips_for_resolved_addrs("searxng.local", 8080, &addrs, &nets).is_ok() + ); + } + + #[test] + fn test_resolve_from_hosts_file_contents_always_blocked_ip_stays_blocked() { + let addrs = + resolve_from_hosts_file_contents("127.0.0.1 loopback.local\n", "loopback.local", 80); + let nets = vec!["127.0.0.0/8".parse::().unwrap()]; + let err = validate_allowed_ips_for_resolved_addrs("loopback.local", 80, &addrs, &nets) + .unwrap_err(); + assert!( + err.contains("always-blocked"), + "expected always-blocked hosts-file resolution to stay blocked: {err}" + ); + } + + #[test] + fn test_resolve_from_hosts_file_contents_returns_empty_without_match() { + let result = + resolve_from_hosts_file_contents("192.168.1.105 searxng.local\n", "missing.local", 80); + assert!(result.is_empty()); + } + + // -- is_host_gateway_alias -- + + #[test] + fn test_is_host_gateway_alias_recognises_known_aliases() { + assert!(is_host_gateway_alias("host.openshell.internal")); + assert!(is_host_gateway_alias("host.containers.internal")); + assert!(is_host_gateway_alias("host.docker.internal")); + } + + #[test] + fn test_is_host_gateway_alias_is_case_insensitive() { + assert!(is_host_gateway_alias("HOST.OPENSHELL.INTERNAL")); + assert!(is_host_gateway_alias("Host.Containers.Internal")); + assert!(is_host_gateway_alias("HOST.DOCKER.INTERNAL")); + } + + #[test] + fn test_is_host_gateway_alias_rejects_unknown_hosts() { + assert!(!is_host_gateway_alias("api.example.com")); + assert!(!is_host_gateway_alias("host.openshell.internal.evil.com")); + assert!(!is_host_gateway_alias("evil.host.openshell.internal")); + assert!(!is_host_gateway_alias("openshell.internal")); + assert!(!is_host_gateway_alias("")); + } + + // -- is_cloud_metadata_ip -- + + #[test] + fn test_is_cloud_metadata_ip_blocks_known_metadata_ip() { + assert!(is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 169, 254 + )))); + } + + #[test] + fn test_is_cloud_metadata_ip_allows_other_link_local() { + // The pasta gateway address on this test host — not a metadata IP. + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 1, 2 + )))); + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 0, 1 + )))); + } + + #[test] + fn test_is_cloud_metadata_ip_allows_private_and_public() { + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 10, 0, 0, 1 + )))); + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 192, 168, 1, 1 + )))); + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); + } + + #[test] + fn test_is_cloud_metadata_ip_blocks_ipv4_mapped_metadata() { + // ::ffff:169.254.169.254 is the IPv4-mapped IPv6 representation of the + // AWS/GCP/Azure IMDS endpoint. is_link_local_ip() recognizes it as + // link-local, so is_cloud_metadata_ip() must also catch it — otherwise + // the trusted-gateway exemption would be granted to the metadata service. + let mapped = Ipv4Addr::new(169, 254, 169, 254).to_ipv6_mapped(); + assert!( + is_cloud_metadata_ip(IpAddr::V6(mapped)), + "::ffff:169.254.169.254 must be recognized as cloud metadata" + ); + } + + #[test] + fn test_is_cloud_metadata_ip_allows_other_ipv4_mapped_link_local() { + // Other IPv4-mapped link-local addresses are NOT metadata. + let mapped = Ipv4Addr::new(169, 254, 1, 2).to_ipv6_mapped(); + assert!( + !is_cloud_metadata_ip(IpAddr::V6(mapped)), + "::ffff:169.254.1.2 should not be flagged as cloud metadata" + ); + } + + // -- detect_trusted_host_gateway -- + + #[test] + fn test_detect_trusted_host_gateway_returns_ip_from_hosts_content() { + // We test the underlying parser directly since detect_trusted_host_gateway + // reads the real /etc/hosts. The production code composes these same primitives. + let contents = "169.254.1.2\thost.openshell.internal host.containers.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + #[test] + fn test_detect_trusted_host_gateway_ignores_cloud_metadata_ip() { + // Simulate a /etc/hosts where the driver injected the cloud metadata IP — + // this should be caught and suppressed. + let contents = "169.254.169.254\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))]); + // is_cloud_metadata_ip should flag it, preventing the exemption. + assert!(is_cloud_metadata_ip(ips[0])); + } + + #[test] + fn test_detect_trusted_host_gateway_no_entry_returns_empty() { + let contents = "127.0.0.1 localhost\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert!(ips.is_empty()); + } + + #[test] + fn test_detect_trusted_host_gateway_rejects_loopback() { + // Loopback is not link-local — must not receive the SSRF exemption. + let ip = IpAddr::V4(Ipv4Addr::LOCALHOST); + assert!(!is_cloud_metadata_ip(ip)); + assert!(!is_link_local_ip(ip)); + // The guard: !link-local → reject. + assert!(!is_link_local_ip(ip)); + } + + #[test] + fn test_detect_trusted_host_gateway_rejects_unspecified() { + // Unspecified (0.0.0.0) is not link-local — must not be trusted. + let ip = IpAddr::V4(Ipv4Addr::UNSPECIFIED); + assert!(!is_cloud_metadata_ip(ip)); + assert!(!is_link_local_ip(ip)); + assert!(!is_link_local_ip(ip)); + } + + #[test] + fn test_detect_trusted_host_gateway_rejects_loopback_v6() { + let ip = IpAddr::V6(Ipv6Addr::LOCALHOST); + assert!(!is_cloud_metadata_ip(ip)); + assert!(!is_link_local_ip(ip)); + } + + #[test] + fn test_detect_trusted_host_gateway_rejects_private_ip() { + // Docker bridge (172.17.0.1) and K8s host gateway (192.168.x.x) are + // RFC 1918 private addresses — not link-local. Before this fix they + // slipped through the old always-blocked guard and received the SSRF + // exemption. The new guard (!is_link_local_ip) rejects them, so + // connections to these hosts fall through to resolve_and_reject_internal(). + for ip in [ + IpAddr::V4(Ipv4Addr::new(172, 17, 0, 1)), + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + ] { + assert!(!is_cloud_metadata_ip(ip), "{ip} should not be metadata"); + assert!(!is_link_local_ip(ip), "{ip} should not be link-local"); + // Guard fires — exemption disabled. + assert!(!is_link_local_ip(ip), "{ip}: guard must reject"); + } + } + + #[test] + fn test_detect_trusted_host_gateway_allows_link_local_non_metadata() { + // 169.254.1.2 (rootless Podman pasta gateway) IS link-local and is + // not a cloud metadata IP — it is the only address class the exemption + // is designed for. + let ip = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + assert!(!is_cloud_metadata_ip(ip)); + assert!(is_link_local_ip(ip)); + // Guard does NOT fire — this IP is eligible for the exemption. + assert!(is_link_local_ip(ip)); + } + + // -- parse_hosts_file_for_host: multi-entry / duplicate scenarios -- + + #[test] + fn test_parse_hosts_file_single_entry() { + // Normal driver-injected case: exactly one IP for the alias. + let contents = "169.254.1.2\thost.openshell.internal host.containers.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + #[test] + fn test_parse_hosts_file_duplicate_same_ip_deduplicated() { + // Same IP on two separate lines for the same alias — deduplicated to one. + let contents = "169.254.1.2\thost.openshell.internal\n\ + 169.254.1.2\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!( + ips, + vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))], + "identical IPs across lines must be deduplicated" + ); + } + + #[test] + fn test_parse_hosts_file_multiple_distinct_ips() { + // Two distinct IPs for the same alias — both returned, first entry wins + // in detect_trusted_host_gateway(), second would cause mismatch rejection + // in resolve_and_check_trusted_gateway(). + let contents = "169.254.1.2\thost.openshell.internal\n\ + 169.254.1.3\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips.len(), 2, "two distinct IPs must both be returned"); + assert_eq!(ips[0], IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))); + assert_eq!(ips[1], IpAddr::V4(Ipv4Addr::new(169, 254, 1, 3))); + } + + #[test] + fn test_parse_hosts_file_first_entry_wins_on_ambiguity() { + // detect_trusted_host_gateway() pins to the first entry via .next(). + // Verify the ordering guarantee: first line wins. + let contents = "169.254.1.3\thost.openshell.internal\n\ + 169.254.1.2\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!( + ips[0], + IpAddr::V4(Ipv4Addr::new(169, 254, 1, 3)), + "first line must be first in the returned vec" + ); + } + + #[test] + fn test_parse_hosts_file_ignores_other_aliases_on_same_line() { + // An entry with multiple aliases — only the matching alias counts. + let contents = + "169.254.1.2\thost.containers.internal host.openshell.internal host.docker.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + // Non-matching aliases on the same line do not produce extra entries. + let ips2 = parse_hosts_file_for_host(contents, "host.docker.internal"); + assert_eq!(ips2, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + #[test] + fn test_parse_hosts_file_alias_not_present() { + let contents = "127.0.0.1\tlocalhost\n\ + ::1\t\tlocalhost\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert!(ips.is_empty()); + } + + #[test] + fn test_parse_hosts_file_comment_lines_skipped() { + let contents = "# 169.254.1.2 host.openshell.internal\n\ + 169.254.1.2\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + // Commented-out line must not produce an entry. + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + #[test] + fn test_parse_hosts_file_inline_comment_stripped() { + // Anything after '#' on a data line is treated as a comment. + let contents = "169.254.1.2\thost.openshell.internal # injected by driver\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + // -- resolve_and_check_trusted_gateway -- + + #[tokio::test] + async fn test_trusted_gateway_allows_link_local_gateway_ip() { + // Simulate the rootless Podman pasta case: host.openshell.internal + // points to a link-local address which is the only path to the host. + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + + // We resolve via /etc/hosts (pid=0 falls back to system), so we + // exercise the trusted_gw mismatch / cloud-metadata guards directly + // against a known resolved address. + let addrs = [SocketAddr::new(trusted_gw, 8080)]; + + // Validate the guard logic inline (mirrors resolve_and_check_trusted_gateway). + assert!(!is_cloud_metadata_ip(trusted_gw)); + assert_eq!(addrs[0].ip(), trusted_gw); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_cloud_metadata_ip() { + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + let metadata_ip = IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)); + + // Simulate resolution returning the metadata IP. + let addrs = [SocketAddr::new(metadata_ip, 80)]; + + // Cloud metadata check must fire before the trusted_gw equality check. + let err: Result<(), String> = if is_cloud_metadata_ip(addrs[0].ip()) { + Err(format!( + "host resolves to cloud metadata address {}, connection rejected", + addrs[0].ip() + )) + } else if addrs[0].ip() != trusted_gw { + Err(format!( + "host resolves to {} which does not match trusted host gateway \ + {trusted_gw}, connection rejected", + addrs[0].ip() + )) + } else { + Ok(()) + }; + + assert!(err.is_err()); + assert!( + err.unwrap_err().contains("cloud metadata"), + "expected cloud-metadata rejection" + ); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_mismatched_ip() { + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + let other_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); + + let addrs = [SocketAddr::new(other_ip, 8080)]; + + let err: Result<(), String> = if is_cloud_metadata_ip(addrs[0].ip()) { + Err("cloud metadata".to_string()) + } else if addrs[0].ip() != trusted_gw { + Err(format!( + "{} does not match trusted host gateway {trusted_gw}", + addrs[0].ip() + )) + } else { + Ok(()) + }; + + assert!(err.is_err()); + assert!( + err.unwrap_err() + .contains("does not match trusted host gateway"), + "expected mismatch rejection" + ); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_control_plane_port() { + // Control-plane port check runs before resolution. + let result = resolve_and_check_trusted_gateway( + "host.openshell.internal", + 6443, + IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)), + 0, + ) + .await; + assert!(result.is_err()); + assert!( + result.unwrap_err().contains("blocked control-plane port"), + "expected control-plane port rejection" ); } - #[test] - fn test_resolve_from_hosts_file_contents_public_ip_passes_default_ssrf_check() { - let addrs = - resolve_from_hosts_file_contents("93.184.216.34 example.local\n", "example.local", 80); - assert!(reject_internal_resolved_addrs("example.local", &addrs).is_ok()); + #[tokio::test] + async fn test_trusted_gateway_rejects_all_control_plane_ports() { + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + for &port in BLOCKED_CONTROL_PLANE_PORTS { + let result = + resolve_and_check_trusted_gateway("host.openshell.internal", port, trusted_gw, 0) + .await; + assert!( + result.is_err(), + "port {port} should be blocked by control-plane guard" + ); + assert!( + result.unwrap_err().contains("blocked control-plane port"), + "expected control-plane rejection for port {port}" + ); + } } - #[test] - fn test_resolve_from_hosts_file_contents_private_ip_requires_allowed_ips() { - let addrs = resolve_from_hosts_file_contents( - "192.168.1.105 searxng.local\n", - "searxng.local", - 8080, + #[tokio::test] + async fn test_trusted_gateway_rejects_loopback_as_trusted_gw() { + // Defense-in-depth: even if detect_trusted_host_gateway somehow admitted + // a loopback IP, resolve_and_check_trusted_gateway must reject it. + // Using an IP literal as the host bypasses DNS and gives a deterministic + // resolved address, allowing us to exercise the actual function. + let loopback = IpAddr::V4(Ipv4Addr::LOCALHOST); + let result = resolve_and_check_trusted_gateway("127.0.0.1", 8080, loopback, 0).await; + assert!(result.is_err(), "loopback must be rejected"); + let err = result.unwrap_err(); + assert!( + err.contains("non-link-local"), + "expected non-link-local rejection, got: {err}" ); + } - let err = reject_internal_resolved_addrs("searxng.local", &addrs).unwrap_err(); + #[tokio::test] + async fn test_trusted_gateway_rejects_unspecified_as_trusted_gw() { + // Defense-in-depth: 0.0.0.0 as trusted_gw must be rejected. + // IP literal resolves to 0.0.0.0 directly, bypassing DNS. + let unspecified = IpAddr::V4(Ipv4Addr::UNSPECIFIED); + let result = resolve_and_check_trusted_gateway("0.0.0.0", 8080, unspecified, 0).await; + assert!(result.is_err(), "unspecified must be rejected"); + let err = result.unwrap_err(); assert!( - err.contains("internal address"), - "expected private hosts-file resolution to remain blocked: {err}" + err.contains("non-link-local"), + "expected non-link-local rejection, got: {err}" ); + } - let nets = parse_allowed_ips(&["192.168.1.105/32".to_string()]).unwrap(); + #[tokio::test] + async fn test_trusted_gateway_rejects_ip_literal_mismatch() { + // If the requested IP literal doesn't match trusted_gw, the mismatch + // guard fires. This exercises the full resolution→validation path. + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + let other_ip = "10.0.0.1"; // RFC1918, resolves as a literal + let result = resolve_and_check_trusted_gateway(other_ip, 8080, trusted_gw, 0).await; + assert!(result.is_err(), "IP mismatch must be rejected"); + let err = result.unwrap_err(); assert!( - validate_allowed_ips_for_resolved_addrs("searxng.local", 8080, &addrs, &nets).is_ok() + err.contains("does not match trusted host gateway"), + "expected mismatch rejection, got: {err}" ); } - #[test] - fn test_resolve_from_hosts_file_contents_always_blocked_ip_stays_blocked() { - let addrs = - resolve_from_hosts_file_contents("127.0.0.1 loopback.local\n", "loopback.local", 80); - let nets = vec!["127.0.0.0/8".parse::().unwrap()]; - let err = validate_allowed_ips_for_resolved_addrs("loopback.local", 80, &addrs, &nets) - .unwrap_err(); + #[tokio::test] + async fn test_trusted_gateway_rejects_cloud_metadata_literal() { + // Cloud metadata IP as a literal address — must be rejected even when + // it matches trusted_gw (which detect_trusted_host_gateway prevents, + // but this is the defense-in-depth layer). + let metadata = IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)); + let result = resolve_and_check_trusted_gateway("169.254.169.254", 80, metadata, 0).await; + assert!(result.is_err(), "cloud metadata IP must be rejected"); + let err = result.unwrap_err(); assert!( - err.contains("always-blocked"), - "expected always-blocked hosts-file resolution to stay blocked: {err}" + err.contains("cloud metadata"), + "expected cloud-metadata rejection, got: {err}" ); } - #[test] - fn test_resolve_from_hosts_file_contents_returns_empty_without_match() { - let result = - resolve_from_hosts_file_contents("192.168.1.105 searxng.local\n", "missing.local", 80); - assert!(result.is_empty()); + #[tokio::test] + async fn test_trusted_gateway_rejects_private_ip_as_trusted_gw() { + // Defense-in-depth: a private RFC 1918 IP (e.g. Docker bridge 172.17.0.1) + // must be rejected even if it somehow matched trusted_gw. + // detect_trusted_host_gateway() already blocks these via !is_link_local_ip(), + // but resolve_and_check_trusted_gateway() must enforce the same invariant. + let docker_bridge = IpAddr::V4(Ipv4Addr::new(172, 17, 0, 1)); + let result = resolve_and_check_trusted_gateway("172.17.0.1", 8080, docker_bridge, 0).await; + assert!(result.is_err(), "private RFC 1918 IP must be rejected"); + let err = result.unwrap_err(); + assert!( + err.contains("non-link-local"), + "expected non-link-local rejection for private IP, got: {err}" + ); } #[tokio::test] @@ -3666,6 +5011,184 @@ mod tests { assert!(!forwarded_lc.contains("cookie:")); } + fn streaming_inference_route(endpoint: String) -> openshell_router::config::ResolvedRoute { + openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint, + model: "meta/llama-3.1-8b-instruct".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["openai_chat_completions".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + } + } + + async fn read_forwarded_inference_request(stream: &mut S) { + use crate::l7::inference::{ParseResult, try_parse_http_request}; + + let mut buf = Vec::new(); + let mut chunk = [0u8; 4096]; + loop { + let n = stream.read(&mut chunk).await.unwrap(); + assert!(n > 0, "upstream request closed before completion"); + buf.extend_from_slice(&chunk[..n]); + + match try_parse_http_request(&buf) { + ParseResult::Complete(_, _) => return, + ParseResult::Incomplete => continue, + ParseResult::Invalid(reason) => { + panic!("forwarded request should parse cleanly: {reason}"); + } + } + } + } + + async fn run_live_streaming_inference(serve_upstream: F) -> String + where + F: FnOnce(TcpStream) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + read_forwarded_inference_request(&mut upstream).await; + serve_upstream(upstream).await; + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![streaming_inference_route(format!("http://{upstream_addr}"))], + vec![], + ); + + let body = r#"{"model":"ignored","messages":[{"role":"user","content":"hi"}]}"#; + let request = format!( + "POST /v1/chat/completions HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Type: application/json\r\n\ + Accept: text/event-stream\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body, + ); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + + let outcome = server_task.await.unwrap().unwrap(); + assert!( + matches!(outcome, InferenceOutcome::Routed), + "expected Routed outcome, got: {outcome:?}" + ); + upstream_task.await.unwrap(); + + String::from_utf8(response).unwrap() + } + + fn assert_streaming_sse_error(response: &str, message: &str) { + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n"), + "expected successful streaming response, got: {response}" + ); + assert!( + response + .to_ascii_lowercase() + .contains("transfer-encoding: chunked"), + "expected chunked streaming response, got: {response}" + ); + assert!( + response.contains("\"type\":\"proxy_stream_error\""), + "expected proxy_stream_error SSE event, got: {response}" + ); + assert!( + response.contains(&format!("\"message\":\"{message}\"")), + "expected SSE message {message:?}, got: {response}" + ); + assert!( + response.ends_with("0\r\n\r\n"), + "streaming response must end with chunked terminator, got: {response}" + ); + } + + #[tokio::test] + async fn inference_stream_byte_limit_injects_sse_error() { + let response = run_live_streaming_inference(|mut upstream| async move { + use crate::l7::inference::{format_chunk, format_chunk_terminator}; + + upstream + .write_all( + b"HTTP/1.1 200 OK\r\n\ + Content-Type: text/event-stream\r\n\ + Transfer-Encoding: chunked\r\n\r\n", + ) + .await + .unwrap(); + let body = vec![b'a'; MAX_STREAMING_BODY + 1]; + let _ = upstream.write_all(&format_chunk(&body)).await; + let _ = upstream.write_all(format_chunk_terminator()).await; + }) + .await; + + assert_streaming_sse_error( + &response, + "response truncated: exceeded maximum streaming body size", + ); + } + + #[tokio::test] + async fn inference_stream_upstream_read_error_injects_sse_error() { + let response = run_live_streaming_inference(|mut upstream| async move { + upstream + .write_all( + b"HTTP/1.1 200 OK\r\n\ + Content-Type: text/event-stream\r\n\ + Content-Length: 64\r\n\r\n\ + partial", + ) + .await + .unwrap(); + }) + .await; + + assert!( + response.contains("partial"), + "expected initial upstream bytes before truncation, got: {response}" + ); + assert_streaming_sse_error(&response, "response truncated: upstream read error"); + } + + #[tokio::test] + async fn inference_stream_idle_timeout_injects_sse_error() { + let response = run_live_streaming_inference(|mut upstream| async move { + upstream + .write_all( + b"HTTP/1.1 200 OK\r\n\ + Content-Type: text/event-stream\r\n\ + Transfer-Encoding: chunked\r\n\r\n", + ) + .await + .unwrap(); + tokio::time::sleep(CHUNK_IDLE_TIMEOUT + std::time::Duration::from_millis(50)).await; + }) + .await; + + assert_streaming_sse_error(&response, "response truncated: chunk idle timeout exceeded"); + } + // -- router_error_to_http -- #[test] @@ -4197,7 +5720,8 @@ mod tests { fn test_rewrite_get_request() { let raw = b"GET http://10.0.0.1:8000/api HTTP/1.1\r\nHost: 10.0.0.1:8000\r\nAccept: */*\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/api", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.starts_with("GET /api HTTP/1.1\r\n")); assert!(result_str.contains("Host: 10.0.0.1:8000")); @@ -4208,7 +5732,8 @@ mod tests { #[test] fn test_rewrite_strips_proxy_headers() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nProxy-Authorization: Basic abc\r\nProxy-Connection: keep-alive\r\nAccept: */*\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!( !result_str @@ -4222,7 +5747,8 @@ mod tests { #[test] fn test_rewrite_replaces_connection_header() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nConnection: keep-alive\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Connection: close")); assert!(!result_str.contains("keep-alive")); @@ -4231,7 +5757,8 @@ mod tests { #[test] fn test_rewrite_preserves_body_overflow() { let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 13\r\n\r\n{\"key\":\"val\"}"; - let result = rewrite_forward_request(raw, raw.len(), "/api", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("{\"key\":\"val\"}")); assert!(result_str.contains("POST /api HTTP/1.1")); @@ -4240,7 +5767,8 @@ mod tests { #[test] fn test_rewrite_preserves_existing_via() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nVia: 1.0 upstream\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Via: 1.0 upstream")); // Should not add a second Via header @@ -4263,7 +5791,7 @@ mod tests { .expect("canonicalization should succeed for the attack payload"); assert_eq!(canon.path, "/secret"); - let rewritten = rewrite_forward_request(raw, raw.len(), &canon.path, None) + let rewritten = rewrite_forward_request(raw, raw.len(), &canon.path, None, false) .expect("rewrite_forward_request should succeed"); let rewritten_str = String::from_utf8_lossy(&rewritten); assert!( @@ -4289,7 +5817,7 @@ mod tests { _ => canon.path, }; - let rewritten = rewrite_forward_request(raw, raw.len(), &upstream_target, None) + let rewritten = rewrite_forward_request(raw, raw.len(), &upstream_target, None, false) .expect("rewrite_forward_request should succeed"); let rewritten_str = String::from_utf8_lossy(&rewritten); assert!( @@ -4308,13 +5836,169 @@ mod tests { .collect(), ); let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nAuthorization: Bearer openshell:resolve:env:ANTHROPIC_API_KEY\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref()) + let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref(), false) .expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Authorization: Bearer sk-test")); assert!(!result_str.contains("openshell:resolve:env:ANTHROPIC_API_KEY")); } + #[tokio::test] + async fn forward_relay_rewrites_urlencoded_body_alias_from_initial_read() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = format!("token={alias}&channel=C123"); + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_forward_request_and_capture( + "POST", + "/api/messages", + raw.as_bytes(), + Some(&resolver), + true, + ) + .await + .expect("forward relay should rewrite credentials"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.starts_with("POST /api/messages HTTP/1.1\r\n")); + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); + } + + #[tokio::test] + async fn forward_relay_rewrites_urlencoded_canonical_body_from_initial_read() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = "token=openshell%3Aresolve%3Aenv%3AAPI_TOKEN&channel=C123"; + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_forward_request_and_capture( + "POST", + "/api/messages", + raw.as_bytes(), + Some(&resolver), + true, + ) + .await + .expect("forward relay should rewrite credentials"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("openshell%3Aresolve%3Aenv%3AAPI_TOKEN")); + assert!(!forwarded.contains("openshell:resolve:env:API_TOKEN")); + } + + #[tokio::test] + async fn forward_relay_unresolved_body_placeholder_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = "token=provider-OPENSHELL-RESOLVE-ENV-MISSING_TOKEN"; + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let guard = forward_test_guard(); + let rewritten = rewrite_forward_request( + raw.as_bytes(), + raw.len(), + "/api/messages", + Some(&resolver), + true, + ) + .expect("header rewrite should defer body overflow to body rewriter"); + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_rewritten_forward_request( + "POST", + "/api/messages", + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: Some(&resolver), + request_body_credential_rewrite: true, + }, + ) + .await + .expect_err("unresolved body placeholder should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + assert!(!err.to_string().contains("MISSING_TOKEN")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed forward body rewrite must not reach upstream" + ); + } + + #[test] + fn test_forward_rewrite_preserves_websocket_upgrade_connection_header() { + let raw = "GET http://gateway.example.test/ws HTTP/1.1\r\n\ + Host: gateway.example.test\r\n\ + Upgrade: websocket\r\n\ + Connection: keep-alive, Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\ + Sec-WebSocket-Version: 13\r\n\r\n"; + + let result = rewrite_forward_request(raw.as_bytes(), raw.len(), "/ws", None, false) + .expect("websocket forward rewrite should succeed"); + let result_str = String::from_utf8_lossy(&result); + + assert!(result_str.starts_with("GET /ws HTTP/1.1\r\n")); + assert!(result_str.contains("Connection: keep-alive, Upgrade\r\n")); + assert!( + !result_str.contains("Connection: close\r\n"), + "websocket forward proxy must not strip the upgrade token" + ); + } + #[tokio::test] async fn test_forward_relay_guard_blocks_stale_generation_before_upstream_write() { let policy = include_str!("../data/sandbox-policy.rego"); @@ -4326,8 +6010,8 @@ mod tests { engine.reload(policy, policy_data).unwrap(); let raw = b"GET http://host/api HTTP/1.1\r\nHost: host\r\n\r\n"; - let rewritten = - rewrite_forward_request(raw, raw.len(), "/api", None).expect("rewrite should succeed"); + let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) + .expect("rewrite should succeed"); let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); @@ -4337,7 +6021,12 @@ mod tests { rewritten, &mut proxy_to_client, &mut proxy_to_upstream, - &guard, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, ) .await; assert!( @@ -4364,8 +6053,8 @@ mod tests { .unwrap(); let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 4\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"; - let rewritten = - rewrite_forward_request(raw, raw.len(), "/api", None).expect("rewrite should succeed"); + let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) + .expect("rewrite should succeed"); let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); @@ -4375,7 +6064,12 @@ mod tests { rewritten, &mut proxy_to_client, &mut proxy_to_upstream, - &guard, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, ) .await; assert!(result.is_err(), "forward relay must reject CL/TE ambiguity"); @@ -4821,6 +6515,10 @@ mod tests { #[cfg(target_os = "linux")] #[test] + // TODO: exec'ing /bin/sleep (SELinux label bin_t) from a user_home_t test + // binary causes /proc//exe readlink to return ENOENT on + // SELinux-enforcing hosts. Fix by building a test-sleep-helper binary in + // the same crate so it inherits the user_home_t label. fn resolve_process_identity_denies_fork_exec_shared_socket_ambiguity() { use crate::identity::BinaryIdentityCache; use std::ffi::CString; @@ -4828,11 +6526,32 @@ mod tests { use std::os::fd::AsRawFd; use std::time::{Duration, Instant}; + struct ChildGuard(libc::pid_t); + impl Drop for ChildGuard { + fn drop(&mut self) { + #[allow(unsafe_code)] + unsafe { + libc::kill(self.0, libc::SIGKILL); + libc::waitpid(self.0, std::ptr::null_mut(), 0); + } + } + } + if !std::path::Path::new("/bin/sleep").exists() { eprintln!("skipping: /bin/sleep not available"); return; } + if std::process::Command::new("getenforce") + .output() + .is_ok_and(|o| String::from_utf8_lossy(&o.stdout).trim() == "Enforcing") + { + eprintln!( + "skipping: SELinux is enforcing — cross-label /proc//exe readlink fails" + ); + return; + } + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); let listener_port = listener.local_addr().unwrap().port(); let stream = TcpStream::connect(("127.0.0.1", listener_port)).expect("connect"); @@ -4873,7 +6592,10 @@ mod tests { } } - let deadline = Instant::now() + Duration::from_secs(2); + let _guard = ChildGuard(child_pid); + let entrypoint_pid = std::process::id(); + + let deadline = Instant::now() + Duration::from_secs(5); loop { if let Ok(link) = std::fs::read_link(format!("/proc/{child_pid}/exe")) && link.to_string_lossy().contains("sleep") @@ -4882,18 +6604,14 @@ mod tests { } assert!( Instant::now() < deadline, - "child pid {child_pid} did not exec into sleep within 2s" + "child pid {child_pid} did not exec into sleep within 5s" ); std::thread::sleep(Duration::from_millis(20)); } let cache = BinaryIdentityCache::new(); - // Resolve with a brief retry loop — under heavy CI load the child's - // procfs entry can momentarily fail to resolve even though the loop - // above just verified `/proc//exe` pointed at `sleep`. Retry a - // few times before declaring failure so the test is not flaky. - let mut result = resolve_process_identity(std::process::id(), peer_port, &cache); + let mut result = resolve_process_identity(entrypoint_pid, peer_port, &cache); for _ in 0..5 { match &result { Err(err) @@ -4901,19 +6619,12 @@ mod tests { || err.reason.contains("os error 2") => { std::thread::sleep(Duration::from_millis(50)); - result = resolve_process_identity(std::process::id(), peer_port, &cache); + result = resolve_process_identity(entrypoint_pid, peer_port, &cache); } _ => break, } } - // libc/syscall FFI requires unsafe - #[allow(unsafe_code)] - unsafe { - libc::kill(child_pid, libc::SIGKILL); - libc::waitpid(child_pid, std::ptr::null_mut(), 0); - } - match result { Ok(identity) => panic!( "resolve_process_identity unexpectedly succeeded for shared socket owned by PID {}", @@ -4926,7 +6637,7 @@ mod tests { err.reason ); assert!( - err.reason.contains(&std::process::id().to_string()), + err.reason.contains(&entrypoint_pid.to_string()), "error should include parent PID; got: {}", err.reason ); @@ -4938,4 +6649,40 @@ mod tests { } } } + + #[test] + fn test_emit_denial_enqueues_denial_event() { + let (tx, mut rx) = mpsc::unbounded_channel::(); + let decision = ConnectDecision { + action: NetworkAction::Deny { + reason: "no matching policy".into(), + }, + generation: 0, + binary: Some(PathBuf::from("/usr/bin/curl")), + binary_pid: Some(1234), + ancestors: vec![], + cmdline_paths: vec![], + }; + + emit_denial( + &Some(tx), + "blocked.invalid", + 443, + "/usr/bin/curl", + &decision, + "no matching policy", + "connect", + ); + + let event = rx + .try_recv() + .expect("DenialEvent should be enqueued after L4 deny"); + assert_eq!(event.host, "blocked.invalid"); + assert_eq!(event.port, 443); + assert_eq!(event.binary, "/usr/bin/curl"); + assert_eq!(event.denial_stage, "connect"); + assert_eq!(event.deny_reason, "no matching policy"); + assert!(event.l7_method.is_none()); + assert!(event.l7_path.is_none()); + } } diff --git a/crates/openshell-sandbox/src/sandbox/linux/landlock.rs b/crates/openshell-sandbox/src/sandbox/linux/landlock.rs index 214fc700a..6b121e0ca 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/landlock.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/landlock.rs @@ -119,6 +119,45 @@ pub fn prepare(policy: &SandboxPolicy, workdir: Option<&str>) -> Result { + openshell_ocsf::ocsf_emit!( + openshell_ocsf::DetectionFindingBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Open) + .severity(openshell_ocsf::SeverityId::High) + .confidence(openshell_ocsf::ConfidenceId::High) + .is_alert(true) + .finding_info( + openshell_ocsf::FindingInfo::new( + "landlock-unavailable", + "Landlock Filesystem Sandbox Unavailable", + ) + .with_desc(&format!( + "Running WITHOUT filesystem restrictions: Landlock is {availability}. \ + Set landlock.compatibility to 'hard_requirement' to make this fatal." + )), + ) + .message(format!( + "Landlock filesystem sandbox unavailable: {availability}" + )) + .build() + ); + return Ok(None); + } + LandlockCompatibility::HardRequirement => { + return Err(miette::miette!( + "Landlock unavailable in hard_requirement mode: {availability}" + )); + } + } + } + let total_paths = read_only.len() + read_write.len(); let abi = ABI::V2; openshell_ocsf::ocsf_emit!( @@ -135,8 +174,6 @@ pub fn prepare(policy: &SandboxPolicy, workdir: Option<&str>) -> Result = (|| { let access_all = AccessFs::from_all(abi); let access_read = AccessFs::from_read(abi); diff --git a/crates/openshell-sandbox/src/sandbox/linux/mod.rs b/crates/openshell-sandbox/src/sandbox/linux/mod.rs index 848ab1e3b..a3a32c77a 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/mod.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/mod.rs @@ -5,6 +5,7 @@ mod landlock; pub mod netns; +mod nft_ruleset; mod seccomp; use crate::policy::SandboxPolicy; diff --git a/crates/openshell-sandbox/src/sandbox/linux/netns.rs b/crates/openshell-sandbox/src/sandbox/linux/netns.rs index 019036e53..433f70b1c 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/netns.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/netns.rs @@ -242,7 +242,7 @@ impl NetworkNamespace { self.ns_fd } - /// Install iptables rules for bypass detection inside the namespace. + /// Install nftables rules for bypass detection inside the namespace. /// /// Sets up OUTPUT chain rules that: /// 1. ACCEPT traffic destined for the proxy (`host_ip:proxy_port`) @@ -253,22 +253,21 @@ impl NetworkNamespace { /// This provides two benefits: /// - **Fast-fail UX**: applications get immediate ECONNREFUSED instead of /// a 30-second timeout when they bypass the proxy - /// - **Diagnostics**: iptables LOG entries are picked up by the bypass + /// - **Diagnostics**: nftables LOG entries are picked up by the bypass /// monitor to emit structured tracing events /// - /// Degrades gracefully if `iptables` is not available — the namespace + /// Degrades gracefully if `nft` is not available — the namespace /// still provides isolation via routing, just without fast-fail and /// diagnostic logging. pub fn install_bypass_rules(&self, proxy_port: u16) -> Result<()> { - // Check if iptables is available before attempting to install rules. - let Some(iptables_path) = find_iptables() else { + let Some(nft_path) = find_nft() else { openshell_ocsf::ocsf_emit!( openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) .severity(openshell_ocsf::SeverityId::Medium) .status(openshell_ocsf::StatusId::Failure) .state(openshell_ocsf::StateId::Disabled, "degraded") .message(format!( - "iptables not found; bypass detection rules will not be installed [ns:{}]", + "nft not found; bypass detection rules will not be installed [ns:{}]", self.name )) .build() @@ -277,49 +276,53 @@ impl NetworkNamespace { }; let host_ip_str = self.host_ip.to_string(); - let proxy_port_str = proxy_port.to_string(); let log_prefix = format!("openshell:bypass:{}:", &self.name); - // "Installing bypass detection rules" is a transient step — skip OCSF. - // The completion event below covers the outcome. + // The kernel's nf_log_syslog module suppresses log output from + // non-init network namespaces by default. Enable it so the bypass + // monitor can see log entries from the sandbox namespace. + enable_nf_log_all_netns(); - // Install IPv4 rules - if let Err(e) = self.install_bypass_rules_for( - &iptables_path, + // Try combined ruleset with log rules first. Log rules must appear + // before reject rules in the chain so packets are logged before being + // rejected. If the kernel lacks nft_log support, fall back to the + // reject-only ruleset. + let ruleset_with_log = super::nft_ruleset::generate_bypass_ruleset( &host_ip_str, - &proxy_port_str, - &log_prefix, - ) { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) - .severity(openshell_ocsf::SeverityId::Medium) - .status(openshell_ocsf::StatusId::Failure) - .state(openshell_ocsf::StateId::Disabled, "failed") - .message(format!( - "Failed to install IPv4 bypass detection rules [ns:{}]: {e}", - self.name - )) - .build() - ); - return Err(e); - } + proxy_port, + Some(&log_prefix), + ); - // Install IPv6 rules — best-effort. - // Skip the proxy ACCEPT rule for IPv6 since the proxy address is IPv4. - if let Some(ip6_path) = find_ip6tables(&iptables_path) - && let Err(e) = self.install_bypass_rules_for_v6(&ip6_path, &log_prefix) - { + if let Err(e) = run_nft_netns(&self.name, &nft_path, &ruleset_with_log) { openshell_ocsf::ocsf_emit!( openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) .severity(openshell_ocsf::SeverityId::Low) .status(openshell_ocsf::StatusId::Failure) .state(openshell_ocsf::StateId::Other, "degraded") .message(format!( - "Failed to install IPv6 bypass detection rules (non-fatal) [ns:{}]: {e}", + "Failed to install bypass log rules (non-fatal), falling back to reject-only [ns:{}]: {e}", self.name )) .build() ); + + let ruleset_no_log = + super::nft_ruleset::generate_bypass_ruleset(&host_ip_str, proxy_port, None); + + if let Err(e) = run_nft_netns(&self.name, &nft_path, &ruleset_no_log) { + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Medium) + .status(openshell_ocsf::StatusId::Failure) + .state(openshell_ocsf::StateId::Disabled, "failed") + .message(format!( + "Failed to install bypass detection rules [ns:{}]: {e}", + self.name + )) + .build() + ); + return Err(e); + } } openshell_ocsf::ocsf_emit!( @@ -336,297 +339,6 @@ impl NetworkNamespace { Ok(()) } - - /// Install bypass detection rules for a specific iptables variant (iptables or ip6tables). - fn install_bypass_rules_for( - &self, - iptables_cmd: &str, - host_ip: &str, - proxy_port: &str, - log_prefix: &str, - ) -> Result<()> { - // Rule 1: ACCEPT traffic to the proxy - run_iptables_netns( - &self.name, - iptables_cmd, - &[ - "-A", - "OUTPUT", - "-d", - &format!("{host_ip}/32"), - "-p", - "tcp", - "--dport", - proxy_port, - "-j", - "ACCEPT", - ], - )?; - - // Rule 2: ACCEPT loopback traffic - run_iptables_netns( - &self.name, - iptables_cmd, - &["-A", "OUTPUT", "-o", "lo", "-j", "ACCEPT"], - )?; - - // Rule 3: ACCEPT established/related connections (response packets) - run_iptables_netns( - &self.name, - iptables_cmd, - &[ - "-A", - "OUTPUT", - "-m", - "conntrack", - "--ctstate", - "ESTABLISHED,RELATED", - "-j", - "ACCEPT", - ], - )?; - - // Rule 4: LOG TCP SYN bypass attempts (rate-limited) - // LOG rule failure is non-fatal — the REJECT rule still provides fast-fail. - if let Err(e) = run_iptables_netns( - &self.name, - iptables_cmd, - &[ - "-A", - "OUTPUT", - "-p", - "tcp", - "--syn", - "-m", - "limit", - "--limit", - "5/sec", - "--limit-burst", - "10", - "-j", - "LOG", - "--log-prefix", - log_prefix, - "--log-uid", - ], - ) { - openshell_ocsf::ocsf_emit!(openshell_ocsf::ConfigStateChangeBuilder::new( - crate::ocsf_ctx() - ) - .severity(openshell_ocsf::SeverityId::Low) - .status(openshell_ocsf::StatusId::Failure) - .state(openshell_ocsf::StateId::Other, "degraded") - .message(format!( - "Failed to install LOG rule for TCP (xt_LOG module may not be loaded) [ns:{}]: {e}", - self.name - )) - .build()); - } - - // Rule 5: REJECT TCP bypass attempts (fast-fail) - run_iptables_netns( - &self.name, - iptables_cmd, - &[ - "-A", - "OUTPUT", - "-p", - "tcp", - "-j", - "REJECT", - "--reject-with", - "icmp-port-unreachable", - ], - )?; - - // Rule 6: LOG UDP bypass attempts (rate-limited, covers DNS bypass) - if let Err(e) = run_iptables_netns( - &self.name, - iptables_cmd, - &[ - "-A", - "OUTPUT", - "-p", - "udp", - "-m", - "limit", - "--limit", - "5/sec", - "--limit-burst", - "10", - "-j", - "LOG", - "--log-prefix", - log_prefix, - "--log-uid", - ], - ) { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) - .severity(openshell_ocsf::SeverityId::Low) - .status(openshell_ocsf::StatusId::Failure) - .state(openshell_ocsf::StateId::Other, "degraded") - .message(format!( - "Failed to install LOG rule for UDP [ns:{}]: {e}", - self.name - )) - .build() - ); - } - - // Rule 7: REJECT UDP bypass attempts (covers DNS bypass) - run_iptables_netns( - &self.name, - iptables_cmd, - &[ - "-A", - "OUTPUT", - "-p", - "udp", - "-j", - "REJECT", - "--reject-with", - "icmp-port-unreachable", - ], - )?; - - Ok(()) - } - - /// Install IPv6 bypass detection rules. - /// - /// Similar to `install_bypass_rules_for` but omits the proxy ACCEPT rule - /// (the proxy listens on an IPv4 address) and uses IPv6-appropriate - /// REJECT types. - fn install_bypass_rules_for_v6(&self, ip6tables_cmd: &str, log_prefix: &str) -> Result<()> { - // ACCEPT loopback traffic - run_iptables_netns( - &self.name, - ip6tables_cmd, - &["-A", "OUTPUT", "-o", "lo", "-j", "ACCEPT"], - )?; - - // ACCEPT established/related connections - run_iptables_netns( - &self.name, - ip6tables_cmd, - &[ - "-A", - "OUTPUT", - "-m", - "conntrack", - "--ctstate", - "ESTABLISHED,RELATED", - "-j", - "ACCEPT", - ], - )?; - - // LOG TCP SYN bypass attempts (rate-limited) - if let Err(e) = run_iptables_netns( - &self.name, - ip6tables_cmd, - &[ - "-A", - "OUTPUT", - "-p", - "tcp", - "--syn", - "-m", - "limit", - "--limit", - "5/sec", - "--limit-burst", - "10", - "-j", - "LOG", - "--log-prefix", - log_prefix, - "--log-uid", - ], - ) { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) - .severity(openshell_ocsf::SeverityId::Low) - .status(openshell_ocsf::StatusId::Failure) - .state(openshell_ocsf::StateId::Other, "degraded") - .message(format!( - "Failed to install IPv6 LOG rule for TCP [ns:{}]: {e}", - self.name - )) - .build() - ); - } - - // REJECT TCP bypass attempts - run_iptables_netns( - &self.name, - ip6tables_cmd, - &[ - "-A", - "OUTPUT", - "-p", - "tcp", - "-j", - "REJECT", - "--reject-with", - "icmp6-port-unreachable", - ], - )?; - - // LOG UDP bypass attempts (rate-limited) - if let Err(e) = run_iptables_netns( - &self.name, - ip6tables_cmd, - &[ - "-A", - "OUTPUT", - "-p", - "udp", - "-m", - "limit", - "--limit", - "5/sec", - "--limit-burst", - "10", - "-j", - "LOG", - "--log-prefix", - log_prefix, - "--log-uid", - ], - ) { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) - .severity(openshell_ocsf::SeverityId::Low) - .status(openshell_ocsf::StatusId::Failure) - .state(openshell_ocsf::StateId::Other, "degraded") - .message(format!( - "Failed to install IPv6 LOG rule for UDP [ns:{}]: {e}", - self.name - )) - .build() - ); - } - - // REJECT UDP bypass attempts - run_iptables_netns( - &self.name, - ip6tables_cmd, - &[ - "-A", - "OUTPUT", - "-p", - "udp", - "-j", - "REJECT", - "--reject-with", - "icmp6-port-unreachable", - ], - )?; - - Ok(()) - } } impl Drop for NetworkNamespace { @@ -732,34 +444,43 @@ fn run_ip_netns(netns: &str, args: &[&str]) -> Result<()> { Ok(()) } -/// Run an iptables command inside a network namespace via `nsenter --net=`. +/// Load an nftables ruleset inside a network namespace via `nsenter --net=`. /// -/// Uses `nsenter` instead of `ip netns exec` to avoid the sysfs remount -/// that fails in rootless container runtimes. See `run_ip_netns` for details. -fn run_iptables_netns(netns: &str, iptables_cmd: &str, args: &[&str]) -> Result<()> { +/// Writes the ruleset to a temp file and loads it with `nft -f `. +/// A temp file is used instead of piping to stdin (`nft -f -`) because +/// `nft` resolves `-` to `/dev/stdin`, which may not exist in minimal +/// VM guest environments (e.g. virtiofs rootfs without /proc mounted +/// at nft invocation time). +fn run_nft_netns(netns: &str, nft_cmd: &str, ruleset: &str) -> Result<()> { + use std::io::Write; + let mut tmp = tempfile::Builder::new() + .prefix("openshell-nft-") + .suffix(".conf") + .tempfile() + .into_diagnostic()?; + tmp.write_all(ruleset.as_bytes()).into_diagnostic()?; + let ruleset_path = tmp.path().to_string_lossy().to_string(); + let nsenter_path = find_trusted_binary("nsenter", NSENTER_SEARCH_PATHS)?; let ns_path = format!("/var/run/netns/{netns}"); let net_flag = format!("--net={ns_path}"); - let mut full_args = vec![net_flag.as_str(), "--", iptables_cmd]; - full_args.extend(args); - debug!( - command = %format!("{nsenter_path} {}", full_args.join(" ")), - "Running iptables in namespace via nsenter" + command = %format!("{nsenter_path} {net_flag} -- {nft_cmd} -f {ruleset_path}"), + "Loading nftables ruleset in namespace" ); let output = Command::new(nsenter_path) - .args(&full_args) + .args([net_flag.as_str(), "--", nft_cmd, "-f", &ruleset_path]) .output() .into_diagnostic()?; + drop(tmp); + if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(miette::miette!( - "{nsenter_path} --net={} {} failed: {}", - ns_path, - iptables_cmd, + "nft ruleset load failed in netns {netns}: {}", stderr.trim() )); } @@ -767,11 +488,35 @@ fn run_iptables_netns(netns: &str, iptables_cmd: &str, args: &[&str]) -> Result< Ok(()) } -/// Well-known paths where iptables may be installed. -/// The sandbox container PATH often excludes `/usr/sbin`, so we probe -/// explicit paths rather than relying on `which`. -const IPTABLES_SEARCH_PATHS: &[&str] = - &["/usr/sbin/iptables", "/sbin/iptables", "/usr/bin/iptables"]; +const NF_LOG_ALL_NETNS_PATH: &str = "/proc/sys/net/netfilter/nf_log_all_netns"; + +/// Enable nftables logging from non-init network namespaces. +/// +/// The kernel's `nf_log_syslog` module silently suppresses log output from +/// non-init network namespaces unless `net.netfilter.nf_log_all_netns` is +/// set to 1. Since sandbox bypass rules live in a per-sandbox network +/// namespace, the bypass monitor can't see log entries without this. +fn enable_nf_log_all_netns() { + use std::path::Path; + if !Path::new(NF_LOG_ALL_NETNS_PATH).exists() { + debug!("nf_log_all_netns sysctl not available (may already be set by init)"); + return; + } + match std::fs::write(NF_LOG_ALL_NETNS_PATH, "1") { + Ok(()) => { + debug!("Enabled nf_log_all_netns for non-init namespace logging"); + } + Err(e) => { + debug!( + error = %e, + "Could not enable nf_log_all_netns; bypass log rules may not produce output" + ); + } + } +} + +/// Well-known paths where nft may be installed. +const NFT_SEARCH_PATHS: &[&str] = &["/usr/sbin/nft", "/sbin/nft", "/usr/bin/nft"]; fn find_trusted_binary<'a>(name: &str, paths: &'a [&str]) -> Result<&'a str> { paths @@ -789,100 +534,11 @@ fn find_trusted_binary<'a>(name: &str, paths: &'a [&str]) -> Result<&'a str> { }) } -/// Returns true if xt extension modules (e.g. `xt_comment`) cannot be used -/// via the given iptables binary. -/// -/// Some kernels have `nf_tables` but lack the `nft_compat` bridge that allows -/// xt extension modules to be used through the `nf_tables` path (e.g. Jetson -/// Linux 5.15-tegra). This probe detects that condition by attempting to -/// insert a rule using the `xt_comment` extension. If it fails, xt extensions -/// are unavailable and the caller should fall back to iptables-legacy. -fn xt_extensions_unavailable(iptables_path: &str) -> bool { - // Create a temporary probe chain. If this fails (e.g. no CAP_NET_ADMIN), - // we can't determine availability — assume extensions are available. - let created = Command::new(iptables_path) - .args(["-t", "filter", "-N", "_xt_probe"]) - .output() - .is_ok_and(|o| o.status.success()); - - if !created { - return false; - } - - // Attempt to insert a rule using xt_comment. Failure means nft_compat - // cannot bridge xt extension modules on this kernel. - let probe_ok = Command::new(iptables_path) - .args([ - "-t", - "filter", - "-A", - "_xt_probe", - "-m", - "comment", - "--comment", - "probe", - "-j", - "ACCEPT", - ]) - .output() - .is_ok_and(|o| o.status.success()); - - // Clean up — best-effort, ignore failures. - let _ = Command::new(iptables_path) - .args([ - "-t", - "filter", - "-D", - "_xt_probe", - "-m", - "comment", - "--comment", - "probe", - "-j", - "ACCEPT", - ]) - .output(); - let _ = Command::new(iptables_path) - .args(["-t", "filter", "-X", "_xt_probe"]) - .output(); - - !probe_ok -} - -/// Find the iptables binary path, checking well-known locations. -/// -/// If xt extension modules are unavailable via the standard binary and -/// `iptables-legacy` is available alongside it, the legacy binary is returned -/// instead. This ensures bypass-detection rules can be installed on kernels -/// where `nft_compat` is unavailable (e.g. Jetson Linux 5.15-tegra). -fn find_iptables() -> Option { - let standard_path = IPTABLES_SEARCH_PATHS - .iter() - .find(|path| Path::new(path).exists()) - .copied()?; - - if xt_extensions_unavailable(standard_path) { - let legacy_path = standard_path.replace("iptables", "iptables-legacy"); - if Path::new(&legacy_path).exists() { - debug!( - legacy = legacy_path, - "xt extensions unavailable; using iptables-legacy" - ); - return Some(legacy_path); - } - } - - Some(standard_path.to_string()) -} - -/// Find the ip6tables binary path, deriving it from the iptables location. -fn find_ip6tables(iptables_path: &str) -> Option { - let ip6_path = iptables_path.replace("iptables", "ip6tables"); - if Path::new(&ip6_path).exists() { - Some(ip6_path) - } else { - None - } +/// Find the nft binary path, checking well-known locations. +fn find_nft() -> Option { + find_trusted_binary("nft", NFT_SEARCH_PATHS) + .ok() + .map(String::from) } #[cfg(test)] @@ -914,6 +570,16 @@ mod tests { assert!(err.to_string().contains("trusted nsenter helper not found")); } + #[test] + fn nft_search_paths_are_absolute() { + for path in NFT_SEARCH_PATHS { + assert!( + path.starts_with('/'), + "NFT_SEARCH_PATHS entry must be absolute: {path}" + ); + } + } + #[test] #[ignore = "requires root privileges"] fn test_create_and_drop_namespace() { diff --git a/crates/openshell-sandbox/src/sandbox/linux/nft_ruleset.rs b/crates/openshell-sandbox/src/sandbox/linux/nft_ruleset.rs new file mode 100644 index 000000000..ba7aeb936 --- /dev/null +++ b/crates/openshell-sandbox/src/sandbox/linux/nft_ruleset.rs @@ -0,0 +1,148 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! nftables ruleset generation for sandbox network bypass enforcement. +//! +//! This module provides pure functions to generate nftables rulesets that enforce +//! the sandbox network policy: all traffic must go through the proxy, with bypass +//! attempts logged and rejected. + +/// Generate a complete nftables ruleset for sandbox network bypass enforcement. +/// +/// Creates an `inet` family table (handles both IPv4 and IPv6) with rules that: +/// 1. Accept traffic to the proxy (IPv4 only) +/// 2. Accept loopback traffic +/// 3. Accept established/related connections +/// 4. Reject TCP and UDP bypass attempts (both IPv4 and IPv6) +/// +/// If `log_prefix` is provided, log rules are inserted before each reject rule +/// so that bypass attempts are recorded in the kernel ring buffer before being +/// rejected. The `log` expression requires kernel `nft_log` module support; +/// pass `None` for `log_prefix` as a fallback when that module is unavailable. +pub fn generate_bypass_ruleset(host_ip: &str, proxy_port: u16, log_prefix: Option<&str>) -> String { + let log_tcp = log_prefix + .map(|p| { + format!( + "\n tcp flags syn limit rate 5/second burst 10 packets log prefix \"{p}\" flags skuid" + ) + }) + .unwrap_or_default(); + let log_udp = log_prefix + .map(|p| { + format!( + "\n meta l4proto udp limit rate 5/second burst 10 packets log prefix \"{p}\" flags skuid" + ) + }) + .unwrap_or_default(); + + format!( + r#"table inet openshell_bypass {{ + chain output {{ + type filter hook output priority 0; policy accept; + + ip daddr {host_ip} tcp dport {proxy_port} accept + oifname "lo" accept + ct state established,related accept{log_tcp} + meta nfproto ipv4 meta l4proto tcp reject with icmp type port-unreachable + meta nfproto ipv6 meta l4proto tcp reject with icmpv6 type port-unreachable{log_udp} + meta nfproto ipv4 meta l4proto udp reject with icmp type port-unreachable + meta nfproto ipv6 meta l4proto udp reject with icmpv6 type port-unreachable + }} +}} +"# + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generates_bypass_ruleset_with_proxy_rule() { + let ruleset = generate_bypass_ruleset("10.0.2.2", 8080, None); + assert!(ruleset.contains("table inet openshell_bypass")); + assert!(ruleset.contains("chain output")); + assert!(ruleset.contains("ip daddr 10.0.2.2 tcp dport 8080 accept")); + } + + #[test] + fn ruleset_has_inet_family_table_and_output_chain() { + let ruleset = generate_bypass_ruleset("192.168.1.1", 3128, None); + assert!(ruleset.contains("table inet openshell_bypass")); + assert!(ruleset.contains("type filter hook output priority 0; policy accept;")); + } + + #[test] + fn proxy_accept_rule_uses_provided_ip_and_port() { + let ruleset = generate_bypass_ruleset("172.16.0.1", 9999, None); + assert!(ruleset.contains("ip daddr 172.16.0.1 tcp dport 9999 accept")); + } + + #[test] + fn rules_are_ordered_accept_then_reject() { + let ruleset = generate_bypass_ruleset("10.0.2.2", 8080, None); + let proxy_pos = ruleset.find("ip daddr").unwrap(); + let lo_pos = ruleset.find("oifname \"lo\"").unwrap(); + let ct_pos = ruleset.find("ct state established,related").unwrap(); + let reject_pos = ruleset.find("reject with icmp type").unwrap(); + + assert!(proxy_pos < lo_pos); + assert!(lo_pos < ct_pos); + assert!(ct_pos < reject_pos); + } + + #[test] + fn both_ipv4_and_ipv6_reject_types_are_present() { + let ruleset = generate_bypass_ruleset("10.0.2.2", 8080, None); + let icmp_count = ruleset + .matches("reject with icmp type port-unreachable") + .count(); + let icmpv6_count = ruleset + .matches("reject with icmpv6 type port-unreachable") + .count(); + assert_eq!(icmp_count, 2, "need IPv4 ICMP rejects for TCP + UDP"); + assert_eq!(icmpv6_count, 2, "need IPv6 ICMPv6 rejects for TCP + UDP"); + } + + #[test] + fn no_log_ruleset_omits_log_rules() { + let ruleset = generate_bypass_ruleset("10.0.2.2", 8080, None); + assert!( + !ruleset.contains("log prefix"), + "no-log ruleset must not contain log rules" + ); + } + + #[test] + fn log_ruleset_contains_prefix_for_tcp_and_udp() { + let ruleset = generate_bypass_ruleset("10.0.2.2", 8080, Some("openshell:bypass:test:")); + let count = ruleset + .matches("log prefix \"openshell:bypass:test:\"") + .count(); + assert_eq!(count, 2, "need log rules for both TCP and UDP"); + assert!(ruleset.contains("tcp flags syn limit rate 5/second burst 10 packets")); + assert!(ruleset.contains("meta l4proto udp limit rate 5/second burst 10 packets")); + } + + #[test] + fn log_rules_appear_before_reject_rules() { + let ruleset = generate_bypass_ruleset("10.0.2.2", 8080, Some("openshell:bypass:test:")); + let tcp_log_pos = ruleset.find("tcp flags syn").unwrap(); + let tcp_reject_pos = ruleset + .find("meta nfproto ipv4 meta l4proto tcp reject") + .unwrap(); + let udp_log_pos = ruleset.find("meta l4proto udp limit rate").unwrap(); + let udp_reject_pos = ruleset + .find("meta nfproto ipv4 meta l4proto udp reject") + .unwrap(); + + assert!( + tcp_log_pos < tcp_reject_pos, + "TCP log rule must come before TCP reject rule" + ); + assert!( + udp_log_pos < udp_reject_pos, + "UDP log rule must come before UDP reject rule" + ); + } +} diff --git a/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs b/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs index f61464023..1044623f5 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs @@ -25,6 +25,16 @@ use tracing::debug; /// Value of `SECCOMP_SET_MODE_FILTER` (linux/seccomp.h). const SECCOMP_SET_MODE_FILTER: u64 = 1; +// libc 0.2.185 omits `SYS_kexec_file_load` from the musl/aarch64 bindings even +// though the kernel exposes syscall 294. Fall back to the literal so the +// supervisor's seccomp filter still blocks fileless kernel-image loads when +// built statically against musl on aarch64. +#[cfg(all(target_arch = "aarch64", target_env = "musl"))] +#[allow(non_upper_case_globals)] +const SYS_kexec_file_load: libc::c_long = 294; +#[cfg(not(all(target_arch = "aarch64", target_env = "musl")))] +use libc::SYS_kexec_file_load; + /// Apply the supervisor seccomp filter across the running process. /// /// This runs after privileged startup helpers complete and synchronizes the @@ -81,7 +91,7 @@ fn build_supervisor_prelude_rules() -> BTreeMap> { libc::SYS_finit_module, libc::SYS_delete_module, libc::SYS_kexec_load, - libc::SYS_kexec_file_load, + SYS_kexec_file_load, ] { rules.entry(syscall).or_default(); } @@ -423,7 +433,7 @@ mod tests { libc::SYS_finit_module, libc::SYS_delete_module, libc::SYS_kexec_load, - libc::SYS_kexec_file_load, + SYS_kexec_file_load, ] { assert!( filter_rules.contains_key(&syscall), diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index 63e253e50..de7804393 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -6,9 +6,11 @@ use std::collections::HashMap; use std::fmt; const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; +const PROVIDER_ALIAS_MARKER: &str = "OPENSHELL-RESOLVE-ENV-"; /// Public access to the placeholder prefix for fail-closed scanning in other modules. pub const PLACEHOLDER_PREFIX_PUBLIC: &str = PLACEHOLDER_PREFIX; +pub const PROVIDER_ALIAS_MARKER_PUBLIC: &str = PROVIDER_ALIAS_MARKER; /// Characters that are valid in an env var key name (used to extract /// placeholder boundaries within concatenated strings like path segments). @@ -16,6 +18,29 @@ fn is_env_key_char(b: u8) -> bool { b.is_ascii_alphanumeric() || b == b'_' } +fn is_alias_token_char(b: u8) -> bool { + b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b'~') +} + +fn contains_raw_reserved_marker(value: &str) -> bool { + value.contains(PLACEHOLDER_PREFIX) || value.contains(PROVIDER_ALIAS_MARKER) +} + +fn current_time_ms() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|duration| i64::try_from(duration.as_millis()).unwrap_or(i64::MAX)) + .unwrap_or_default() +} + +pub fn contains_reserved_credential_marker(value: &str) -> bool { + if contains_raw_reserved_marker(value) { + return true; + } + let decoded = percent_decode(value); + contains_raw_reserved_marker(&decoded) +} + // --------------------------------------------------------------------------- // Error and result types // --------------------------------------------------------------------------- @@ -31,7 +56,7 @@ impl fmt::Display for UnresolvedPlaceholderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "unresolved credential placeholder in {}: detected openshell:resolve:env:* token that could not be resolved", + "unresolved credential placeholder in {}: detected reserved credential token that could not be resolved", self.location ) } @@ -64,14 +89,91 @@ pub struct RewriteTargetResult { // SecretResolver // --------------------------------------------------------------------------- -#[derive(Debug, Clone, Default)] +#[derive(Clone, Default)] pub struct SecretResolver { - by_placeholder: HashMap, + by_placeholder: HashMap, +} + +#[derive(Clone)] +struct SecretValue { + value: String, + expires_at_ms: i64, +} + +// Manual `Debug` impl: the auto-derived `Debug` would format the +// `by_placeholder` map, exposing both placeholder keys (which reveal which +// provider env var names are configured) and the resolved secret values +// themselves. Any accidental `{:?}` in a tracing call, or a +// derived `Debug` on a containing struct, would write secrets to logs. +// +// We expose only the count of registered placeholders without leaking anything. +impl fmt::Debug for SecretResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SecretResolver") + .field("placeholders", &self.by_placeholder.len()) + .finish_non_exhaustive() // Use to show that the struct is not empty + } } impl SecretResolver { + #[cfg_attr(not(test), allow(dead_code))] pub(crate) fn from_provider_env( provider_env: HashMap, + ) -> (HashMap, Option) { + Self::from_provider_env_for_revision(provider_env, HashMap::new(), 0) + } + + pub(crate) fn from_provider_env_for_revision( + provider_env: HashMap, + credential_expires_at_ms: HashMap, + revision: u64, + ) -> (HashMap, Option) { + Self::from_provider_env_for_revision_with_current_aliases( + provider_env, + credential_expires_at_ms, + revision, + false, + ) + } + + pub(crate) fn from_provider_env_for_current_revision( + provider_env: HashMap, + credential_expires_at_ms: HashMap, + revision: u64, + ) -> (HashMap, Option, Option) { + if revision == 0 { + let (child_env, current_resolver) = + Self::from_provider_env_for_revision_with_current_aliases( + provider_env, + credential_expires_at_ms, + 0, + true, + ); + return (child_env, None, current_resolver); + } + let provider_env_for_current = provider_env.clone(); + let credential_expires_at_ms_for_current = credential_expires_at_ms.clone(); + let (child_env, revision_resolver) = + Self::from_provider_env_for_revision_with_current_aliases( + provider_env, + credential_expires_at_ms, + revision, + false, + ); + let (_, current_resolver) = Self::from_provider_env_for_revision_with_current_aliases( + provider_env_for_current, + credential_expires_at_ms_for_current, + revision, + true, + ); + (child_env, revision_resolver, current_resolver) + } + + fn from_provider_env_for_revision_with_current_aliases( + provider_env: HashMap, + credential_expires_at_ms: HashMap, + revision: u64, + include_current_aliases: bool, ) -> (HashMap, Option) { if provider_env.is_empty() { return (HashMap::new(), None); @@ -81,21 +183,56 @@ impl SecretResolver { let mut by_placeholder = HashMap::with_capacity(provider_env.len()); for (key, value) in provider_env { - let placeholder = placeholder_for_env_key(&key); - child_env.insert(key, placeholder.clone()); - by_placeholder.insert(placeholder, value); + let placeholder = placeholder_for_env_key_for_revision(&key, revision); + let secret = SecretValue { + value, + expires_at_ms: credential_expires_at_ms + .get(&key) + .copied() + .unwrap_or_default(), + }; + child_env.insert(key.clone(), placeholder.clone()); + by_placeholder.insert(placeholder, secret.clone()); + if include_current_aliases && revision != 0 { + by_placeholder.insert(placeholder_for_env_key(&key), secret.clone()); + } } (child_env, Some(Self { by_placeholder })) } + pub(crate) fn merge<'a>(resolvers: impl IntoIterator) -> Option { + let mut by_placeholder = HashMap::new(); + for resolver in resolvers { + by_placeholder.extend(resolver.by_placeholder.clone()); + } + if by_placeholder.is_empty() { + None + } else { + Some(Self { by_placeholder }) + } + } + /// Resolve a placeholder string to the real secret value. /// /// Returns `None` if the placeholder is unknown or the resolved value /// contains prohibited control characters (CRLF, null byte). pub(crate) fn resolve_placeholder(&self, value: &str) -> Option<&str> { - let secret = self.by_placeholder.get(value).map(String::as_str)?; - match validate_resolved_secret(secret) { + let secret = if let Some(secret) = self.by_placeholder.get(value) { + secret + } else { + let key = alias_env_key(value)?; + let canonical = placeholder_for_env_key(key); + self.by_placeholder.get(&canonical)? + }; + if secret.expires_at_ms > 0 && secret.expires_at_ms <= current_time_ms() { + tracing::warn!( + location = "resolve_placeholder", + "credential resolution rejected: credential is expired" + ); + return None; + } + match validate_resolved_secret(&secret.value) { Ok(s) => Some(s), Err(reason) => { tracing::warn!( @@ -108,10 +245,13 @@ impl SecretResolver { } } - pub(crate) fn rewrite_header_value(&self, value: &str) -> Option { + pub(crate) fn rewrite_header_value( + &self, + value: &str, + ) -> Result, UnresolvedPlaceholderError> { // Direct placeholder match: `x-api-key: openshell:resolve:env:KEY` if let Some(secret) = self.resolve_placeholder(value.trim()) { - return Some(secret.to_string()); + return Ok(Some(secret.to_string())); } let trimmed = value.trim(); @@ -122,62 +262,242 @@ impl SecretResolver { .strip_prefix("Basic ") .or_else(|| trimmed.strip_prefix("basic ")) .map(str::trim) - && let Some(rewritten) = self.rewrite_basic_auth_token(encoded) + && let Some(rewritten) = self.rewrite_basic_auth_token(encoded)? { - return Some(format!("Basic {rewritten}")); + return Ok(Some(format!("Basic {rewritten}"))); } // Prefixed placeholder: `Bearer openshell:resolve:env:KEY` - let split_at = trimmed.find(char::is_whitespace)?; + let Some(split_at) = trimmed.find(char::is_whitespace) else { + if contains_reserved_credential_marker(trimmed) { + return Err(UnresolvedPlaceholderError { location: "header" }); + } + return Ok(None); + }; let prefix = &trimmed[..split_at]; let candidate = trimmed[split_at..].trim(); - let secret = self.resolve_placeholder(candidate)?; - Some(format!("{prefix} {secret}")) + if let Some(secret) = self.resolve_placeholder(candidate) { + return Ok(Some(format!("{prefix} {secret}"))); + } + + if contains_reserved_credential_marker(candidate) { + return Err(UnresolvedPlaceholderError { location: "header" }); + } + + Ok(None) + } + + pub(crate) fn rewrite_text_placeholders( + &self, + text: &mut String, + location: &'static str, + ) -> Result { + if !contains_raw_reserved_marker(text) { + return Ok(0); + } + + let mut rewritten = String::with_capacity(text.len()); + let mut pos = 0; + let mut replacements = 0; + + while pos < text.len() { + let next_canonical = text[pos..].find(PLACEHOLDER_PREFIX).map(|p| pos + p); + let next_alias = text[pos..].find(PROVIDER_ALIAS_MARKER).map(|marker_pos| { + let marker_abs = pos + marker_pos; + alias_start_for_marker(text, marker_abs) + }); + let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() else { + rewritten.push_str(&text[pos..]); + break; + }; + + rewritten.push_str(&text[pos..abs_start]); + + if text[abs_start..].starts_with(PLACEHOLDER_PREFIX) { + let Some((token_end, token)) = self.credential_token_at(text, abs_start) else { + return Err(UnresolvedPlaceholderError { location }); + }; + let Some(secret) = self.resolve_placeholder(token) else { + return Err(UnresolvedPlaceholderError { location }); + }; + rewritten.push_str(secret); + replacements += 1; + pos = token_end; + continue; + } + + if let Some((token_end, token)) = alias_token_at(text, abs_start) { + let Some(secret) = self.resolve_placeholder(token) else { + return Err(UnresolvedPlaceholderError { location }); + }; + rewritten.push_str(secret); + replacements += 1; + pos = token_end; + continue; + } + + return Err(UnresolvedPlaceholderError { location }); + } + + if contains_raw_reserved_marker(&rewritten) { + return Err(UnresolvedPlaceholderError { location }); + } + + *text = rewritten; + Ok(replacements) + } + + /// Rewrite credential placeholders inside a WebSocket text message. + /// + /// The message is mutated only after all placeholders resolve + /// successfully. The return value is the number of replacements; callers + /// must not log the rewritten text. + pub(crate) fn rewrite_websocket_text_placeholders( + &self, + text: &mut String, + ) -> Result { + self.rewrite_text_placeholders(text, "websocket") + } + + fn credential_token_at<'a>( + &'a self, + text: &'a str, + abs_start: usize, + ) -> Option<(usize, &'a str)> { + self.longest_known_token_match(text, abs_start) + .or_else(|| canonical_token_at(text, abs_start)) + .or_else(|| alias_token_at(text, abs_start)) + } + + fn longest_known_token_match<'a>( + &'a self, + text: &str, + abs_start: usize, + ) -> Option<(usize, &'a str)> { + let suffix = &text[abs_start..]; + self.by_placeholder + .keys() + .filter_map(|placeholder| { + if !suffix.starts_with(placeholder) { + return None; + } + let key_end = abs_start + placeholder.len(); + let boundary_ok = token_boundary_ok(text, abs_start, key_end, placeholder); + boundary_ok.then_some((key_end, placeholder.as_str())) + }) + .max_by_key(|(_, placeholder)| placeholder.len()) } /// Decode a Base64-encoded Basic auth token, resolve any placeholders in /// the decoded `username:password` string, and re-encode. /// /// Returns `None` if decoding fails or no placeholders are found. - fn rewrite_basic_auth_token(&self, encoded: &str) -> Option { + fn rewrite_basic_auth_token( + &self, + encoded: &str, + ) -> Result, UnresolvedPlaceholderError> { let b64 = base64::engine::general_purpose::STANDARD; - let decoded_bytes = b64.decode(encoded.trim()).ok()?; - let decoded = std::str::from_utf8(&decoded_bytes).ok()?; - - // Check if the decoded string contains any placeholder - if !decoded.contains(PLACEHOLDER_PREFIX) { - return None; + let Some(decoded_bytes) = b64.decode(encoded.trim()).ok() else { + return Ok(None); + }; + let Some(decoded) = std::str::from_utf8(&decoded_bytes).ok() else { + return Ok(None); + }; + + if !contains_raw_reserved_marker(decoded) { + return Ok(None); } - // Rewrite all placeholder occurrences in the decoded string let mut rewritten = decoded.to_string(); - for (placeholder, secret) in &self.by_placeholder { - if rewritten.contains(placeholder.as_str()) { - // Validate the resolved secret for control characters - if validate_resolved_secret(secret).is_err() { - tracing::warn!( - location = "basic_auth", - "credential resolution rejected: resolved value contains prohibited characters" - ); - return None; - } - rewritten = rewritten.replace(placeholder.as_str(), secret); - } - } + let replacements = self.rewrite_text_placeholders(&mut rewritten, "header")?; - // Only return if we actually changed something - if rewritten == decoded { - return None; + if replacements == 0 { + return Ok(None); } - Some(b64.encode(rewritten.as_bytes())) + Ok(Some(b64.encode(rewritten.as_bytes()))) + } +} + +fn alias_start_for_marker(text: &str, marker_abs: usize) -> usize { + let mut start = marker_abs; + let bytes = text.as_bytes(); + while start > 0 && is_alias_token_char(bytes[start - 1]) { + start -= 1; + } + start +} + +fn canonical_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { + if !text[abs_start..].starts_with(PLACEHOLDER_PREFIX) { + return None; } + let key_start = abs_start + PLACEHOLDER_PREFIX.len(); + let key_end = text[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(text.len(), |p| key_start + p); + (key_end > key_start).then_some((key_end, &text[abs_start..key_end])) +} + +fn alias_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { + let suffix = &text[abs_start..]; + let marker_rel = suffix.find(PROVIDER_ALIAS_MARKER)?; + if marker_rel == 0 { + return None; + } + let key_start = abs_start + marker_rel + PROVIDER_ALIAS_MARKER.len(); + let key_end = text[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(text.len(), |p| key_start + p); + if key_end == key_start { + return None; + } + let before_ok = abs_start == 0 || !is_alias_token_char(text.as_bytes()[abs_start - 1]); + let after_ok = key_end == text.len() || !is_alias_token_char(text.as_bytes()[key_end]); + (before_ok && after_ok).then_some((key_end, &text[abs_start..key_end])) +} + +fn alias_env_key(token: &str) -> Option<&str> { + let marker_start = token.find(PROVIDER_ALIAS_MARKER)?; + if marker_start == 0 { + return None; + } + if !token[..marker_start].bytes().all(is_alias_token_char) { + return None; + } + let key_start = marker_start + PROVIDER_ALIAS_MARKER.len(); + let key_end = token[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(token.len(), |p| key_start + p); + (key_end == token.len() && key_end > key_start).then_some(&token[key_start..key_end]) +} + +fn token_boundary_ok(text: &str, abs_start: usize, token_end: usize, token: &str) -> bool { + if token.starts_with(PLACEHOLDER_PREFIX) { + return token_end == text.len() + || !is_env_key_char(text.as_bytes()[token_end]) + || text[token_end..].starts_with(PLACEHOLDER_PREFIX); + } + let before_ok = abs_start == 0 || !is_alias_token_char(text.as_bytes()[abs_start - 1]); + let after_ok = token_end == text.len() || !is_alias_token_char(text.as_bytes()[token_end]); + before_ok && after_ok } pub fn placeholder_for_env_key(key: &str) -> String { format!("{PLACEHOLDER_PREFIX}{key}") } +pub fn placeholder_for_env_key_for_revision(key: &str, revision: u64) -> String { + if revision == 0 { + placeholder_for_env_key(key) + } else { + format!("{PLACEHOLDER_PREFIX}v{revision}_{key}") + } +} + // --------------------------------------------------------------------------- // Secret validation (F1 — CWE-113) // --------------------------------------------------------------------------- @@ -359,8 +679,9 @@ fn rewrite_request_line( return unchanged(); }; - // Only rewrite if the URI contains a placeholder - if !uri.contains(PLACEHOLDER_PREFIX) { + // Only rewrite if the URI contains a placeholder or a provider-shaped + // credential alias, including percent-encoded canonical placeholders. + if !contains_reserved_credential_marker(uri) { return unchanged(); } @@ -416,10 +737,6 @@ fn rewrite_uri_path( path: &str, resolver: &SecretResolver, ) -> Result, UnresolvedPlaceholderError> { - if !path.contains(PLACEHOLDER_PREFIX) { - return Ok(None); - } - let segments: Vec<&str> = path.split('/').collect(); let mut resolved_segments = Vec::with_capacity(segments.len()); let mut redacted_segments = Vec::with_capacity(segments.len()); @@ -427,7 +744,7 @@ fn rewrite_uri_path( for segment in &segments { let decoded = percent_decode(segment); - if !decoded.contains(PLACEHOLDER_PREFIX) { + if !contains_raw_reserved_marker(&decoded) { resolved_segments.push(segment.to_string()); redacted_segments.push(segment.to_string()); continue; @@ -467,28 +784,23 @@ fn rewrite_path_segment( let bytes = segment.as_bytes(); while pos < bytes.len() { - if let Some(start) = segment[pos..].find(PLACEHOLDER_PREFIX) { - let abs_start = pos + start; + let next_canonical = segment[pos..].find(PLACEHOLDER_PREFIX).map(|p| pos + p); + let next_alias = segment[pos..] + .find(PROVIDER_ALIAS_MARKER) + .map(|marker_pos| { + let marker_abs = pos + marker_pos; + alias_start_for_marker(segment, marker_abs) + }); + if let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() { // Copy literal prefix before the placeholder resolved.push_str(&segment[pos..abs_start]); redacted.push_str(&segment[pos..abs_start]); - // Extract the key name using the env var grammar: [A-Za-z_][A-Za-z0-9_]* - let key_start = abs_start + PLACEHOLDER_PREFIX.len(); - let key_end = segment[key_start..] - .bytes() - .position(|b| !is_env_key_char(b)) - .map_or(segment.len(), |p| key_start + p); - - if key_end == key_start { - // Empty key — not a valid placeholder, copy literally - resolved.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); - redacted.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); - pos = abs_start + PLACEHOLDER_PREFIX.len(); - continue; - } - - let full_placeholder = &segment[abs_start..key_end]; + let Some((token_end, full_placeholder)) = canonical_token_at(segment, abs_start) + .or_else(|| alias_token_at(segment, abs_start)) + else { + return Err(UnresolvedPlaceholderError { location: "path" }); + }; if let Some(secret) = resolver.resolve_placeholder(full_placeholder) { validate_credential_for_path(secret).map_err(|reason| { tracing::warn!( @@ -503,7 +815,7 @@ fn rewrite_path_segment( } else { return Err(UnresolvedPlaceholderError { location: "path" }); } - pos = key_end; + pos = token_end; } else { // No more placeholders in remainder resolved.push_str(&segment[pos..]); @@ -522,7 +834,7 @@ fn rewrite_uri_query_params( query: &str, resolver: &SecretResolver, ) -> Result, UnresolvedPlaceholderError> { - if !query.contains(PLACEHOLDER_PREFIX) { + if !contains_reserved_credential_marker(query) { return Ok(None); } @@ -533,15 +845,18 @@ fn rewrite_uri_query_params( for param in query.split('&') { if let Some((key, value)) = param.split_once('=') { let decoded_value = percent_decode(value); - if let Some(secret) = resolver.resolve_placeholder(&decoded_value) { - resolved_params.push(format!("{key}={}", percent_encode_query(secret))); + if contains_raw_reserved_marker(&decoded_value) { + let mut rewritten = decoded_value.clone(); + let replacements = + resolver.rewrite_text_placeholders(&mut rewritten, "query_param")?; + if replacements == 0 || contains_raw_reserved_marker(&rewritten) { + return Err(UnresolvedPlaceholderError { + location: "query_param", + }); + } + resolved_params.push(format!("{key}={}", percent_encode_query(&rewritten))); redacted_params.push(format!("{key}=[CREDENTIAL]")); any_rewritten = true; - } else if decoded_value.contains(PLACEHOLDER_PREFIX) { - // Placeholder detected but not resolved - return Err(UnresolvedPlaceholderError { - location: "query_param", - }); } else { resolved_params.push(param.to_string()); redacted_params.push(param.to_string()); @@ -611,41 +926,42 @@ pub fn rewrite_http_header_block( break; } - output.extend_from_slice(rewrite_header_line(line, resolver).as_bytes()); + output.extend_from_slice(rewrite_header_line_checked(line, resolver)?.as_bytes()); output.extend_from_slice(b"\r\n"); } output.extend_from_slice(b"\r\n"); output.extend_from_slice(&raw[header_end..]); - // Fail-closed scan: check for any remaining unresolved placeholders - // in both raw form and percent-decoded form of the output header block. + // Fail-closed scan: check for any remaining unresolved placeholders or + // provider-shaped aliases in both raw and percent-decoded header bytes. let output_header = String::from_utf8_lossy(&output[..output.len().min(header_end + 256)]); - if output_header.contains(PLACEHOLDER_PREFIX) { + if contains_reserved_credential_marker(&output_header) { return Err(UnresolvedPlaceholderError { location: "header" }); } - // Also check percent-decoded form of the request line (F5 — encoded placeholder bypass) - let rewritten_rl = output_header.split("\r\n").next().unwrap_or(""); - let decoded_rl = percent_decode(rewritten_rl); - if decoded_rl.contains(PLACEHOLDER_PREFIX) { - return Err(UnresolvedPlaceholderError { location: "path" }); - } - Ok(RewriteResult { rewritten: output, redacted_target: rl_result.redacted_target, }) } +#[cfg_attr(not(test), allow(dead_code))] pub fn rewrite_header_line(line: &str, resolver: &SecretResolver) -> String { + rewrite_header_line_checked(line, resolver).unwrap_or_else(|_| line.to_string()) +} + +pub fn rewrite_header_line_checked( + line: &str, + resolver: &SecretResolver, +) -> Result { let Some((name, value)) = line.split_once(':') else { - return line.to_string(); + return Ok(line.to_string()); }; - resolver.rewrite_header_value(value.trim()).map_or_else( - || line.to_string(), - |rewritten| format!("{name}: {rewritten}"), + resolver.rewrite_header_value(value.trim())?.map_or_else( + || Ok(line.to_string()), + |rewritten| Ok(format!("{name}: {rewritten}")), ) } @@ -660,12 +976,7 @@ pub fn rewrite_target_for_eval( target: &str, resolver: &SecretResolver, ) -> Result { - if !target.contains(PLACEHOLDER_PREFIX) { - // Also check percent-decoded form - let decoded = percent_decode(target); - if decoded.contains(PLACEHOLDER_PREFIX) { - return Err(UnresolvedPlaceholderError { location: "path" }); - } + if !contains_reserved_credential_marker(target) { return Ok(RewriteTargetResult { resolved: target.to_string(), redacted: target.to_string(), @@ -772,6 +1083,50 @@ mod tests { ); } + #[test] + fn rewrites_provider_shaped_alias_header_values() { + let (_, resolver) = SecretResolver::from_provider_env( + [ + ("API_TOKEN".to_string(), "provider-real-token".to_string()), + ("CHAT_APP_TOKEN".to_string(), "app-real-token".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + assert_eq!( + rewrite_header_line( + "Authorization: Bearer vendor-OPENSHELL-RESOLVE-ENV-API_TOKEN", + &resolver, + ), + "Authorization: Bearer provider-real-token" + ); + assert_eq!( + rewrite_header_line( + "x-app-token: token.v1-OPENSHELL-RESOLVE-ENV-CHAT_APP_TOKEN", + &resolver, + ), + "x-app-token: app-real-token" + ); + } + + #[test] + fn unresolved_provider_shaped_alias_fails_closed() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let raw = b"GET / HTTP/1.1\r\nAuthorization: Bearer vendor-OPENSHELL-RESOLVE-ENV-UNKNOWN_TOKEN\r\n\r\n"; + + let err = rewrite_http_header_block(raw, Some(&resolver)) + .expect_err("unknown alias should fail closed"); + + assert_eq!(err.location, "header"); + } + #[test] fn rewrites_http_header_blocks_and_preserves_body() { let (_, resolver) = SecretResolver::from_provider_env( @@ -1382,6 +1737,29 @@ mod tests { ); } + #[test] + fn percent_encoded_canonical_placeholder_in_query_rewrites() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let encoded = "openshell%3Aresolve%3Aenv%3AAPI_TOKEN"; + let raw = format!("GET /api?token={encoded} HTTP/1.1\r\nHost: x\r\n\r\n"); + + let result = + rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should rewrite"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!(rewritten.starts_with("GET /api?token=provider-real-token HTTP/1.1")); + assert!(!rewritten.contains("openshell")); + assert_eq!( + result.redacted_target.as_deref(), + Some("/api?token=[CREDENTIAL]") + ); + } + #[test] fn all_resolved_succeeds() { let (child_env, resolver) = SecretResolver::from_provider_env( @@ -1416,6 +1794,129 @@ mod tests { assert_eq!(raw.as_slice(), result.rewritten.as_slice()); } + #[test] + fn rewrite_websocket_text_replaces_placeholders_and_returns_count() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string()), + ("APP_ID".to_string(), "app-123".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let token = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let app_id = child_env.get("APP_ID").unwrap(); + let mut payload = + format!(r#"{{"op":2,"d":{{"token":"{token}","properties":{{"app":"{app_id}"}}}}}}"#); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 2); + assert!(payload.contains(r#""token":"real-token""#)); + assert!(payload.contains(r#""app":"app-123""#)); + assert!(!payload.contains(PLACEHOLDER_PREFIX)); + } + + #[test] + fn rewrite_websocket_text_replaces_provider_shaped_alias() { + let (_, resolver) = SecretResolver::from_provider_env( + [("APP_TOKEN".to_string(), "app-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let mut payload = r#"{"token":"provider-OPENSHELL-RESOLVE-ENV-APP_TOKEN"}"#.to_string(); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("alias should rewrite"); + + assert_eq!(count, 1); + assert_eq!(payload, r#"{"token":"app-real-token"}"#); + } + + #[test] + fn rewrite_websocket_text_without_placeholder_is_unchanged() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let mut payload = r#"{"op":1,"d":42}"#.to_string(); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 0); + assert_eq!(payload, r#"{"op":1,"d":42}"#); + } + + #[test] + fn rewrite_websocket_text_unknown_placeholder_fails_closed_without_mutating() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let original = r#"{"token":"openshell:resolve:env:UNKNOWN"}"#.to_string(); + let mut payload = original.clone(); + + let err = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect_err("unknown placeholder should fail"); + + assert_eq!(err.location, "websocket"); + assert_eq!(payload, original); + } + + #[test] + fn rewrite_websocket_text_handles_repeated_adjacent_and_unicode_placeholders() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("TOKEN".to_string(), "tok".to_string()), + ("APP".to_string(), "app".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let token = child_env.get("TOKEN").unwrap(); + let app = child_env.get("APP").unwrap(); + let mut payload = format!("prefix-☃-{token}{app}-{token}-suffix"); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 3); + assert_eq!(payload, "prefix-☃-tokapp-tok-suffix"); + } + + #[test] + fn rewrite_websocket_text_placeholder_like_prefix_fails_without_mutating() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let original = "openshell:resolve:env:-not-a-key".to_string(); + let mut payload = original.clone(); + + let err = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect_err("placeholder-like prefix should fail closed"); + + assert_eq!(err.location, "websocket"); + assert_eq!(payload, original); + } + // === Redaction tests === #[test] @@ -1471,6 +1972,66 @@ mod tests { assert_eq!(result.redacted, "/v1/chat/completions?format=json"); } + #[test] + fn debug_format_does_not_leak_secret_values() { + let (_, resolver) = SecretResolver::from_provider_env( + [ + ( + "ANTHROPIC_API_KEY".to_string(), + "sk-very-secret-value-12345".to_string(), + ), + ("DB_PASSWORD".to_string(), "very-secret-value".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + let plain = format!("{resolver:?}"); + let pretty = format!("{resolver:#?}"); + + for output in [&plain, &pretty] { + assert!( + !output.contains("sk-very-secret-value-12345"), + "secret value leaked via Debug: {output}" + ); + assert!( + !output.contains("very-secret-value"), + "secret value leaked via Debug: {output}" + ); + assert!( + !output.contains("ANTHROPIC_API_KEY"), + "placeholder key (env var name) leaked via Debug: {output}" + ); + assert!( + !output.contains("DB_PASSWORD"), + "placeholder key (env var name) leaked via Debug: {output}" + ); + assert!( + !output.contains(PLACEHOLDER_PREFIX), + "placeholder prefix leaked via Debug: {output}" + ); + assert!( + output.contains("SecretResolver"), + "Debug output should still identify the type: {output}" + ); + } + + assert!( + plain.contains('2'), + "Debug output should expose the placeholder count: {plain}" + ); + } + + #[test] + fn debug_format_of_empty_resolver_is_safe() { + let resolver = SecretResolver::default(); + let output = format!("{resolver:?}"); + assert!(output.contains("SecretResolver")); + assert!(output.contains('0')); + assert!(!output.contains(PLACEHOLDER_PREFIX)); + } + #[test] fn rewrite_target_for_eval_roundtrip() { let (child_env, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-sandbox/src/skills.rs b/crates/openshell-sandbox/src/skills.rs new file mode 100644 index 000000000..d29d56247 --- /dev/null +++ b/crates/openshell-sandbox/src/skills.rs @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Static agent guidance files exposed inside the sandbox. + +use miette::{IntoDiagnostic, Result}; +use std::path::{Path, PathBuf}; + +const SKILLS_RELATIVE_DIR: &str = "etc/openshell/skills"; +const POLICY_ADVISOR_FILE: &str = "policy_advisor.md"; +const POLICY_ADVISOR_CONTENT: &str = include_str!("skills/policy_advisor.md"); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct InstalledSkills { + pub policy_advisor: PathBuf, +} + +pub fn install_static_skills() -> Result { + install_static_skills_at(Path::new("/")) +} + +fn install_static_skills_at(root: &Path) -> Result { + let skills_dir = root.join(SKILLS_RELATIVE_DIR); + std::fs::create_dir_all(&skills_dir).into_diagnostic()?; + + let policy_advisor = skills_dir.join(POLICY_ADVISOR_FILE); + std::fs::write(&policy_advisor, POLICY_ADVISOR_CONTENT).into_diagnostic()?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + + std::fs::set_permissions(&policy_advisor, std::fs::Permissions::from_mode(0o444)) + .into_diagnostic()?; + } + + Ok(InstalledSkills { policy_advisor }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn install_static_skills_at_writes_policy_advisor() { + let dir = tempfile::tempdir().unwrap(); + + let installed = install_static_skills_at(dir.path()).unwrap(); + + let expected = dir + .path() + .join("etc") + .join("openshell") + .join("skills") + .join("policy_advisor.md"); + assert_eq!(installed.policy_advisor, expected); + + let content = std::fs::read_to_string(expected).unwrap(); + assert!(content.contains("# OpenShell Policy Advisor")); + assert!(content.contains("policy.local")); + assert!(content.contains("addRule")); + // The wait-loop teaching is load-bearing for the agent feedback + // UX; lock the workflow language in so future skill edits cannot + // drop it silently. Each substring targets a directive, not the + // field name (which could appear in the API doc block alone). + assert!(content.contains("/v1/proposals/{chunk_id}/wait")); + assert!(content.contains("read `rejection_reason`")); + // policy_reloaded distinguishes "safe to retry" from "approval + // landed but supervisor hasn't reloaded yet"; without both + // branches taught the agent retries blind on approve+not-yet + // and re-runs into policy_denied. + assert!(content.contains("`policy_reloaded: true`")); + assert!(content.contains("`policy_reloaded: false`")); + } +} diff --git a/crates/openshell-sandbox/src/skills/policy_advisor.md b/crates/openshell-sandbox/src/skills/policy_advisor.md new file mode 100644 index 000000000..8ca64f977 --- /dev/null +++ b/crates/openshell-sandbox/src/skills/policy_advisor.md @@ -0,0 +1,141 @@ +# OpenShell Policy Advisor + +Use this when OpenShell blocks a network request and the response or logs say +`policy_denied`. + +## Goal + +Draft the smallest policy proposal that allows the user's current task without +giving the sandbox broad new network access. The developer approves or rejects +the proposal; do not try to bypass policy. + +## Local API + +The sandbox-local policy API is reachable at `http://policy.local`: + +- `GET /v1/policy/current` — current effective policy as YAML. +- `GET /v1/denials?last=10` — most recent network/L7 denials seen by this + sandbox (newest first), returned as raw shorthand log lines. Each line + carries the timestamp, class, severity, action, host/port, binary, policy + name, and (for denied events) a short reason. Read the lines directly; you + do not need to parse them into structured fields. +- `POST /v1/proposals` — submit a proposal. The 202 response carries + `accepted_chunk_ids` (one ID per operation the gateway accepted) and + `rejection_reasons` (one entry per operation the gateway refused at + submit-time). The two arrays together account for every operation you + sent. +- `GET /v1/proposals/{chunk_id}` — immediate state of one proposal. + Returns `status` (`pending` / `approved` / `rejected`), + `rejection_reason` (the reviewer's free-form note, only set on reject), + and `validation_result` (the gateway prover's verdict on this chunk; + may be empty). +- `GET /v1/proposals/{chunk_id}/wait?timeout=` — block on this + proposal until the developer decides or the timeout expires. Default + 60s, clamped [1, 300]. On timeout you get `status: "pending"` plus + `timed_out: true`. On approval the response also carries + `policy_reloaded: true|false` indicating whether the local sandbox has + already loaded a policy containing the approved rule. Use this endpoint + instead of polling `/v1/proposals/{chunk_id}`. + +The proposal body takes an `intent_summary` and one or more `addRule` +operations. Each `addRule` carries a complete narrow `NetworkPolicyRule`. + +## Workflow + +1. Read the denial response body. Use `layer`, `method`, `path`, `host`, + `port`, `binary`, `rule_missing`, and `detail` as evidence. +2. Fetch the current policy from `/v1/policy/current`. +3. Fetch recent denials from `/v1/denials` if the response body is incomplete. +4. Prefer L7 REST rules for REST APIs. Use L4 only for non-REST protocols or + when the client tunnels opaque traffic that OpenShell cannot inspect. +5. Draft the narrowest rule: exact host, exact port, exact binary when known, + exact method, and the smallest safe path. +6. Submit the proposal, save `accepted_chunk_ids` from the response, and + tell the developer what you proposed. If the response also carries + `rejection_reasons`, the gateway refused those operations at + submit-time before any human review; fix them and resubmit before + waiting on the rest. +7. For each accepted chunk_id, call + `GET /v1/proposals/{chunk_id}/wait?timeout=300` and act on the result: + - `status: "approved"` with `policy_reloaded: true` — retry the + original denied action. The merged policy is already loaded; the + request should succeed. If it still fails with `policy_denied`, + re-read the denial — your rule may not match. If it fails for any + other reason, surface to the user. + - `status: "approved"` with `policy_reloaded: false` — approval + landed but the local sandbox hasn't observed the reload within the + `/wait` window. Re-issue the same `/wait` call once with + `timeout=30`. If the second response is still + `policy_reloaded: false`, surface to the user rather than retrying + blind; do not loop tightly. + - `status: "rejected"` — read `rejection_reason` and + `validation_result`. `rejection_reason` is what the reviewer typed; + `validation_result` is the prover's verdict, which often explains + a reject driven by automated checks. If either has content, address + the specific feedback and submit a revised proposal. If both are + empty, draft something materially different or ask the user. + - `status: "pending"` with `timed_out: true` — call `/wait` again. + - Any non-2xx response — surface to the user; do not retry the denied + action without approval. + +## Proposal shape + +A complete narrow REST-inspected rule looks like this: + +```json +{ + "intent_summary": "Allow gh to update repository contents in NVIDIA/OpenShell only.", + "operations": [ + { + "addRule": { + "ruleName": "github_api_repo_contents_write", + "rule": { + "name": "github_api_repo_contents_write", + "endpoints": [ + { + "host": "api.github.com", + "port": 443, + "protocol": "rest", + "enforcement": "enforce", + "rules": [ + { + "allow": { + "method": "PUT", + "path": "/repos/NVIDIA/OpenShell/contents/**" + } + } + ] + } + ], + "binaries": [ + { + "path": "/usr/bin/gh" + } + ] + } + } + } + ] +} +``` + +## Norms + +- Do not propose wildcard hosts such as `**` or `*.com`. +- Do not propose `access: full` to fix a single denied REST request. +- Do not include query strings, tokens, credentials, or secret values in + paths. +- Explain uncertainty in `intent_summary` instead of widening the rule. +- If pushing with `git` fails, that is a separate L4 or protocol-specific + path from GitHub REST API access. Propose it separately. + +## Local logs (read-only) + +Two local files complement the API and are useful when debugging policy +behavior: + +- `/var/log/openshell.YYYY-MM-DD.log` — shorthand log of sandbox activity. + This is what `/v1/denials` reads from. +- `/var/log/openshell-ocsf.YYYY-MM-DD.log` — full OCSF JSON events, only + written when the `ocsf_json_enabled` setting is on. Not used by + `/v1/denials`; useful for SIEM ingestion. diff --git a/crates/openshell-sandbox/src/ssh.rs b/crates/openshell-sandbox/src/ssh.rs index 9434d0a16..67fbc7e57 100644 --- a/crates/openshell-sandbox/src/ssh.rs +++ b/crates/openshell-sandbox/src/ssh.rs @@ -6,6 +6,7 @@ use crate::child_env; use crate::policy::SandboxPolicy; use crate::process::drop_privileges; +use crate::provider_credentials::ProviderCredentialState; use crate::sandbox; #[cfg(target_os = "linux")] use crate::{register_managed_child, unregister_managed_child}; @@ -105,7 +106,7 @@ pub async fn run_ssh_server( netns_fd: Option, proxy_url: Option, ca_file_paths: Option<(PathBuf, PathBuf)>, - provider_env: HashMap, + provider_credentials: ProviderCredentialState, ) -> Result<()> { let (listener, config, ca_paths) = match ssh_server_init(&listen_path, &ca_file_paths) { Ok(v) => { @@ -129,7 +130,7 @@ pub async fn run_ssh_server( let workdir = workdir.clone(); let proxy_url = proxy_url.clone(); let ca_paths = ca_paths.clone(); - let provider_env = provider_env.clone(); + let provider_credentials = provider_credentials.clone(); tokio::spawn(async move { if let Err(err) = handle_connection( @@ -140,7 +141,7 @@ pub async fn run_ssh_server( netns_fd, proxy_url, ca_paths, - provider_env, + provider_credentials, ) .await { @@ -166,7 +167,7 @@ async fn handle_connection( netns_fd: Option, proxy_url: Option, ca_file_paths: Option>, - provider_env: HashMap, + provider_credentials: ProviderCredentialState, ) -> Result<()> { // Access is gated by the Unix-socket filesystem permissions (root-only), // not by an application-level preface. The supervisor bridges the @@ -188,7 +189,7 @@ async fn handle_connection( netns_fd, proxy_url, ca_file_paths, - provider_env, + provider_credentials, ); russh::server::run_stream(config, stream, handler) .await @@ -215,7 +216,7 @@ struct SshHandler { netns_fd: Option, proxy_url: Option, ca_file_paths: Option>, - provider_env: HashMap, + provider_credentials: ProviderCredentialState, channels: HashMap, } @@ -226,7 +227,7 @@ impl SshHandler { netns_fd: Option, proxy_url: Option, ca_file_paths: Option>, - provider_env: HashMap, + provider_credentials: ProviderCredentialState, ) -> Self { Self { policy, @@ -234,7 +235,7 @@ impl SshHandler { netns_fd, proxy_url, ca_file_paths, - provider_env, + provider_credentials, channels: HashMap::new(), } } @@ -456,7 +457,7 @@ impl russh::server::Handler for SshHandler { self.netns_fd, self.proxy_url.clone(), self.ca_file_paths.clone(), - &self.provider_env, + &self.provider_credentials.snapshot().child_env, )?; let state = self.channels.get_mut(&channel).ok_or_else(|| { anyhow::anyhow!("subsystem_request on unknown channel {channel:?}") @@ -533,6 +534,7 @@ impl SshHandler { handle: Handle, command: Option, ) -> anyhow::Result<()> { + let provider_snapshot = self.provider_credentials.snapshot(); let state = self .channels .get_mut(&channel) @@ -550,7 +552,7 @@ impl SshHandler { self.netns_fd, self.proxy_url.clone(), self.ca_file_paths.clone(), - &self.provider_env, + &provider_snapshot.child_env, )?; state.pty_master = Some(pty_master); state.input_sender = Some(input_sender); @@ -567,7 +569,7 @@ impl SshHandler { self.netns_fd, self.proxy_url.clone(), self.ca_file_paths.clone(), - &self.provider_env, + &provider_snapshot.child_env, )?; state.input_sender = Some(input_sender); } @@ -588,7 +590,7 @@ impl SshHandler { /// the calling thread's network namespace permanently — a tokio blocking-pool /// thread could be reused for unrelated tasks and must not be contaminated. /// On non-Linux platforms (no network namespace support), we connect directly. -async fn connect_in_netns( +pub async fn connect_in_netns( addr: &str, netns_fd: Option, ) -> std::io::Result { @@ -678,7 +680,7 @@ fn apply_child_env( let path = std::env::var("PATH").unwrap_or_else(|_| "/usr/local/bin:/usr/bin:/bin".into()); cmd.env_clear() - .env("OPENSHELL_SANDBOX", "1") + .env(openshell_core::sandbox_env::SANDBOX, "1") .env("HOME", session_home) .env("USER", session_user) .env("SHELL", "/bin/bash") @@ -1487,7 +1489,7 @@ mod tests { // Skip if running as root: drop_privileges would try to switch to // "sandbox" which may not exist in the test environment. - if nix::unistd::geteuid().is_root() { + if rustix::process::geteuid().is_root() { return; } diff --git a/crates/openshell-sandbox/src/supervisor_session.rs b/crates/openshell-sandbox/src/supervisor_session.rs index 490a0cba7..4d7392ee3 100644 --- a/crates/openshell-sandbox/src/supervisor_session.rs +++ b/crates/openshell-sandbox/src/supervisor_session.rs @@ -4,27 +4,30 @@ //! Persistent supervisor-to-gateway session. //! //! Maintains a long-lived `ConnectSupervisor` bidirectional gRPC stream to the -//! gateway. When the gateway sends `RelayOpen`, the supervisor initiates a -//! `RelayStream` gRPC call (a new HTTP/2 stream multiplexed over the same -//! TCP+TLS connection as the control stream) and bridges it to the local SSH -//! daemon. The supervisor is a dumb byte bridge — it has no protocol awareness -//! of the SSH or NSSH1 bytes flowing through. - +//! gateway. When the gateway sends `RelayOpen`, the supervisor dials the +//! requested local target, initiates a `RelayStream` gRPC call (a new HTTP/2 +//! stream multiplexed over the same TCP+TLS connection as the control stream), +//! and bridges bytes. The supervisor is a dumb byte bridge after target +//! selection — it has no protocol awareness of the bytes flowing through. + +use std::net::IpAddr; +#[cfg(target_os = "linux")] +use std::os::fd::RawFd; use std::time::Duration; use openshell_core::proto::open_shell_client::OpenShellClient; use openshell_core::proto::{ - GatewayMessage, RelayFrame, RelayInit, SupervisorHeartbeat, SupervisorHello, SupervisorMessage, - gateway_message, supervisor_message, + GatewayMessage, RelayFrame, RelayInit, RelayOpen, RelayOpenResult, SupervisorHeartbeat, + SupervisorHello, SupervisorMessage, TcpRelayTarget, gateway_message, relay_open, + supervisor_message, }; use openshell_ocsf::{ - ActivityId, Endpoint, NetworkActivityBuilder, OcsfEvent, SandboxContext, SeverityId, StatusId, - ocsf_emit, + ActivityId, ConnectionInfo, Endpoint, NetworkActivityBuilder, OcsfEvent, SandboxContext, + SeverityId, StatusId, ocsf_emit, }; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc; use tokio_stream::StreamExt; -use tonic::transport::Channel; use tracing::{debug, warn}; use crate::grpc_client; @@ -91,33 +94,102 @@ fn session_failed_event( .build() } -fn relay_open_event(ctx: &SandboxContext, channel_id: &str) -> OcsfEvent { - NetworkActivityBuilder::new(ctx) +fn relay_target_endpoint(open: &RelayOpen) -> Option { + let relay_open::Target::Tcp(target) = open.target.as_ref()? else { + return None; + }; + let host = target.host.trim(); + let port = u16::try_from(target.port).ok()?; + host.parse().map_or_else( + |_| Some(Endpoint::from_domain(host, port)), + |ip| Some(Endpoint::from_ip(ip, port)), + ) +} + +fn relay_target_kind(open: &RelayOpen) -> &'static str { + match open.target.as_ref() { + Some(relay_open::Target::Tcp(_)) => "tcp relay", + Some(relay_open::Target::Ssh(_)) | None => "ssh relay", + } +} + +fn relay_target_message( + open: &RelayOpen, + state: &str, + ssh_socket_path: &std::path::Path, +) -> String { + let target = match open.target.as_ref() { + Some(relay_open::Target::Tcp(target)) => { + format!("{}:{}", target.host.trim(), target.port) + } + Some(relay_open::Target::Ssh(_)) | None => { + format!("unix:{}", ssh_socket_path.display()) + } + }; + + format!( + "{} {state} (channel_id={}, target={target})", + relay_target_kind(open), + open.channel_id + ) +} + +fn relay_open_event( + ctx: &SandboxContext, + open: &RelayOpen, + ssh_socket_path: &std::path::Path, +) -> OcsfEvent { + let mut builder = NetworkActivityBuilder::new(ctx) .activity(ActivityId::Open) .severity(SeverityId::Informational) .status(StatusId::Success) - .message(format!("relay open (channel_id={channel_id})")) - .build() + .message(relay_target_message(open, "open", ssh_socket_path)); + if let Some(endpoint) = relay_target_endpoint(open) { + builder = builder + .dst_endpoint(endpoint) + .connection_info(ConnectionInfo::new("tcp")); + } + builder.build() } -fn relay_closed_event(ctx: &SandboxContext, channel_id: &str) -> OcsfEvent { - NetworkActivityBuilder::new(ctx) +fn relay_closed_event( + ctx: &SandboxContext, + open: &RelayOpen, + ssh_socket_path: &std::path::Path, +) -> OcsfEvent { + let mut builder = NetworkActivityBuilder::new(ctx) .activity(ActivityId::Close) .severity(SeverityId::Informational) .status(StatusId::Success) - .message(format!("relay closed (channel_id={channel_id})")) - .build() + .message(relay_target_message(open, "closed", ssh_socket_path)); + if let Some(endpoint) = relay_target_endpoint(open) { + builder = builder + .dst_endpoint(endpoint) + .connection_info(ConnectionInfo::new("tcp")); + } + builder.build() } -fn relay_failed_event(ctx: &SandboxContext, channel_id: &str, error: &str) -> OcsfEvent { - NetworkActivityBuilder::new(ctx) +fn relay_failed_event( + ctx: &SandboxContext, + open: &RelayOpen, + ssh_socket_path: &std::path::Path, + error: &str, +) -> OcsfEvent { + let mut builder = NetworkActivityBuilder::new(ctx) .activity(ActivityId::Fail) .severity(SeverityId::Low) .status(StatusId::Failure) .message(format!( - "relay bridge failed (channel_id={channel_id}): {error}" - )) - .build() + "{}: {error}", + relay_target_message(open, "bridge failed", ssh_socket_path) + )); + if let Some(endpoint) = relay_target_endpoint(open) { + builder = builder + .dst_endpoint(endpoint) + .connection_info(ConnectionInfo::new("tcp")); + } + builder.build() } fn relay_close_from_gateway_event( @@ -139,6 +211,10 @@ fn relay_close_from_gateway_event( /// HTTP/2 frame size so each `RelayFrame::data` fits in one frame. const RELAY_CHUNK_SIZE: usize = 16 * 1024; +trait TargetStream: AsyncRead + AsyncWrite + Send + Unpin {} + +impl TargetStream for T where T: AsyncRead + AsyncWrite + Send + Unpin {} + fn map_stream_message( message: Result, tonic::Status>, eof_error: &'static str, @@ -158,14 +234,21 @@ pub fn spawn( endpoint: String, sandbox_id: String, ssh_socket_path: std::path::PathBuf, + netns_fd: Option, ) -> tokio::task::JoinHandle<()> { - tokio::spawn(run_session_loop(endpoint, sandbox_id, ssh_socket_path)) + tokio::spawn(run_session_loop( + endpoint, + sandbox_id, + ssh_socket_path, + netns_fd, + )) } async fn run_session_loop( endpoint: String, sandbox_id: String, ssh_socket_path: std::path::PathBuf, + netns_fd: Option, ) { let mut backoff = INITIAL_BACKOFF; let mut attempt: u64 = 0; @@ -173,7 +256,7 @@ async fn run_session_loop( loop { attempt += 1; - match run_single_session(&endpoint, &sandbox_id, &ssh_socket_path).await { + match run_single_session(&endpoint, &sandbox_id, &ssh_socket_path, netns_fd).await { Ok(()) => { let event = session_closed_event(crate::ocsf_ctx(), &endpoint, &sandbox_id); ocsf_emit!(event); @@ -194,6 +277,7 @@ async fn run_single_session( endpoint: &str, sandbox_id: &str, ssh_socket_path: &std::path::Path, + netns_fd: Option, ) -> Result<(), Box> { // Connect to the gateway. The same `Channel` is used for both the // long-lived control stream and all data-plane `RelayStream` calls, so @@ -262,7 +346,9 @@ async fn run_single_session( &msg, sandbox_id, ssh_socket_path, + netns_fd, &channel, + &tx, ); } _ = heartbeat_interval.tick() => { @@ -283,7 +369,9 @@ fn handle_gateway_message( msg: &GatewayMessage, sandbox_id: &str, ssh_socket_path: &std::path::Path, - channel: &Channel, + netns_fd: Option, + channel: &grpc_client::AuthedChannel, + tx: &mpsc::Sender, ) { match &msg.payload { Some(gateway_message::Payload::Heartbeat(_)) => { @@ -291,22 +379,30 @@ fn handle_gateway_message( } Some(gateway_message::Payload::RelayOpen(open)) => { let channel_id = open.channel_id.clone(); + let relay_open = open.clone(); let sandbox_id = sandbox_id.to_string(); let channel = channel.clone(); let ssh_socket_path = ssh_socket_path.to_path_buf(); + let tx = tx.clone(); - let event = relay_open_event(crate::ocsf_ctx(), &channel_id); + let event = relay_open_event(crate::ocsf_ctx(), &relay_open, &ssh_socket_path); ocsf_emit!(event); tokio::spawn(async move { - match handle_relay_open(&channel_id, &ssh_socket_path, channel).await { + let event_open = relay_open.clone(); + match handle_relay_open(relay_open, &ssh_socket_path, netns_fd, channel, tx).await { Ok(()) => { - let event = relay_closed_event(crate::ocsf_ctx(), &channel_id); + let event = + relay_closed_event(crate::ocsf_ctx(), &event_open, &ssh_socket_path); ocsf_emit!(event); } Err(e) => { - let event = - relay_failed_event(crate::ocsf_ctx(), &channel_id, &e.to_string()); + let event = relay_failed_event( + crate::ocsf_ctx(), + &event_open, + &ssh_socket_path, + &e.to_string(), + ); ocsf_emit!(event); warn!( sandbox_id = %sandbox_id, @@ -336,10 +432,23 @@ fn handle_gateway_message( /// TLS handshake. The first `RelayFrame` we send is a `RelayInit`; subsequent /// frames carry raw SSH bytes in `data`. async fn handle_relay_open( - channel_id: &str, + relay_open: RelayOpen, ssh_socket_path: &std::path::Path, - channel: Channel, + netns_fd: Option, + channel: grpc_client::AuthedChannel, + tx: mpsc::Sender, ) -> Result<(), Box> { + let channel_id = relay_open.channel_id.clone(); + let target = match open_target(&relay_open, ssh_socket_path, netns_fd).await { + Ok(target) => target, + Err(err) => { + send_relay_open_result(&tx, &channel_id, false, err.to_string()).await; + return Err(err); + } + }; + + send_relay_open_result(&tx, &channel_id, true, String::new()).await; + let mut client = OpenShellClient::new(channel); // Outbound chunks to the gateway. @@ -351,7 +460,7 @@ async fn handle_relay_open( .send(RelayFrame { payload: Some(openshell_core::proto::relay_frame::Payload::Init( RelayInit { - channel_id: channel_id.to_string(), + channel_id: channel_id.clone(), }, )), }) @@ -366,21 +475,19 @@ async fn handle_relay_open( let mut inbound = response.into_inner(); // Connect to the local SSH daemon on its Unix socket. - let ssh = tokio::net::UnixStream::connect(ssh_socket_path).await?; - let (mut ssh_r, mut ssh_w) = ssh.into_split(); + let (mut target_r, mut target_w) = tokio::io::split(target); debug!( channel_id = %channel_id, - socket = %ssh_socket_path.display(), - "relay bridge: connected to local SSH daemon" + "relay bridge: connected to local target" ); - // SSH → gRPC (out_tx): read local SSH, forward as `RelayFrame::data`. + // Target → gRPC (out_tx): read local target, forward as `RelayFrame::data`. let out_tx_writer = out_tx.clone(); - let ssh_to_grpc = tokio::spawn(async move { + let target_to_grpc = tokio::spawn(async move { let mut buf = vec![0u8; RELAY_CHUNK_SIZE]; loop { - match ssh_r.read(&mut buf).await { + match target_r.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => { let chunk = RelayFrame { @@ -396,7 +503,7 @@ async fn handle_relay_open( } }); - // gRPC (inbound) → SSH: drain inbound chunks into the local SSH socket. + // gRPC (inbound) → target: drain inbound chunks into the local target socket. let mut inbound_err: Option = None; while let Some(next) = inbound.next().await { match next { @@ -409,8 +516,8 @@ async fn handle_relay_open( if data.is_empty() { continue; } - if let Err(e) = ssh_w.write_all(&data).await { - inbound_err = Some(format!("write to ssh failed: {e}")); + if let Err(e) = target_w.write_all(&data).await { + inbound_err = Some(format!("write to target failed: {e}")); break; } } @@ -421,13 +528,13 @@ async fn handle_relay_open( } } - // Half-close the SSH socket's write side so the daemon sees EOF. - let _ = ssh_w.shutdown().await; + // Half-close the target socket's write side so the service sees EOF. + let _ = target_w.shutdown().await; // Dropping out_tx closes the outbound gRPC stream, letting the gateway // observe EOF on its side too. drop(out_tx); - let _ = ssh_to_grpc.await; + let _ = target_to_grpc.await; if let Some(e) = inbound_err { return Err(e.into()); @@ -435,6 +542,165 @@ async fn handle_relay_open( Ok(()) } +async fn send_relay_open_result( + tx: &mpsc::Sender, + channel_id: &str, + success: bool, + error: String, +) { + let _ = tx + .send(SupervisorMessage { + payload: Some(supervisor_message::Payload::RelayOpenResult( + RelayOpenResult { + channel_id: channel_id.to_string(), + success, + error, + }, + )), + }) + .await; +} + +async fn open_target( + relay_open: &RelayOpen, + ssh_socket_path: &std::path::Path, + netns_fd: Option, +) -> Result, Box> { + match relay_open.target.as_ref() { + Some(relay_open::Target::Tcp(target)) => open_tcp_target(target, netns_fd).await, + Some(relay_open::Target::Ssh(_)) | None => { + let stream = tokio::net::UnixStream::connect(ssh_socket_path).await?; + Ok(Box::new(stream)) + } + } +} + +async fn open_tcp_target( + target: &TcpRelayTarget, + netns_fd: Option, +) -> Result, Box> { + let host = normalize_tcp_target_host(target)?; + let port = u16::try_from(target.port).map_err(|_| "tcp target port must fit in u16")?; + let stream = connect_tcp_target(host, port, netns_fd).await?; + Ok(Box::new(stream)) +} + +#[cfg(target_os = "linux")] +async fn connect_tcp_target( + host: String, + port: u16, + netns_fd: Option, +) -> Result> { + if let Some(fd) = netns_fd { + let (tx, rx) = tokio::sync::oneshot::channel(); + std::thread::spawn(move || { + let result = (|| -> std::io::Result { + #[allow(unsafe_code)] + let rc = unsafe { libc::setns(fd, libc::CLONE_NEWNET) }; + if rc != 0 { + return Err(std::io::Error::last_os_error()); + } + std::net::TcpStream::connect((host.as_str(), port)) + })(); + let _ = tx.send(result); + }); + + let stream = rx + .await + .map_err(|_| "netns tcp connect thread panicked")??; + stream.set_nonblocking(true)?; + return Ok(tokio::net::TcpStream::from_std(stream)?); + } + + Ok(tokio::net::TcpStream::connect((host.as_str(), port)).await?) +} + +#[cfg(not(target_os = "linux"))] +async fn connect_tcp_target( + host: String, + port: u16, + _netns_fd: Option, +) -> Result> { + Ok(tokio::net::TcpStream::connect((host.as_str(), port)).await?) +} + +#[cfg(test)] +fn validate_tcp_target(target: &TcpRelayTarget) -> Result<(), String> { + normalize_tcp_target_host(target).map(|_| ()) +} + +fn normalize_tcp_target_host(target: &TcpRelayTarget) -> Result { + if target.port == 0 || target.port > u32::from(u16::MAX) { + return Err("tcp target port must be between 1 and 65535".to_string()); + } + + let host = target.host.trim(); + if host.is_empty() { + return Err("tcp target host is required".to_string()); + } + if host.eq_ignore_ascii_case("localhost") { + return Ok("127.0.0.1".to_string()); + } + + let ip: IpAddr = host + .parse() + .map_err(|_| "tcp target host must be loopback".to_string())?; + if ip.is_loopback() { + Ok(ip.to_string()) + } else { + Err("tcp target host must be loopback".to_string()) + } +} + +#[cfg(test)] +mod target_tests { + use super::*; + + fn tcp(host: &str, port: u32) -> TcpRelayTarget { + TcpRelayTarget { + host: host.to_string(), + port, + } + } + + #[test] + fn tcp_target_allows_loopback_hosts() { + validate_tcp_target(&tcp("127.0.0.1", 8080)).expect("ipv4 loopback"); + validate_tcp_target(&tcp("::1", 8080)).expect("ipv6 loopback"); + validate_tcp_target(&tcp("localhost", 8080)).expect("localhost"); + } + + #[test] + fn tcp_target_normalizes_localhost_before_dialing() { + assert_eq!( + normalize_tcp_target_host(&tcp("localhost", 8080)).expect("localhost"), + "127.0.0.1" + ); + assert_eq!( + normalize_tcp_target_host(&tcp("LOCALHOST", 8080)).expect("localhost"), + "127.0.0.1" + ); + } + + #[test] + fn tcp_target_rejects_non_loopback_hosts() { + let err = validate_tcp_target(&tcp("10.0.0.1", 8080)).expect_err("private ip rejected"); + assert_eq!(err, "tcp target host must be loopback"); + + let err = validate_tcp_target(&tcp("example.com", 8080)).expect_err("hostname rejected"); + assert_eq!(err, "tcp target host must be loopback"); + } + + #[test] + fn tcp_target_rejects_invalid_ports() { + let err = validate_tcp_target(&tcp("127.0.0.1", 0)).expect_err("zero rejected"); + assert_eq!(err, "tcp target port must be between 1 and 65535"); + + let err = validate_tcp_target(&tcp("127.0.0.1", 70000)).expect_err("too large rejected"); + assert_eq!(err, "tcp target port must be between 1 and 65535"); + } +} + #[cfg(test)] mod ocsf_event_tests { use super::*; @@ -479,6 +745,31 @@ mod ocsf_event_tests { } } + fn ssh_relay_open(channel_id: &str) -> RelayOpen { + RelayOpen { + channel_id: channel_id.to_string(), + target: Some(relay_open::Target::Ssh( + openshell_core::proto::SshRelayTarget::default(), + )), + service_id: String::new(), + } + } + + fn tcp_relay_open(channel_id: &str, host: &str, port: u32) -> RelayOpen { + RelayOpen { + channel_id: channel_id.to_string(), + target: Some(relay_open::Target::Tcp(TcpRelayTarget { + host: host.to_string(), + port, + })), + service_id: String::new(), + } + } + + fn ssh_socket_path() -> &'static std::path::Path { + std::path::Path::new("/run/openshell/ssh.sock") + } + #[test] fn session_established_emits_network_open_success() { let event = session_established_event(&ctx(), "https://gw:443", "sess-1", 30); @@ -518,22 +809,43 @@ mod ocsf_event_tests { #[test] fn relay_open_emits_network_open_success() { - let event = relay_open_event(&ctx(), "ch-42"); + let event = relay_open_event(&ctx(), &ssh_relay_open("ch-42"), ssh_socket_path()); let na = network_activity(&event); assert_eq!(na.base.activity_id, ActivityId::Open.as_u8()); assert_eq!(na.base.severity, SeverityId::Informational); + let msg = na.base.message.as_deref().unwrap_or_default(); + assert!(msg.contains("ch-42"), "message: {msg}"); assert!( - na.base - .message - .as_deref() - .unwrap_or_default() - .contains("ch-42") + msg.contains("target=unix:/run/openshell/ssh.sock"), + "message: {msg}" + ); + } + + #[test] + fn tcp_relay_open_emits_target_endpoint() { + let event = relay_open_event( + &ctx(), + &tcp_relay_open("ch-42", "127.0.0.1", 8765), + ssh_socket_path(), + ); + let na = network_activity(&event); + assert_eq!(na.base.activity_id, ActivityId::Open.as_u8()); + assert_eq!( + na.dst_endpoint.as_ref().and_then(|e| e.ip.as_deref()), + Some("127.0.0.1") + ); + assert_eq!(na.dst_endpoint.as_ref().and_then(|e| e.port), Some(8765)); + assert_eq!( + na.connection_info + .as_ref() + .map(|c| c.protocol_name.as_str()), + Some("tcp") ); } #[test] fn relay_closed_emits_network_close_success() { - let event = relay_closed_event(&ctx(), "ch-42"); + let event = relay_closed_event(&ctx(), &ssh_relay_open("ch-42"), ssh_socket_path()); let na = network_activity(&event); assert_eq!(na.base.activity_id, ActivityId::Close.as_u8()); assert_eq!(na.base.status, Some(StatusId::Success)); @@ -541,7 +853,12 @@ mod ocsf_event_tests { #[test] fn relay_failed_emits_network_fail_low() { - let event = relay_failed_event(&ctx(), "ch-42", "write to ssh failed"); + let event = relay_failed_event( + &ctx(), + &ssh_relay_open("ch-42"), + ssh_socket_path(), + "write to ssh failed", + ); let na = network_activity(&event); assert_eq!(na.base.activity_id, ActivityId::Fail.as_u8()); assert_eq!(na.base.severity, SeverityId::Low); diff --git a/crates/openshell-sandbox/tests/websocket_upgrade.rs b/crates/openshell-sandbox/tests/websocket_upgrade.rs index e4cd232ce..b35076a9a 100644 --- a/crates/openshell-sandbox/tests/websocket_upgrade.rs +++ b/crates/openshell-sandbox/tests/websocket_upgrade.rs @@ -124,7 +124,7 @@ async fn websocket_upgrade_through_l7_relay_exchanges_message() { .expect("relay should succeed"); match outcome { - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { overflow, .. } => { // This is what handle_upgrade() does in relay.rs if !overflow.is_empty() { client_proxy.write_all(&overflow).await.unwrap(); diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index cb6561f3e..69319f63a 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -15,6 +15,7 @@ name = "openshell-gateway" path = "src/main.rs" [dependencies] +openshell-bootstrap = { path = "../openshell-bootstrap" } openshell-core = { path = "../openshell-core" } openshell-driver-docker = { path = "../openshell-driver-docker" } openshell-driver-kubernetes = { path = "../openshell-driver-kubernetes" } @@ -24,6 +25,10 @@ openshell-policy = { path = "../openshell-policy" } openshell-providers = { path = "../openshell-providers" } openshell-router = { path = "../openshell-router" } +# Kubernetes client (used by the `generate-certs` subcommand) +kube = { workspace = true } +k8s-openapi = { workspace = true } + # Async runtime tokio = { workspace = true } @@ -69,6 +74,7 @@ bytes = { workspace = true } pin-project-lite = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +toml = { workspace = true } tokio-stream = { workspace = true } sqlx = { workspace = true } reqwest = { workspace = true } @@ -76,12 +82,16 @@ uuid = { workspace = true } hmac = "0.12" sha2 = { workspace = true } jsonwebtoken = { workspace = true } +async-trait = "0.1" +url = { workspace = true } hex = "0.4" russh = "0.57" rand = { workspace = true } petname = "2" ipnet = "2" tempfile = "3" +rustix = { workspace = true } +x509-parser = "0.16" [features] dev-settings = ["openshell-core/dev-settings"] diff --git a/crates/openshell-server/migrations/postgres/005_add_resource_version.sql b/crates/openshell-server/migrations/postgres/005_add_resource_version.sql new file mode 100644 index 000000000..e6a294d62 --- /dev/null +++ b/crates/openshell-server/migrations/postgres/005_add_resource_version.sql @@ -0,0 +1,5 @@ +-- Add resource_version column for optimistic concurrency control +ALTER TABLE objects ADD COLUMN resource_version BIGINT NOT NULL DEFAULT 1; + +-- Backfill existing rows with resource_version = 1 +-- (DEFAULT clause handles this automatically for existing rows in PostgreSQL) diff --git a/crates/openshell-server/migrations/sqlite/005_add_resource_version.sql b/crates/openshell-server/migrations/sqlite/005_add_resource_version.sql new file mode 100644 index 000000000..50aacb99d --- /dev/null +++ b/crates/openshell-server/migrations/sqlite/005_add_resource_version.sql @@ -0,0 +1,5 @@ +-- Add resource_version column for optimistic concurrency control +ALTER TABLE objects ADD COLUMN resource_version INTEGER NOT NULL DEFAULT 1; + +-- Backfill existing rows with resource_version = 1 +-- (DEFAULT clause handles this automatically for existing rows in SQLite) diff --git a/crates/openshell-server/src/auth/authenticator.rs b/crates/openshell-server/src/auth/authenticator.rs new file mode 100644 index 000000000..f5d5c7b2a --- /dev/null +++ b/crates/openshell-server/src/auth/authenticator.rs @@ -0,0 +1,229 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Pluggable authentication trait + chain dispatch. +//! +//! The gateway runs every authenticated request through an +//! [`AuthenticatorChain`] of [`Authenticator`] implementations. The chain +//! evaluates authenticators in order; the first one that recognizes the +//! caller produces the [`Principal`]. An authenticator that does not apply +//! (e.g. an OIDC authenticator seeing no Bearer header) returns `Ok(None)` +//! so the chain falls through to the next. An authenticator that *does* +//! apply but rejects the caller returns `Err(Status)`, which terminates +//! the chain — fail-closed. +//! +//! Live authenticators slotting into the chain: +//! - [`super::sandbox_jwt::SandboxJwtAuthenticator`] — gateway-minted JWTs +//! - [`super::k8s_sa::K8sServiceAccountAuthenticator`] — K8s projected SA +//! tokens (path-scoped to `IssueSandboxToken`) +//! - [`super::oidc::OidcAuthenticator`] — user OIDC Bearer tokens +use super::principal::Principal; +use async_trait::async_trait; +use std::sync::Arc; +use tonic::Status; + +/// Pluggable authentication step. +/// +/// Implementations are expected to be cheap to clone (they live behind +/// `Arc` inside an [`AuthenticatorChain`]). +#[async_trait] +pub trait Authenticator: Send + Sync + 'static { + /// Inspect an inbound request and return the authenticated principal. + /// + /// - `Ok(Some(principal))` — this authenticator recognized the caller. + /// The chain stops and the principal is inserted into request + /// extensions. + /// - `Ok(None)` — this authenticator does not apply (e.g. no Bearer + /// token for an OIDC authenticator). The chain falls through to + /// the next authenticator. + /// - `Err(status)` — this authenticator applies but rejected the + /// caller. The chain terminates and the status is returned to the + /// client. Fail-closed. + async fn authenticate( + &self, + headers: &http::HeaderMap, + path: &str, + ) -> Result, Status>; +} + +/// First-match-wins authenticator chain. +/// +/// The chain owns its authenticators behind `Arc` so the entire chain is +/// cheap to clone — required because `tower::Service::call` clones the +/// router on every request. +#[derive(Clone)] +pub struct AuthenticatorChain { + authenticators: Arc<[Arc]>, +} + +impl AuthenticatorChain { + /// Build a chain from an ordered list of authenticators. Earlier + /// entries are evaluated first. + pub fn new(authenticators: Vec>) -> Self { + Self { + authenticators: Arc::from(authenticators), + } + } + + /// Run the chain. Returns the first principal produced. If every + /// authenticator returns `Ok(None)`, the result is `Ok(None)` — the + /// router translates that to `unauthenticated`. + pub async fn authenticate( + &self, + headers: &http::HeaderMap, + path: &str, + ) -> Result, Status> { + for authenticator in self.authenticators.iter() { + if let Some(principal) = authenticator.authenticate(headers, path).await? { + return Ok(Some(principal)); + } + } + Ok(None) + } +} + +impl std::fmt::Debug for AuthenticatorChain { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AuthenticatorChain") + .field("len", &self.authenticators.len()) + .finish() + } +} + +#[cfg(test)] +pub mod test_support { + use super::*; + use std::sync::Mutex; + + /// Authenticator that always returns the configured outcome. Used by + /// tests to inject a known principal (or rejection) without running real + /// crypto. Each call records the path it was invoked with so tests can + /// assert chain ordering. + pub struct MockAuthenticator { + pub outcome: Result, Status>, + pub calls: Mutex>, + } + + impl MockAuthenticator { + pub fn returning(outcome: Result, Status>) -> Self { + Self { + outcome, + calls: Mutex::new(Vec::new()), + } + } + + pub fn call_count(&self) -> usize { + self.calls.lock().unwrap().len() + } + } + + #[async_trait] + impl Authenticator for MockAuthenticator { + async fn authenticate( + &self, + _headers: &http::HeaderMap, + path: &str, + ) -> Result, Status> { + self.calls.lock().unwrap().push(path.to_string()); + self.outcome.clone() + } + } +} + +#[cfg(test)] +mod tests { + use super::test_support::MockAuthenticator; + use super::*; + use crate::auth::identity::{Identity, IdentityProvider}; + use crate::auth::principal::UserPrincipal; + + fn user_principal(subject: &str) -> Principal { + Principal::User(UserPrincipal { + identity: Identity { + subject: subject.to_string(), + display_name: None, + roles: vec![], + scopes: vec![], + provider: IdentityProvider::Oidc, + }, + }) + } + + #[tokio::test] + async fn chain_returns_first_match() { + let first = Arc::new(MockAuthenticator::returning(Ok(Some(user_principal( + "alice", + ))))); + let second = Arc::new(MockAuthenticator::returning(Ok(Some(user_principal( + "bob", + ))))); + let chain = AuthenticatorChain::new(vec![first.clone(), second.clone()]); + let result = chain + .authenticate(&http::HeaderMap::new(), "/some/path") + .await + .unwrap() + .expect("expected a principal"); + match result { + Principal::User(u) => assert_eq!(u.identity.subject, "alice"), + _ => panic!("expected user principal"), + } + assert_eq!(first.call_count(), 1); + assert_eq!( + second.call_count(), + 0, + "second authenticator must be skipped after first matches" + ); + } + + #[tokio::test] + async fn chain_falls_through_on_none() { + let first = Arc::new(MockAuthenticator::returning(Ok(None))); + let second = Arc::new(MockAuthenticator::returning(Ok(Some(user_principal( + "bob", + ))))); + let chain = AuthenticatorChain::new(vec![first.clone(), second.clone()]); + let result = chain + .authenticate(&http::HeaderMap::new(), "/some/path") + .await + .unwrap() + .expect("expected a principal"); + match result { + Principal::User(u) => assert_eq!(u.identity.subject, "bob"), + _ => panic!("expected user principal"), + } + assert_eq!(first.call_count(), 1); + assert_eq!(second.call_count(), 1); + } + + #[tokio::test] + async fn chain_fails_closed_on_first_error() { + let first = Arc::new(MockAuthenticator::returning(Err(Status::unauthenticated( + "bad token", + )))); + let second = Arc::new(MockAuthenticator::returning(Ok(Some(user_principal( + "bob", + ))))); + let chain = AuthenticatorChain::new(vec![first.clone(), second.clone()]); + let err = chain + .authenticate(&http::HeaderMap::new(), "/some/path") + .await + .expect_err("must short-circuit on error"); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + assert_eq!(first.call_count(), 1); + assert_eq!( + second.call_count(), + 0, + "must not consult later authenticators after an error" + ); + } + + #[tokio::test] + async fn empty_chain_returns_none() { + let chain = AuthenticatorChain::new(vec![]); + let result = chain + .authenticate(&http::HeaderMap::new(), "/some/path") + .await + .unwrap(); + assert!(result.is_none()); + } +} diff --git a/crates/openshell-server/src/auth/authz.rs b/crates/openshell-server/src/auth/authz.rs index 05ac19354..b4aa072d6 100644 --- a/crates/openshell-server/src/auth/authz.rs +++ b/crates/openshell-server/src/auth/authz.rs @@ -9,7 +9,8 @@ //! identity verification. //! //! This separation follows RFC 0001's control-plane identity design: -//! authentication is a driver concern, authorization is a gateway concern. +//! authentication is handled by explicit application-layer authenticators, +//! authorization is a gateway concern. use super::identity::Identity; use tonic::Status; @@ -22,6 +23,9 @@ const ADMIN_METHODS: &[&str] = &[ "/openshell.v1.OpenShell/CreateProvider", "/openshell.v1.OpenShell/UpdateProvider", "/openshell.v1.OpenShell/DeleteProvider", + "/openshell.v1.OpenShell/ConfigureProviderRefresh", + "/openshell.v1.OpenShell/RotateProviderCredential", + "/openshell.v1.OpenShell/DeleteProviderRefresh", // Global config and policy "/openshell.v1.OpenShell/UpdateConfig", // Draft policy approvals @@ -41,8 +45,14 @@ const SCOPED_METHODS: &[(&str, &str)] = &[ // sandbox:read ("/openshell.v1.OpenShell/GetSandbox", "sandbox:read"), ("/openshell.v1.OpenShell/ListSandboxes", "sandbox:read"), + ( + "/openshell.v1.OpenShell/ListSandboxProviders", + "sandbox:read", + ), ("/openshell.v1.OpenShell/WatchSandbox", "sandbox:read"), ("/openshell.v1.OpenShell/GetSandboxLogs", "sandbox:read"), + ("/openshell.v1.OpenShell/GetService", "sandbox:read"), + ("/openshell.v1.OpenShell/ListServices", "sandbox:read"), ( "/openshell.v1.OpenShell/GetSandboxPolicyStatus", "sandbox:read", @@ -55,15 +65,42 @@ const SCOPED_METHODS: &[(&str, &str)] = &[ ("/openshell.v1.OpenShell/CreateSandbox", "sandbox:write"), ("/openshell.v1.OpenShell/DeleteSandbox", "sandbox:write"), ("/openshell.v1.OpenShell/ExecSandbox", "sandbox:write"), + ("/openshell.v1.OpenShell/ForwardTcp", "sandbox:write"), ("/openshell.v1.OpenShell/CreateSshSession", "sandbox:write"), ("/openshell.v1.OpenShell/RevokeSshSession", "sandbox:write"), + ("/openshell.v1.OpenShell/ExposeService", "sandbox:write"), + ("/openshell.v1.OpenShell/DeleteService", "sandbox:write"), + ( + "/openshell.v1.OpenShell/AttachSandboxProvider", + "sandbox:write", + ), + ( + "/openshell.v1.OpenShell/DetachSandboxProvider", + "sandbox:write", + ), // provider:read ("/openshell.v1.OpenShell/GetProvider", "provider:read"), ("/openshell.v1.OpenShell/ListProviders", "provider:read"), + ( + "/openshell.v1.OpenShell/GetProviderRefreshStatus", + "provider:read", + ), // provider:write ("/openshell.v1.OpenShell/CreateProvider", "provider:write"), ("/openshell.v1.OpenShell/UpdateProvider", "provider:write"), ("/openshell.v1.OpenShell/DeleteProvider", "provider:write"), + ( + "/openshell.v1.OpenShell/ConfigureProviderRefresh", + "provider:write", + ), + ( + "/openshell.v1.OpenShell/RotateProviderCredential", + "provider:write", + ), + ( + "/openshell.v1.OpenShell/DeleteProviderRefresh", + "provider:write", + ), // config:read ("/openshell.v1.OpenShell/GetGatewayConfig", "config:read"), ("/openshell.v1.OpenShell/GetSandboxConfig", "config:read"), @@ -398,11 +435,51 @@ mod tests { .check(&id, "/openshell.v1.OpenShell/ListSandboxes") .is_ok() ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/ListSandboxProviders") + .is_ok() + ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/ListServices") + .is_ok() + ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/GetService") + .is_ok() + ); assert!( policy .check(&id, "/openshell.v1.OpenShell/CreateSandbox") .is_ok() ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/ForwardTcp") + .is_ok() + ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/ExposeService") + .is_ok() + ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/DeleteService") + .is_ok() + ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/AttachSandboxProvider") + .is_ok() + ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/DetachSandboxProvider") + .is_ok() + ); } #[test] @@ -414,11 +491,76 @@ mod tests { .check(&id, "/openshell.v1.OpenShell/ListSandboxes") .is_ok() ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/ListServices") + .is_ok() + ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/GetService") + .is_ok() + ); + let err = policy + .check(&id, "/openshell.v1.OpenShell/AttachSandboxProvider") + .unwrap_err(); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + assert!(err.message().contains("sandbox:write")); + let err = policy - .check(&id, "/openshell.v1.OpenShell/CreateSandbox") + .check(&id, "/openshell.v1.OpenShell/ExposeService") .unwrap_err(); assert_eq!(err.code(), tonic::Code::PermissionDenied); assert!(err.message().contains("sandbox:write")); + + let err = policy + .check(&id, "/openshell.v1.OpenShell/DeleteService") + .unwrap_err(); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + assert!(err.message().contains("sandbox:write")); + } + + #[test] + fn provider_refresh_methods_require_provider_scopes_and_admin_for_writes() { + let policy = scoped_policy(); + let reader = identity_with_roles_and_scopes(&["openshell-user"], &["provider:read"]); + assert!( + policy + .check(&reader, "/openshell.v1.OpenShell/GetProviderRefreshStatus") + .is_ok() + ); + + let writer_without_admin = + identity_with_roles_and_scopes(&["openshell-user"], &["provider:write"]); + let err = policy + .check( + &writer_without_admin, + "/openshell.v1.OpenShell/ConfigureProviderRefresh", + ) + .unwrap_err(); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + assert!(err.message().contains("openshell-admin")); + + let admin_without_scope = + identity_with_roles_and_scopes(&["openshell-admin"], &["provider:read"]); + let err = policy + .check( + &admin_without_scope, + "/openshell.v1.OpenShell/RotateProviderCredential", + ) + .unwrap_err(); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + assert!(err.message().contains("provider:write")); + + let admin_writer = + identity_with_roles_and_scopes(&["openshell-admin"], &["provider:write"]); + for method in [ + "/openshell.v1.OpenShell/ConfigureProviderRefresh", + "/openshell.v1.OpenShell/RotateProviderCredential", + "/openshell.v1.OpenShell/DeleteProviderRefresh", + ] { + assert!(policy.check(&admin_writer, method).is_ok(), "{method}"); + } } #[test] diff --git a/crates/openshell-server/src/auth/guard.rs b/crates/openshell-server/src/auth/guard.rs new file mode 100644 index 000000000..edcd6bc01 --- /dev/null +++ b/crates/openshell-server/src/auth/guard.rs @@ -0,0 +1,177 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Per-handler sandbox-scope guards. +//! +//! Closes the IDOR half of issue #1354: a sandbox principal may only +//! reference its own sandbox, identified by its [`Principal::Sandbox`]'s +//! `sandbox_id`. User principals retain the broad scope the RBAC layer +//! already evaluated. + +use super::principal::Principal; +use super::principal::SandboxPrincipal; +use tonic::Status; +use tracing::info; + +/// Reject a sandbox-class request whose body references a sandbox other +/// than the one the calling principal was authenticated against. +/// +/// - [`Principal::User`] passes through (RBAC has already evaluated user +/// scope at the router level). +/// - [`Principal::Sandbox`] must reference the same canonical UUID it +/// was authenticated with. +/// - [`Principal::Anonymous`] is rejected — sandbox-class methods are +/// never anonymously callable. +/// +/// `claimed_sandbox_id` is the canonical UUID the request is operating +/// on. Name-keyed handlers must resolve the name to a UUID via the +/// store before calling this guard. +#[allow(clippy::result_large_err)] +pub fn ensure_sandbox_scope(principal: &Principal, claimed_sandbox_id: &str) -> Result<(), Status> { + match principal { + Principal::User(_) => Ok(()), + Principal::Sandbox(p) => { + if p.sandbox_id == claimed_sandbox_id { + Ok(()) + } else { + info!( + principal_sandbox_id = %p.sandbox_id, + requested_sandbox_id = %claimed_sandbox_id, + "cross-sandbox access denied" + ); + Err(Status::permission_denied( + "cross-sandbox access denied: principal does not own this sandbox", + )) + } + } + Principal::Anonymous => Err(Status::unauthenticated( + "sandbox-scoped methods require an authenticated caller", + )), + } +} + +/// Convenience: read the `Principal` out of a request and apply +/// [`ensure_sandbox_scope`]. Returns the principal so callers can read it +/// further (e.g. for audit logging). +#[allow(clippy::result_large_err)] +pub fn enforce_sandbox_scope( + request: &tonic::Request, + claimed_sandbox_id: &str, +) -> Result { + let principal = request + .extensions() + .get::() + .cloned() + .ok_or_else(|| Status::unauthenticated("missing principal"))?; + ensure_sandbox_scope(&principal, claimed_sandbox_id)?; + Ok(principal) +} + +/// Require a sandbox principal and reject users or anonymous callers. +/// +/// Supervisor-only control/data plane RPCs (`ConnectSupervisor`, +/// `RelayStream`) must be presented by the sandbox supervisor itself. +/// User principals intentionally pass [`ensure_sandbox_scope`] for normal +/// CLI/TUI APIs because RBAC is their gate, but they are not valid +/// supervisor identities. +#[allow(clippy::result_large_err)] +pub fn ensure_sandbox_principal_scope( + principal: &Principal, + claimed_sandbox_id: &str, +) -> Result { + match principal { + Principal::Sandbox(p) => { + ensure_sandbox_scope(principal, claimed_sandbox_id)?; + Ok(p.clone()) + } + Principal::User(_) => Err(Status::permission_denied( + "supervisor RPCs require a sandbox principal", + )), + Principal::Anonymous => Err(Status::unauthenticated( + "supervisor RPCs require an authenticated sandbox principal", + )), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::identity::{Identity, IdentityProvider}; + use crate::auth::principal::{SandboxIdentitySource, SandboxPrincipal, UserPrincipal}; + + fn user(subject: &str) -> Principal { + Principal::User(UserPrincipal { + identity: Identity { + subject: subject.to_string(), + display_name: None, + roles: vec![], + scopes: vec![], + provider: IdentityProvider::Oidc, + }, + }) + } + + fn sandbox(id: &str) -> Principal { + Principal::Sandbox(SandboxPrincipal { + sandbox_id: id.to_string(), + source: SandboxIdentitySource::BootstrapJwt { + issuer: "openshell-gateway:test".to_string(), + }, + trust_domain: Some("openshell".to_string()), + }) + } + + #[test] + fn user_principal_bypasses_equality_check() { + // RBAC was the user's gate at the router layer. + assert!(ensure_sandbox_scope(&user("alice"), "any-sandbox").is_ok()); + } + + #[test] + fn sandbox_principal_matching_id_is_allowed() { + assert!(ensure_sandbox_scope(&sandbox("sbx-1"), "sbx-1").is_ok()); + } + + #[test] + fn sandbox_principal_mismatched_id_is_denied() { + let err = + ensure_sandbox_scope(&sandbox("sbx-1"), "sbx-2").expect_err("must deny cross-sandbox"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[test] + fn anonymous_principal_is_rejected() { + let err = + ensure_sandbox_scope(&Principal::Anonymous, "sbx-1").expect_err("must reject anon"); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + } + + #[test] + fn sandbox_principal_scope_returns_matching_sandbox() { + let principal = sandbox("sbx-1"); + let scoped = ensure_sandbox_principal_scope(&principal, "sbx-1").expect("scope OK"); + assert_eq!(scoped.sandbox_id, "sbx-1"); + } + + #[test] + fn sandbox_principal_scope_rejects_users() { + let err = ensure_sandbox_principal_scope(&user("alice"), "sbx-1") + .expect_err("users are not supervisor identities"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[test] + fn enforce_reads_from_request_extensions() { + let mut req = tonic::Request::new(()); + req.extensions_mut().insert(sandbox("sbx-1")); + let result = enforce_sandbox_scope(&req, "sbx-1").expect("scope OK"); + assert!(matches!(result, Principal::Sandbox(_))); + } + + #[test] + fn enforce_rejects_request_without_principal() { + let req = tonic::Request::new(()); + let err = enforce_sandbox_scope(&req, "sbx-1").expect_err("must require principal"); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + } +} diff --git a/crates/openshell-server/src/auth/identity.rs b/crates/openshell-server/src/auth/identity.rs index 8c3da3cc3..fa504e34b 100644 --- a/crates/openshell-server/src/auth/identity.rs +++ b/crates/openshell-server/src/auth/identity.rs @@ -39,6 +39,6 @@ pub enum IdentityProvider { Mtls, /// Cloudflare Access JWT. CloudflareAccess, - /// Internal (skip-listed methods, sandbox supervisor RPCs). - Internal, + /// Explicit unauthenticated local-development user principal. + LocalDev, } diff --git a/crates/openshell-server/src/auth/k8s_sa.rs b/crates/openshell-server/src/auth/k8s_sa.rs new file mode 100644 index 000000000..59f6651bb --- /dev/null +++ b/crates/openshell-server/src/auth/k8s_sa.rs @@ -0,0 +1,840 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Kubernetes `ServiceAccount` bootstrap authenticator. +//! +//! Path-scoped to `IssueSandboxToken`. Validates a projected SA token +//! presented by a sandbox pod, reads the pod's `openshell.io/sandbox-id` +//! annotation, verifies the pod is controlled by the corresponding Sandbox CR, +//! and returns a [`Principal::Sandbox`] with +//! [`SandboxIdentitySource::K8sServiceAccount`]. The `IssueSandboxToken` handler +//! then mints a gateway-signed JWT for that sandbox id; subsequent gRPC calls +//! from the supervisor use the gateway-minted JWT validated by +//! [`super::sandbox_jwt::SandboxJwtAuthenticator`]. +//! +//! This is the only authenticator that talks to the K8s apiserver. It is +//! optional — the gateway boots without it in singleplayer deployments. + +use super::authenticator::Authenticator; +use super::principal::{Principal, SandboxIdentitySource, SandboxPrincipal}; +use async_trait::async_trait; +use k8s_openapi::api::{ + authentication::v1::{TokenReview, TokenReviewSpec, TokenReviewStatus, UserInfo}, + core::v1::Pod, +}; +use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; +use kube::api::{Api, ApiResource, PostParams}; +use kube::core::{DynamicObject, gvk::GroupVersionKind}; +use std::sync::Arc; +use tonic::Status; +use tracing::{debug, info, warn}; + +/// gRPC method path that this authenticator accepts. All other paths fall +/// through (return `Ok(None)`) so a gateway-minted JWT is required there. +pub const ISSUE_SANDBOX_TOKEN_PATH: &str = "/openshell.v1.OpenShell/IssueSandboxToken"; + +/// Pod annotation that binds a sandbox pod to its UUID. Set by the +/// Kubernetes compute driver at pod-create time. The gateway accepts this +/// annotation only after validating the pod's `TokenReview` binding, live UID, +/// and owning Sandbox CR. The K8s `Role` granted to the gateway must not +/// include `patch pods` (see plan §11.8). +pub const SANDBOX_ID_ANNOTATION: &str = "openshell.io/sandbox-id"; +const SANDBOX_API_GROUP: &str = "agents.x-k8s.io"; +const SANDBOX_API_VERSION: &str = "v1alpha1"; +const SANDBOX_API_VERSION_FULL: &str = "agents.x-k8s.io/v1alpha1"; +const SANDBOX_KIND: &str = "Sandbox"; +const SANDBOX_ID_LABEL: &str = "openshell.ai/sandbox-id"; +const POD_NAME_EXTRA: &str = "authentication.kubernetes.io/pod-name"; +const POD_UID_EXTRA: &str = "authentication.kubernetes.io/pod-uid"; + +/// Resolved identity extracted from a validated SA token + pod lookup. +#[derive(Debug, Clone)] +pub struct ResolvedK8sIdentity { + pub sandbox_id: String, + pub pod_name: String, + pub pod_uid: String, +} + +/// Apiserver-facing operations the authenticator depends on. Split out so +/// tests can fake the apiserver without standing up a kube cluster. +#[async_trait] +pub trait K8sIdentityResolver: Send + Sync + 'static { + /// Validate `token` via `TokenReview` (`aud == openshell-gateway`), + /// extract the pod name/uid, then `GET` the pod and read + /// `openshell.io/sandbox-id`. Returns `Ok(None)` when the token is + /// well-formed but does not authenticate (e.g. wrong audience); returns + /// `Err` for transport/server errors. + async fn resolve(&self, token: &str) -> Result, Status>; +} + +/// Authenticator wrapper around a [`K8sIdentityResolver`]. +pub struct K8sServiceAccountAuthenticator { + resolver: Arc, +} + +impl std::fmt::Debug for K8sServiceAccountAuthenticator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("K8sServiceAccountAuthenticator") + .finish_non_exhaustive() + } +} + +impl K8sServiceAccountAuthenticator { + pub fn new(resolver: Arc) -> Self { + Self { resolver } + } +} + +#[async_trait] +impl Authenticator for K8sServiceAccountAuthenticator { + async fn authenticate( + &self, + headers: &http::HeaderMap, + path: &str, + ) -> Result, Status> { + // Scope: only the bootstrap RPC. Other paths fall through so the + // SandboxJwtAuthenticator (or OIDC) handles them. + if path != ISSUE_SANDBOX_TOKEN_PATH { + return Ok(None); + } + + let Some(token) = headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + else { + return Ok(None); + }; + + let Some(resolved) = self.resolver.resolve(token).await? else { + debug!("K8s SA token did not authenticate; falling through"); + return Ok(None); + }; + + if resolved.sandbox_id.is_empty() { + warn!( + pod = %resolved.pod_name, + "pod missing openshell.io/sandbox-id annotation; rejecting" + ); + return Err(Status::permission_denied( + "pod is not bound to a sandbox identity", + )); + } + + Ok(Some(Principal::Sandbox(SandboxPrincipal { + sandbox_id: resolved.sandbox_id, + source: SandboxIdentitySource::K8sServiceAccount { + pod_name: resolved.pod_name, + pod_uid: resolved.pod_uid, + }, + trust_domain: Some("openshell".to_string()), + }))) + } +} + +#[derive(Debug)] +struct TokenReviewIdentity { + pod_name: String, + pod_uid: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SandboxOwnerReference { + name: String, + uid: String, +} + +/// Resolver backed by the apiserver's `TokenReview` API and `kube::Client` +/// for the per-pod annotation lookup. +pub struct LiveK8sResolver { + token_reviews_api: Api, + pods_api: Api, + sandboxes_api: Api, + expected_audience: String, + sandbox_namespace: String, + expected_service_account: String, +} + +impl LiveK8sResolver { + pub fn new( + client: kube::Client, + namespace: &str, + expected_audience: String, + expected_service_account: String, + ) -> Self { + let token_reviews_api: Api = Api::all(client.clone()); + let pods_api: Api = Api::namespaced(client.clone(), namespace); + let sandbox_gvk = + GroupVersionKind::gvk(SANDBOX_API_GROUP, SANDBOX_API_VERSION, SANDBOX_KIND); + let sandbox_resource = ApiResource::from_gvk(&sandbox_gvk); + let sandboxes_api: Api = + Api::namespaced_with(client, namespace, &sandbox_resource); + Self { + token_reviews_api, + pods_api, + sandboxes_api, + expected_audience, + sandbox_namespace: namespace.to_string(), + expected_service_account, + } + } +} + +#[async_trait] +impl K8sIdentityResolver for LiveK8sResolver { + async fn resolve(&self, token: &str) -> Result, Status> { + let review = TokenReview { + metadata: ObjectMeta::default(), + spec: TokenReviewSpec { + audiences: Some(vec![self.expected_audience.clone()]), + token: Some(token.to_string()), + }, + status: None, + }; + + let review = self + .token_reviews_api + .create(&PostParams::default(), &review) + .await + .map_err(|e| { + warn!(error = %e, "K8s TokenReview failed"); + Status::internal(format!("tokenreview failed: {e}")) + })?; + let status = review + .status + .ok_or_else(|| Status::internal("TokenReview response missing status"))?; + let Some(identity) = token_review_identity( + &status, + &self.expected_audience, + &self.sandbox_namespace, + &self.expected_service_account, + )? + else { + return Ok(None); + }; + + info!( + pod_name = %identity.pod_name, + pod_uid = %identity.pod_uid, + service_account = %self.expected_service_account, + "validated K8s SA token via TokenReview" + ); + + // Look up the pod and read its sandbox-id annotation. + let pod = self + .pods_api + .get_opt(&identity.pod_name) + .await + .map_err(|e| { + warn!( + pod = %identity.pod_name, + error = %e, + "failed to fetch sandbox pod for annotation lookup" + ); + Status::internal(format!("pod GET failed: {e}")) + })?; + let Some(pod) = pod else { + warn!( + pod = %identity.pod_name, + "sandbox pod referenced by SA token not found in this namespace" + ); + return Err(Status::not_found("sandbox pod not found")); + }; + + // Defense-in-depth: confirm the pod UID matches the SA token's + // `kubernetes.io.pod.uid`. Prevents a replayed token from a + // recreated pod with the same name. + let actual_uid = pod.metadata.uid.as_deref().unwrap_or_default(); + if actual_uid != identity.pod_uid { + warn!( + pod = %identity.pod_name, + claimed_uid = %identity.pod_uid, + actual_uid = %actual_uid, + "SA token pod UID does not match live pod; rejecting" + ); + return Err(Status::permission_denied("SA token pod UID mismatch")); + } + + let sandbox_id = pod_sandbox_id(&pod)?; + + let owner = sandbox_owner_reference(&pod)?; + let sandbox_cr = self.sandboxes_api.get_opt(&owner.name).await.map_err(|e| { + warn!( + pod = %identity.pod_name, + sandbox_owner = %owner.name, + error = %e, + "failed to fetch owning Sandbox CR for pod identity validation" + ); + Status::internal(format!("sandbox GET failed: {e}")) + })?; + let Some(sandbox_cr) = sandbox_cr else { + warn!( + pod = %identity.pod_name, + sandbox_owner = %owner.name, + "pod ownerReference points to a Sandbox CR that does not exist" + ); + return Err(Status::permission_denied("sandbox owner not found")); + }; + validate_sandbox_owner_reference(&owner, &sandbox_id, &sandbox_cr)?; + + Ok(Some(ResolvedK8sIdentity { + sandbox_id, + pod_name: identity.pod_name, + pod_uid: identity.pod_uid, + })) + } +} + +#[allow(clippy::result_large_err)] +fn token_review_identity( + status: &TokenReviewStatus, + expected_audience: &str, + sandbox_namespace: &str, + expected_service_account: &str, +) -> Result, Status> { + if status.authenticated != Some(true) { + debug!( + error = status.error.as_deref().unwrap_or_default(), + "K8s TokenReview did not authenticate token" + ); + return Ok(None); + } + + let audiences = status.audiences.as_deref().unwrap_or_default(); + if !audiences.iter().any(|aud| aud == expected_audience) { + warn!( + expected_audience = %expected_audience, + audiences = ?audiences, + "K8s TokenReview authenticated token without expected audience" + ); + return Err(Status::unauthenticated("SA token audience not accepted")); + } + + let user = status + .user + .as_ref() + .ok_or_else(|| Status::permission_denied("TokenReview response missing user info"))?; + let username = user + .username + .as_deref() + .ok_or_else(|| Status::permission_denied("TokenReview response missing username"))?; + let expected_username = + format!("system:serviceaccount:{sandbox_namespace}:{expected_service_account}"); + if username != expected_username { + warn!( + username = %username, + sandbox_namespace = %sandbox_namespace, + service_account = %expected_service_account, + "K8s TokenReview principal is not the configured sandbox service account" + ); + return Err(Status::permission_denied( + "SA token is not from the configured sandbox service account", + )); + } + + let pod_name = user_extra_one(user, POD_NAME_EXTRA)?; + let pod_uid = user_extra_one(user, POD_UID_EXTRA)?; + Ok(Some(TokenReviewIdentity { pod_name, pod_uid })) +} + +#[allow(clippy::result_large_err)] +fn user_extra_one(user: &UserInfo, key: &str) -> Result { + let Some(values) = user.extra.as_ref().and_then(|extra| extra.get(key)) else { + return Err(Status::permission_denied("SA token is not pod-bound")); + }; + if values.len() != 1 || values[0].is_empty() { + return Err(Status::permission_denied( + "SA token has invalid pod binding", + )); + } + Ok(values[0].clone()) +} + +#[allow(clippy::result_large_err)] +fn pod_sandbox_id(pod: &Pod) -> Result { + let sandbox_id = pod + .metadata + .annotations + .as_ref() + .and_then(|a| a.get(SANDBOX_ID_ANNOTATION)) + .cloned() + .unwrap_or_default(); + if sandbox_id.is_empty() { + return Err(Status::permission_denied( + "pod is not bound to a sandbox identity", + )); + } + Ok(sandbox_id) +} + +#[allow(clippy::result_large_err)] +fn sandbox_owner_reference(pod: &Pod) -> Result { + let owner_refs = pod.metadata.owner_references.as_deref().unwrap_or_default(); + let mut sandbox_refs = owner_refs.iter().filter(|owner| { + owner.api_version == SANDBOX_API_VERSION_FULL && owner.kind == SANDBOX_KIND + }); + let Some(owner) = sandbox_refs.next() else { + return Err(Status::permission_denied( + "pod is not controlled by an OpenShell Sandbox", + )); + }; + if sandbox_refs.next().is_some() { + return Err(Status::permission_denied( + "pod has multiple OpenShell Sandbox owners", + )); + } + if owner.controller != Some(true) { + return Err(Status::permission_denied( + "pod Sandbox ownerReference is not controlling", + )); + } + if owner.name.is_empty() || owner.uid.is_empty() { + return Err(Status::permission_denied( + "pod Sandbox ownerReference is incomplete", + )); + } + Ok(SandboxOwnerReference { + name: owner.name.clone(), + uid: owner.uid.clone(), + }) +} + +#[allow(clippy::result_large_err)] +fn validate_sandbox_owner_reference( + owner: &SandboxOwnerReference, + sandbox_id: &str, + sandbox_cr: &DynamicObject, +) -> Result<(), Status> { + let actual_uid = sandbox_cr.metadata.uid.as_deref().unwrap_or_default(); + if actual_uid != owner.uid { + warn!( + sandbox_owner = %owner.name, + owner_uid = %owner.uid, + actual_uid = %actual_uid, + "pod Sandbox ownerReference UID does not match live Sandbox CR" + ); + return Err(Status::permission_denied("sandbox owner UID mismatch")); + } + + let actual_sandbox_id = sandbox_cr + .metadata + .labels + .as_ref() + .and_then(|labels| labels.get(SANDBOX_ID_LABEL)) + .map(String::as_str) + .unwrap_or_default(); + if actual_sandbox_id != sandbox_id { + warn!( + sandbox_owner = %owner.name, + owner_uid = %owner.uid, + pod_sandbox_id = %sandbox_id, + cr_sandbox_id = %actual_sandbox_id, + "pod sandbox annotation does not match owning Sandbox CR label" + ); + return Err(Status::permission_denied("sandbox owner ID mismatch")); + } + + Ok(()) +} + +#[cfg(test)] +pub mod test_support { + use super::*; + use std::sync::Mutex; + + /// Fake resolver for unit tests. Returns the configured outcome on + /// every call and records the tokens it observed. + pub struct FakeResolver { + pub outcome: Result, Status>, + pub seen_tokens: Mutex>, + } + + impl FakeResolver { + pub fn returning(outcome: Result, Status>) -> Self { + Self { + outcome, + seen_tokens: Mutex::new(Vec::new()), + } + } + } + + #[async_trait] + impl K8sIdentityResolver for FakeResolver { + async fn resolve(&self, token: &str) -> Result, Status> { + self.seen_tokens.lock().unwrap().push(token.to_string()); + match &self.outcome { + Ok(opt) => Ok(opt.clone()), + Err(s) => Err(Status::new(s.code(), s.message())), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::test_support::FakeResolver; + use super::*; + use k8s_openapi::apimachinery::pkg::apis::meta::v1::OwnerReference; + use std::collections::BTreeMap; + + fn bearer_headers(token: &str) -> http::HeaderMap { + let mut h = http::HeaderMap::new(); + h.insert( + "authorization", + http::HeaderValue::from_str(&format!("Bearer {token}")).unwrap(), + ); + h + } + + fn token_review_status( + authenticated: bool, + audiences: Vec<&str>, + username: &str, + extra: Vec<(&str, &str)>, + ) -> TokenReviewStatus { + TokenReviewStatus { + authenticated: Some(authenticated), + audiences: Some(audiences.into_iter().map(str::to_string).collect()), + error: None, + user: Some(UserInfo { + username: Some(username.to_string()), + uid: Some("sa-uid".to_string()), + groups: Some(vec![ + "system:serviceaccounts".to_string(), + "system:serviceaccounts:openshell".to_string(), + "system:authenticated".to_string(), + ]), + extra: Some( + extra + .into_iter() + .map(|(k, v)| (k.to_string(), vec![v.to_string()])) + .collect::>(), + ), + }), + } + } + + fn sandbox_owner(name: &str, uid: &str) -> OwnerReference { + OwnerReference { + api_version: SANDBOX_API_VERSION_FULL.to_string(), + block_owner_deletion: None, + controller: Some(true), + kind: SANDBOX_KIND.to_string(), + name: name.to_string(), + uid: uid.to_string(), + } + } + + fn pod_with_owner_refs(owner_references: Vec) -> Pod { + Pod { + metadata: ObjectMeta { + owner_references: Some(owner_references), + ..Default::default() + }, + ..Default::default() + } + } + + fn pod_with_sandbox_id(sandbox_id: Option<&str>) -> Pod { + Pod { + metadata: ObjectMeta { + annotations: sandbox_id.map(|id| { + BTreeMap::from([(SANDBOX_ID_ANNOTATION.to_string(), id.to_string())]) + }), + ..Default::default() + }, + ..Default::default() + } + } + + fn sandbox_cr(name: &str, uid: &str, sandbox_id: &str) -> DynamicObject { + let sandbox_gvk = + GroupVersionKind::gvk(SANDBOX_API_GROUP, SANDBOX_API_VERSION, SANDBOX_KIND); + let sandbox_resource = ApiResource::from_gvk(&sandbox_gvk); + let mut cr = DynamicObject::new(name, &sandbox_resource); + cr.metadata.uid = Some(uid.to_string()); + cr.metadata.labels = Some(BTreeMap::from([( + SANDBOX_ID_LABEL.to_string(), + sandbox_id.to_string(), + )])); + cr + } + + #[test] + fn token_review_identity_extracts_pod_binding() { + let status = token_review_status( + true, + vec!["openshell-gateway"], + "system:serviceaccount:openshell:default", + vec![ + (POD_NAME_EXTRA, "openshell-sandbox-a"), + (POD_UID_EXTRA, "uid-a"), + ], + ); + + let identity = token_review_identity(&status, "openshell-gateway", "openshell", "default") + .unwrap() + .expect("authenticated token should resolve"); + + assert_eq!(identity.pod_name, "openshell-sandbox-a"); + assert_eq!(identity.pod_uid, "uid-a"); + } + + #[test] + fn token_review_identity_returns_none_when_not_authenticated() { + let status = TokenReviewStatus { + authenticated: Some(false), + error: Some("invalid audience".to_string()), + ..Default::default() + }; + + assert!( + token_review_identity(&status, "openshell-gateway", "openshell", "default") + .unwrap() + .is_none() + ); + } + + #[test] + fn token_review_identity_requires_expected_audience() { + let status = token_review_status( + true, + vec!["kubernetes.default.svc"], + "system:serviceaccount:openshell:default", + vec![ + (POD_NAME_EXTRA, "openshell-sandbox-a"), + (POD_UID_EXTRA, "uid-a"), + ], + ); + + let err = token_review_identity(&status, "openshell-gateway", "openshell", "default") + .expect_err("wrong audience must fail closed"); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + } + + #[test] + fn token_review_identity_requires_sandbox_namespace() { + let status = token_review_status( + true, + vec!["openshell-gateway"], + "system:serviceaccount:other:default", + vec![ + (POD_NAME_EXTRA, "openshell-sandbox-a"), + (POD_UID_EXTRA, "uid-a"), + ], + ); + + let err = token_review_identity(&status, "openshell-gateway", "openshell", "default") + .expect_err("other namespace must be rejected"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[test] + fn token_review_identity_requires_configured_service_account() { + let status = token_review_status( + true, + vec!["openshell-gateway"], + "system:serviceaccount:openshell:other", + vec![ + (POD_NAME_EXTRA, "openshell-sandbox-a"), + (POD_UID_EXTRA, "uid-a"), + ], + ); + + let err = token_review_identity(&status, "openshell-gateway", "openshell", "default") + .expect_err("other service account must be rejected"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[test] + fn token_review_identity_requires_pod_bound_extras() { + let status = token_review_status( + true, + vec!["openshell-gateway"], + "system:serviceaccount:openshell:default", + vec![], + ); + + let err = token_review_identity(&status, "openshell-gateway", "openshell", "default") + .expect_err("non pod-bound tokens must be rejected"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[test] + fn pod_sandbox_id_requires_annotation() { + assert_eq!( + pod_sandbox_id(&pod_with_sandbox_id(Some("sandbox-id-a"))).unwrap(), + "sandbox-id-a" + ); + + let err = pod_sandbox_id(&pod_with_sandbox_id(None)) + .expect_err("missing sandbox-id annotation must fail"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[test] + fn sandbox_owner_reference_extracts_controlling_sandbox_owner() { + let pod = pod_with_owner_refs(vec![sandbox_owner("sandbox-a", "cr-uid-a")]); + + let owner = sandbox_owner_reference(&pod).expect("expected Sandbox owner"); + + assert_eq!( + owner, + SandboxOwnerReference { + name: "sandbox-a".to_string(), + uid: "cr-uid-a".to_string(), + } + ); + } + + #[test] + fn sandbox_owner_reference_rejects_missing_owner() { + let pod = pod_with_owner_refs(vec![]); + + let err = sandbox_owner_reference(&pod).expect_err("missing owner must fail"); + + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[test] + fn sandbox_owner_reference_requires_controlling_owner() { + let mut owner = sandbox_owner("sandbox-a", "cr-uid-a"); + owner.controller = Some(false); + let pod = pod_with_owner_refs(vec![owner]); + + let err = sandbox_owner_reference(&pod).expect_err("non-controller owner must fail"); + + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[test] + fn sandbox_owner_reference_rejects_ambiguous_sandbox_owners() { + let pod = pod_with_owner_refs(vec![ + sandbox_owner("sandbox-a", "cr-uid-a"), + sandbox_owner("sandbox-b", "cr-uid-b"), + ]); + + let err = sandbox_owner_reference(&pod).expect_err("multiple owners must fail"); + + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[test] + fn validate_sandbox_owner_reference_requires_matching_cr_uid_and_label() { + let owner = SandboxOwnerReference { + name: "sandbox-a".to_string(), + uid: "cr-uid-a".to_string(), + }; + let cr = sandbox_cr("sandbox-a", "cr-uid-a", "sandbox-id-a"); + validate_sandbox_owner_reference(&owner, "sandbox-id-a", &cr) + .expect("matching CR should be accepted"); + + let wrong_uid = sandbox_cr("sandbox-a", "cr-uid-b", "sandbox-id-a"); + let err = validate_sandbox_owner_reference(&owner, "sandbox-id-a", &wrong_uid) + .expect_err("wrong CR UID must fail"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + + let wrong_label = sandbox_cr("sandbox-a", "cr-uid-a", "sandbox-id-b"); + let err = validate_sandbox_owner_reference(&owner, "sandbox-id-a", &wrong_label) + .expect_err("wrong sandbox-id label must fail"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[tokio::test] + async fn authenticates_on_issue_path_only() { + let resolved = ResolvedK8sIdentity { + sandbox_id: "sandbox-a".to_string(), + pod_name: "openshell-sandbox-a".to_string(), + pod_uid: "uid-a".to_string(), + }; + let fake = Arc::new(FakeResolver::returning(Ok(Some(resolved)))); + let auth = K8sServiceAccountAuthenticator::new(fake.clone()); + + let on_issue = auth + .authenticate(&bearer_headers("sa-jwt"), ISSUE_SANDBOX_TOKEN_PATH) + .await + .unwrap() + .expect("expected principal"); + match on_issue { + Principal::Sandbox(p) => { + assert_eq!(p.sandbox_id, "sandbox-a"); + assert!(matches!( + p.source, + SandboxIdentitySource::K8sServiceAccount { .. } + )); + } + _ => panic!("expected sandbox principal"), + } + + let off_issue = auth + .authenticate( + &bearer_headers("sa-jwt"), + "/openshell.v1.OpenShell/GetSandboxConfig", + ) + .await + .unwrap(); + assert!( + off_issue.is_none(), + "K8s SA authenticator must be scoped to IssueSandboxToken" + ); + assert_eq!( + fake.seen_tokens.lock().unwrap().len(), + 1, + "off-path call must not consult the apiserver" + ); + } + + #[tokio::test] + async fn missing_bearer_yields_none() { + let fake = Arc::new(FakeResolver::returning(Ok(None))); + let auth = K8sServiceAccountAuthenticator::new(fake); + let result = auth + .authenticate(&http::HeaderMap::new(), ISSUE_SANDBOX_TOKEN_PATH) + .await + .unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn resolver_returning_none_falls_through() { + let fake = Arc::new(FakeResolver::returning(Ok(None))); + let auth = K8sServiceAccountAuthenticator::new(fake); + let result = auth + .authenticate( + &bearer_headers("not-a-real-sa-token"), + ISSUE_SANDBOX_TOKEN_PATH, + ) + .await + .unwrap(); + assert!(result.is_none(), "non-authenticating tokens fall through"); + } + + #[tokio::test] + async fn pod_without_annotation_is_rejected() { + let resolved = ResolvedK8sIdentity { + sandbox_id: String::new(), + pod_name: "stray-pod".to_string(), + pod_uid: "uid".to_string(), + }; + let fake = Arc::new(FakeResolver::returning(Ok(Some(resolved)))); + let auth = K8sServiceAccountAuthenticator::new(fake); + let err = auth + .authenticate(&bearer_headers("sa-jwt"), ISSUE_SANDBOX_TOKEN_PATH) + .await + .expect_err("unbound pod must be rejected"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[tokio::test] + async fn resolver_error_propagates() { + let fake = Arc::new(FakeResolver::returning(Err(Status::unavailable( + "apiserver down", + )))); + let auth = K8sServiceAccountAuthenticator::new(fake); + let err = auth + .authenticate(&bearer_headers("sa-jwt"), ISSUE_SANDBOX_TOKEN_PATH) + .await + .expect_err("resolver error must propagate"); + assert_eq!(err.code(), tonic::Code::Unavailable); + } +} diff --git a/crates/openshell-server/src/auth/mod.rs b/crates/openshell-server/src/auth/mod.rs index 8e4f332d8..ca032a006 100644 --- a/crates/openshell-server/src/auth/mod.rs +++ b/crates/openshell-server/src/auth/mod.rs @@ -8,9 +8,15 @@ //! - `identity`: Provider-agnostic identity representation //! - `http`: HTTP endpoints for auth discovery and token exchange +pub mod authenticator; pub mod authz; +pub mod guard; mod http; pub mod identity; +pub mod k8s_sa; pub mod oidc; +pub mod principal; +pub mod sandbox_jwt; +pub mod sandbox_methods; pub use http::router; diff --git a/crates/openshell-server/src/auth/oidc.rs b/crates/openshell-server/src/auth/oidc.rs index d3b74aa81..5e5a23500 100644 --- a/crates/openshell-server/src/auth/oidc.rs +++ b/crates/openshell-server/src/auth/oidc.rs @@ -10,7 +10,10 @@ //! This module owns authentication (verifying who the caller is). //! Authorization (deciding what the caller can do) is in `authz.rs`. +use super::authenticator::Authenticator; use super::identity::{Identity, IdentityProvider}; +use super::principal::{Principal, UserPrincipal}; +use async_trait::async_trait; use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header}; use openshell_core::OidcConfig; use reqwest::Client; @@ -22,14 +25,6 @@ use tokio::sync::RwLock; use tonic::Status; use tracing::{debug, info, warn}; -/// Internal metadata header set by the auth middleware after it validates -/// a sandbox-secret-authenticated request. This is stripped from all incoming -/// requests first so external callers cannot spoof it. -pub const INTERNAL_AUTH_SOURCE_HEADER: &str = "x-openshell-auth-source"; -/// Internal auth-source marker for requests authenticated via the shared -/// sandbox secret. -pub const AUTH_SOURCE_SANDBOX_SECRET: &str = "sandbox-secret"; - /// Truly unauthenticated methods — health probes and infrastructure. const UNAUTHENTICATED_METHODS: &[&str] = &[ "/openshell.v1.OpenShell/Health", @@ -39,33 +34,6 @@ const UNAUTHENTICATED_METHODS: &[&str] = &[ /// Path prefixes that bypass OIDC validation (gRPC reflection, health probes). const UNAUTHENTICATED_PREFIXES: &[&str] = &["/grpc.reflection.", "/grpc.health."]; -/// Sandbox-to-server RPCs that use the shared sandbox secret instead of -/// OIDC Bearer tokens. These require the `x-sandbox-secret` metadata header -/// matching the server's SSH handshake secret. -const SANDBOX_SECRET_METHODS: &[&str] = &[ - "/openshell.v1.OpenShell/ReportPolicyStatus", - "/openshell.v1.OpenShell/PushSandboxLogs", - "/openshell.v1.OpenShell/GetSandboxProviderEnvironment", - "/openshell.v1.OpenShell/SubmitPolicyAnalysis", - "/openshell.sandbox.v1.SandboxService/GetSandboxConfig", - "/openshell.inference.v1.Inference/GetInferenceBundle", -]; - -/// Methods that accept either OIDC Bearer token (CLI users) or sandbox -/// secret (supervisor). `UpdateConfig` is called by both CLI -/// (policy/settings mutations) and the sandbox supervisor (policy sync on -/// startup). `OpenShell/GetSandboxConfig` serves CLI settings reads while -/// remaining compatible with sandbox-secret-authenticated callers. -const DUAL_AUTH_METHODS: &[&str] = &[ - "/openshell.v1.OpenShell/UpdateConfig", - "/openshell.v1.OpenShell/GetSandboxConfig", -]; - -/// Returns `true` if the method accepts either Bearer or sandbox-secret auth. -pub fn is_dual_auth_method(path: &str) -> bool { - DUAL_AUTH_METHODS.contains(&path) -} - /// Returns `true` if the method needs no authentication at all. pub fn is_unauthenticated_method(path: &str) -> bool { UNAUTHENTICATED_METHODS.contains(&path) @@ -74,52 +42,6 @@ pub fn is_unauthenticated_method(path: &str) -> bool { .any(|prefix| path.starts_with(prefix)) } -/// Returns `true` if the method authenticates via the sandbox shared secret -/// rather than an OIDC Bearer token. -pub fn is_sandbox_secret_method(path: &str) -> bool { - SANDBOX_SECRET_METHODS.contains(&path) -} - -/// Validate the `x-sandbox-secret` header against the server's handshake secret. -#[allow(clippy::result_large_err)] -pub fn validate_sandbox_secret( - headers: &http::HeaderMap, - expected_secret: &str, -) -> Result<(), Status> { - let provided = headers - .get("x-sandbox-secret") - .and_then(|v| v.to_str().ok()) - .ok_or_else(|| Status::unauthenticated("sandbox secret required for this method"))?; - - if provided != expected_secret { - return Err(Status::unauthenticated("invalid sandbox secret")); - } - - Ok(()) -} - -/// Remove internal auth-source markers from the request before any auth -/// decision is made so external callers cannot spoof them. -pub fn clear_internal_auth_markers(headers: &mut http::HeaderMap) { - headers.remove(INTERNAL_AUTH_SOURCE_HEADER); -} - -/// Mark the request as authenticated via the shared sandbox secret. -pub fn mark_sandbox_secret_authenticated(headers: &mut http::HeaderMap) { - headers.insert( - INTERNAL_AUTH_SOURCE_HEADER, - http::HeaderValue::from_static(AUTH_SOURCE_SANDBOX_SECRET), - ); -} - -/// Returns `true` if the request metadata indicates sandbox-secret auth. -pub fn is_sandbox_secret_authenticated(metadata: &tonic::metadata::MetadataMap) -> bool { - metadata - .get(INTERNAL_AUTH_SOURCE_HEADER) - .and_then(|v| v.to_str().ok()) - == Some(AUTH_SOURCE_SANDBOX_SECRET) -} - /// Cached JWKS key set fetched from the OIDC issuer. /// /// A `refresh_mutex` ensures that only one refresh runs at a time, @@ -429,6 +351,42 @@ impl JwksCache { } } +/// Authenticator that validates `Authorization: Bearer ` headers against +/// the configured OIDC issuer. +/// +/// Returns `Ok(None)` when no Bearer header is present, so the chain can fall +/// through to other authenticators (e.g. the gateway-minted sandbox JWT +/// authenticator). +pub struct OidcAuthenticator { + cache: Arc, +} + +impl OidcAuthenticator { + pub fn new(cache: Arc) -> Self { + Self { cache } + } +} + +#[async_trait] +impl Authenticator for OidcAuthenticator { + async fn authenticate( + &self, + headers: &http::HeaderMap, + _path: &str, + ) -> Result, Status> { + let Some(token) = headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + else { + return Ok(None); + }; + + let identity = self.cache.validate_token(token).await?; + Ok(Some(Principal::User(UserPrincipal { identity }))) + } +} + #[cfg(test)] mod tests { use super::*; @@ -443,9 +401,6 @@ mod tests { assert!(!is_unauthenticated_method( "/openshell.v1.OpenShell/CreateSandbox" )); - assert!(!is_sandbox_secret_method( - "/openshell.v1.OpenShell/CreateSandbox" - )); } #[test] @@ -463,52 +418,6 @@ mod tests { assert!(is_unauthenticated_method("/grpc.health.v1.Health/Check")); } - #[test] - fn sandbox_rpcs_use_sandbox_secret() { - assert!(is_sandbox_secret_method( - "/openshell.sandbox.v1.SandboxService/GetSandboxConfig" - )); - assert!(is_sandbox_secret_method( - "/openshell.v1.OpenShell/GetSandboxProviderEnvironment" - )); - assert!(is_sandbox_secret_method( - "/openshell.v1.OpenShell/ReportPolicyStatus" - )); - assert!(is_sandbox_secret_method( - "/openshell.v1.OpenShell/PushSandboxLogs" - )); - assert!(is_sandbox_secret_method( - "/openshell.v1.OpenShell/SubmitPolicyAnalysis" - )); - assert!(is_sandbox_secret_method( - "/openshell.inference.v1.Inference/GetInferenceBundle" - )); - } - - #[test] - fn openshell_get_sandbox_config_is_dual_auth() { - assert!(!is_sandbox_secret_method( - "/openshell.v1.OpenShell/GetSandboxConfig" - )); - assert!(is_dual_auth_method( - "/openshell.v1.OpenShell/GetSandboxConfig" - )); - } - - #[test] - fn sandbox_secret_validation() { - let mut headers = http::HeaderMap::new(); - headers.insert("x-sandbox-secret", "test-secret".parse().unwrap()); - assert!(validate_sandbox_secret(&headers, "test-secret").is_ok()); - assert!(validate_sandbox_secret(&headers, "wrong-secret").is_err()); - } - - #[test] - fn sandbox_secret_missing_header() { - let headers = http::HeaderMap::new(); - assert!(validate_sandbox_secret(&headers, "test-secret").is_err()); - } - #[test] fn extract_roles_keycloak_path() { let json = serde_json::json!({ diff --git a/crates/openshell-server/src/auth/principal.rs b/crates/openshell-server/src/auth/principal.rs new file mode 100644 index 000000000..a95eb831b --- /dev/null +++ b/crates/openshell-server/src/auth/principal.rs @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Authenticated caller principals. +//! +//! A `Principal` is the result of running the [`super::authenticator::Authenticator`] +//! chain on an inbound request. It generalizes over the kinds of callers the +//! gateway recognizes — human users (OIDC), sandbox supervisors (gateway-minted +//! JWT, future SPIFFE), and anonymous callers (truly unauthenticated methods +//! like health probes). +//! +//! Handlers read the principal from the gRPC `Request` extensions and gate +//! access accordingly. Sandbox-class handlers MUST compare +//! `Principal::Sandbox.sandbox_id` against the request body's `sandbox_id` +//! to prevent cross-sandbox access (see issue #1354). + +use super::identity::Identity; + +/// Who is calling. +/// +/// Inserted into `tonic::Request::extensions` by the auth router. Handlers +/// retrieve it via `req.extensions().get::()`. +#[derive(Debug, Clone)] +pub enum Principal { + /// Human caller authenticated via OIDC (Keycloak, Entra ID, Okta, etc.). + User(UserPrincipal), + /// Sandbox supervisor authenticated by an identity bound to a specific + /// sandbox UUID. The wrapped `sandbox_id` MUST match any sandbox referenced + /// in the request body for sandbox-class methods. + Sandbox(#[allow(dead_code)] SandboxPrincipal), + /// Truly unauthenticated caller (health probes, reflection). Sandbox-class + /// and user-class methods reject this variant. + #[allow(dead_code)] + Anonymous, +} + +/// User caller — wraps the existing provider-agnostic [`Identity`]. +#[derive(Debug, Clone)] +pub struct UserPrincipal { + /// The verified identity from the authentication provider. + pub identity: Identity, +} + +/// Sandbox caller — bound to one specific sandbox UUID. +/// +/// `sandbox_id` and `source` are consumed by the router and handler guards. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct SandboxPrincipal { + /// Canonical sandbox UUID populated from a verified sandbox credential. + pub sandbox_id: String, + /// How this principal was verified — used for audit logs and method-specific + /// authorization checks. + pub source: SandboxIdentitySource, + /// SPIFFE trust domain. Populated when the credential is SPIFFE-shaped; + /// reserved for future per-sandbox cert / SPIRE authenticators. + pub trust_domain: Option, +} + +/// How a [`SandboxPrincipal`] was authenticated. +/// +/// Variant fields are populated by the producing authenticator and consumed +/// by audit logging and method-specific authorization checks. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub enum SandboxIdentitySource { + /// Gateway-minted JWT validated against the gateway's signing key. + /// Produced by [`super::sandbox_jwt::SandboxJwtAuthenticator`]. + BootstrapJwt { issuer: String }, + /// Per-sandbox client certificate. Reserved for channel-bound sandbox + /// identity. + BootstrapCert { fingerprint: String }, + /// SPIRE-issued SVID. Reserved for SPIFFE/SPIRE sandbox identity. + SpiffeSvid { spiffe_id: String }, + /// K8s `ServiceAccount` token used to bootstrap a gateway-minted JWT + /// via `IssueSandboxToken`. Populated only on that one RPC path. + K8sServiceAccount { pod_name: String, pod_uid: String }, +} diff --git a/crates/openshell-server/src/auth/sandbox_jwt.rs b/crates/openshell-server/src/auth/sandbox_jwt.rs new file mode 100644 index 000000000..2ec890249 --- /dev/null +++ b/crates/openshell-server/src/auth/sandbox_jwt.rs @@ -0,0 +1,347 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Gateway-minted per-sandbox JWTs. +//! +//! The gateway signs an Ed25519 JWT for each sandbox at create time and +//! the sandbox supervisor presents it as `Authorization: Bearer ` on +//! supervisor-to-gateway gRPC calls. This module implements both sides of the +//! gateway-controlled token: +//! - [`SandboxJwtIssuer`] mints fresh tokens (called from +//! `handle_create_sandbox` and the `IssueSandboxToken` RPC). +//! - [`SandboxJwtAuthenticator`] validates tokens on inbound requests and +//! produces a [`Principal::Sandbox`] with [`SandboxIdentitySource::BootstrapJwt`]. +//! +//! Algorithm: `EdDSA` (Ed25519). Pinned via `Validation::algorithms` to +//! prevent algorithm-confusion attacks. + +use super::authenticator::Authenticator; +use super::principal::{Principal, SandboxIdentitySource, SandboxPrincipal}; +use async_trait::async_trait; +use jsonwebtoken::{ + Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, decode_header, encode, +}; +use serde::{Deserialize, Serialize}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tonic::Status; +use tracing::{debug, warn}; + +/// SPIFFE-shaped subject prefix. Embedded in the `sub` claim of every +/// minted token so a future migration to per-sandbox certs or SPIRE can +/// reuse the same subject namespace without breaking handler equality +/// checks. +const SPIFFE_SUBJECT_PREFIX: &str = "spiffe://openshell/sandbox/"; + +/// JWT claim set serialized in every gateway-minted sandbox token. +#[derive(Debug, Serialize, Deserialize)] +pub struct SandboxJwtClaims { + /// `spiffe://openshell/sandbox/`. SPIFFE-shaped for forward + /// compatibility with channel-bound identity (per-sandbox cert / SPIRE). + pub sub: String, + /// Gateway identity (`openshell-gateway:`). Both `iss` and + /// `aud` use the same value so any future replicas of the same + /// deployment validate each others' tokens without configuration. + pub iss: String, + pub aud: String, + pub iat: i64, + pub exp: i64, + /// Canonical sandbox UUID, denormalized from `sub` for cheap parsing + /// without a SPIFFE library. + pub sandbox_id: String, +} + +/// Mints fresh sandbox JWTs. +pub struct SandboxJwtIssuer { + encoding_key: EncodingKey, + kid: String, + issuer: String, + audience: String, + ttl: Duration, +} + +impl std::fmt::Debug for SandboxJwtIssuer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SandboxJwtIssuer") + .field("kid", &self.kid) + .field("issuer", &self.issuer) + .field("audience", &self.audience) + .field("ttl", &self.ttl) + .finish_non_exhaustive() + } +} + +/// Outcome of a successful mint. +#[derive(Debug, Clone)] +pub struct MintedToken { + pub token: String, + pub expires_at_ms: i64, +} + +impl SandboxJwtIssuer { + pub fn from_pem( + signing_key_pem: &[u8], + kid: String, + gateway_id: &str, + ttl: Duration, + ) -> Result { + let encoding_key = EncodingKey::from_ed_pem(signing_key_pem) + .map_err(|e| format!("failed to parse Ed25519 signing key PEM: {e}"))?; + let identity = format!("openshell-gateway:{gateway_id}"); + Ok(Self { + encoding_key, + kid, + issuer: identity.clone(), + audience: identity, + ttl, + }) + } + + /// Mint a fresh token for `sandbox_id`. + #[allow(clippy::result_large_err)] // `tonic::Status` is the natural error here + pub fn mint(&self, sandbox_id: &str) -> Result { + let now = now_secs(); + let exp = now + i64::try_from(self.ttl.as_secs()).unwrap_or(3_600); + let claims = SandboxJwtClaims { + sub: format!("{SPIFFE_SUBJECT_PREFIX}{sandbox_id}"), + iss: self.issuer.clone(), + aud: self.audience.clone(), + iat: now, + exp, + sandbox_id: sandbox_id.to_string(), + }; + let mut header = Header::new(Algorithm::EdDSA); + header.kid = Some(self.kid.clone()); + let token = encode(&header, &claims, &self.encoding_key).map_err(|e| { + warn!(error = %e, "failed to mint sandbox JWT"); + Status::internal("failed to mint sandbox token") + })?; + Ok(MintedToken { + token, + expires_at_ms: exp.saturating_mul(1000), + }) + } + + pub fn ttl(&self) -> Duration { + self.ttl + } +} + +/// Authenticator that validates gateway-minted sandbox JWTs. +pub struct SandboxJwtAuthenticator { + decoding_key: DecodingKey, + kid: String, + issuer: String, + audience: String, +} + +impl std::fmt::Debug for SandboxJwtAuthenticator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SandboxJwtAuthenticator") + .field("kid", &self.kid) + .field("issuer", &self.issuer) + .field("audience", &self.audience) + .finish_non_exhaustive() + } +} + +impl SandboxJwtAuthenticator { + pub fn from_pem(public_key_pem: &[u8], kid: String, gateway_id: &str) -> Result { + let decoding_key = DecodingKey::from_ed_pem(public_key_pem) + .map_err(|e| format!("failed to parse Ed25519 public key PEM: {e}"))?; + let identity = format!("openshell-gateway:{gateway_id}"); + Ok(Self { + decoding_key, + kid, + issuer: identity.clone(), + audience: identity, + }) + } + + #[allow(clippy::result_large_err)] + fn validate_bearer(&self, token: &str) -> Result, Status> { + let header = decode_header(token).map_err(|e| { + debug!(error = %e, "sandbox JWT header decode failed"); + Status::unauthenticated("invalid token") + })?; + + // Fall through to other authenticators when the kid does not match — + // OIDC issuers may share the Bearer slot. + if header.kid.as_deref() != Some(self.kid.as_str()) { + return Ok(None); + } + if !matches!(header.alg, Algorithm::EdDSA) { + return Ok(None); + } + + let mut validation = Validation::new(Algorithm::EdDSA); + validation.algorithms = vec![Algorithm::EdDSA]; + validation.set_issuer(&[&self.issuer]); + validation.set_audience(&[&self.audience]); + validation.set_required_spec_claims(&["iss", "aud", "exp", "sub"]); + + let data = + decode::(token, &self.decoding_key, &validation).map_err(|e| { + debug!(error = %e, "sandbox JWT validation failed"); + Status::unauthenticated(format!("invalid token: {e}")) + })?; + + let claims = data.claims; + Ok(Some(Principal::Sandbox(SandboxPrincipal { + sandbox_id: claims.sandbox_id, + source: SandboxIdentitySource::BootstrapJwt { issuer: claims.iss }, + trust_domain: Some("openshell".to_string()), + }))) + } +} + +#[async_trait] +impl Authenticator for SandboxJwtAuthenticator { + async fn authenticate( + &self, + headers: &http::HeaderMap, + _path: &str, + ) -> Result, Status> { + let Some(token) = headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + else { + return Ok(None); + }; + self.validate_bearer(token) + } +} + +fn now_secs() -> i64 { + i64::try_from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |d| d.as_secs()), + ) + .unwrap_or(i64::MAX) +} + +#[cfg(test)] +mod tests { + use super::*; + use openshell_bootstrap::jwt::generate_jwt_key; + + fn header_map_with_bearer(token: &str) -> http::HeaderMap { + let mut h = http::HeaderMap::new(); + h.insert( + "authorization", + http::HeaderValue::from_str(&format!("Bearer {token}")).unwrap(), + ); + h + } + + fn pair() -> (SandboxJwtIssuer, SandboxJwtAuthenticator) { + let mat = generate_jwt_key().expect("jwt key"); + let issuer = SandboxJwtIssuer::from_pem( + mat.signing_key_pem.as_bytes(), + mat.kid.clone(), + "test-gateway", + Duration::from_secs(3600), + ) + .unwrap(); + let auth = SandboxJwtAuthenticator::from_pem( + mat.public_key_pem.as_bytes(), + mat.kid, + "test-gateway", + ) + .unwrap(); + (issuer, auth) + } + + #[tokio::test] + async fn mint_and_validate_round_trip() { + let (issuer, auth) = pair(); + let minted = issuer.mint("sandbox-a").unwrap(); + let principal = auth + .authenticate(&header_map_with_bearer(&minted.token), "/anything") + .await + .unwrap() + .expect("expected principal"); + match principal { + Principal::Sandbox(p) => { + assert_eq!(p.sandbox_id, "sandbox-a"); + match p.source { + SandboxIdentitySource::BootstrapJwt { issuer: iss } => { + assert_eq!(iss, "openshell-gateway:test-gateway"); + } + other => panic!("unexpected source: {other:?}"), + } + } + _ => panic!("expected Sandbox principal"), + } + } + + #[tokio::test] + async fn token_signed_by_other_key_is_rejected() { + let (_, auth_a) = pair(); + let (issuer_b, _) = pair(); // different keypair + let minted = issuer_b.mint("sandbox-b").unwrap(); + // The token has a different `kid` than auth_a expects, so the + // authenticator yields None (lets the chain fall through). That is + // the documented behavior for cross-issuer Bearer headers. + let result = auth_a + .authenticate(&header_map_with_bearer(&minted.token), "/anything") + .await + .unwrap(); + assert!(result.is_none(), "different kid must fall through"); + } + + #[tokio::test] + async fn missing_bearer_yields_none() { + let (_, auth) = pair(); + let result = auth + .authenticate(&http::HeaderMap::new(), "/anything") + .await + .unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn malformed_token_is_rejected() { + let (_, auth) = pair(); + let err = auth + .authenticate(&header_map_with_bearer("not.a.jwt"), "/anything") + .await + .expect_err("malformed must reject"); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + } + + #[tokio::test] + async fn expired_token_is_rejected() { + // Mint a token whose iat is far in the past so its TTL window is + // already closed by `now`. We sign the JWT directly with the same + // signing key to bypass the issuer's TTL-vs-now coupling. + let mat = generate_jwt_key().unwrap(); + let issuer = SandboxJwtIssuer::from_pem( + mat.signing_key_pem.as_bytes(), + mat.kid.clone(), + "g", + Duration::from_secs(3600), + ) + .unwrap(); + let auth = + SandboxJwtAuthenticator::from_pem(mat.public_key_pem.as_bytes(), mat.kid.clone(), "g") + .unwrap(); + let claims = SandboxJwtClaims { + sub: format!("{SPIFFE_SUBJECT_PREFIX}sandbox-c"), + iss: "openshell-gateway:g".to_string(), + aud: "openshell-gateway:g".to_string(), + iat: now_secs() - 7200, + exp: now_secs() - 3600, + sandbox_id: "sandbox-c".to_string(), + }; + let mut header = Header::new(Algorithm::EdDSA); + header.kid = Some(mat.kid); + let token = encode(&header, &claims, &issuer.encoding_key).unwrap(); + let err = auth + .authenticate(&header_map_with_bearer(&token), "/anything") + .await + .expect_err("expired token must reject"); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + } +} diff --git a/crates/openshell-server/src/auth/sandbox_methods.rs b/crates/openshell-server/src/auth/sandbox_methods.rs new file mode 100644 index 000000000..e03b8eeb6 --- /dev/null +++ b/crates/openshell-server/src/auth/sandbox_methods.rs @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Method-level allowlist for sandbox principals. +//! +//! Gateway-minted sandbox JWTs identify a single sandbox supervisor. They +//! must not authorize user-facing or admin APIs. The router rejects sandbox +//! principals for every method outside this supervisor-to-gateway allowlist; +//! handlers still perform same-sandbox checks on request bodies. + +/// Methods a `Principal::Sandbox` may invoke. +const ALLOWED_SANDBOX_METHODS: &[&str] = &[ + "/openshell.v1.OpenShell/IssueSandboxToken", + "/openshell.v1.OpenShell/RefreshSandboxToken", + "/openshell.v1.OpenShell/ConnectSupervisor", + "/openshell.v1.OpenShell/RelayStream", + "/openshell.v1.OpenShell/GetSandboxConfig", + "/openshell.v1.OpenShell/GetSandboxProviderEnvironment", + "/openshell.v1.OpenShell/UpdateConfig", + "/openshell.v1.OpenShell/ReportPolicyStatus", + "/openshell.v1.OpenShell/PushSandboxLogs", + "/openshell.v1.OpenShell/SubmitPolicyAnalysis", + "/openshell.v1.OpenShell/GetDraftPolicy", + "/openshell.inference.v1.Inference/GetInferenceBundle", +]; + +pub fn is_sandbox_callable(path: &str) -> bool { + ALLOWED_SANDBOX_METHODS.contains(&path) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn supervisor_callbacks_are_allowed() { + assert!(is_sandbox_callable( + "/openshell.v1.OpenShell/ConnectSupervisor" + )); + assert!(is_sandbox_callable("/openshell.v1.OpenShell/RelayStream")); + assert!(is_sandbox_callable( + "/openshell.v1.OpenShell/GetSandboxConfig" + )); + assert!(is_sandbox_callable( + "/openshell.inference.v1.Inference/GetInferenceBundle" + )); + } + + #[test] + fn user_and_admin_methods_are_not_allowed() { + assert!(!is_sandbox_callable( + "/openshell.v1.OpenShell/ListSandboxes" + )); + assert!(!is_sandbox_callable( + "/openshell.v1.OpenShell/DeleteSandbox" + )); + assert!(!is_sandbox_callable( + "/openshell.v1.OpenShell/CreateProvider" + )); + assert!(!is_sandbox_callable( + "/openshell.v1.OpenShell/ApproveDraftChunk" + )); + assert!(!is_sandbox_callable( + "/openshell.inference.v1.Inference/GetClusterInference" + )); + assert!(!is_sandbox_callable( + "/openshell.inference.v1.Inference/SetClusterInference" + )); + } +} diff --git a/crates/openshell-server/src/certgen.rs b/crates/openshell-server/src/certgen.rs new file mode 100644 index 000000000..b7ce0421c --- /dev/null +++ b/crates/openshell-server/src/certgen.rs @@ -0,0 +1,988 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! `generate-certs` subcommand: bootstrap mTLS PKI for the gateway. +//! +//! Two output modes, dispatched by the presence of `--output-dir`: +//! +//! - **Kubernetes mode** (default): create two `kubernetes.io/tls` Secrets +//! in the supplied namespace. Used by the Helm pre-install hook. Requires +//! `--namespace`, `--server-secret-name`, `--client-secret-name`. +//! - **Local mode** (`--output-dir `): write PEMs to the local package +//! filesystem layout. Used by systemd units' `ExecStartPre`. Also copies +//! client materials to +//! `$XDG_CONFIG_HOME/openshell/gateways/openshell/mtls/` so the local CLI +//! picks them up automatically. +//! +//! Both modes share the same idempotency contract: all targets present → +//! skip; legacy TLS-only state → add the JWT signing material; partial +//! state → error with a recovery hint; nothing present → generate and write. + +use clap::Args; +use k8s_openapi::ByteString; +use k8s_openapi::api::core::v1::Secret; +use kube::Client; +use kube::api::{Api, ObjectMeta, PostParams}; +use miette::{IntoDiagnostic, Result, WrapErr}; +use openshell_bootstrap::pki::{DEFAULT_SERVER_SANS, PkiBundle, generate_pki}; +use openshell_core::paths::{create_dir_restricted, set_file_owner_only}; +use std::collections::{BTreeMap, BTreeSet}; +use std::fmt; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::path::{Path, PathBuf}; +use tracing::{info, warn}; +use tracing_subscriber::EnvFilter; + +#[derive(Args, Debug)] +pub struct CertgenArgs { + /// Write PEMs to a filesystem directory instead of Kubernetes Secrets. + /// When set, the kube-related flags are not required. + #[arg(long, value_name = "DIR")] + output_dir: Option, + + /// Kubernetes namespace to create Secrets in. + /// Default comes from `POD_NAMESPACE`, which the Helm hook injects via + /// the downward API. + #[arg(long, env = "POD_NAMESPACE", required_unless_present = "output_dir")] + namespace: Option, + + /// Name of the server TLS Secret (`kubernetes.io/tls`) to create. + #[arg(long, required_unless_present = "output_dir")] + server_secret_name: Option, + + /// Name of the client TLS Secret (`kubernetes.io/tls`) to create. + #[arg(long, required_unless_present = "output_dir")] + client_secret_name: Option, + + /// Name of the sandbox-JWT signing-key Secret (`Opaque`) to create. + /// Holds `signing.pem`, `public.pem`, and `kid` keys. Mounted on the + /// gateway pod (only) so it can mint and validate per-sandbox JWTs. + #[arg(long, required_unless_present = "output_dir")] + jwt_secret_name: Option, + + /// Extra Subject Alternative Name for the server certificate. Repeatable. + /// Auto-detected as an IP address or DNS name. + #[arg(long = "server-san", value_name = "SAN")] + server_sans: Vec, + + /// Print the generated PEM materials to stdout instead of writing them. + /// For local debugging. + #[arg(long)] + dry_run: bool, +} + +pub async fn run(args: CertgenArgs) -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) + .init(); + + if args.dry_run { + let bundle = generate_pki(&args.server_sans)?; + print_bundle(&bundle); + return Ok(()); + } + + if let Some(dir) = args.output_dir.as_deref() { + run_local(dir, &args.server_sans) + } else { + let bundle = generate_pki(&args.server_sans)?; + run_kubernetes(&args, &bundle).await + } +} + +// ─────────────────────────── Kubernetes mode ─────────────────────────── + +#[derive(Debug, PartialEq, Eq)] +enum K8sAction { + SkipExists, + CreateJwtOnly, + PartialState, + CreateAll, +} + +fn decide_k8s(server_exists: bool, client_exists: bool, jwt_exists: bool) -> K8sAction { + match (server_exists, client_exists, jwt_exists) { + (true, true, true) => K8sAction::SkipExists, + (true, true, false) => K8sAction::CreateJwtOnly, + (false, false, false) => K8sAction::CreateAll, + _ => K8sAction::PartialState, + } +} + +async fn run_kubernetes(args: &CertgenArgs, bundle: &PkiBundle) -> Result<()> { + let namespace = args + .namespace + .as_deref() + .ok_or_else(|| miette::miette!("--namespace is required (or set POD_NAMESPACE)"))?; + let server_name = args + .server_secret_name + .as_deref() + .ok_or_else(|| miette::miette!("--server-secret-name is required"))?; + let client_name = args + .client_secret_name + .as_deref() + .ok_or_else(|| miette::miette!("--client-secret-name is required"))?; + let jwt_name = args + .jwt_secret_name + .as_deref() + .ok_or_else(|| miette::miette!("--jwt-secret-name is required"))?; + + let client = Client::try_default() + .await + .into_diagnostic() + .wrap_err("failed to construct in-cluster Kubernetes client")?; + let api: Api = Api::namespaced(client, namespace); + + let server_exists = api + .get_opt(server_name) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to read secret {server_name}"))? + .is_some(); + let client_exists = api + .get_opt(client_name) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to read secret {client_name}"))? + .is_some(); + let jwt_exists = api + .get_opt(jwt_name) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to read secret {jwt_name}"))? + .is_some(); + + match decide_k8s(server_exists, client_exists, jwt_exists) { + K8sAction::SkipExists => { + info!( + namespace = %namespace, + server = %server_name, + client = %client_name, + jwt = %jwt_name, + "PKI secrets already exist, skipping." + ); + return Ok(()); + } + K8sAction::PartialState => { + return Err(miette::miette!( + "partial PKI state in namespace {namespace}: only some of \ + {server_name} / {client_name} / {jwt_name} exist. Recover with: \ + kubectl delete secret -n {namespace} {server_name} {client_name} {jwt_name}", + )); + } + K8sAction::CreateJwtOnly => { + let jwt_secret = jwt_signing_secret( + jwt_name, + &bundle.jwt_signing_key_pem, + &bundle.jwt_public_key_pem, + &bundle.jwt_key_id, + ); + api.create(&PostParams::default(), &jwt_secret) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to create secret {jwt_name}"))?; + info!( + namespace = %namespace, + jwt = %jwt_name, + "JWT signing secret created for existing TLS install." + ); + return Ok(()); + } + K8sAction::CreateAll => {} + } + + let server_secret = tls_secret( + server_name, + &bundle.server_cert_pem, + &bundle.server_key_pem, + &bundle.ca_cert_pem, + ); + let client_secret = tls_secret( + client_name, + &bundle.client_cert_pem, + &bundle.client_key_pem, + &bundle.ca_cert_pem, + ); + let jwt_secret = jwt_signing_secret( + jwt_name, + &bundle.jwt_signing_key_pem, + &bundle.jwt_public_key_pem, + &bundle.jwt_key_id, + ); + + api.create(&PostParams::default(), &server_secret) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to create secret {server_name}"))?; + api.create(&PostParams::default(), &client_secret) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to create secret {client_name}"))?; + api.create(&PostParams::default(), &jwt_secret) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to create secret {jwt_name}"))?; + + info!( + namespace = %namespace, + server = %server_name, + client = %client_name, + jwt = %jwt_name, + "PKI secrets created." + ); + Ok(()) +} + +fn tls_secret(name: &str, crt_pem: &str, key_pem: &str, ca_pem: &str) -> Secret { + let mut data = BTreeMap::new(); + data.insert( + "tls.crt".to_string(), + ByteString(crt_pem.as_bytes().to_vec()), + ); + data.insert( + "tls.key".to_string(), + ByteString(key_pem.as_bytes().to_vec()), + ); + data.insert("ca.crt".to_string(), ByteString(ca_pem.as_bytes().to_vec())); + Secret { + metadata: ObjectMeta { + name: Some(name.to_string()), + ..Default::default() + }, + type_: Some("kubernetes.io/tls".to_string()), + data: Some(data), + ..Default::default() + } +} + +/// Build an `Opaque` Secret carrying the gateway-minted sandbox JWT +/// signing material. Mounted only on the gateway pod — sandbox pods +/// receive a per-pod gateway-signed token, never the signing key itself. +fn jwt_signing_secret(name: &str, signing_pem: &str, public_pem: &str, kid: &str) -> Secret { + let mut data = BTreeMap::new(); + data.insert( + "signing.pem".to_string(), + ByteString(signing_pem.as_bytes().to_vec()), + ); + data.insert( + "public.pem".to_string(), + ByteString(public_pem.as_bytes().to_vec()), + ); + data.insert("kid".to_string(), ByteString(kid.as_bytes().to_vec())); + Secret { + metadata: ObjectMeta { + name: Some(name.to_string()), + ..Default::default() + }, + type_: Some("Opaque".to_string()), + data: Some(data), + ..Default::default() + } +} + +// ─────────────────────────────── Local mode ─────────────────────────────── + +#[derive(Debug, PartialEq, Eq)] +enum LocalAction { + Skip, + CreateJwtOnly, + PartialState, + CreateAll, +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +enum CertSan { + Dns(String), + Ip(IpAddr), +} + +impl fmt::Display for CertSan { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Dns(name) => write!(f, "{name}"), + Self::Ip(addr) => write!(f, "{addr}"), + } + } +} + +/// Layout under ``: +/// +/// ```text +/// /ca.crt +/// /ca.key +/// /server/tls.crt +/// /server/tls.key +/// /client/tls.crt +/// /client/tls.key +/// ``` +struct LocalPaths { + ca_crt: PathBuf, + ca_key: PathBuf, + server_dir: PathBuf, + server_crt: PathBuf, + server_key: PathBuf, + client_dir: PathBuf, + client_crt: PathBuf, + client_key: PathBuf, + jwt_dir: PathBuf, + jwt_signing: PathBuf, + jwt_public: PathBuf, + jwt_kid: PathBuf, +} + +impl LocalPaths { + fn resolve(dir: &Path) -> Self { + let server_dir = dir.join("server"); + let client_dir = dir.join("client"); + let jwt_dir = dir.join("jwt"); + Self { + ca_crt: dir.join("ca.crt"), + ca_key: dir.join("ca.key"), + server_crt: server_dir.join("tls.crt"), + server_key: server_dir.join("tls.key"), + server_dir, + client_crt: client_dir.join("tls.crt"), + client_key: client_dir.join("tls.key"), + client_dir, + jwt_signing: jwt_dir.join("signing.pem"), + jwt_public: jwt_dir.join("public.pem"), + jwt_kid: jwt_dir.join("kid"), + jwt_dir, + } + } + + fn tls_files(&self) -> [&Path; 6] { + [ + &self.ca_crt, + &self.ca_key, + &self.server_crt, + &self.server_key, + &self.client_crt, + &self.client_key, + ] + } + + fn jwt_files(&self) -> [&Path; 3] { + [&self.jwt_signing, &self.jwt_public, &self.jwt_kid] + } + + #[cfg(test)] + fn all_files(&self) -> [&Path; 9] { + [ + &self.ca_crt, + &self.ca_key, + &self.server_crt, + &self.server_key, + &self.client_crt, + &self.client_key, + &self.jwt_signing, + &self.jwt_public, + &self.jwt_kid, + ] + } + + fn tls_existence_count(&self) -> usize { + self.tls_files().iter().filter(|p| p.exists()).count() + } + + fn jwt_existence_count(&self) -> usize { + self.jwt_files().iter().filter(|p| p.exists()).count() + } +} + +fn decide_local(tls_present: usize, jwt_present: usize) -> LocalAction { + match (tls_present, jwt_present) { + (6, 3) => LocalAction::Skip, + (6, 0) => LocalAction::CreateJwtOnly, + (0, 0) => LocalAction::CreateAll, + _ => LocalAction::PartialState, + } +} + +fn run_local(dir: &Path, server_sans: &[String]) -> Result<()> { + let paths = LocalPaths::resolve(dir); + + let bundle = match decide_local(paths.tls_existence_count(), paths.jwt_existence_count()) { + LocalAction::Skip => { + let missing_sans = missing_required_server_sans(&paths, server_sans)?; + if missing_sans.is_empty() { + info!(dir = %dir.display(), "PKI files already exist, skipping."); + } else { + let bundle = generate_pki(server_sans)?; + write_local_tls_bundle(&bundle, &paths)?; + info!( + dir = %dir.display(), + missing_sans = %format_cert_sans(&missing_sans), + "server TLS certificate refreshed for current SAN set.", + ); + } + read_local_bundle(&paths)? + } + LocalAction::CreateJwtOnly => { + let bundle = generate_pki(server_sans)?; + let missing_sans = missing_required_server_sans(&paths, server_sans)?; + if missing_sans.is_empty() { + write_local_jwt_bundle(&bundle, &paths)?; + info!(dir = %dir.display(), "JWT signing files created for existing TLS install."); + } else { + write_local_bundle(dir, &bundle, &paths)?; + info!( + dir = %dir.display(), + missing_sans = %format_cert_sans(&missing_sans), + "PKI files refreshed for current SAN set and JWT signing material.", + ); + } + read_local_bundle(&paths)? + } + LocalAction::PartialState => { + return Err(miette::miette!( + "partial PKI state in {dir}: some files exist but not all. \ + Recover with: rm -rf {dir} (the gateway will regenerate on next start)", + dir = dir.display(), + )); + } + LocalAction::CreateAll => { + let bundle = generate_pki(server_sans)?; + write_local_bundle(dir, &bundle, &paths)?; + info!(dir = %dir.display(), "PKI files created."); + bundle + } + }; + + // Always make sure the CLI auto-discovery copy is in place. This + // self-heals the case where the operator wiped ~/.config/openshell but + // left the gateway state directory intact. + if let Err(e) = openshell_bootstrap::mtls::store_pki_bundle("openshell", &bundle) { + warn!(error = %e, "failed to copy client mTLS materials for CLI auto-discovery"); + } + + Ok(()) +} + +fn required_server_sans(server_sans: &[String]) -> BTreeSet { + DEFAULT_SERVER_SANS + .iter() + .copied() + .chain(server_sans.iter().map(String::as_str)) + .filter_map(|san| { + san.parse::() + .map(CertSan::Ip) + .ok() + .or_else(|| san.is_ascii().then(|| CertSan::Dns(san.to_string()))) + }) + .collect() +} + +fn missing_required_server_sans( + paths: &LocalPaths, + server_sans: &[String], +) -> Result> { + let required = required_server_sans(server_sans); + let actual = server_cert_sans(&paths.server_crt)?; + Ok(required.difference(&actual).cloned().collect()) +} + +fn server_cert_sans(path: &Path) -> Result> { + use x509_parser::pem::parse_x509_pem; + use x509_parser::prelude::{FromDer, GeneralName, X509Certificate}; + + let pem = std::fs::read(path) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read {}", path.display()))?; + let (_, pem) = parse_x509_pem(&pem).map_err(|e| { + miette::miette!( + "failed to parse server certificate PEM {}: {e:?}", + path.display() + ) + })?; + let (_, cert) = X509Certificate::from_der(&pem.contents).map_err(|e| { + miette::miette!( + "failed to parse server certificate {}: {e:?}", + path.display() + ) + })?; + + let Some(ext) = cert.subject_alternative_name().map_err(|e| { + miette::miette!( + "failed to read server certificate SANs {}: {e:?}", + path.display() + ) + })? + else { + return Ok(BTreeSet::new()); + }; + + let mut sans = BTreeSet::new(); + for name in &ext.value.general_names { + match name { + GeneralName::DNSName(name) => { + sans.insert(CertSan::Dns((*name).to_string())); + } + GeneralName::IPAddress(raw) => match raw.len() { + 4 => { + sans.insert(CertSan::Ip(IpAddr::V4(Ipv4Addr::new( + raw[0], raw[1], raw[2], raw[3], + )))); + } + 16 => { + let octets: [u8; 16] = (*raw).try_into().expect("checked IPv6 SAN length"); + sans.insert(CertSan::Ip(IpAddr::V6(Ipv6Addr::from(octets)))); + } + _ => {} + }, + _ => {} + } + } + Ok(sans) +} + +fn format_cert_sans(sans: &[CertSan]) -> String { + sans.iter() + .map(ToString::to_string) + .collect::>() + .join(", ") +} + +fn read_local_bundle(paths: &LocalPaths) -> Result { + Ok(PkiBundle { + ca_cert_pem: read_pem(&paths.ca_crt)?, + ca_key_pem: read_pem(&paths.ca_key)?, + server_cert_pem: read_pem(&paths.server_crt)?, + server_key_pem: read_pem(&paths.server_key)?, + client_cert_pem: read_pem(&paths.client_crt)?, + client_key_pem: read_pem(&paths.client_key)?, + jwt_signing_key_pem: read_pem(&paths.jwt_signing)?, + jwt_public_key_pem: read_pem(&paths.jwt_public)?, + jwt_key_id: read_pem(&paths.jwt_kid)?.trim().to_string(), + }) +} + +fn read_pem(path: &Path) -> Result { + std::fs::read_to_string(path) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read {}", path.display())) +} + +fn write_local_bundle(dir: &Path, bundle: &PkiBundle, paths: &LocalPaths) -> Result<()> { + // Stage to a sibling tmp dir so individual renames into the final layout + // are atomic on the same filesystem. + let temp = sibling_temp_dir(dir); + if temp.exists() { + std::fs::remove_dir_all(&temp) + .into_diagnostic() + .wrap_err_with(|| format!("failed to remove stale {}", temp.display()))?; + } + + let temp_server = temp.join("server"); + let temp_client = temp.join("client"); + let temp_jwt = temp.join("jwt"); + create_dir_restricted(&temp)?; + create_dir_restricted(&temp_server)?; + create_dir_restricted(&temp_client)?; + create_dir_restricted(&temp_jwt)?; + + write_pem(&temp.join("ca.crt"), &bundle.ca_cert_pem, false)?; + write_pem(&temp.join("ca.key"), &bundle.ca_key_pem, true)?; + write_pem(&temp_server.join("tls.crt"), &bundle.server_cert_pem, false)?; + write_pem(&temp_server.join("tls.key"), &bundle.server_key_pem, true)?; + write_pem(&temp_client.join("tls.crt"), &bundle.client_cert_pem, false)?; + write_pem(&temp_client.join("tls.key"), &bundle.client_key_pem, true)?; + write_pem( + &temp_jwt.join("signing.pem"), + &bundle.jwt_signing_key_pem, + true, + )?; + write_pem( + &temp_jwt.join("public.pem"), + &bundle.jwt_public_key_pem, + false, + )?; + write_pem(&temp_jwt.join("kid"), &bundle.jwt_key_id, false)?; + + // Final destination (might not exist yet on first run). + create_dir_restricted(dir)?; + create_dir_restricted(&paths.server_dir)?; + create_dir_restricted(&paths.client_dir)?; + create_dir_restricted(&paths.jwt_dir)?; + + let renames: [(PathBuf, &Path); 9] = [ + (temp.join("ca.crt"), paths.ca_crt.as_path()), + (temp.join("ca.key"), paths.ca_key.as_path()), + (temp_server.join("tls.crt"), paths.server_crt.as_path()), + (temp_server.join("tls.key"), paths.server_key.as_path()), + (temp_client.join("tls.crt"), paths.client_crt.as_path()), + (temp_client.join("tls.key"), paths.client_key.as_path()), + (temp_jwt.join("signing.pem"), paths.jwt_signing.as_path()), + (temp_jwt.join("public.pem"), paths.jwt_public.as_path()), + (temp_jwt.join("kid"), paths.jwt_kid.as_path()), + ]; + for (from, to) in &renames { + std::fs::rename(from, to) + .into_diagnostic() + .wrap_err_with(|| format!("failed to move {} -> {}", from.display(), to.display()))?; + } + + let _ = std::fs::remove_dir_all(&temp); + Ok(()) +} + +fn write_local_tls_bundle(bundle: &PkiBundle, paths: &LocalPaths) -> Result<()> { + let temp = sibling_temp_dir(&paths.server_dir); + if temp.exists() { + std::fs::remove_dir_all(&temp) + .into_diagnostic() + .wrap_err_with(|| format!("failed to remove stale {}", temp.display()))?; + } + + let temp_server = temp.join("server"); + let temp_client = temp.join("client"); + create_dir_restricted(&temp)?; + create_dir_restricted(&temp_server)?; + create_dir_restricted(&temp_client)?; + + write_pem(&temp.join("ca.crt"), &bundle.ca_cert_pem, false)?; + write_pem(&temp.join("ca.key"), &bundle.ca_key_pem, true)?; + write_pem(&temp_server.join("tls.crt"), &bundle.server_cert_pem, false)?; + write_pem(&temp_server.join("tls.key"), &bundle.server_key_pem, true)?; + write_pem(&temp_client.join("tls.crt"), &bundle.client_cert_pem, false)?; + write_pem(&temp_client.join("tls.key"), &bundle.client_key_pem, true)?; + + create_dir_restricted(&paths.server_dir)?; + create_dir_restricted(&paths.client_dir)?; + let renames: [(PathBuf, &Path); 6] = [ + (temp.join("ca.crt"), paths.ca_crt.as_path()), + (temp.join("ca.key"), paths.ca_key.as_path()), + (temp_server.join("tls.crt"), paths.server_crt.as_path()), + (temp_server.join("tls.key"), paths.server_key.as_path()), + (temp_client.join("tls.crt"), paths.client_crt.as_path()), + (temp_client.join("tls.key"), paths.client_key.as_path()), + ]; + for (from, to) in &renames { + std::fs::rename(from, to) + .into_diagnostic() + .wrap_err_with(|| format!("failed to move {} -> {}", from.display(), to.display()))?; + } + + let _ = std::fs::remove_dir_all(&temp); + Ok(()) +} + +fn write_local_jwt_bundle(bundle: &PkiBundle, paths: &LocalPaths) -> Result<()> { + let temp = sibling_temp_dir(&paths.jwt_dir); + if temp.exists() { + std::fs::remove_dir_all(&temp) + .into_diagnostic() + .wrap_err_with(|| format!("failed to remove stale {}", temp.display()))?; + } + + create_dir_restricted(&temp)?; + write_pem(&temp.join("signing.pem"), &bundle.jwt_signing_key_pem, true)?; + write_pem(&temp.join("public.pem"), &bundle.jwt_public_key_pem, false)?; + write_pem(&temp.join("kid"), &bundle.jwt_key_id, false)?; + + create_dir_restricted(&paths.jwt_dir)?; + let renames: [(PathBuf, &Path); 3] = [ + (temp.join("signing.pem"), paths.jwt_signing.as_path()), + (temp.join("public.pem"), paths.jwt_public.as_path()), + (temp.join("kid"), paths.jwt_kid.as_path()), + ]; + for (from, to) in &renames { + std::fs::rename(from, to) + .into_diagnostic() + .wrap_err_with(|| format!("failed to move {} -> {}", from.display(), to.display()))?; + } + + let _ = std::fs::remove_dir_all(&temp); + Ok(()) +} + +fn write_pem(path: &Path, contents: &str, owner_only: bool) -> Result<()> { + std::fs::write(path, contents) + .into_diagnostic() + .wrap_err_with(|| format!("failed to write {}", path.display()))?; + if owner_only { + set_file_owner_only(path)?; + } + Ok(()) +} + +fn sibling_temp_dir(dir: &Path) -> PathBuf { + // Use a sibling so std::fs::rename succeeds (same filesystem). + let mut name = dir + .file_name() + .map(std::ffi::OsStr::to_os_string) + .unwrap_or_default(); + name.push(".certgen.tmp"); + dir.with_file_name(name) +} + +// ────────────────────────────── Shared utility ───────────────────────────── + +fn print_bundle(bundle: &PkiBundle) { + println!("# CA certificate\n{}", bundle.ca_cert_pem); + println!("# Server certificate\n{}", bundle.server_cert_pem); + println!("# Server key\n{}", bundle.server_key_pem); + println!("# Client certificate\n{}", bundle.client_cert_pem); + println!("# Client key\n{}", bundle.client_key_pem); +} + +#[cfg(test)] +mod tests { + use super::{ + CertSan, K8sAction, LocalAction, LocalPaths, decide_k8s, decide_local, jwt_signing_secret, + missing_required_server_sans, read_local_bundle, sibling_temp_dir, tls_secret, + write_local_bundle, write_local_jwt_bundle, write_local_tls_bundle, + }; + use openshell_bootstrap::pki::generate_pki; + use std::path::Path; + + // ── Kubernetes-mode decision ── + + #[test] + fn decide_k8s_skip_when_all_three_exist() { + assert_eq!(decide_k8s(true, true, true), K8sAction::SkipExists); + } + + #[test] + fn decide_k8s_create_when_none_exist() { + assert_eq!(decide_k8s(false, false, false), K8sAction::CreateAll); + } + + #[test] + fn decide_k8s_creates_jwt_only_for_existing_tls() { + assert_eq!(decide_k8s(true, true, false), K8sAction::CreateJwtOnly); + } + + #[test] + fn decide_k8s_partial_for_any_mixed_state() { + let mixes = [ + (true, false, false), + (false, true, false), + (false, false, true), + (true, false, true), + (false, true, true), + ]; + for (s, c, j) in mixes { + assert_eq!( + decide_k8s(s, c, j), + K8sAction::PartialState, + "({s},{c},{j})" + ); + } + } + + #[test] + fn tls_secret_has_kubernetes_io_tls_type_and_three_keys() { + let s = tls_secret("foo", "CRT-PEM", "KEY-PEM", "CA-PEM"); + assert_eq!(s.metadata.name.as_deref(), Some("foo")); + assert_eq!(s.type_.as_deref(), Some("kubernetes.io/tls")); + let data = s.data.expect("data set"); + assert_eq!(data.len(), 3); + assert_eq!(data["tls.crt"].0, b"CRT-PEM"); + assert_eq!(data["tls.key"].0, b"KEY-PEM"); + assert_eq!(data["ca.crt"].0, b"CA-PEM"); + } + + #[test] + fn jwt_signing_secret_has_opaque_type_and_three_keys() { + let s = jwt_signing_secret("jwt", "SIGN", "PUB", "kid-1"); + assert_eq!(s.metadata.name.as_deref(), Some("jwt")); + assert_eq!(s.type_.as_deref(), Some("Opaque")); + let data = s.data.expect("data set"); + assert_eq!(data.len(), 3); + assert_eq!(data["signing.pem"].0, b"SIGN"); + assert_eq!(data["public.pem"].0, b"PUB"); + assert_eq!(data["kid"].0, b"kid-1"); + } + + // ── Local-mode decision ── + + #[test] + fn decide_local_skip_when_all_nine_present() { + assert_eq!(decide_local(6, 3), LocalAction::Skip); + } + + #[test] + fn decide_local_create_when_none_present() { + assert_eq!(decide_local(0, 0), LocalAction::CreateAll); + } + + #[test] + fn decide_local_creates_jwt_only_for_existing_tls() { + assert_eq!(decide_local(6, 0), LocalAction::CreateJwtOnly); + } + + #[test] + fn decide_local_partial_for_incomplete_tls_or_jwt_sets() { + for tls in 0..=6 { + for jwt in 0..=3 { + if matches!((tls, jwt), (6, 3 | 0) | (0, 0)) { + continue; + } + assert_eq!( + decide_local(tls, jwt), + LocalAction::PartialState, + "tls={tls} jwt={jwt}" + ); + } + } + } + + // ── Local-mode layout & writes ── + + #[test] + fn local_paths_resolve_matches_init_pki_layout() { + let p = LocalPaths::resolve(Path::new("/tmp/openshell/tls")); + assert_eq!(p.ca_crt, Path::new("/tmp/openshell/tls/ca.crt")); + assert_eq!(p.ca_key, Path::new("/tmp/openshell/tls/ca.key")); + assert_eq!(p.server_crt, Path::new("/tmp/openshell/tls/server/tls.crt")); + assert_eq!(p.server_key, Path::new("/tmp/openshell/tls/server/tls.key")); + assert_eq!(p.client_crt, Path::new("/tmp/openshell/tls/client/tls.crt")); + assert_eq!(p.client_key, Path::new("/tmp/openshell/tls/client/tls.key")); + } + + #[test] + fn sibling_temp_dir_is_adjacent_to_target() { + assert_eq!( + sibling_temp_dir(Path::new("/var/lib/openshell/tls")), + Path::new("/var/lib/openshell/tls.certgen.tmp") + ); + } + + #[test] + fn write_local_bundle_writes_six_files_and_removes_temp() { + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + + write_local_bundle(&dir, &bundle, &paths).expect("write_local_bundle"); + + for f in paths.all_files() { + assert!(f.is_file(), "missing {}", f.display()); + } + assert!( + !sibling_temp_dir(&dir).exists(), + "temp dir should be cleaned up" + ); + + // Spot-check contents. + let ca = std::fs::read_to_string(&paths.ca_crt).unwrap(); + assert!(ca.contains("BEGIN CERTIFICATE")); + let server_key = std::fs::read_to_string(&paths.server_key).unwrap(); + assert!(server_key.contains("BEGIN PRIVATE KEY")); + } + + #[test] + fn write_local_jwt_bundle_preserves_existing_tls_files() { + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let old_bundle = generate_pki(&[]).expect("generate_pki"); + let new_bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + + write_local_bundle(&dir, &old_bundle, &paths).expect("write_local_bundle"); + std::fs::remove_dir_all(&paths.jwt_dir).expect("remove jwt dir"); + + write_local_jwt_bundle(&new_bundle, &paths).expect("write_local_jwt_bundle"); + + let read = read_local_bundle(&paths).expect("read_local_bundle"); + assert_eq!(read.ca_cert_pem, old_bundle.ca_cert_pem); + assert_eq!(read.server_cert_pem, old_bundle.server_cert_pem); + assert_eq!(read.client_cert_pem, old_bundle.client_cert_pem); + assert_eq!(read.jwt_key_id, new_bundle.jwt_key_id); + assert_eq!(read.jwt_public_key_pem, new_bundle.jwt_public_key_pem); + } + + #[test] + fn write_local_tls_bundle_preserves_existing_jwt_files() { + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let old_bundle = generate_pki(&[]).expect("generate_pki"); + let new_bundle = generate_pki(&["extra.example.test".to_string()]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + + write_local_bundle(&dir, &old_bundle, &paths).expect("write_local_bundle"); + write_local_tls_bundle(&new_bundle, &paths).expect("write_local_tls_bundle"); + + let read = read_local_bundle(&paths).expect("read_local_bundle"); + assert_eq!(read.ca_cert_pem, new_bundle.ca_cert_pem); + assert_eq!(read.server_cert_pem, new_bundle.server_cert_pem); + assert_eq!(read.client_cert_pem, new_bundle.client_cert_pem); + assert_eq!(read.jwt_key_id, old_bundle.jwt_key_id); + assert_eq!(read.jwt_public_key_pem, old_bundle.jwt_public_key_pem); + } + + #[test] + fn missing_required_server_sans_detects_new_required_name() { + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + + write_local_bundle(&dir, &bundle, &paths).expect("write_local_bundle"); + + assert!( + missing_required_server_sans(&paths, &[]) + .unwrap() + .is_empty() + ); + + let missing = + missing_required_server_sans(&paths, &["future.example.test".to_string()]).unwrap(); + assert_eq!( + missing, + vec![CertSan::Dns("future.example.test".to_string())] + ); + } + + #[test] + fn read_local_bundle_uses_existing_files() { + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + + write_local_bundle(&dir, &bundle, &paths).expect("write_local_bundle"); + + let read = read_local_bundle(&paths).expect("read_local_bundle"); + assert_eq!(read.ca_cert_pem, bundle.ca_cert_pem); + assert_eq!(read.client_cert_pem, bundle.client_cert_pem); + assert_eq!(read.client_key_pem, bundle.client_key_pem); + } + + #[cfg(unix)] + #[test] + fn write_local_bundle_sets_owner_only_on_keys() { + use std::os::unix::fs::PermissionsExt; + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + + write_local_bundle(&dir, &bundle, &paths).expect("write_local_bundle"); + + for key in [&paths.ca_key, &paths.server_key, &paths.client_key] { + let mode = std::fs::metadata(key).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o600, "key {} has mode {:o}", key.display(), mode); + } + } + + #[test] + fn write_local_bundle_recovers_from_stale_temp_dir() { + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let stale = sibling_temp_dir(&dir); + std::fs::create_dir_all(&stale).unwrap(); + std::fs::write(stale.join("garbage"), "stale").unwrap(); + + let bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + write_local_bundle(&dir, &bundle, &paths).expect("write_local_bundle"); + + assert!(paths.ca_crt.is_file()); + assert!(!stale.exists(), "stale temp dir should be removed"); + } +} diff --git a/crates/openshell-server/src/cli.rs b/crates/openshell-server/src/cli.rs index 6eb1ab2db..b8d345f9e 100644 --- a/crates/openshell-server/src/cli.rs +++ b/crates/openshell-server/src/cli.rs @@ -3,26 +3,55 @@ //! Shared CLI entrypoint for the gateway binaries. -use clap::{Command, CommandFactory, FromArgMatches, Parser}; +use clap::parser::ValueSource; +use clap::{ArgAction, ArgMatches, Command, CommandFactory, FromArgMatches, Parser}; use miette::{IntoDiagnostic, Result}; use openshell_core::ComputeDriverKind; -use openshell_core::config::{ - DEFAULT_DOCKER_NETWORK_NAME, DEFAULT_SERVER_PORT, DEFAULT_SSH_HANDSHAKE_SKEW_SECS, - DEFAULT_SSH_PORT, -}; +use openshell_core::config::DEFAULT_SERVER_PORT; use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; -use tracing::info; +use tracing::{info, warn}; use tracing_subscriber::EnvFilter; +use crate::certgen; use crate::compute::{DockerComputeConfig, VmComputeConfig}; +use crate::config_file::{self, ConfigFile, GatewayFileSection}; +use crate::defaults::{self, LocalTlsPaths}; use crate::{run_server, tracing_bus::TracingLogBus}; /// `OpenShell` gateway process - gRPC and HTTP server with protocol multiplexing. +/// +/// Top-level CLI. When invoked without a subcommand the binary runs the +/// gateway server using `RunArgs`. The `generate-certs` subcommand is used by +/// the Helm pre-install hook to bootstrap mTLS Secrets. #[derive(Parser, Debug)] #[command(version = openshell_core::VERSION)] #[command(about = "OpenShell gRPC/HTTP server", long_about = None)] -struct Args { +struct Cli { + #[command(subcommand)] + command: Option, + + #[command(flatten)] + run: RunArgs, +} + +#[derive(clap::Subcommand, Debug)] +enum Commands { + /// Generate mTLS PKI and write Kubernetes Secrets (Helm pre-install hook). + GenerateCerts(certgen::CertgenArgs), +} + +#[derive(clap::Args, Debug)] +#[allow(clippy::struct_excessive_bools)] +struct RunArgs { + /// Path to a TOML configuration file (see RFC 0003). + /// + /// When set, gateway-wide settings and per-driver tables are read from + /// the file. Gateway command-line flags and `OPENSHELL_*` environment + /// variables continue to take precedence over gateway file values. + #[arg(long, env = "OPENSHELL_GATEWAY_CONFIG")] + config: Option, + /// IP address to bind the server, health, and metrics listeners to. #[arg(long, default_value = "127.0.0.1", env = "OPENSHELL_BIND_ADDRESS")] bind_address: IpAddr, @@ -58,8 +87,12 @@ struct Args { tls_client_ca: Option, /// Database URL for persistence. - #[arg(long, env = "OPENSHELL_DB_URL", required = true)] - db_url: String, + /// + /// When unset, the gateway stores state under the `XDG` state + /// directory. Kept as an Option at the clap layer so the `generate-certs` + /// subcommand can run without gateway runtime defaults. + #[arg(long, env = "OPENSHELL_DB_URL")] + db_url: Option, /// Compute drivers configured for this gateway. /// @@ -67,8 +100,8 @@ struct Args { /// `kubernetes,podman`. The configuration format is future-proofed for /// multiple drivers, but the gateway currently requires exactly one. /// When unset, the gateway auto-detects the driver based on the runtime - /// environment (Kubernetes → Podman → Docker). VM is never auto-detected - /// and requires explicit configuration. + /// environment (Kubernetes → Podman → Docker CLI or socket). VM is never + /// auto-detected and requires explicit configuration. #[arg( long, alias = "driver", @@ -78,166 +111,31 @@ struct Args { )] drivers: Vec, - /// Kubernetes namespace for sandboxes. - #[arg(long, env = "OPENSHELL_SANDBOX_NAMESPACE", default_value = "default")] - sandbox_namespace: String, - - /// Default container image for sandboxes. - #[arg(long, env = "OPENSHELL_SANDBOX_IMAGE")] - sandbox_image: Option, - - /// Kubernetes `imagePullPolicy` for sandbox pods (Always, `IfNotPresent`, Never). - #[arg(long, env = "OPENSHELL_SANDBOX_IMAGE_PULL_POLICY")] - sandbox_image_pull_policy: Option, - - /// gRPC endpoint for sandboxes to callback to `OpenShell`. - /// This should be reachable from within the Kubernetes cluster. - #[arg(long, env = "OPENSHELL_GRPC_ENDPOINT")] - grpc_endpoint: Option, - - /// Public host for the SSH gateway. - #[arg(long, env = "OPENSHELL_SSH_GATEWAY_HOST", default_value = "127.0.0.1")] - ssh_gateway_host: String, - - /// Public port for the SSH gateway. - #[arg(long, env = "OPENSHELL_SSH_GATEWAY_PORT", default_value_t = DEFAULT_SERVER_PORT)] - ssh_gateway_port: u16, - - /// HTTP path for SSH CONNECT/upgrade. - #[arg( - long, - env = "OPENSHELL_SSH_CONNECT_PATH", - default_value = "/connect/ssh" - )] - ssh_connect_path: String, - - /// SSH port inside sandbox pods. - #[arg(long, env = "OPENSHELL_SANDBOX_SSH_PORT", default_value_t = DEFAULT_SSH_PORT)] - sandbox_ssh_port: u16, - /// Shared secret for gateway-to-sandbox SSH handshake. - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SECRET")] - ssh_handshake_secret: Option, - - /// Allowed clock skew in seconds for SSH handshake. - #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SKEW_SECS", default_value_t = DEFAULT_SSH_HANDSHAKE_SKEW_SECS)] - ssh_handshake_skew_secs: u64, - - /// Kubernetes secret name containing client TLS materials for sandbox pods. - #[arg(long, env = "OPENSHELL_CLIENT_TLS_SECRET_NAME")] - client_tls_secret_name: Option, - - /// Host gateway IP for sandbox pod hostAliases. - /// When set, sandbox pods get hostAliases entries mapping - /// host.docker.internal and host.openshell.internal to this IP. - #[arg(long, env = "OPENSHELL_HOST_GATEWAY_IP")] - host_gateway_ip: Option, - - /// Working directory for VM driver sandbox state. - #[arg( - long, - env = "OPENSHELL_VM_DRIVER_STATE_DIR", - default_value_os_t = VmComputeConfig::default_state_dir() - )] - vm_driver_state_dir: PathBuf, - - /// Directory searched for compute-driver binaries (e.g. - /// `openshell-driver-vm`) when an explicit binary override isn't - /// configured. When unset, the gateway searches - /// `$HOME/.local/libexec/openshell`, `/usr/libexec/openshell`, - /// `/usr/local/libexec/openshell`, `/usr/local/libexec`, then a sibling - /// of the gateway binary. - #[arg(long, env = "OPENSHELL_DRIVER_DIR")] - driver_dir: Option, - - /// libkrun log level used by the VM helper. - #[arg( - long, - env = "OPENSHELL_VM_KRUN_LOG_LEVEL", - default_value_t = VmComputeConfig::default_krun_log_level() - )] - vm_krun_log_level: u32, - - /// Default vCPU count for VM sandboxes. - #[arg( - long, - env = "OPENSHELL_VM_DRIVER_VCPUS", - default_value_t = VmComputeConfig::default_vcpus() - )] - vm_vcpus: u8, - - /// Default memory allocation for VM sandboxes, in MiB. - #[arg( - long, - env = "OPENSHELL_VM_DRIVER_MEM_MIB", - default_value_t = VmComputeConfig::default_mem_mib() - )] - vm_mem_mib: u32, - - /// CA certificate installed into VM sandboxes for gateway mTLS. - #[arg(long, env = "OPENSHELL_VM_TLS_CA")] - vm_tls_ca: Option, - - /// Client certificate installed into VM sandboxes for gateway mTLS. - #[arg(long, env = "OPENSHELL_VM_TLS_CERT")] - vm_tls_cert: Option, - - /// Client private key installed into VM sandboxes for gateway mTLS. - #[arg(long, env = "OPENSHELL_VM_TLS_KEY")] - vm_tls_key: Option, - - /// Linux `openshell-sandbox` binary bind-mounted into Docker sandboxes. - /// - /// When unset the gateway falls back to (in order) a sibling - /// `openshell-sandbox` next to the gateway binary, a local cargo build, - /// or extracting the binary from `--docker-supervisor-image`. - #[arg(long, env = "OPENSHELL_DOCKER_SUPERVISOR_BIN")] - docker_supervisor_bin: Option, - - /// Image the Docker driver pulls to extract the Linux - /// `openshell-sandbox` binary when no explicit `--docker-supervisor-bin` - /// override or local build is available. Defaults to - /// `ghcr.io/nvidia/openshell/supervisor:`. - #[arg(long, env = "OPENSHELL_DOCKER_SUPERVISOR_IMAGE")] - docker_supervisor_image: Option, - - /// CA certificate bind-mounted into Docker sandboxes for gateway mTLS. - #[arg(long, env = "OPENSHELL_DOCKER_TLS_CA")] - docker_tls_ca: Option, - - /// Client certificate bind-mounted into Docker sandboxes for gateway mTLS. - #[arg(long, env = "OPENSHELL_DOCKER_TLS_CERT")] - docker_tls_cert: Option, - - /// Client private key bind-mounted into Docker sandboxes for gateway mTLS. - #[arg(long, env = "OPENSHELL_DOCKER_TLS_KEY")] - docker_tls_key: Option, - - /// Docker bridge network used for sandbox containers. - #[arg( - long, - env = "OPENSHELL_DOCKER_NETWORK_NAME", - default_value = DEFAULT_DOCKER_NETWORK_NAME - )] - docker_network_name: String, - /// Disable TLS entirely — listen on plaintext HTTP. /// Use this when the gateway sits behind a reverse proxy or tunnel /// (e.g. Cloudflare Tunnel) that terminates TLS at the edge. #[arg(long, env = "OPENSHELL_DISABLE_TLS")] disable_tls: bool, - /// Disable gateway authentication (mTLS client certificate requirement). - /// When set, the TLS handshake accepts connections without a client - /// certificate. Ignored when --disable-tls is set. - #[arg(long, env = "OPENSHELL_DISABLE_GATEWAY_AUTH")] - disable_gateway_auth: bool, - /// OIDC issuer URL for JWT-based authentication. /// When set, the server validates `authorization: Bearer` tokens on gRPC /// requests against the issuer's JWKS endpoint. #[arg(long, env = "OPENSHELL_OIDC_ISSUER")] oidc_issuer: Option, + /// Enable mTLS client certificate authentication for local single-user gateways. + /// + /// When unset, this defaults on for Docker, Podman, and VM gateways that + /// have client certificate verification configured and no OIDC issuer. + /// Kubernetes deployments must use OIDC or fronting-proxy auth instead. + #[arg( + long = "enable-mtls-auth", + env = "OPENSHELL_ENABLE_MTLS_AUTH", + default_value_t = false, + action = ArgAction::Set + )] + enable_mtls_auth: bool, + /// Expected OIDC audience claim (typically the client ID). #[arg(long, env = "OPENSHELL_OIDC_AUDIENCE", default_value = "openshell-cli")] oidc_audience: String, @@ -276,10 +174,28 @@ struct Args { /// Keycloak: "scope". Okta: "scp". Leave empty to disable scope enforcement. #[arg(long, env = "OPENSHELL_OIDC_SCOPES_CLAIM", default_value = "")] oidc_scopes_claim: String, + + /// Subject Alternative Names configured on the gateway server certificate. + /// Wildcard DNS SANs also enable sandbox service URLs under that domain. + #[arg( + long = "server-san", + env = "OPENSHELL_SERVER_SAN", + value_delimiter = ',' + )] + server_sans: Vec, + + /// Enable plaintext HTTP routing for loopback sandbox service URLs. + #[arg( + long, + env = "OPENSHELL_ENABLE_LOOPBACK_SERVICE_HTTP", + default_value_t = true, + action = ArgAction::Set + )] + enable_loopback_service_http: bool, } pub fn command() -> Command { - Args::command() + Cli::command() .name("openshell-gateway") .bin_name("openshell-gateway") } @@ -289,12 +205,32 @@ pub async fn run_cli() -> Result<()> { .install_default() .map_err(|e| miette::miette!("failed to install rustls crypto provider: {e:?}"))?; - let args = Args::from_arg_matches(&command().get_matches()).expect("clap validated args"); + let matches = command().get_matches(); + let cli = Cli::from_arg_matches(&matches).expect("clap validated args"); - Box::pin(run_from_args(args)).await + match cli.command { + Some(Commands::GenerateCerts(args)) => certgen::run(args).await, + None => Box::pin(run_from_args(cli.run, matches)).await, + } } -async fn run_from_args(args: Args) -> Result<()> { +async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> { + // Load TOML when explicitly requested, or from the default XDG location + // when that file exists. Missing default config is not an error: runtime + // defaults and OPENSHELL_* env vars are enough for package-managed starts. + let config_path = resolve_config_path(&args)?; + let file: Option = if let Some(path) = config_path { + Some(config_file::load(&path).map_err(|e| miette::miette!("{e}"))?) + } else { + None + }; + if let Some(file) = file.as_ref() { + merge_file_into_args(&mut args, &file.openshell.gateway, &matches); + } + + let local_tls = apply_runtime_defaults(&mut args)?; + let local_jwt = defaults::complete_local_jwt_config()?; + let tracing_log_bus = TracingLogBus::new(); tracing_log_bus.install_subscriber( EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)), @@ -302,154 +238,495 @@ async fn run_from_args(args: Args) -> Result<()> { let bind = SocketAddr::new(args.bind_address, args.port); + let has_client_ca = args.tls_client_ca.is_some(); + let has_oidc = args.oidc_issuer.is_some(); + let mtls_auth_enabled = resolve_mtls_auth_enabled(&args, &matches, file.as_ref()); + + if args.disable_tls && has_client_ca { + return Err(miette::miette!( + "--disable-tls and --tls-client-ca are mutually exclusive. Client certificate verification requires that TLS be enabled." + )); + } + if mtls_auth_enabled && args.disable_tls { + return Err(miette::miette!( + "mTLS user authentication requires TLS. Remove --disable-tls or disable --enable-mtls-auth." + )); + } + if mtls_auth_enabled && !has_client_ca { + return Err(miette::miette!( + "mTLS user authentication requires --tls-client-ca so client certificates can be verified." + )); + } + if mtls_auth_enabled + && matches!( + effective_single_driver(&args), + Some(ComputeDriverKind::Kubernetes) + ) + { + return Err(miette::miette!( + "mTLS user authentication is not supported with the Kubernetes compute driver. Configure OIDC or a trusted fronting proxy for user authentication." + )); + } + let tls = if args.disable_tls { None } else { - let cert_path = args.tls_cert.ok_or_else(|| { + let cert_path = args.tls_cert.clone().ok_or_else(|| { miette::miette!( "--tls-cert is required when TLS is enabled (use --disable-tls to skip)" ) })?; - let key_path = args.tls_key.ok_or_else(|| { + let key_path = args.tls_key.clone().ok_or_else(|| { miette::miette!("--tls-key is required when TLS is enabled (use --disable-tls to skip)") })?; - let client_ca_path = args.tls_client_ca.ok_or_else(|| { - miette::miette!( - "--tls-client-ca is required when TLS is enabled (use --disable-tls to skip)" - ) - })?; Some(openshell_core::TlsConfig { cert_path, key_path, - client_ca_path, - allow_unauthenticated: args.disable_gateway_auth, + require_client_auth: has_client_ca && !has_oidc, + client_ca_path: args.tls_client_ca.clone(), }) }; + let db_url = args + .db_url + .clone() + .expect("runtime defaults populate db_url"); + let mut config = openshell_core::Config::new(tls) .with_bind_address(bind) .with_log_level(&args.log_level); + if let Some(auth) = file.as_ref().and_then(|f| f.openshell.gateway.auth.clone()) { + config.auth = auth; + } + config.mtls_auth.enabled = mtls_auth_enabled; + + // Listener addresses for the health and metrics endpoints. The file may + // pin a different interface than the main listener (e.g. health on + // 127.0.0.1 while gRPC binds 0.0.0.0); the full `SocketAddr` from the + // file is preserved unless CLI/env supplied an explicit `--health-port` / + // `--metrics-port`, in which case the port overrides the file value + // while the IP defaults to `args.bind_address`. + let file_gateway = file.as_ref().map(|f| &f.openshell.gateway); + let health_bind = resolve_aux_listener( + args.bind_address, + args.health_port, + &matches, + "health_port", + || file_gateway.and_then(|g| g.health_bind_address), + ); + let metrics_bind = resolve_aux_listener( + args.bind_address, + args.metrics_port, + &matches, + "metrics_port", + || file_gateway.and_then(|g| g.metrics_bind_address), + ); - if args.health_port != 0 { - if args.port == args.health_port { + if let Some(addr) = health_bind { + if args.port == addr.port() { return Err(miette::miette!( "--port and --health-port must be different (both set to {})", args.port )); } - let health_bind = SocketAddr::new(args.bind_address, args.health_port); - config = config.with_health_bind_address(health_bind); + config = config.with_health_bind_address(addr); } - if args.metrics_port != 0 { - if args.port == args.metrics_port { + if let Some(addr) = metrics_bind { + if args.port == addr.port() { return Err(miette::miette!( "--port and --metrics-port must be different (both set to {})", args.port )); } - if args.health_port != 0 && args.health_port == args.metrics_port { + if let Some(health) = health_bind + && health.port() == addr.port() + { return Err(miette::miette!( "--health-port and --metrics-port must be different (both set to {})", - args.health_port + health.port() )); } - let metrics_bind = SocketAddr::new(args.bind_address, args.metrics_port); - config = config.with_metrics_bind_address(metrics_bind); + config = config.with_metrics_bind_address(addr); } config = config - .with_database_url(args.db_url) - .with_compute_drivers(args.drivers) - .with_sandbox_namespace(args.sandbox_namespace) - .with_ssh_gateway_host(args.ssh_gateway_host) - .with_ssh_gateway_port(args.ssh_gateway_port) - .with_ssh_connect_path(args.ssh_connect_path) - .with_sandbox_ssh_port(args.sandbox_ssh_port) - .with_ssh_handshake_skew_secs(args.ssh_handshake_skew_secs); - - if let Some(image) = args.sandbox_image { - config = config.with_sandbox_image(image); + .with_database_url(db_url) + .with_compute_drivers(args.drivers.clone()) + .with_server_sans(args.server_sans.clone()) + .with_loopback_service_http(args.enable_loopback_service_http); + + if let Some(ttl) = file + .as_ref() + .and_then(|f| f.openshell.gateway.ssh_session_ttl_secs) + { + config = config.with_ssh_session_ttl_secs(ttl); } - if let Some(policy) = args.sandbox_image_pull_policy { - config = config.with_sandbox_image_pull_policy(policy); + if let Some(issuer) = args.oidc_issuer.clone() { + config = config.with_oidc(openshell_core::OidcConfig { + issuer, + audience: args.oidc_audience.clone(), + jwks_ttl_secs: args.oidc_jwks_ttl, + roles_claim: args.oidc_roles_claim.clone(), + admin_role: args.oidc_admin_role.clone(), + user_role: args.oidc_user_role.clone(), + scopes_claim: args.oidc_scopes_claim.clone(), + }); } - if let Some(endpoint) = args.grpc_endpoint { - config = config.with_grpc_endpoint(endpoint); + // `gateway_jwt` is configured through TOML in cluster deployments. Local + // package-managed starts also auto-detect the JWT bundle written next to + // the generated TLS bundle so upgrades pick up sandbox auth without a + // user-authored config file. + if let Some(jwt) = file + .as_ref() + .and_then(|f| f.openshell.gateway.gateway_jwt.clone()) + { + config.gateway_jwt = Some(jwt); + } else if let Some(jwt) = local_jwt { + config.gateway_jwt = Some(jwt); } - if let Some(secret) = args.ssh_handshake_secret { - config = config.with_ssh_handshake_secret(secret); + let vm_config = build_vm_config( + file.as_ref(), + local_tls.as_ref(), + args.disable_tls, + args.port, + )?; + let docker_config = build_docker_config(file.as_ref(), local_tls.as_ref())?; + + if args.disable_tls { + warn!("TLS disabled — listening on plaintext HTTP"); + } else { + info!("TLS enabled — listening on encrypted HTTPS"); } - if let Some(name) = args.client_tls_secret_name { - config = config.with_client_tls_secret_name(name); + if has_client_ca { + info!("TLS client certificate verification enabled"); + } + if config.mtls_auth.enabled { + info!("mTLS user authentication enabled"); + } + if has_oidc { + info!("OIDC authentication enabled"); + } + if config.auth.allow_unauthenticated_users { + warn!( + "Unauthenticated user access enabled — only use this for trusted local development or a fully trusted fronting proxy" + ); } - if let Some(ip) = args.host_gateway_ip { - config = config.with_host_gateway_ip(ip); + if !config.auth.allow_unauthenticated_users + && !config.mtls_auth.enabled + && !has_oidc + && config.gateway_jwt.is_none() + { + warn!( + "Neither mTLS user auth nor OIDC nor sandbox JWT auth is configured — \ + the gateway has no authentication mechanism" + ); } - if let Some(issuer) = args.oidc_issuer { - config = config.with_oidc(openshell_core::OidcConfig { - issuer, - audience: args.oidc_audience, - jwks_ttl_secs: args.oidc_jwks_ttl, - roles_claim: args.oidc_roles_claim, - admin_role: args.oidc_admin_role, - user_role: args.oidc_user_role, - scopes_claim: args.oidc_scopes_claim, - }); + info!(bind = %config.bind_address, "Starting OpenShell server"); + + Box::pin(run_server( + config, + vm_config, + docker_config, + file, + tracing_log_bus, + )) + .await + .into_diagnostic() +} + +fn parse_compute_driver(value: &str) -> std::result::Result { + value.parse() +} + +fn resolve_config_path(args: &RunArgs) -> Result> { + if let Some(path) = args.config.clone() { + return Ok(Some(path)); } - let vm_config = VmComputeConfig { - state_dir: args.vm_driver_state_dir, - driver_dir: args.driver_dir, - default_image: config.sandbox_image.clone(), - krun_log_level: args.vm_krun_log_level, - vcpus: args.vm_vcpus, - mem_mib: args.vm_mem_mib, - guest_tls_ca: args.vm_tls_ca, - guest_tls_cert: args.vm_tls_cert, - guest_tls_key: args.vm_tls_key, - }; + let default_path = defaults::default_gateway_config_path()?; + Ok(default_path.is_file().then_some(default_path)) +} - let docker_config = DockerComputeConfig { - supervisor_bin: args.docker_supervisor_bin, - supervisor_image: args.docker_supervisor_image, - guest_tls_ca: args.docker_tls_ca, - guest_tls_cert: args.docker_tls_cert, - guest_tls_key: args.docker_tls_key, - network_name: args.docker_network_name, +fn apply_runtime_defaults(args: &mut RunArgs) -> Result> { + let local_tls = if args.disable_tls { + None + } else { + defaults::complete_local_tls_paths()? }; - if args.disable_tls { - info!("TLS disabled — listening on plaintext HTTP"); - } else if args.disable_gateway_auth { - info!("Gateway auth disabled — accepting connections without client certificates"); + if args.db_url.is_none() { + args.db_url = Some(defaults::default_database_url()?); } - info!(bind = %config.bind_address, "Starting OpenShell server"); + if !args.disable_tls + && args.tls_cert.is_none() + && args.tls_key.is_none() + && args.tls_client_ca.is_none() + && let Some(paths) = &local_tls + { + args.tls_cert = Some(paths.server_cert.clone()); + args.tls_key = Some(paths.server_key.clone()); + args.tls_client_ca = Some(paths.ca.clone()); + } - run_server(config, vm_config, docker_config, tracing_log_bus) - .await - .into_diagnostic() + Ok(local_tls) } -fn parse_compute_driver(value: &str) -> std::result::Result { - value.parse() +/// Returns `true` when an argument's value came from clap's built-in default +/// (or was never supplied at all). When the predicate is `true`, the loader +/// is free to replace the value with one read from the TOML config file. +fn arg_defaulted(matches: &ArgMatches, id: &str) -> bool { + matches!( + matches.value_source(id), + None | Some(ValueSource::DefaultValue) + ) +} + +/// Resolve the bind address for an auxiliary listener (health / metrics). +/// +/// The precedence is: +/// 1. CLI flag or `OPENSHELL_*` env var explicitly set on the corresponding +/// port argument → `bind_address:port` (port from CLI, IP from the main +/// listener interface). +/// 2. Full `SocketAddr` from `[openshell.gateway].{health,metrics}_bind_address` +/// → used as-is (this is how operators pin a loopback-only health port +/// on a gateway whose gRPC listener is bound publicly). +/// 3. Otherwise the listener is disabled (returns `None`). +fn resolve_aux_listener( + bind_ip: IpAddr, + port_arg: u16, + matches: &ArgMatches, + port_id: &str, + file_addr: impl FnOnce() -> Option, +) -> Option { + if !arg_defaulted(matches, port_id) { + if port_arg == 0 { + return None; + } + return Some(SocketAddr::new(bind_ip, port_arg)); + } + if let Some(addr) = file_addr() { + return Some(addr); + } + if port_arg == 0 { + None + } else { + Some(SocketAddr::new(bind_ip, port_arg)) + } +} + +/// Apply gateway-wide values from `[openshell.gateway]` onto `RunArgs` for +/// every argument that is still sourced from clap's built-in default. +/// +/// The function intentionally does not touch `database_url` — that secret is +/// env-only and the loader already rejected it when it appears in the file. +fn merge_file_into_args(args: &mut RunArgs, file: &GatewayFileSection, matches: &ArgMatches) { + if let Some(addr) = file.bind_address { + if arg_defaulted(matches, "bind_address") { + args.bind_address = addr.ip(); + } + if arg_defaulted(matches, "port") { + args.port = addr.port(); + } + } + // Note: file's full health_bind_address / metrics_bind_address are + // consumed in `run_from_args`'s listener-resolution block so the IP + // half of the SocketAddr is preserved. Copying only the port here + // would silently relocate a loopback-intended listener onto the + // public bind address. + if let Some(level) = &file.log_level + && arg_defaulted(matches, "log_level") + { + args.log_level.clone_from(level); + } + if let Some(drivers) = &file.compute_drivers + && arg_defaulted(matches, "drivers") + { + args.drivers.clone_from(drivers); + } + if let Some(sans) = &file.server_sans + && args.server_sans.is_empty() + && arg_defaulted(matches, "server_sans") + { + args.server_sans.clone_from(sans); + } + if let Some(enabled) = file.enable_loopback_service_http + && arg_defaulted(matches, "enable_loopback_service_http") + { + args.enable_loopback_service_http = enabled; + } + if let Some(mtls_auth) = &file.mtls_auth + && arg_defaulted(matches, "enable_mtls_auth") + { + args.enable_mtls_auth = mtls_auth.enabled; + } + if let Some(disabled) = file.disable_tls + && arg_defaulted(matches, "disable_tls") + { + args.disable_tls = disabled; + } + // TLS gateway listener fields + if let Some(tls) = &file.tls { + if args.tls_cert.is_none() && arg_defaulted(matches, "tls_cert") { + args.tls_cert = Some(tls.cert_path.clone()); + } + if args.tls_key.is_none() && arg_defaulted(matches, "tls_key") { + args.tls_key = Some(tls.key_path.clone()); + } + if args.tls_client_ca.is_none() && arg_defaulted(matches, "tls_client_ca") { + args.tls_client_ca.clone_from(&tls.client_ca_path); + } + } + // OIDC fields + if let Some(oidc) = &file.oidc { + if args.oidc_issuer.is_none() && arg_defaulted(matches, "oidc_issuer") { + args.oidc_issuer = Some(oidc.issuer.clone()); + } + if arg_defaulted(matches, "oidc_audience") { + args.oidc_audience.clone_from(&oidc.audience); + } + if arg_defaulted(matches, "oidc_jwks_ttl") { + args.oidc_jwks_ttl = oidc.jwks_ttl_secs; + } + if arg_defaulted(matches, "oidc_roles_claim") { + args.oidc_roles_claim.clone_from(&oidc.roles_claim); + } + if arg_defaulted(matches, "oidc_admin_role") { + args.oidc_admin_role.clone_from(&oidc.admin_role); + } + if arg_defaulted(matches, "oidc_user_role") { + args.oidc_user_role.clone_from(&oidc.user_role); + } + if arg_defaulted(matches, "oidc_scopes_claim") { + args.oidc_scopes_claim.clone_from(&oidc.scopes_claim); + } + } +} + +fn effective_single_driver(args: &RunArgs) -> Option { + match args.drivers.as_slice() { + [] => openshell_core::config::detect_driver(), + [driver] => Some(*driver), + _ => None, + } +} + +fn resolve_mtls_auth_enabled( + args: &RunArgs, + matches: &ArgMatches, + file: Option<&ConfigFile>, +) -> bool { + let file_configured = file + .and_then(|f| f.openshell.gateway.mtls_auth.as_ref()) + .is_some(); + if file_configured || !arg_defaulted(matches, "enable_mtls_auth") { + return args.enable_mtls_auth; + } + + if args.disable_tls || args.tls_client_ca.is_none() || args.oidc_issuer.is_some() { + return false; + } + + matches!( + effective_single_driver(args), + Some(ComputeDriverKind::Docker | ComputeDriverKind::Podman | ComputeDriverKind::Vm) + ) +} + +/// Build [`VmComputeConfig`] from the `[openshell.drivers.vm]` table +/// inherited from `[openshell.gateway]`. +fn build_vm_config( + file: Option<&ConfigFile>, + local_tls: Option<&LocalTlsPaths>, + disable_tls: bool, + gateway_port: u16, +) -> Result { + let mut cfg = if let Some(file) = file { + let merged = config_file::driver_table( + ComputeDriverKind::Vm, + &file.openshell.gateway, + file.openshell.drivers.get("vm"), + ); + merged + .try_into::() + .map_err(|e| miette::miette!("invalid [openshell.drivers.vm] table: {e}"))? + } else { + VmComputeConfig::default() + }; + + if cfg.state_dir.as_os_str().is_empty() { + cfg.state_dir = VmComputeConfig::default_state_dir(); + } + if cfg.grpc_endpoint.trim().is_empty() && (disable_tls || local_tls.is_some()) { + let scheme = if disable_tls { "http" } else { "https" }; + cfg.grpc_endpoint = format!("{scheme}://127.0.0.1:{gateway_port}"); + } + apply_guest_tls_defaults( + &mut cfg.guest_tls_ca, + &mut cfg.guest_tls_cert, + &mut cfg.guest_tls_key, + local_tls, + ); + Ok(cfg) +} + +/// Build [`DockerComputeConfig`] using the same inheritance pattern as +/// [`build_vm_config`]. +fn build_docker_config( + file: Option<&ConfigFile>, + local_tls: Option<&LocalTlsPaths>, +) -> Result { + let mut cfg = if let Some(file) = file { + let merged = config_file::driver_table( + ComputeDriverKind::Docker, + &file.openshell.gateway, + file.openshell.drivers.get("docker"), + ); + merged + .try_into::() + .map_err(|e| miette::miette!("invalid [openshell.drivers.docker] table: {e}"))? + } else { + DockerComputeConfig::default() + }; + apply_guest_tls_defaults( + &mut cfg.guest_tls_ca, + &mut cfg.guest_tls_cert, + &mut cfg.guest_tls_key, + local_tls, + ); + Ok(cfg) +} + +fn apply_guest_tls_defaults( + ca: &mut Option, + cert: &mut Option, + key: &mut Option, + local_tls: Option<&LocalTlsPaths>, +) { + if ca.is_none() + && cert.is_none() + && key.is_none() + && let Some(paths) = local_tls + { + *ca = Some(paths.ca.clone()); + *cert = Some(paths.client_cert.clone()); + *key = Some(paths.client_key.clone()); + } } #[cfg(test)] mod tests { - use super::{Args, command}; + use super::{Cli, command}; + use crate::TEST_ENV_LOCK as ENV_LOCK; use clap::Parser; use std::net::{IpAddr, Ipv4Addr}; - use std::sync::{LazyLock, Mutex}; - - static ENV_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); struct EnvVarGuard { key: &'static str, @@ -507,9 +784,9 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let _guard = EnvVarGuard::remove("OPENSHELL_BIND_ADDRESS"); - let args = - Args::try_parse_from(["openshell-gateway", "--db-url", "sqlite::memory:"]).unwrap(); - assert_eq!(args.bind_address, IpAddr::V4(Ipv4Addr::LOCALHOST)); + let cli = + Cli::try_parse_from(["openshell-gateway", "--db-url", "sqlite::memory:"]).unwrap(); + assert_eq!(cli.run.bind_address, IpAddr::V4(Ipv4Addr::LOCALHOST)); } #[test] @@ -518,7 +795,7 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let _guard = EnvVarGuard::remove("OPENSHELL_BIND_ADDRESS"); - let args = Args::try_parse_from([ + let cli = Cli::try_parse_from([ "openshell-gateway", "--db-url", "sqlite::memory:", @@ -526,7 +803,7 @@ mod tests { "127.0.0.1", ]) .unwrap(); - assert_eq!(args.bind_address, IpAddr::V4(Ipv4Addr::LOCALHOST)); + assert_eq!(cli.run.bind_address, IpAddr::V4(Ipv4Addr::LOCALHOST)); } #[test] @@ -536,9 +813,665 @@ mod tests { .unwrap_or_else(std::sync::PoisonError::into_inner); let _guard = EnvVarGuard::set("OPENSHELL_BIND_ADDRESS", "0.0.0.0"); - let args = Args::try_parse_from(["openshell-gateway", "--db-url", "sqlite::memory:"]) + let cli = Cli::try_parse_from(["openshell-gateway", "--db-url", "sqlite::memory:"]) .expect("env should provide bind address"); + assert_eq!(cli.run.bind_address, IpAddr::V4(Ipv4Addr::UNSPECIFIED)); + } + + #[test] + fn command_enables_loopback_service_http_by_default() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _guard = EnvVarGuard::remove("OPENSHELL_ENABLE_LOOPBACK_SERVICE_HTTP"); + + let cli = + Cli::try_parse_from(["openshell-gateway", "--db-url", "sqlite::memory:"]).unwrap(); + + assert!(cli.run.enable_loopback_service_http); + } + + #[test] + fn command_disables_loopback_service_http_with_false_value() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _guard = EnvVarGuard::remove("OPENSHELL_ENABLE_LOOPBACK_SERVICE_HTTP"); + + let cli = Cli::try_parse_from([ + "openshell-gateway", + "--db-url", + "sqlite::memory:", + "--enable-loopback-service-http=false", + ]) + .unwrap(); + + assert!(!cli.run.enable_loopback_service_http); + } + + #[test] + fn command_reads_loopback_service_http_from_env() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _guard = EnvVarGuard::set("OPENSHELL_ENABLE_LOOPBACK_SERVICE_HTTP", "false"); + + let cli = + Cli::try_parse_from(["openshell-gateway", "--db-url", "sqlite::memory:"]).unwrap(); + + assert!(!cli.run.enable_loopback_service_http); + } + + #[test] + fn command_reads_server_san_from_env() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _guard = EnvVarGuard::set("OPENSHELL_SERVER_SAN", "*.apps.example.com"); + + let cli = + Cli::try_parse_from(["openshell-gateway", "--db-url", "sqlite::memory:"]).unwrap(); + + assert_eq!(cli.run.server_sans, vec!["*.apps.example.com".to_string()]); + } + + #[test] + fn command_reads_mtls_auth_from_env() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _guard = EnvVarGuard::set("OPENSHELL_ENABLE_MTLS_AUTH", "true"); + + let cli = + Cli::try_parse_from(["openshell-gateway", "--db-url", "sqlite::memory:"]).unwrap(); + + assert!(cli.run.enable_mtls_auth); + } + + #[test] + fn command_rejects_removed_driver_flags() { + let err = command() + .try_get_matches_from([ + "openshell-gateway", + "--db-url", + "sqlite::memory:", + "--sandbox-image", + "example/sandbox:latest", + ]) + .expect_err("driver implementation flags should not be accepted"); + + assert_eq!(err.kind(), clap::error::ErrorKind::UnknownArgument); + } + + #[test] + fn command_rejects_removed_ssh_endpoint_flags() { + for flag in [ + "--ssh-gateway-host", + "--ssh-gateway-port", + "--sandbox-ssh-port", + ] { + let err = command() + .try_get_matches_from([ + "openshell-gateway", + "--db-url", + "sqlite::memory:", + flag, + "x", + ]) + .expect_err("SSH endpoint flags should not be accepted"); + + assert_eq!(err.kind(), clap::error::ErrorKind::UnknownArgument); + } + } + + #[test] + fn generate_certs_subcommand_parses_without_db_url() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g1 = EnvVarGuard::remove("OPENSHELL_DB_URL"); + let _g2 = EnvVarGuard::remove("POD_NAMESPACE"); + + let cli = Cli::try_parse_from([ + "openshell-gateway", + "generate-certs", + "--namespace", + "openshell", + "--server-secret-name", + "openshell-server-tls", + "--client-secret-name", + "openshell-client-tls", + "--jwt-secret-name", + "openshell-jwt-keys", + "--server-san", + "openshell.example.com", + "--server-san", + "10.0.0.1", + ]) + .expect("generate-certs should parse without --db-url"); + + assert!(matches!( + cli.command, + Some(super::Commands::GenerateCerts(_)) + )); + } + + #[test] + fn generate_certs_local_mode_parses_without_kube_flags() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g1 = EnvVarGuard::remove("OPENSHELL_DB_URL"); + let _g2 = EnvVarGuard::remove("POD_NAMESPACE"); + + let cli = Cli::try_parse_from([ + "openshell-gateway", + "generate-certs", + "--output-dir", + "/tmp/openshell-certgen", + ]) + .expect("--output-dir should make namespace/secret-name flags optional"); + + assert!(matches!( + cli.command, + Some(super::Commands::GenerateCerts(_)) + )); + } + + #[test] + fn bare_invocation_with_no_db_url_parses_for_runtime_defaults() { + // db_url is Option at the clap level so subcommand parsing + // does not require it. The Run path fills a default URL from XDG + // state when neither CLI nor env supplied one. + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g = EnvVarGuard::remove("OPENSHELL_DB_URL"); + + let cli = Cli::try_parse_from(["openshell-gateway"]).expect("parses without --db-url"); + assert!(cli.command.is_none()); + assert!(cli.run.db_url.is_none()); + } + + // ── Config-file merge tests ────────────────────────────────────────── + // + // `merge_file_into_args` is the bridge between `config_file::ConfigFile` + // and `RunArgs`. These cases lock in the precedence rule: + // + // CLI flag > OPENSHELL_* env var > TOML file > built-in default + // + // by exercising each combination on representative gateway fields. + + use super::{ConfigFile, merge_file_into_args}; + use clap::FromArgMatches; + + fn parse_with_args(argv: &[&str]) -> (super::RunArgs, clap::ArgMatches) { + let matches = command().try_get_matches_from(argv).expect("parses"); + let cli = Cli::from_arg_matches(&matches).expect("from arg matches"); + (cli.run, matches) + } + + fn config_file_from_toml(toml: &str) -> ConfigFile { + toml::from_str(toml).expect("valid TOML in test fixture") + } + + #[test] + fn default_config_path_is_loaded_only_when_present() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let tmp = tempfile::tempdir().unwrap(); + let _g1 = EnvVarGuard::remove("OPENSHELL_GATEWAY_CONFIG"); + let _g2 = EnvVarGuard::set("XDG_CONFIG_HOME", tmp.path().to_str().unwrap()); + + let (args, _) = parse_with_args(&["openshell-gateway"]); + assert_eq!(super::resolve_config_path(&args).unwrap(), None); + + let config = tmp.path().join("openshell").join("gateway.toml"); + std::fs::create_dir_all(config.parent().unwrap()).unwrap(); + std::fs::write(&config, "[openshell]\nversion = 1\n").unwrap(); + + assert_eq!(super::resolve_config_path(&args).unwrap(), Some(config)); + } + + #[test] + fn explicit_config_path_is_returned_even_when_missing() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g = EnvVarGuard::remove("OPENSHELL_GATEWAY_CONFIG"); + + let (args, _) = parse_with_args(&["openshell-gateway", "--config", "/tmp/missing.toml"]); + + assert_eq!( + super::resolve_config_path(&args).unwrap(), + Some(std::path::PathBuf::from("/tmp/missing.toml")) + ); + } + + #[test] + fn runtime_defaults_populate_database_url_from_xdg_state() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let tmp = tempfile::tempdir().unwrap(); + let _g1 = EnvVarGuard::remove("OPENSHELL_DB_URL"); + let _g2 = EnvVarGuard::set("XDG_STATE_HOME", tmp.path().to_str().unwrap()); + + let (mut args, _) = parse_with_args(&["openshell-gateway", "--disable-tls"]); + let local_tls = super::apply_runtime_defaults(&mut args).unwrap(); + + let expected = format!( + "sqlite:{}", + tmp.path().join("openshell/gateway/openshell.db").display() + ); + assert!(local_tls.is_none()); + assert_eq!(args.db_url.as_deref(), Some(expected.as_str())); + assert!(tmp.path().join("openshell/gateway").is_dir()); + } + + #[test] + fn runtime_defaults_use_complete_local_tls_bundle() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let state = tempfile::tempdir().unwrap(); + let tls = tempfile::tempdir().unwrap(); + let _g1 = EnvVarGuard::remove("OPENSHELL_DB_URL"); + let _g2 = EnvVarGuard::remove("OPENSHELL_TLS_CERT"); + let _g3 = EnvVarGuard::remove("OPENSHELL_TLS_KEY"); + let _g4 = EnvVarGuard::remove("OPENSHELL_TLS_CLIENT_CA"); + let _g5 = EnvVarGuard::remove("OPENSHELL_DISABLE_TLS"); + let _g6 = EnvVarGuard::set("XDG_STATE_HOME", state.path().to_str().unwrap()); + let _g7 = EnvVarGuard::set("OPENSHELL_LOCAL_TLS_DIR", tls.path().to_str().unwrap()); + + std::fs::create_dir_all(tls.path().join("server")).unwrap(); + std::fs::create_dir_all(tls.path().join("client")).unwrap(); + for rel in [ + "ca.crt", + "server/tls.crt", + "server/tls.key", + "client/tls.crt", + "client/tls.key", + ] { + std::fs::write(tls.path().join(rel), "pem").unwrap(); + } + + let (mut args, _) = parse_with_args(&["openshell-gateway"]); + let local_tls = super::apply_runtime_defaults(&mut args) + .unwrap() + .expect("complete bundle should be returned"); + + assert_eq!(args.tls_cert, Some(tls.path().join("server/tls.crt"))); + assert_eq!(args.tls_key, Some(tls.path().join("server/tls.key"))); + assert_eq!(args.tls_client_ca, Some(tls.path().join("ca.crt"))); + assert_eq!(local_tls.client_cert, tls.path().join("client/tls.crt")); + } + + #[test] + fn mtls_auth_auto_defaults_for_local_tls_driver() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _guard = EnvVarGuard::remove("OPENSHELL_ENABLE_MTLS_AUTH"); + + let (args, matches) = parse_with_args(&[ + "openshell-gateway", + "--db-url", + "sqlite::memory:", + "--drivers", + "docker", + "--tls-cert", + "/tmp/server.crt", + "--tls-key", + "/tmp/server.key", + "--tls-client-ca", + "/tmp/ca.crt", + ]); + + assert!(super::resolve_mtls_auth_enabled(&args, &matches, None)); + } + + #[test] + fn mtls_auth_does_not_auto_default_for_kubernetes_driver() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _guard = EnvVarGuard::remove("OPENSHELL_ENABLE_MTLS_AUTH"); + + let (args, matches) = parse_with_args(&[ + "openshell-gateway", + "--db-url", + "sqlite::memory:", + "--drivers", + "kubernetes", + "--tls-cert", + "/tmp/server.crt", + "--tls-key", + "/tmp/server.key", + "--tls-client-ca", + "/tmp/ca.crt", + ]); + + assert!(!super::resolve_mtls_auth_enabled(&args, &matches, None)); + } + + #[test] + fn file_mtls_auth_value_overrides_local_auto_default() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _guard = EnvVarGuard::remove("OPENSHELL_ENABLE_MTLS_AUTH"); + + let (mut args, matches) = parse_with_args(&[ + "openshell-gateway", + "--db-url", + "sqlite::memory:", + "--drivers", + "docker", + "--tls-cert", + "/tmp/server.crt", + "--tls-key", + "/tmp/server.key", + "--tls-client-ca", + "/tmp/ca.crt", + ]); + let file = config_file_from_toml( + r" +[openshell.gateway.mtls_auth] +enabled = false +", + ); + + merge_file_into_args(&mut args, &file.openshell.gateway, &matches); + + assert!(!super::resolve_mtls_auth_enabled( + &args, + &matches, + Some(&file) + )); + } + + #[test] + fn file_value_applies_when_cli_uses_default() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g1 = EnvVarGuard::remove("OPENSHELL_BIND_ADDRESS"); + let _g2 = EnvVarGuard::remove("OPENSHELL_SERVER_PORT"); + let _g3 = EnvVarGuard::remove("OPENSHELL_LOG_LEVEL"); + + let (mut args, matches) = + parse_with_args(&["openshell-gateway", "--db-url", "sqlite::memory:"]); + let file = config_file_from_toml( + r#" +[openshell.gateway] +bind_address = "0.0.0.0:9090" +log_level = "debug" +"#, + ); + merge_file_into_args(&mut args, &file.openshell.gateway, &matches); + assert_eq!(args.bind_address, IpAddr::V4(Ipv4Addr::UNSPECIFIED)); + assert_eq!(args.port, 9090); + assert_eq!(args.log_level, "debug"); + } + + #[test] + fn cli_flag_overrides_file_value() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g1 = EnvVarGuard::remove("OPENSHELL_BIND_ADDRESS"); + let _g2 = EnvVarGuard::remove("OPENSHELL_LOG_LEVEL"); + + let (mut args, matches) = parse_with_args(&[ + "openshell-gateway", + "--db-url", + "sqlite::memory:", + "--log-level", + "warn", + ]); + let file = config_file_from_toml( + r#" +[openshell.gateway] +log_level = "debug" +"#, + ); + merge_file_into_args(&mut args, &file.openshell.gateway, &matches); + + assert_eq!(args.log_level, "warn", "CLI flag must win over file"); + } + + #[test] + fn env_var_overrides_file_value() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g = EnvVarGuard::set("OPENSHELL_LOG_LEVEL", "trace"); + + let (mut args, matches) = + parse_with_args(&["openshell-gateway", "--db-url", "sqlite::memory:"]); + let file = config_file_from_toml( + r#" +[openshell.gateway] +log_level = "debug" +"#, + ); + merge_file_into_args(&mut args, &file.openshell.gateway, &matches); + + assert_eq!(args.log_level, "trace", "env var must win over file"); + } + + #[test] + fn file_oidc_block_populates_oidc_args() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g1 = EnvVarGuard::remove("OPENSHELL_OIDC_ISSUER"); + let _g2 = EnvVarGuard::remove("OPENSHELL_OIDC_AUDIENCE"); + + let (mut args, matches) = + parse_with_args(&["openshell-gateway", "--db-url", "sqlite::memory:"]); + let file = config_file_from_toml( + r#" +[openshell.gateway.oidc] +issuer = "https://idp.example.com" +audience = "openshell-cli" +"#, + ); + merge_file_into_args(&mut args, &file.openshell.gateway, &matches); + + assert_eq!(args.oidc_issuer.as_deref(), Some("https://idp.example.com")); + assert_eq!(args.oidc_audience, "openshell-cli"); + } + + #[test] + fn aux_listener_preserves_file_ip_against_public_bind() { + use std::net::SocketAddr; + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g = EnvVarGuard::remove("OPENSHELL_HEALTH_PORT"); + + let (_args, matches) = + parse_with_args(&["openshell-gateway", "--db-url", "sqlite::memory:"]); + let file_addr: SocketAddr = "127.0.0.1:8081".parse().unwrap(); + let resolved = super::resolve_aux_listener( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + 0, + &matches, + "health_port", + || Some(file_addr), + ); + assert_eq!( + resolved, + Some(file_addr), + "TOML health_bind_address 127.0.0.1:8081 must not be relocated to 0.0.0.0:8081" + ); + } + + #[test] + fn aux_listener_cli_port_overrides_file_addr() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g = EnvVarGuard::remove("OPENSHELL_HEALTH_PORT"); + + let (_args, matches) = parse_with_args(&[ + "openshell-gateway", + "--db-url", + "sqlite::memory:", + "--health-port", + "9999", + ]); + let file_addr: std::net::SocketAddr = "127.0.0.1:8081".parse().unwrap(); + let resolved = super::resolve_aux_listener( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + 9999, + &matches, + "health_port", + || Some(file_addr), + ); + assert_eq!( + resolved, + Some("0.0.0.0:9999".parse().unwrap()), + "CLI flag must win over file value" + ); + } + + #[test] + fn file_disable_tls_applies() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g = EnvVarGuard::remove("OPENSHELL_DISABLE_TLS"); + + let (mut args, matches) = + parse_with_args(&["openshell-gateway", "--db-url", "sqlite::memory:"]); + let file = config_file_from_toml( + r" +[openshell.gateway] +disable_tls = true +", + ); + merge_file_into_args(&mut args, &file.openshell.gateway, &matches); + + assert!(args.disable_tls); + } + + #[test] + fn file_ssh_session_ttl_secs_is_parsed() { + // The loader must accept and surface the documented key. The actual + // wiring into `Config` happens in `run_from_args` against the parsed + // file (not via `merge_file_into_args`, since there is no matching + // `RunArgs` field), so this test pins the schema half. + let file = config_file_from_toml( + r" +[openshell.gateway] +ssh_session_ttl_secs = 1234 +", + ); + assert_eq!(file.openshell.gateway.ssh_session_ttl_secs, Some(1234)); + } + + #[test] + fn file_populates_service_routing_fields() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g1 = EnvVarGuard::remove("OPENSHELL_SERVER_SAN"); + let _g2 = EnvVarGuard::remove("OPENSHELL_ENABLE_LOOPBACK_SERVICE_HTTP"); + + let (mut args, matches) = + parse_with_args(&["openshell-gateway", "--db-url", "sqlite::memory:"]); + let file = config_file_from_toml( + r#" +[openshell.gateway] +server_sans = ["gateway.local", "*.dev.openshell.localhost"] +enable_loopback_service_http = false +"#, + ); + merge_file_into_args(&mut args, &file.openshell.gateway, &matches); + + assert_eq!( + args.server_sans, + vec![ + "gateway.local".to_string(), + "*.dev.openshell.localhost".to_string() + ] + ); + assert!(!args.enable_loopback_service_http); + } + + #[test] + fn env_var_overrides_file_loopback_service_http() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g = EnvVarGuard::set("OPENSHELL_ENABLE_LOOPBACK_SERVICE_HTTP", "true"); + + let (mut args, matches) = + parse_with_args(&["openshell-gateway", "--db-url", "sqlite::memory:"]); + let file = config_file_from_toml( + r" +[openshell.gateway] +enable_loopback_service_http = false +", + ); + merge_file_into_args(&mut args, &file.openshell.gateway, &matches); + + assert!( + args.enable_loopback_service_http, + "env var must win over file" + ); + } + + #[test] + fn driver_inherits_shared_image_from_gateway_section() { + // [openshell.gateway].default_image inherits into the K8s driver + // table when the driver-specific table does not set it. + let file = config_file_from_toml( + r#" +[openshell.gateway] +default_image = "ghcr.io/nvidia/openshell/sandbox:1.0" + +[openshell.drivers.kubernetes] +namespace = "agents" +"#, + ); + let merged = crate::config_file::driver_table( + super::ComputeDriverKind::Kubernetes, + &file.openshell.gateway, + file.openshell.drivers.get("kubernetes"), + ); + let parsed = merged + .try_into::() + .expect("merged table deserializes"); + assert_eq!(parsed.default_image, "ghcr.io/nvidia/openshell/sandbox:1.0"); + assert_eq!(parsed.namespace, "agents"); + } + + #[test] + fn driver_specific_value_overrides_gateway_inheritance() { + let file = config_file_from_toml( + r#" +[openshell.gateway] +default_image = "gateway-default:1.0" + +[openshell.drivers.kubernetes] +default_image = "k8s-specific:1.0" +"#, + ); + let merged = crate::config_file::driver_table( + super::ComputeDriverKind::Kubernetes, + &file.openshell.gateway, + file.openshell.drivers.get("kubernetes"), + ); + let parsed = merged + .try_into::() + .expect("deserializes"); + assert_eq!(parsed.default_image, "k8s-specific:1.0"); } } diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 2d6351637..98dc3fd63 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -9,7 +9,7 @@ pub use openshell_driver_docker::DockerComputeConfig; pub use vm::VmComputeConfig; use crate::grpc::policy::SANDBOX_SETTINGS_OBJECT_TYPE; -use crate::persistence::{ObjectId, ObjectName, ObjectRecord, ObjectType, Store}; +use crate::persistence::{ObjectId, ObjectName, ObjectRecord, ObjectType, Store, WriteCondition}; use crate::sandbox_index::SandboxIndex; use crate::sandbox_watch::SandboxWatchBus; use crate::supervisor_session::SupervisorSessionRegistry; @@ -275,6 +275,16 @@ impl ComputeRuntime { }) } + /// Serializes sandbox object read-modify-write operations within this + /// gateway process. + /// + /// This is a temporary single-gateway guard for full-object sandbox writes. + /// It is not HA-safe; replace it with DB-backed CAS/resource-version writes + /// tracked by #1255 before enabling multiple gateway writers. + pub(crate) async fn sandbox_sync_guard(&self) -> tokio::sync::OwnedMutexGuard<()> { + self.sync_lock.clone().lock_owned().await + } + pub async fn new_docker( config: openshell_core::Config, docker_config: DockerComputeConfig, @@ -411,26 +421,47 @@ impl ComputeRuntime { .map(|_| ()) } - pub async fn create_sandbox(&self, sandbox: Sandbox) -> Result { - let existing = self - .store - .get_message_by_name::(sandbox.object_name()) - .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; - if existing.is_some() { - return Err(Status::already_exists(format!( - "sandbox '{}' already exists", - sandbox.object_name() - ))); - } + pub async fn create_sandbox( + &self, + sandbox: Sandbox, + sandbox_token: Option, + ) -> Result { + let sandbox_id = sandbox.object_id().to_string(); + // Create with MustCreate condition to prevent duplicate creation race self.sandbox_index.update_from_sandbox(&sandbox); - self.store - .put_message(&sandbox) + let mut sandbox = sandbox; + let result = self + .store + .put_if( + Sandbox::object_type(), + &sandbox_id, + sandbox.object_name(), + &sandbox.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + .map_err(|e| { + if matches!( + e, + crate::persistence::PersistenceError::UniqueViolation { .. } + ) { + Status::already_exists(format!( + "sandbox '{}' already exists", + sandbox.object_name() + )) + } else { + Status::internal(format!("persist sandbox failed: {e}")) + } + })?; - let driver_sandbox = driver_sandbox_from_public(&sandbox); + let mut driver_sandbox = driver_sandbox_from_public(&sandbox); + if let Some(token) = sandbox_token + && let Some(spec) = driver_sandbox.spec.as_mut() + { + spec.sandbox_token = token; + } match self .driver .create_sandbox(Request::new(CreateSandboxRequest { @@ -440,6 +471,9 @@ impl ComputeRuntime { { Ok(_) => { self.sandbox_watch_bus.notify(sandbox.object_id()); + if let Some(metadata) = sandbox.metadata.as_mut() { + metadata.resource_version = result.resource_version; + } Ok(sandbox) } Err(status) if status.code() == Code::AlreadyExists => { @@ -473,22 +507,31 @@ impl ComputeRuntime { } pub async fn delete_sandbox(&self, name: &str) -> Result { + // Resolve sandbox ID from name let sandbox = self .store .get_message_by_name::(name) .await .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; - let Some(mut sandbox) = sandbox else { + let Some(sandbox) = sandbox else { return Err(Status::not_found("sandbox not found")); }; let id = sandbox.object_id().to_string(); - sandbox.phase = SandboxPhase::Deleting as i32; - self.store - .put_message(&sandbox) + + // Use CAS to set phase to Deleting + // TODO: Accept expected_version from DeleteSandboxRequest for proper client-driven CAS + let sandbox = self + .store + .update_message_cas::(&id, 0, |s| { + s.phase = SandboxPhase::Deleting as i32; + }) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + .map_err(|e| { + crate::grpc::persistence_error_to_status(e, "set sandbox phase to Deleting") + })?; + self.sandbox_index.update_from_sandbox(&sandbox); self.sandbox_watch_bus.notify(&id); self.cleanup_sandbox_owned_records(&sandbox).await; @@ -633,30 +676,40 @@ impl ComputeRuntime { async fn mark_sandbox_error(&self, sandbox: &Sandbox, reason: &str, message: &str) { let _guard = self.sync_lock.lock().await; - let mut updated = sandbox.clone(); - updated.phase = SandboxPhase::Error as i32; - let updated_name = updated.object_name().to_string(); - upsert_ready_condition( - &mut updated.status, - &updated_name, - SandboxCondition { - r#type: "Ready".to_string(), - status: "False".to_string(), - reason: reason.to_string(), - message: message.to_string(), - last_transition_time: String::new(), - }, - ); - self.sandbox_index.update_from_sandbox(&updated); - if let Err(err) = self.store.put_message(&updated).await { - warn!( - sandbox_id = %sandbox.object_id(), - error = %err, - "Failed to persist sandbox error state during startup resume" - ); - return; + let sandbox_id = sandbox.object_id().to_string(); + let reason = reason.to_string(); + let message = message.to_string(); + match self + .store + .update_message_cas::(&sandbox_id, 0, |s| { + s.phase = SandboxPhase::Error as i32; + let name = s.object_name().to_string(); + upsert_ready_condition( + &mut s.status, + &name, + SandboxCondition { + r#type: "Ready".to_string(), + status: "False".to_string(), + reason: reason.clone(), + message: message.clone(), + last_transition_time: String::new(), + }, + ); + }) + .await + { + Ok(updated) => { + self.sandbox_index.update_from_sandbox(&updated); + self.sandbox_watch_bus.notify(&sandbox_id); + } + Err(err) => { + warn!( + sandbox_id = %sandbox_id, + error = %err, + "Failed to persist sandbox error state during startup resume" + ); + } } - self.sandbox_watch_bus.notify(sandbox.object_id()); } async fn watch_loop(self: Arc) { @@ -707,7 +760,7 @@ impl ComputeRuntime { } async fn reconcile_store_with_backend(&self, grace_period: Duration) -> Result<(), String> { - let sweep_started_at_ms = current_time_ms(); + let sweep_started_at_ms = openshell_core::time::now_ms(); let backend_sandboxes = self .driver .list_sandboxes(Request::new(ListSandboxesRequest {})) @@ -801,86 +854,136 @@ impl ComputeRuntime { .as_ref() .map(decode_sandbox_record) .transpose()?; - let previous = existing.clone(); - let mut status = incoming.status.as_ref().map(public_status_from_driver); - rewrite_user_facing_conditions( - &mut status, - existing.as_ref().and_then(|sandbox| sandbox.spec.as_ref()), - ); + // If no existing record, create initial sandbox (first watch event for this sandbox) + if existing.is_none() { + use crate::persistence::WriteCondition; + let now_ms = openshell_core::time::now_ms(); - let session_connected = self.supervisor_sessions.has_session(&incoming.id); - let mut phase = derive_phase(incoming.status.as_ref()); - let mut sandbox = existing.unwrap_or_else(|| { - use crate::persistence::current_time_ms; - let now_ms = current_time_ms().unwrap_or(0); - Sandbox { + let mut status = incoming.status.as_ref().map(public_status_from_driver); + rewrite_user_facing_conditions(&mut status, None); + + let session_connected = self.supervisor_sessions.has_session(&incoming.id); + let mut phase = derive_phase(incoming.status.as_ref()); + + let sandbox_name = incoming.name.clone(); + if session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) + { + ensure_supervisor_ready_status(&mut status, &sandbox_name); + phase = SandboxPhase::Ready; + } + + let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: incoming.id.clone(), - name: incoming.name.clone(), + name: sandbox_name, created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: None, - status: None, - phase: SandboxPhase::Unknown as i32, + status, + phase: phase as i32, current_policy_version: 0, - } - }); + }; - if session_connected && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) - { - ensure_supervisor_ready_status(&mut status, sandbox.object_name()); - phase = SandboxPhase::Ready; - } + self.store + .put_if( + Sandbox::object_type(), + &incoming.id, + sandbox.object_name(), + &sandbox.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) + .await + .map_err(|e| match e { + crate::persistence::PersistenceError::Conflict { + current_resource_version, + } => format!( + "concurrent modification detected during sandbox creation (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + ), + other => other.to_string(), + })?; - let old_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - if old_phase != phase { - info!( - sandbox_id = %incoming.id, - sandbox_name = %incoming.name, - old_phase = ?old_phase, - new_phase = ?phase, - "Sandbox phase changed" - ); + self.sandbox_index.update_from_sandbox(&sandbox); + self.sandbox_watch_bus.notify(sandbox.object_id()); + return Ok(()); } - if phase == SandboxPhase::Error - && let Some(ref status) = status - { - for condition in &status.conditions { - if condition.r#type == "Ready" - && condition.status.eq_ignore_ascii_case("false") - && is_terminal_failure_reason(&condition.reason) + // Single-attempt CAS: on conflict, the next watch event will naturally retry + let session_connected = self.supervisor_sessions.has_session(&incoming.id); + let sandbox_name = incoming.name.clone(); + + let sandbox = self + .store + .update_message_cas::(&incoming.id, 0, |sandbox| { + let mut status = incoming.status.as_ref().map(public_status_from_driver); + rewrite_user_facing_conditions(&mut status, sandbox.spec.as_ref()); + + let mut phase = derive_phase(incoming.status.as_ref()); + if session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) { - warn!( + ensure_supervisor_ready_status(&mut status, &sandbox_name); + phase = SandboxPhase::Ready; + } + + let old_phase = + SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + if old_phase != phase { + info!( sandbox_id = %incoming.id, - sandbox_name = %incoming.name, - reason = %condition.reason, - message = %condition.message, - "Sandbox failed to become ready" + sandbox_name = %sandbox_name, + old_phase = ?old_phase, + new_phase = ?phase, + "Sandbox phase changed" ); } - } - } - // Update metadata fields - if let Some(metadata) = sandbox.metadata.as_mut() { - metadata.name = incoming.name; - } - // Note: namespace field removed from public Sandbox API - it remains internal to DriverSandbox - sandbox.status = status; - sandbox.phase = phase as i32; + if phase == SandboxPhase::Error + && let Some(ref status) = status + { + for condition in &status.conditions { + if condition.r#type == "Ready" + && condition.status.eq_ignore_ascii_case("false") + && is_terminal_failure_reason(&condition.reason) + { + warn!( + sandbox_id = %incoming.id, + sandbox_name = %sandbox_name, + reason = %condition.reason, + message = %condition.message, + "Sandbox failed to become ready" + ); + } + } + } - if previous.as_ref() == Some(&sandbox) { - return Ok(()); - } + // Update metadata fields + if let Some(metadata) = sandbox.metadata.as_mut() { + metadata.name.clone_from(&sandbox_name); + } + // Note: namespace field removed from public Sandbox API - it remains internal to DriverSandbox + sandbox.status = status; + sandbox.phase = phase as i32; + }) + .await + .map_err(|e| match e { + crate::persistence::PersistenceError::Conflict { + current_resource_version, + } => format!( + "concurrent modification detected during sandbox reconciliation (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + ), + other => other.to_string(), + })?; self.sandbox_index.update_from_sandbox(&sandbox); - self.store - .put_message(&sandbox) - .await - .map_err(|e| e.to_string())?; self.sandbox_watch_bus.notify(sandbox.object_id()); Ok(()) } @@ -899,38 +1002,51 @@ impl ComputeRuntime { connected: bool, ) -> Result<(), String> { let _guard = self.sync_lock.lock().await; - let Some(record) = self + + // Use CAS to update sandbox phase based on supervisor session state + let result = self .store - .get(Sandbox::object_type(), sandbox_id) - .await - .map_err(|e| e.to_string())? - else { - return Ok(()); - }; + .update_message_cas::(sandbox_id, 0, |sandbox| { + let current_phase = + SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - let mut sandbox = decode_sandbox_record(&record)?; - let current_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + // Skip if sandbox is in terminal state + if current_phase == SandboxPhase::Deleting || current_phase == SandboxPhase::Error { + return; + } - if current_phase == SandboxPhase::Deleting || current_phase == SandboxPhase::Error { - return Ok(()); - } + let sandbox_name = sandbox.object_name().to_string(); + if connected { + ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.phase = SandboxPhase::Ready as i32; + } else if current_phase == SandboxPhase::Ready { + ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.phase = SandboxPhase::Provisioning as i32; + } + }) + .await; - let sandbox_name = sandbox.object_name().to_string(); - if connected { - ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.phase = SandboxPhase::Ready as i32; - } else if current_phase == SandboxPhase::Ready { - ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.phase = SandboxPhase::Provisioning as i32; - } else { - return Ok(()); - } + // Handle not found gracefully (sandbox may have been deleted) + let sandbox = match result { + Ok(s) => s, + Err(crate::persistence::PersistenceError::Database(ref msg)) + if msg.contains("not found") => + { + return Ok(()); + } + Err(crate::persistence::PersistenceError::Conflict { + current_resource_version, + }) => { + return Err(format!( + "concurrent modification detected (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + )); + } + Err(e) => return Err(e.to_string()), + }; self.sandbox_index.update_from_sandbox(&sandbox); - self.store - .put_message(&sandbox) - .await - .map_err(|e| e.to_string())?; self.sandbox_watch_bus.notify(sandbox_id); Ok(()) } @@ -1056,7 +1172,7 @@ impl ComputeRuntime { } let sandbox = decode_sandbox_record(¤t_record)?; - let age_ms = current_time_ms().saturating_sub(current_record.created_at_ms); + let age_ms = openshell_core::time::now_ms().saturating_sub(current_record.created_at_ms); if age_ms < grace_ms { return Ok(()); } @@ -1122,6 +1238,7 @@ fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { .map(driver_sandbox_template_from_public), gpu: spec.gpu, gpu_device: spec.gpu_device.clone(), + sandbox_token: String::new(), } } @@ -1230,6 +1347,19 @@ fn build_platform_config(template: &SandboxTemplate) -> Option ComputeError { } } -fn current_time_ms() -> i64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis() - .try_into() - .unwrap_or(i64::MAX) -} - fn decode_sandbox_record(record: &ObjectRecord) -> Result { Sandbox::decode(record.payload.as_slice()).map_err(|e| e.to_string()) } @@ -1503,6 +1624,7 @@ fn is_terminal_failure_reason(reason: &str) -> bool { "dependenciesnotready", "starting", "containerstarting", + "containercreated", "healthcheckstarting", "inspectfailed", ]; @@ -1807,6 +1929,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), phase: phase as i32, ..Default::default() @@ -1820,6 +1943,7 @@ mod tests { name: format!("session-{id}"), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), sandbox_id: sandbox_id.to_string(), token: format!("token-{id}"), @@ -1883,6 +2007,10 @@ mod tests { ), ("dependenciesnotready", "lowercase also works"), ("Starting", "VM is starting"), + ( + "ContainerCreated", + "Podman created the container before starting it", + ), ]; for (reason, message) in transient_cases { @@ -1920,6 +2048,10 @@ mod tests { "Pod exists with phase: Pending; Service Exists", ), ("Starting", "VM is starting"), + ( + "ContainerCreated", + "Container exists but has not started yet", + ), ]; for (reason, message) in transient_conditions { @@ -2692,4 +2824,147 @@ mod tests { SandboxPhase::Ready ); } + + #[test] + fn build_platform_config_inverts_user_namespaces_to_host_users() { + use prost_types::value::Kind; + + // user_namespaces: true → host_users: false + let mut template = SandboxTemplate { + user_namespaces: Some(true), + ..SandboxTemplate::default() + }; + let config = build_platform_config(&template).expect("config should be Some"); + let host_users = config + .fields + .get("host_users") + .expect("host_users must exist"); + assert_eq!( + host_users.kind, + Some(Kind::BoolValue(false)), + "user_namespaces: true must produce host_users: false" + ); + + // user_namespaces: false → host_users: true + template.user_namespaces = Some(false); + let config = build_platform_config(&template).expect("config should be Some"); + let host_users = config + .fields + .get("host_users") + .expect("host_users must exist"); + assert_eq!( + host_users.kind, + Some(Kind::BoolValue(true)), + "user_namespaces: false must produce host_users: true" + ); + + // user_namespaces: None → host_users absent + template.user_namespaces = None; + let config = build_platform_config(&template); + assert!( + config.is_none() || !config.as_ref().unwrap().fields.contains_key("host_users"), + "unset user_namespaces must not produce host_users" + ); + } + + #[tokio::test] + async fn create_sandbox_returns_resource_version_one() { + let runtime = test_runtime(Arc::new(TestDriver::default())).await; + + let mut sandbox = sandbox_record("sb-new", "test-sandbox", SandboxPhase::Provisioning); + // Clear metadata to simulate incoming request + sandbox.metadata = Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-new".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }); + + let created = runtime.create_sandbox(sandbox, None).await.unwrap(); + + assert_eq!( + created.metadata.as_ref().unwrap().resource_version, + 1, + "create_sandbox should return resource_version: 1 after insert" + ); + + // Verify database also has resource_version: 1 + let created_id = created.metadata.as_ref().unwrap().id.clone(); + let stored = runtime + .store + .get_message::(&created_id) + .await + .unwrap() + .unwrap(); + assert_eq!( + stored.metadata.as_ref().unwrap().resource_version, + 1, + "database should have resource_version: 1 after create" + ); + } + + #[tokio::test] + async fn concurrent_create_sandbox_rejects_duplicate() { + let runtime = Arc::new(test_runtime(Arc::new(TestDriver::default())).await); + + let sandbox = sandbox_record( + "sb-concurrent", + "test-concurrent", + SandboxPhase::Provisioning, + ); + + // Spawn two concurrent creation attempts for the same sandbox + let runtime1 = runtime.clone(); + let sandbox1 = sandbox.clone(); + let handle1 = tokio::spawn(async move { runtime1.create_sandbox(sandbox1, None).await }); + + let runtime2 = runtime.clone(); + let sandbox2 = sandbox.clone(); + let handle2 = tokio::spawn(async move { runtime2.create_sandbox(sandbox2, None).await }); + + // Wait for both to complete + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // Exactly one should succeed, one should fail with AlreadyExists + let success_count = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + let already_exists_count = [&result1, &result2] + .iter() + .filter(|r| { + r.as_ref() + .err() + .is_some_and(|e| e.code() == Code::AlreadyExists) + }) + .count(); + + assert_eq!( + success_count, 1, + "exactly one creation should succeed, got results: {result1:?} {result2:?}" + ); + assert_eq!( + already_exists_count, 1, + "exactly one creation should fail with AlreadyExists, got results: {result1:?} {result2:?}" + ); + + // Verify the successful sandbox can be retrieved by name + let created_sandbox = [result1, result2] + .into_iter() + .find_map(Result::ok) + .expect("should have one successful creation"); + let retrieved = runtime + .store + .get_message_by_name::("test-concurrent") + .await + .unwrap(); + assert!( + retrieved.is_some(), + "created sandbox should be retrievable by name" + ); + assert_eq!( + retrieved.unwrap().object_id(), + created_sandbox.object_id(), + "retrieved sandbox should match created sandbox" + ); + } } diff --git a/crates/openshell-server/src/compute/vm.rs b/crates/openshell-server/src/compute/vm.rs index e5b974f74..efdc9daab 100644 --- a/crates/openshell-server/src/compute/vm.rs +++ b/crates/openshell-server/src/compute/vm.rs @@ -38,6 +38,10 @@ use openshell_core::proto::compute::v1::{ GetCapabilitiesRequest, compute_driver_client::ComputeDriverClient, }; use openshell_core::{Config, Error, Result}; +#[cfg(unix)] +use std::os::unix::fs::{FileTypeExt, MetadataExt, PermissionsExt}; +#[cfg(unix)] +use std::path::Path; use std::path::PathBuf; #[cfg(unix)] use std::{io::ErrorKind, process::Stdio, sync::Arc, time::Duration}; @@ -52,9 +56,12 @@ use tonic::transport::Endpoint; use tower::service_fn; const DRIVER_BIN_NAME: &str = "openshell-driver-vm"; +const COMPUTE_DRIVER_SOCKET_RUN_DIR: &str = "run"; +const COMPUTE_DRIVER_SOCKET_NAME: &str = "compute-driver.sock"; /// Configuration for launching and talking to the VM compute driver. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(default, deny_unknown_fields)] pub struct VmComputeConfig { /// Working directory for VM driver sandbox state. pub state_dir: PathBuf, @@ -66,6 +73,12 @@ pub struct VmComputeConfig { /// Default sandbox image the driver should use when a request omits one. pub default_image: String, + /// Gateway gRPC endpoint the sandbox guest connects back to. + pub grpc_endpoint: String, + + /// Bootstrap image used to boot and prepare VM sandbox target images. + pub bootstrap_image: String, + /// libkrun log level used by the VM driver helper. pub krun_log_level: u32, @@ -75,6 +88,9 @@ pub struct VmComputeConfig { /// Default memory allocation for VM sandboxes, in MiB. pub mem_mib: u32, + /// Writable overlay disk size for each VM sandbox, in MiB. + pub overlay_disk_mib: u64, + /// Host-side CA certificate for the guest's mTLS client bundle. pub guest_tls_ca: Option, @@ -89,7 +105,10 @@ impl VmComputeConfig { /// Default working directory for VM driver state. #[must_use] pub fn default_state_dir() -> PathBuf { - PathBuf::from("target/openshell-vm-driver") + openshell_core::paths::openshell_state_dir().map_or_else( + |_| PathBuf::from("target/openshell-vm-driver"), + |dir| dir.join("vm-driver"), + ) } /// Default libkrun log level. @@ -110,6 +129,12 @@ impl VmComputeConfig { 2048 } + /// Default writable overlay disk size, in MiB. + #[must_use] + pub const fn default_overlay_disk_mib() -> u64 { + 4096 + } + #[must_use] fn default_driver_search_dirs(home: Option) -> Vec { let mut dirs = Vec::new(); @@ -128,10 +153,13 @@ impl Default for VmComputeConfig { Self { state_dir: Self::default_state_dir(), driver_dir: None, - default_image: String::new(), + default_image: openshell_core::image::default_sandbox_image(), + grpc_endpoint: String::new(), + bootstrap_image: String::new(), krun_log_level: Self::default_krun_log_level(), vcpus: Self::default_vcpus(), mem_mib: Self::default_mem_mib(), + overlay_disk_mib: Self::default_overlay_disk_mib(), guest_tls_ca: None, guest_tls_cert: None, guest_tls_key: None, @@ -151,7 +179,7 @@ pub struct VmGuestTlsPaths { /// /// Resolution order: /// 1. `{driver_dir}/openshell-driver-vm`, where `driver_dir` comes from -/// `--driver-dir` / `OPENSHELL_DRIVER_DIR`. +/// `[openshell.drivers.vm].driver_dir`. /// 2. Conventional install directories: /// `~/.local/libexec/openshell`, `/usr/libexec/openshell`, /// `/usr/local/libexec/openshell`, `/usr/local/libexec`. @@ -191,13 +219,27 @@ pub fn resolve_compute_driver_bin(vm_config: &VmComputeConfig) -> Result>() .join(", "); Err(Error::config(format!( - "vm compute driver binary not found (searched {searched_display}); install it under --driver-dir / OPENSHELL_DRIVER_DIR, a conventional libexec path such as ~/.local/libexec/openshell, /usr/libexec/openshell, or /usr/local/libexec{{,/openshell}}, or place it next to the gateway binary" + "vm compute driver binary not found (searched {searched_display}); install it under [openshell.drivers.vm].driver_dir, a conventional libexec path such as ~/.local/libexec/openshell, /usr/libexec/openshell, or /usr/local/libexec{{,/openshell}}, or place it next to the gateway binary" ))) } fn resolve_driver_search_dirs(vm_config: &VmComputeConfig) -> Vec { vm_config.driver_dir.clone().map_or_else( - || VmComputeConfig::default_driver_search_dirs(std::env::var_os("HOME").map(PathBuf::from)), + || { + let mut dirs = Vec::new(); + if let Ok(current_exe) = std::env::current_exe() + && let Some(prefix) = current_exe.parent().and_then(Path::parent) + { + push_unique_path(&mut dirs, prefix.join("libexec")); + push_unique_path(&mut dirs, prefix.join("libexec").join("openshell")); + } + for dir in VmComputeConfig::default_driver_search_dirs( + std::env::var_os("HOME").map(PathBuf::from), + ) { + push_unique_path(&mut dirs, dir); + } + dirs + }, |dir| vec![dir], ) } @@ -210,15 +252,156 @@ fn push_unique_path(paths: &mut Vec, path: PathBuf) { /// Path of the Unix domain socket the driver will listen on. pub fn compute_driver_socket_path(vm_config: &VmComputeConfig) -> PathBuf { - vm_config.state_dir.join("compute-driver.sock") + vm_config + .state_dir + .join(COMPUTE_DRIVER_SOCKET_RUN_DIR) + .join(COMPUTE_DRIVER_SOCKET_NAME) +} + +#[cfg(unix)] +fn prepare_compute_driver_socket_path( + vm_config: &VmComputeConfig, + socket_path: &Path, +) -> Result<()> { + let expected_uid = current_euid(); + prepare_vm_state_dir(&vm_config.state_dir, expected_uid)?; + let parent = socket_path.parent().ok_or_else(|| { + Error::execution(format!( + "vm compute driver socket path '{}' has no parent directory", + socket_path.display() + )) + })?; + prepare_private_socket_dir(parent, expected_uid)?; + remove_stale_socket(socket_path, expected_uid) +} + +#[cfg(unix)] +fn current_euid() -> u32 { + rustix::process::geteuid().as_raw() +} + +#[cfg(unix)] +fn prepare_vm_state_dir(state_dir: &Path, expected_uid: u32) -> Result<()> { + std::fs::create_dir_all(state_dir).map_err(|err| { + Error::execution(format!( + "failed to create vm driver state dir '{}': {err}", + state_dir.display() + )) + })?; + let metadata = checked_directory_metadata(state_dir, expected_uid, "vm driver state dir")?; + let mode = metadata.permissions().mode() & 0o777; + if mode != 0o700 { + std::fs::set_permissions(state_dir, std::fs::Permissions::from_mode(0o700)).map_err( + |err| { + Error::execution(format!( + "failed to restrict vm driver state dir '{}': {err}", + state_dir.display() + )) + }, + )?; + } + Ok(()) +} + +#[cfg(unix)] +fn prepare_private_socket_dir(socket_dir: &Path, expected_uid: u32) -> Result<()> { + std::fs::create_dir_all(socket_dir).map_err(|err| { + Error::execution(format!( + "failed to create vm compute driver socket dir '{}': {err}", + socket_dir.display() + )) + })?; + let _ = checked_directory_metadata(socket_dir, expected_uid, "vm compute driver socket dir")?; + std::fs::set_permissions(socket_dir, std::fs::Permissions::from_mode(0o700)).map_err(|err| { + Error::execution(format!( + "failed to restrict vm compute driver socket dir '{}': {err}", + socket_dir.display() + )) + }) +} + +#[cfg(unix)] +fn checked_directory_metadata( + path: &Path, + expected_uid: u32, + label: &str, +) -> Result { + let metadata = std::fs::symlink_metadata(path).map_err(|err| { + Error::execution(format!( + "failed to stat {label} '{}': {err}", + path.display() + )) + })?; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Error::execution(format!( + "{label} '{}' is a symlink; refusing to use it", + path.display() + ))); + } + if !file_type.is_dir() { + return Err(Error::execution(format!( + "{label} '{}' is not a directory", + path.display() + ))); + } + if metadata.uid() != expected_uid { + return Err(Error::execution(format!( + "{label} '{}' is owned by uid {} but current euid is {}", + path.display(), + metadata.uid(), + expected_uid + ))); + } + Ok(metadata) +} + +#[cfg(unix)] +fn remove_stale_socket(socket_path: &Path, expected_uid: u32) -> Result<()> { + let metadata = match std::fs::symlink_metadata(socket_path) { + Ok(metadata) => metadata, + Err(err) if err.kind() == ErrorKind::NotFound => return Ok(()), + Err(err) => { + return Err(Error::execution(format!( + "failed to stat vm compute driver socket '{}': {err}", + socket_path.display() + ))); + } + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Error::execution(format!( + "vm compute driver socket '{}' is a symlink; refusing to remove it", + socket_path.display() + ))); + } + if metadata.uid() != expected_uid { + return Err(Error::execution(format!( + "vm compute driver socket '{}' is owned by uid {} but current euid is {}", + socket_path.display(), + metadata.uid(), + expected_uid + ))); + } + if !file_type.is_socket() { + return Err(Error::execution(format!( + "vm compute driver socket path '{}' exists but is not a Unix socket", + socket_path.display() + ))); + } + std::fs::remove_file(socket_path).map_err(|err| { + Error::execution(format!( + "failed to remove stale vm compute driver socket '{}': {err}", + socket_path.display() + )) + }) } #[cfg(unix)] pub fn compute_driver_guest_tls_paths( - config: &Config, vm_config: &VmComputeConfig, ) -> Result> { - if !config.grpc_endpoint.starts_with("https://") { + if !vm_config.grpc_endpoint.starts_with("https://") { return Ok(None); } @@ -229,23 +412,23 @@ pub fn compute_driver_guest_tls_paths( ]; if provided.iter().all(Option::is_none) { return Err(Error::config( - "vm compute driver requires --vm-tls-ca, --vm-tls-cert, and --vm-tls-key when OPENSHELL_GRPC_ENDPOINT uses https://", + "vm compute driver requires guest_tls_ca, guest_tls_cert, and guest_tls_key when grpc_endpoint uses https://", )); } let Some(ca) = vm_config.guest_tls_ca.clone() else { return Err(Error::config( - "--vm-tls-ca is required when VM guest TLS materials are configured", + "guest_tls_ca is required when VM guest TLS materials are configured", )); }; let Some(cert) = vm_config.guest_tls_cert.clone() else { return Err(Error::config( - "--vm-tls-cert is required when VM guest TLS materials are configured", + "guest_tls_cert is required when VM guest TLS materials are configured", )); }; let Some(key) = vm_config.guest_tls_key.clone() else { return Err(Error::config( - "--vm-tls-key is required when VM guest TLS materials are configured", + "guest_tls_key is required when VM guest TLS materials are configured", )); }; @@ -269,7 +452,7 @@ pub async fn spawn( config: &Config, vm_config: &VmComputeConfig, ) -> Result<(Channel, Arc)> { - if config.grpc_endpoint.trim().is_empty() { + if vm_config.grpc_endpoint.trim().is_empty() { return Err(Error::config( "grpc_endpoint is required when using the vm compute driver", )); @@ -277,25 +460,8 @@ pub async fn spawn( let driver_bin = resolve_compute_driver_bin(vm_config)?; let socket_path = compute_driver_socket_path(vm_config); - let guest_tls_paths = compute_driver_guest_tls_paths(config, vm_config)?; - if let Some(parent) = socket_path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - Error::execution(format!( - "failed to create vm compute driver socket dir '{}': {e}", - parent.display() - )) - })?; - } - match std::fs::remove_file(&socket_path) { - Ok(()) => {} - Err(err) if err.kind() == ErrorKind::NotFound => {} - Err(err) => { - return Err(Error::execution(format!( - "failed to remove stale vm compute driver socket '{}': {err}", - socket_path.display() - ))); - } - } + let guest_tls_paths = compute_driver_guest_tls_paths(vm_config)?; + prepare_compute_driver_socket_path(vm_config, &socket_path)?; let mut command = Command::new(&driver_bin); command.kill_on_drop(true); @@ -303,30 +469,30 @@ pub async fn spawn( command.stdout(Stdio::inherit()); command.stderr(Stdio::inherit()); command.arg("--bind-socket").arg(&socket_path); + command + .arg("--expected-peer-pid") + .arg(std::process::id().to_string()); command.arg("--log-level").arg(&config.log_level); command .arg("--openshell-endpoint") - .arg(&config.grpc_endpoint); + .arg(&vm_config.grpc_endpoint); command.arg("--state-dir").arg(&vm_config.state_dir); if !vm_config.default_image.trim().is_empty() { command.arg("--default-image").arg(&vm_config.default_image); } - // Only forward the handshake secret when one is configured. The VM - // driver does not consume it, but accepts it for parity with the - // Kubernetes/Podman drivers; passing an empty value is noise. - if !config.ssh_handshake_secret.is_empty() { + if !vm_config.bootstrap_image.trim().is_empty() { command - .arg("--ssh-handshake-secret") - .arg(&config.ssh_handshake_secret); + .arg("--bootstrap-image") + .arg(&vm_config.bootstrap_image); } - command - .arg("--ssh-handshake-skew-secs") - .arg(config.ssh_handshake_skew_secs.to_string()); command .arg("--krun-log-level") .arg(vm_config.krun_log_level.to_string()); command.arg("--vcpus").arg(vm_config.vcpus.to_string()); command.arg("--mem-mib").arg(vm_config.mem_mib.to_string()); + command + .arg("--overlay-disk-mib") + .arg(vm_config.overlay_disk_mib.to_string()); if let Some(tls) = guest_tls_paths { command.arg("--guest-tls-ca").arg(tls.ca); command.arg("--guest-tls-cert").arg(tls.cert); @@ -356,7 +522,7 @@ pub async fn spawn( #[cfg(unix)] async fn wait_for_compute_driver( - socket_path: &std::path::Path, + socket_path: &Path, child: &mut tokio::process::Child, ) -> Result { let mut last_error: Option = None; @@ -395,7 +561,7 @@ async fn wait_for_compute_driver( } #[cfg(unix)] -async fn connect_compute_driver(socket_path: &std::path::Path) -> Result { +async fn connect_compute_driver(socket_path: &Path) -> Result { let socket_path = socket_path.to_path_buf(); let display_path = socket_path.clone(); Endpoint::from_static("http://[::]:50051") @@ -415,11 +581,12 @@ async fn connect_compute_driver(socket_path: &std::path::Path) -> Result]` table so each driver crate's +//! `Deserialize` impl sees a fully-populated table. +//! +//! The merge precedence for gateway process settings is: +//! ```text +//! CLI flag > OPENSHELL_* env var > TOML file > built-in default +//! ``` +//! Driver implementation settings are configured in the TOML driver tables. +//! Per-field application of gateway file values happens in [`crate::cli`], +//! which uses clap's `ArgMatches::value_source` to detect arguments that fell +//! back to their default and are therefore eligible for replacement by file +//! values. + +use std::collections::BTreeMap; +use std::net::SocketAddr; +use std::path::{Path, PathBuf}; + +use openshell_core::config::ComputeDriverKind; +use openshell_core::{GatewayAuthConfig, GatewayJwtConfig, MtlsAuthConfig, OidcConfig, TlsConfig}; +use serde::{Deserialize, Serialize}; + +/// Latest schema version this build understands. +pub const SCHEMA_VERSION: u32 = 1; + +/// Root of the gateway TOML config file. +/// +/// The file is rooted at `[openshell]` to reserve room for future components +/// (CLI, sandbox, router) to share a single config file without key +/// collisions. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ConfigFile { + #[serde(default)] + pub openshell: OpenShellRoot, +} + +/// `[openshell]` table. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct OpenShellRoot { + /// Reserved for future schema migrations. Versions greater than + /// [`SCHEMA_VERSION`] are rejected at load time. + #[serde(default)] + pub version: Option, + + #[serde(default)] + pub gateway: GatewayFileSection, + + /// `[openshell.drivers.]` tables — passed verbatim to each driver + /// crate's `Deserialize` impl after the gateway-side inheritance merge. + /// Stored as raw [`toml::Value`] so each driver can evolve its schema + /// independently of this crate. + #[serde(default)] + pub drivers: BTreeMap, +} + +/// `[openshell.gateway]` section. +/// +/// All fields are `Option` so the loader can tell whether a key was set +/// in the file (`Some`) or not (`None` — value is taken from CLI/env/default). +/// +/// The fields under "Shared driver defaults" are inherited into +/// `[openshell.drivers.]` tables per [`inheritable_keys`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct GatewayFileSection { + // ── Listeners ──────────────────────────────────────────────────────── + #[serde(default)] + pub bind_address: Option, + #[serde(default)] + pub health_bind_address: Option, + #[serde(default)] + pub metrics_bind_address: Option, + + // ── Logging ────────────────────────────────────────────────────────── + #[serde(default)] + pub log_level: Option, + + // ── Drivers ────────────────────────────────────────────────────────── + #[serde(default)] + pub compute_drivers: Option>, + + // ── Sandbox / SSH ──────────────────────────────────────────────────── + #[serde(default)] + pub sandbox_namespace: Option, + #[serde(default)] + pub ssh_session_ttl_secs: Option, + + // ── Service routing ────────────────────────────────────────────────── + /// Subject Alternative Names configured on the gateway server certificate. + /// Wildcard DNS SANs also enable sandbox service URLs under that domain. + #[serde(default)] + pub server_sans: Option>, + /// Enable plaintext HTTP routing for loopback sandbox service URLs. + #[serde(default)] + pub enable_loopback_service_http: Option, + + // ── Shared driver defaults (inherited into [openshell.drivers.]) ─ + #[serde(default)] + pub default_image: Option, + #[serde(default)] + pub supervisor_image: Option, + #[serde(default)] + pub client_tls_secret_name: Option, + #[serde(default)] + pub service_account_name: Option, + #[serde(default)] + pub host_gateway_ip: Option, + #[serde(default)] + pub enable_user_namespaces: Option, + /// Lifetime (seconds) of the projected `ServiceAccount` token kubelet + /// writes for the `IssueSandboxToken` bootstrap exchange. Driver + /// clamps to `[600, 86400]`. + #[serde(default)] + pub sa_token_ttl_secs: Option, + #[serde(default)] + pub guest_tls_ca: Option, + #[serde(default)] + pub guest_tls_cert: Option, + #[serde(default)] + pub guest_tls_key: Option, + + // ── TLS toggle ─────────────────────────────────────────────────────── + /// When `true`, the gateway listens on plaintext HTTP and ignores any + /// `[openshell.gateway.tls]` table. Mirrors `--disable-tls`. + #[serde(default)] + pub disable_tls: Option, + + // ── Nested tables ──────────────────────────────────────────────────── + #[serde(default)] + pub tls: Option, + #[serde(default)] + pub oidc: Option, + #[serde(default)] + pub auth: Option, + #[serde(default)] + pub mtls_auth: Option, + #[serde(default)] + pub gateway_jwt: Option, + + // ── Disallowed-in-file fields ──────────────────────────────────────── + // + // Captured so we can produce a friendly "set this via env/CLI instead" + // error rather than a generic "unknown field" message. Validated and + // rejected in [`load`]. + #[serde(default)] + pub database_url: Option, +} + +#[derive(Debug, thiserror::Error)] +pub enum ConfigFileError { + #[error("failed to read gateway config file '{}': {source}", path.display())] + Io { + path: PathBuf, + #[source] + source: std::io::Error, + }, + #[error("failed to parse gateway config file '{}': {source}", path.display())] + Parse { + path: PathBuf, + #[source] + source: toml::de::Error, + }, + #[error( + "unsupported gateway config version {version}; this build only supports version {SCHEMA_VERSION}" + )] + UnsupportedVersion { version: u32 }, + #[error( + "`{field}` is not allowed in the gateway config file — set the {env} env var or pass {cli} on the command line" + )] + SecretInFile { + field: &'static str, + env: &'static str, + cli: &'static str, + }, +} + +/// Load and validate a TOML config file. +/// +/// Returns `Ok(ConfigFile::default())` for an empty file (the gateway then +/// falls back entirely to CLI/env/built-in defaults). +pub fn load(path: &Path) -> Result { + let contents = std::fs::read_to_string(path).map_err(|source| ConfigFileError::Io { + path: path.to_path_buf(), + source, + })?; + if contents.trim().is_empty() { + return Ok(ConfigFile::default()); + } + let file: ConfigFile = toml::from_str(&contents).map_err(|source| ConfigFileError::Parse { + path: path.to_path_buf(), + source, + })?; + + if let Some(version) = file.openshell.version + && version > SCHEMA_VERSION + { + return Err(ConfigFileError::UnsupportedVersion { version }); + } + + if file.openshell.gateway.database_url.is_some() { + return Err(ConfigFileError::SecretInFile { + field: "database_url", + env: "OPENSHELL_DB_URL", + cli: "--db-url", + }); + } + + Ok(file) +} + +/// Build the merged TOML table for `driver` by overlaying inheritable +/// `[openshell.gateway]` defaults onto `[openshell.drivers.]`. +/// +/// The returned [`toml::Value`] is a Table ready to feed into the driver's +/// `Deserialize` impl — keys present in `raw` win over the gateway defaults. +/// Keys outside [`inheritable_keys`] for this driver are never copied from +/// the gateway section, which keeps each driver's `deny_unknown_fields` +/// invariant intact. +pub fn driver_table( + driver: ComputeDriverKind, + gateway: &GatewayFileSection, + raw: Option<&toml::Value>, +) -> toml::Value { + let mut merged = match raw { + Some(toml::Value::Table(table)) => table.clone(), + _ => toml::Table::new(), + }; + + for key in inheritable_keys(driver) { + if merged.contains_key(*key) { + continue; + } + if let Some(value) = gateway_inherited_value(gateway, key) { + merged.insert((*key).to_string(), value); + } + } + + toml::Value::Table(merged) +} + +/// Inheritance allowlist (the Q4 "high-overlap set"). Each driver opts in +/// to a specific subset so a gateway-wide default does not accidentally land +/// in a driver table that does not understand the field. +fn inheritable_keys(driver: ComputeDriverKind) -> &'static [&'static str] { + match driver { + ComputeDriverKind::Kubernetes => &[ + "namespace", + "default_image", + "supervisor_image", + "client_tls_secret_name", + "service_account_name", + "host_gateway_ip", + "enable_user_namespaces", + "sa_token_ttl_secs", + ], + ComputeDriverKind::Docker => &[ + "sandbox_namespace", + "default_image", + "supervisor_image", + "host_gateway_ip", + "guest_tls_ca", + "guest_tls_cert", + "guest_tls_key", + ], + ComputeDriverKind::Podman => &[ + "default_image", + "supervisor_image", + "guest_tls_ca", + "guest_tls_cert", + "guest_tls_key", + ], + ComputeDriverKind::Vm => &[ + "default_image", + "guest_tls_ca", + "guest_tls_cert", + "guest_tls_key", + ], + } +} + +fn gateway_inherited_value(g: &GatewayFileSection, key: &str) -> Option { + match key { + "namespace" | "sandbox_namespace" => g.sandbox_namespace.as_deref().map(string_value), + "default_image" => g.default_image.as_deref().map(string_value), + "supervisor_image" => g.supervisor_image.as_deref().map(string_value), + "client_tls_secret_name" => g.client_tls_secret_name.as_deref().map(string_value), + "service_account_name" => g.service_account_name.as_deref().map(string_value), + "host_gateway_ip" => g.host_gateway_ip.as_deref().map(string_value), + "enable_user_namespaces" => g.enable_user_namespaces.map(toml::Value::Boolean), + "sa_token_ttl_secs" => g.sa_token_ttl_secs.map(toml::Value::Integer), + "guest_tls_ca" => g.guest_tls_ca.as_deref().map(path_value), + "guest_tls_cert" => g.guest_tls_cert.as_deref().map(path_value), + "guest_tls_key" => g.guest_tls_key.as_deref().map(path_value), + _ => None, + } +} + +fn string_value(s: &str) -> toml::Value { + toml::Value::String(s.to_owned()) +} + +fn path_value(p: &Path) -> toml::Value { + toml::Value::String(p.display().to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + + fn write_tmp(contents: &str) -> tempfile::NamedTempFile { + let mut tmp = tempfile::Builder::new() + .suffix(".toml") + .tempfile() + .expect("tempfile"); + tmp.write_all(contents.as_bytes()).expect("write"); + tmp + } + + #[test] + fn empty_file_yields_default_config() { + let tmp = write_tmp(""); + let file = load(tmp.path()).expect("empty file parses"); + assert!(file.openshell.version.is_none()); + assert!(file.openshell.gateway.bind_address.is_none()); + assert!(file.openshell.drivers.is_empty()); + } + + #[test] + fn parses_full_example() { + let toml = r#" +[openshell] +version = 1 + +[openshell.gateway] +bind_address = "0.0.0.0:8080" +health_bind_address = "0.0.0.0:8081" +log_level = "info" +compute_drivers = ["kubernetes"] +sandbox_namespace = "agents" +default_image = "ghcr.io/nvidia/openshell/sandbox:latest" +supervisor_image = "ghcr.io/nvidia/openshell/supervisor:latest" +client_tls_secret_name = "openshell-sandbox-tls" +service_account_name = "openshell-sandbox" + +[openshell.gateway.tls] +cert_path = "/etc/openshell/certs/gateway.pem" +key_path = "/etc/openshell/certs/gateway-key.pem" +client_ca_path = "/etc/openshell/certs/client-ca.pem" + +[openshell.gateway.oidc] +issuer = "https://idp.example.com/realms/openshell" +audience = "openshell-cli" + +[openshell.drivers.kubernetes] +namespace = "agents" +grpc_endpoint = "https://openshell-gateway.agents.svc:8080" +"#; + let tmp = write_tmp(toml); + let file = load(tmp.path()).expect("valid file parses"); + let gw = &file.openshell.gateway; + assert_eq!(gw.log_level.as_deref(), Some("info")); + assert_eq!( + gw.default_image.as_deref(), + Some("ghcr.io/nvidia/openshell/sandbox:latest") + ); + assert!(gw.tls.is_some()); + assert!(gw.oidc.is_some()); + assert!(file.openshell.drivers.contains_key("kubernetes")); + } + + #[test] + fn parses_gateway_auth_config() { + let toml = r" +[openshell.gateway.auth] +allow_unauthenticated_users = true +"; + let tmp = write_tmp(toml); + let file = load(tmp.path()).expect("valid auth config parses"); + let auth = file.openshell.gateway.auth.expect("auth config"); + assert!(auth.allow_unauthenticated_users); + } + + #[test] + fn rejects_database_url_in_file() { + let toml = r#" +[openshell.gateway] +database_url = "sqlite::memory:" +"#; + let tmp = write_tmp(toml); + let err = load(tmp.path()).expect_err("database_url must be rejected"); + assert!(matches!( + err, + ConfigFileError::SecretInFile { + field: "database_url", + .. + } + )); + } + + #[test] + fn rejects_unknown_gateway_field() { + let toml = r" +[openshell.gateway] +nonsense = true +"; + let tmp = write_tmp(toml); + let err = load(tmp.path()).expect_err("unknown field must be rejected"); + assert!(matches!(err, ConfigFileError::Parse { .. })); + } + + #[test] + fn rejects_removed_ssh_endpoint_fields() { + let toml = r" +[openshell.gateway] +ssh_gateway_port = 8080 +"; + let tmp = write_tmp(toml); + let err = load(tmp.path()).expect_err("removed SSH endpoint keys must be rejected"); + assert!(matches!(err, ConfigFileError::Parse { .. })); + } + + #[test] + fn rejects_unsupported_version() { + let toml = r" +[openshell] +version = 2 +"; + let tmp = write_tmp(toml); + let err = load(tmp.path()).expect_err("version > 1 must be rejected"); + assert!(matches!( + err, + ConfigFileError::UnsupportedVersion { version: 2 } + )); + } + + #[test] + fn driver_table_inherits_gateway_defaults() { + let gateway = GatewayFileSection { + default_image: Some("ghcr.io/nvidia/openshell/sandbox:0.9".to_string()), + supervisor_image: Some("ghcr.io/nvidia/openshell/supervisor:0.9".to_string()), + ..Default::default() + }; + let raw = toml::toml! { + namespace = "agents" + }; + let merged = driver_table( + ComputeDriverKind::Kubernetes, + &gateway, + Some(&toml::Value::Table(raw)), + ); + let table = merged.as_table().expect("table"); + assert_eq!( + table.get("namespace").and_then(|v| v.as_str()), + Some("agents") + ); + assert_eq!( + table.get("default_image").and_then(|v| v.as_str()), + Some("ghcr.io/nvidia/openshell/sandbox:0.9") + ); + assert_eq!( + table.get("supervisor_image").and_then(|v| v.as_str()), + Some("ghcr.io/nvidia/openshell/supervisor:0.9") + ); + } + + #[test] + fn docker_driver_table_inherits_gateway_defaults() { + let gateway = GatewayFileSection { + sandbox_namespace: Some("agents".to_string()), + default_image: Some("ghcr.io/nvidia/openshell/sandbox:0.9".to_string()), + host_gateway_ip: Some("10.0.0.1".to_string()), + ..Default::default() + }; + let merged = driver_table(ComputeDriverKind::Docker, &gateway, None); + let table = merged.as_table().expect("table"); + assert_eq!( + table.get("sandbox_namespace").and_then(|v| v.as_str()), + Some("agents") + ); + assert_eq!( + table.get("default_image").and_then(|v| v.as_str()), + Some("ghcr.io/nvidia/openshell/sandbox:0.9") + ); + assert_eq!( + table.get("host_gateway_ip").and_then(|v| v.as_str()), + Some("10.0.0.1") + ); + } + + #[test] + fn driver_table_specific_value_overrides_gateway_default() { + let gateway = GatewayFileSection { + default_image: Some("gateway-default".to_string()), + ..Default::default() + }; + let raw = toml::toml! { + default_image = "driver-specific" + }; + let merged = driver_table( + ComputeDriverKind::Podman, + &gateway, + Some(&toml::Value::Table(raw)), + ); + assert_eq!( + merged + .as_table() + .unwrap() + .get("default_image") + .and_then(|v| v.as_str()), + Some("driver-specific") + ); + } + + #[test] + fn driver_table_does_not_leak_keys_outside_allowlist() { + // `client_tls_secret_name` is K8s-only; Docker must not receive it + // even when set at gateway scope. + let gateway = GatewayFileSection { + client_tls_secret_name: Some("openshell-sandbox-tls".to_string()), + ..Default::default() + }; + let merged = driver_table(ComputeDriverKind::Docker, &gateway, None); + assert!( + !merged + .as_table() + .unwrap() + .contains_key("client_tls_secret_name") + ); + } + + #[test] + fn missing_path_is_io_error() { + let err = load(Path::new("/nonexistent/openshell-gateway.toml")) + .expect_err("missing file must be io error"); + assert!(matches!(err, ConfigFileError::Io { .. })); + } + + /// Contract test: the RPM default config template must parse against the + /// current schema and must pin the settings that Podman deployments require. + /// + /// This test loads `deploy/rpm/gateway.toml.default` through the same + /// `load()` path that the gateway uses at runtime, catching: + /// - template corruption or unknown fields (`deny_unknown_fields`) + /// - schema drift (version bump or field renames) + /// - accidental changes to the bind address or compute driver list + #[test] + fn rpm_default_config_parses_and_has_podman_defaults() { + let path = + Path::new(env!("CARGO_MANIFEST_DIR")).join("../../deploy/rpm/gateway.toml.default"); + let config = + load(&path).expect("deploy/rpm/gateway.toml.default must parse against current schema"); + let gw = &config.openshell.gateway; + + let addr = gw + .bind_address + .expect("bind_address must be explicitly set in the RPM default config"); + assert!( + addr.ip().is_unspecified(), + "RPM default bind_address must be 0.0.0.0 so Podman sandbox containers \ + can reach the gateway over the host network bridge, got {addr}" + ); + assert_eq!( + addr.port(), + openshell_core::config::DEFAULT_SERVER_PORT, + "RPM default port must match DEFAULT_SERVER_PORT ({})", + openshell_core::config::DEFAULT_SERVER_PORT + ); + + let drivers = gw + .compute_drivers + .as_ref() + .expect("compute_drivers must be explicitly set in the RPM default config"); + assert_eq!( + drivers, + &[ComputeDriverKind::Podman], + "RPM default must pin compute_drivers to [podman] to prevent unexpected \ + driver selection when Docker is also installed" + ); + } +} diff --git a/crates/openshell-server/src/defaults.rs b/crates/openshell-server/src/defaults.rs new file mode 100644 index 000000000..25179bbd3 --- /dev/null +++ b/crates/openshell-server/src/defaults.rs @@ -0,0 +1,242 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Runtime defaults for local gateway installs. + +use miette::Result; +use openshell_core::GatewayJwtConfig; +use std::path::{Path, PathBuf}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocalTlsPaths { + pub ca: PathBuf, + pub server_cert: PathBuf, + pub server_key: PathBuf, + pub client_cert: PathBuf, + pub client_key: PathBuf, +} + +impl LocalTlsPaths { + fn resolve(dir: &Path) -> Self { + Self { + ca: dir.join("ca.crt"), + server_cert: dir.join("server").join("tls.crt"), + server_key: dir.join("server").join("tls.key"), + client_cert: dir.join("client").join("tls.crt"), + client_key: dir.join("client").join("tls.key"), + } + } + + fn files(&self) -> [&Path; 5] { + [ + &self.ca, + &self.server_cert, + &self.server_key, + &self.client_cert, + &self.client_key, + ] + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocalJwtPaths { + pub signing_key: PathBuf, + pub public_key: PathBuf, + pub kid: PathBuf, +} + +impl LocalJwtPaths { + fn resolve(dir: &Path) -> Self { + let jwt = dir.join("jwt"); + Self { + signing_key: jwt.join("signing.pem"), + public_key: jwt.join("public.pem"), + kid: jwt.join("kid"), + } + } + + fn files(&self) -> [&Path; 3] { + [&self.signing_key, &self.public_key, &self.kid] + } +} + +pub fn default_gateway_config_path() -> Result { + Ok(openshell_core::paths::openshell_config_dir()?.join("gateway.toml")) +} + +pub fn default_database_url() -> Result { + let path = openshell_core::paths::openshell_state_dir()? + .join("gateway") + .join("openshell.db"); + openshell_core::paths::ensure_parent_dir_restricted(&path)?; + Ok(format!("sqlite:{}", path.display())) +} + +fn default_local_tls_dir() -> Result { + if let Some(path) = std::env::var_os("OPENSHELL_LOCAL_TLS_DIR") { + return Ok(PathBuf::from(path)); + } + Ok(openshell_core::paths::openshell_state_dir()?.join("tls")) +} + +pub fn complete_local_tls_paths() -> Result> { + let dir = default_local_tls_dir()?; + let paths = LocalTlsPaths::resolve(&dir); + let present = paths.files().iter().filter(|path| path.is_file()).count(); + match present { + 0 => Ok(None), + 5 => Ok(Some(paths)), + _ => Err(miette::miette!( + "partial local TLS state in {}: expected ca.crt, server/tls.crt, server/tls.key, client/tls.crt, and client/tls.key", + dir.display() + )), + } +} + +pub fn complete_local_jwt_config() -> Result> { + let dir = default_local_tls_dir()?; + let paths = LocalJwtPaths::resolve(&dir); + let present = paths.files().iter().filter(|path| path.is_file()).count(); + match present { + 0 => Ok(None), + 3 => Ok(Some(GatewayJwtConfig { + signing_key_path: paths.signing_key, + public_key_path: paths.public_key, + kid_path: paths.kid, + gateway_id: "openshell".to_string(), + ttl_secs: 3_600, + })), + _ => Err(miette::miette!( + "partial local sandbox JWT state in {}: expected jwt/signing.pem, jwt/public.pem, and jwt/kid", + dir.display() + )), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TEST_ENV_LOCK as ENV_LOCK; + + struct EnvVarGuard { + key: &'static str, + original: Option, + } + + impl EnvVarGuard { + #[allow(unsafe_code)] + fn set(key: &'static str, value: &Path) -> Self { + let original = std::env::var(key).ok(); + // SAFETY: tests serialize environment mutation with ENV_LOCK. + unsafe { std::env::set_var(key, value) }; + Self { key, original } + } + } + + impl Drop for EnvVarGuard { + #[allow(unsafe_code)] + fn drop(&mut self) { + match self.original.as_deref() { + // SAFETY: tests serialize environment mutation with ENV_LOCK. + Some(value) => unsafe { std::env::set_var(self.key, value) }, + // SAFETY: tests serialize environment mutation with ENV_LOCK. + None => unsafe { std::env::remove_var(self.key) }, + } + } + } + + #[test] + fn complete_local_tls_paths_returns_none_when_bundle_absent() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let tmp = tempfile::tempdir().unwrap(); + let _guard = EnvVarGuard::set("OPENSHELL_LOCAL_TLS_DIR", tmp.path()); + + assert!(complete_local_tls_paths().unwrap().is_none()); + } + + #[test] + fn complete_local_tls_paths_rejects_partial_bundle() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let tmp = tempfile::tempdir().unwrap(); + let _guard = EnvVarGuard::set("OPENSHELL_LOCAL_TLS_DIR", tmp.path()); + std::fs::write(tmp.path().join("ca.crt"), "ca").unwrap(); + + let err = complete_local_tls_paths().unwrap_err(); + assert!(err.to_string().contains("partial local TLS state")); + } + + #[test] + fn complete_local_tls_paths_returns_full_bundle() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let tmp = tempfile::tempdir().unwrap(); + let _guard = EnvVarGuard::set("OPENSHELL_LOCAL_TLS_DIR", tmp.path()); + std::fs::create_dir_all(tmp.path().join("server")).unwrap(); + std::fs::create_dir_all(tmp.path().join("client")).unwrap(); + for rel in [ + "ca.crt", + "server/tls.crt", + "server/tls.key", + "client/tls.crt", + "client/tls.key", + ] { + std::fs::write(tmp.path().join(rel), "pem").unwrap(); + } + + let paths = complete_local_tls_paths().unwrap().unwrap(); + assert_eq!(paths.ca, tmp.path().join("ca.crt")); + assert_eq!(paths.server_cert, tmp.path().join("server/tls.crt")); + assert_eq!(paths.client_key, tmp.path().join("client/tls.key")); + } + + #[test] + fn complete_local_jwt_config_returns_none_when_bundle_absent() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let tmp = tempfile::tempdir().unwrap(); + let _guard = EnvVarGuard::set("OPENSHELL_LOCAL_TLS_DIR", tmp.path()); + + assert!(complete_local_jwt_config().unwrap().is_none()); + } + + #[test] + fn complete_local_jwt_config_rejects_partial_bundle() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let tmp = tempfile::tempdir().unwrap(); + let _guard = EnvVarGuard::set("OPENSHELL_LOCAL_TLS_DIR", tmp.path()); + std::fs::create_dir_all(tmp.path().join("jwt")).unwrap(); + std::fs::write(tmp.path().join("jwt/signing.pem"), "key").unwrap(); + + let err = complete_local_jwt_config().unwrap_err(); + assert!(err.to_string().contains("partial local sandbox JWT state")); + } + + #[test] + fn complete_local_jwt_config_returns_full_bundle() { + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let tmp = tempfile::tempdir().unwrap(); + let _guard = EnvVarGuard::set("OPENSHELL_LOCAL_TLS_DIR", tmp.path()); + std::fs::create_dir_all(tmp.path().join("jwt")).unwrap(); + for rel in ["jwt/signing.pem", "jwt/public.pem", "jwt/kid"] { + std::fs::write(tmp.path().join(rel), "pem").unwrap(); + } + + let config = complete_local_jwt_config().unwrap().unwrap(); + + assert_eq!(config.signing_key_path, tmp.path().join("jwt/signing.pem")); + assert_eq!(config.public_key_path, tmp.path().join("jwt/public.pem")); + assert_eq!(config.kid_path, tmp.path().join("jwt/kid")); + assert_eq!(config.gateway_id, "openshell"); + assert_eq!(config.ttl_secs, 3_600); + } +} diff --git a/crates/openshell-server/src/grpc/auth_rpc.rs b/crates/openshell-server/src/grpc/auth_rpc.rs new file mode 100644 index 000000000..8e98b1824 --- /dev/null +++ b/crates/openshell-server/src/grpc/auth_rpc.rs @@ -0,0 +1,365 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Authentication-related RPC handlers. +//! +//! Hosts the two sandbox-identity RPCs: +//! - `IssueSandboxToken` — bootstrap exchange (K8s SA token → gateway JWT) +//! - `RefreshSandboxToken` — renew a still-valid gateway JWT +//! +//! Both end in a fresh gateway-signed JWT minted by +//! [`crate::auth::sandbox_jwt::SandboxJwtIssuer`]. Older tokens remain valid +//! until their own `exp` and are bounded by the configured short TTL. + +use crate::ServerState; +use crate::auth::principal::{Principal, SandboxIdentitySource}; +use openshell_core::proto::{ + IssueSandboxTokenRequest, IssueSandboxTokenResponse, RefreshSandboxTokenRequest, + RefreshSandboxTokenResponse, Sandbox, +}; +use std::sync::Arc; +use tonic::{Request, Response, Status}; +use tracing::{debug, info, warn}; + +#[allow(clippy::result_large_err, clippy::unused_async)] +pub async fn handle_issue_sandbox_token( + state: &Arc, + request: Request, +) -> Result, Status> { + let principal = request + .extensions() + .get::() + .cloned() + .ok_or_else(|| Status::unauthenticated("missing principal"))?; + + let Principal::Sandbox(sandbox) = principal else { + return Err(Status::permission_denied( + "IssueSandboxToken requires a sandbox principal", + )); + }; + + // Only the bootstrap K8s ServiceAccount path can mint a fresh gateway JWT + // via this RPC. Sandboxes already holding a gateway JWT use + // `RefreshSandboxToken` instead. + if !matches!( + sandbox.source, + SandboxIdentitySource::K8sServiceAccount { .. } + ) { + debug!( + sandbox_id = %sandbox.sandbox_id, + "IssueSandboxToken rejected: non-bootstrap principal source" + ); + return Err(Status::permission_denied( + "this principal cannot mint a sandbox token; use RefreshSandboxToken", + )); + } + + let issuer = state.sandbox_jwt_issuer.as_ref().ok_or_else(|| { + warn!( + sandbox_id = %sandbox.sandbox_id, + "IssueSandboxToken called but sandbox JWT issuer is not configured" + ); + Status::unavailable("sandbox JWT minting is not configured on this gateway") + })?; + + ensure_sandbox_exists(state, &sandbox.sandbox_id).await?; + + let minted = issuer.mint(&sandbox.sandbox_id)?; + info!( + sandbox_id = %sandbox.sandbox_id, + "issued gateway sandbox JWT" + ); + Ok(Response::new(IssueSandboxTokenResponse { + token: minted.token, + expires_at_ms: minted.expires_at_ms, + })) +} + +#[allow(clippy::result_large_err, clippy::unused_async)] +pub async fn handle_refresh_sandbox_token( + state: &Arc, + request: Request, +) -> Result, Status> { + let principal = request + .extensions() + .get::() + .cloned() + .ok_or_else(|| Status::unauthenticated("missing principal"))?; + + let Principal::Sandbox(sandbox) = principal else { + return Err(Status::permission_denied( + "RefreshSandboxToken requires a sandbox principal", + )); + }; + + // Only callers already holding a gateway-minted JWT may refresh; the + // K8s bootstrap path must use `IssueSandboxToken`. + let SandboxIdentitySource::BootstrapJwt { .. } = &sandbox.source else { + debug!( + sandbox_id = %sandbox.sandbox_id, + "RefreshSandboxToken rejected: non-gateway-JWT principal source" + ); + return Err(Status::permission_denied( + "this principal cannot refresh; use IssueSandboxToken for bootstrap", + )); + }; + + let issuer = state.sandbox_jwt_issuer.as_ref().ok_or_else(|| { + warn!( + sandbox_id = %sandbox.sandbox_id, + "RefreshSandboxToken called but sandbox JWT issuer is not configured" + ); + Status::unavailable("sandbox JWT minting is not configured on this gateway") + })?; + + ensure_sandbox_exists(state, &sandbox.sandbox_id).await?; + + let minted = issuer.mint(&sandbox.sandbox_id)?; + info!( + sandbox_id = %sandbox.sandbox_id, + "renewed gateway sandbox JWT" + ); + + Ok(Response::new(RefreshSandboxTokenResponse { + token: minted.token, + expires_at_ms: minted.expires_at_ms, + })) +} + +async fn ensure_sandbox_exists(state: &Arc, sandbox_id: &str) -> Result<(), Status> { + if sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox_id is required")); + } + + state + .store + .get_message::(sandbox_id) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ServerState; + use crate::auth::principal::{Principal, SandboxPrincipal, UserPrincipal}; + use crate::auth::sandbox_jwt::SandboxJwtIssuer; + use crate::compute::new_test_runtime; + use crate::persistence::Store; + use crate::sandbox_index::SandboxIndex; + use crate::sandbox_watch::SandboxWatchBus; + use crate::supervisor_session::SupervisorSessionRegistry; + use crate::tracing_bus::TracingLogBus; + use openshell_bootstrap::jwt::generate_jwt_key; + use openshell_core::Config; + use openshell_core::proto::datamodel::v1::ObjectMeta; + use openshell_core::proto::{Sandbox, SandboxPhase, SandboxSpec}; + use std::collections::HashMap; + use std::time::Duration; + + async fn state_with_issuer() -> Arc { + let mat = generate_jwt_key().expect("jwt key"); + let store = Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(), + ); + let compute = new_test_runtime(store.clone()).await; + let mut state = ServerState::new( + Config::new(None).with_database_url("sqlite::memory:?cache=shared"), + store, + compute, + SandboxIndex::new(), + SandboxWatchBus::new(), + TracingLogBus::new(), + Arc::new(SupervisorSessionRegistry::new()), + None, + ); + // We don't need the authenticator for these tests; only the issuer. + let issuer = SandboxJwtIssuer::from_pem( + mat.signing_key_pem.as_bytes(), + mat.kid, + "test-gateway", + Duration::from_secs(3600), + ) + .unwrap(); + state.sandbox_jwt_issuer = Some(Arc::new(issuer)); + let state = Arc::new(state); + insert_sandbox(&state, "sandbox-a").await; + state + } + + async fn insert_sandbox(state: &Arc, sandbox_id: &str) { + let sandbox = Sandbox { + metadata: Some(ObjectMeta { + id: sandbox_id.to_string(), + name: sandbox_id.to_string(), + created_at_ms: 1_000_000, + labels: HashMap::default(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + } + + fn sandbox_principal(sandbox_id: &str) -> Principal { + use crate::auth::principal::SandboxIdentitySource; + Principal::Sandbox(SandboxPrincipal { + sandbox_id: sandbox_id.to_string(), + source: SandboxIdentitySource::BootstrapJwt { + issuer: "openshell-gateway:test-gateway".to_string(), + }, + trust_domain: Some("openshell".to_string()), + }) + } + + #[tokio::test] + async fn refresh_returns_new_token() { + let state = state_with_issuer().await; + let mut req = Request::new(RefreshSandboxTokenRequest {}); + req.extensions_mut().insert(sandbox_principal("sandbox-a")); + let resp = handle_refresh_sandbox_token(&state, req) + .await + .expect("refresh OK") + .into_inner(); + assert!(!resp.token.is_empty()); + assert!(resp.expires_at_ms > 0); + } + + #[tokio::test] + async fn refresh_rejects_missing_sandbox() { + let state = state_with_issuer().await; + let mut req = Request::new(RefreshSandboxTokenRequest {}); + req.extensions_mut() + .insert(sandbox_principal("sandbox-deleted")); + let err = handle_refresh_sandbox_token(&state, req) + .await + .expect_err("missing sandbox must not refresh"); + assert_eq!(err.code(), tonic::Code::NotFound); + } + + #[tokio::test] + async fn issue_returns_token_for_existing_sandbox() { + use crate::auth::principal::SandboxIdentitySource; + + let state = state_with_issuer().await; + let mut req = Request::new(IssueSandboxTokenRequest {}); + req.extensions_mut() + .insert(Principal::Sandbox(SandboxPrincipal { + sandbox_id: "sandbox-a".to_string(), + source: SandboxIdentitySource::K8sServiceAccount { + pod_name: "pod-a".to_string(), + pod_uid: "uid-a".to_string(), + }, + trust_domain: Some("openshell".to_string()), + })); + let resp = handle_issue_sandbox_token(&state, req) + .await + .expect("issue OK") + .into_inner(); + assert!(!resp.token.is_empty()); + assert!(resp.expires_at_ms > 0); + } + + #[tokio::test] + async fn issue_rejects_missing_sandbox() { + use crate::auth::principal::SandboxIdentitySource; + + let state = state_with_issuer().await; + let mut req = Request::new(IssueSandboxTokenRequest {}); + req.extensions_mut() + .insert(Principal::Sandbox(SandboxPrincipal { + sandbox_id: "sandbox-deleted".to_string(), + source: SandboxIdentitySource::K8sServiceAccount { + pod_name: "pod-a".to_string(), + pod_uid: "uid-a".to_string(), + }, + trust_domain: Some("openshell".to_string()), + })); + let err = handle_issue_sandbox_token(&state, req) + .await + .expect_err("missing sandbox must not receive a token"); + assert_eq!(err.code(), tonic::Code::NotFound); + } + + #[tokio::test] + async fn refresh_rejects_user_principal() { + use crate::auth::identity::{Identity, IdentityProvider}; + let state = state_with_issuer().await; + let mut req = Request::new(RefreshSandboxTokenRequest {}); + req.extensions_mut().insert(Principal::User(UserPrincipal { + identity: Identity { + subject: "alice".to_string(), + display_name: None, + roles: vec![], + scopes: vec![], + provider: IdentityProvider::Oidc, + }, + })); + let err = handle_refresh_sandbox_token(&state, req) + .await + .expect_err("user must not refresh"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[tokio::test] + async fn refresh_rejects_k8s_sa_principal() { + // K8s SA-bootstrap principals must use IssueSandboxToken, not + // RefreshSandboxToken — the refresh path assumes a still-valid + // gateway-minted JWT exists. + use crate::auth::principal::SandboxIdentitySource; + let state = state_with_issuer().await; + let mut req = Request::new(RefreshSandboxTokenRequest {}); + req.extensions_mut() + .insert(Principal::Sandbox(SandboxPrincipal { + sandbox_id: "sandbox-a".to_string(), + source: SandboxIdentitySource::K8sServiceAccount { + pod_name: "pod-a".to_string(), + pod_uid: "uid-a".to_string(), + }, + trust_domain: Some("openshell".to_string()), + })); + let err = handle_refresh_sandbox_token(&state, req) + .await + .expect_err("K8s SA principal must not refresh"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[tokio::test] + async fn refresh_fails_when_issuer_not_configured() { + // Build a ServerState without the issuer to confirm the handler + // returns Unavailable. + let store = Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(), + ); + let compute = new_test_runtime(store.clone()).await; + let state = Arc::new(ServerState::new( + Config::new(None).with_database_url("sqlite::memory:?cache=shared"), + store, + compute, + SandboxIndex::new(), + SandboxWatchBus::new(), + TracingLogBus::new(), + Arc::new(SupervisorSessionRegistry::new()), + None, + )); + insert_sandbox(&state, "sandbox-a").await; + let mut req = Request::new(RefreshSandboxTokenRequest {}); + req.extensions_mut().insert(sandbox_principal("sandbox-a")); + let err = handle_refresh_sandbox_token(&state, req) + .await + .expect_err("missing issuer must yield unavailable"); + assert_eq!(err.code(), tonic::Code::Unavailable); + } +} diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index 87af948ed..8538c8658 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -3,34 +3,45 @@ //! gRPC service implementation. +mod auth_rpc; pub mod policy; -mod provider; +pub mod provider; mod sandbox; +mod service; mod validation; use openshell_core::proto::{ ApproveAllDraftChunksRequest, ApproveAllDraftChunksResponse, ApproveDraftChunkRequest, - ApproveDraftChunkResponse, ClearDraftChunksRequest, ClearDraftChunksResponse, - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderProfileRequest, DeleteProviderProfileResponse, DeleteProviderRequest, - DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, EditDraftChunkRequest, - EditDraftChunkResponse, ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, + ApproveDraftChunkResponse, AttachSandboxProviderRequest, AttachSandboxProviderResponse, + ClearDraftChunksRequest, ClearDraftChunksResponse, ConfigureProviderRefreshRequest, + ConfigureProviderRefreshResponse, CreateProviderRequest, CreateSandboxRequest, + CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderProfileRequest, + DeleteProviderProfileResponse, DeleteProviderRefreshRequest, DeleteProviderRefreshResponse, + DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DeleteServiceRequest, DeleteServiceResponse, DetachSandboxProviderRequest, + DetachSandboxProviderResponse, EditDraftChunkRequest, EditDraftChunkResponse, ExecSandboxEvent, + ExecSandboxInput, ExecSandboxRequest, ExposeServiceRequest, GatewayMessage, GetDraftHistoryRequest, GetDraftHistoryResponse, GetDraftPolicyRequest, GetDraftPolicyResponse, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderProfileRequest, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxLogsRequest, + GetProviderRefreshStatusRequest, GetProviderRefreshStatusResponse, GetProviderRequest, + GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxLogsRequest, GetSandboxLogsResponse, GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ImportProviderProfilesRequest, ImportProviderProfilesResponse, + GetServiceRequest, HealthRequest, HealthResponse, ImportProviderProfilesRequest, + ImportProviderProfilesResponse, IssueSandboxTokenRequest, IssueSandboxTokenResponse, LintProviderProfilesRequest, LintProviderProfilesResponse, ListProviderProfilesRequest, ListProviderProfilesResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, ListSandboxesRequest, - ListSandboxesResponse, ProviderProfileResponse, ProviderResponse, PushSandboxLogsRequest, - PushSandboxLogsResponse, RejectDraftChunkRequest, RejectDraftChunkResponse, RelayFrame, - ReportPolicyStatusRequest, ReportPolicyStatusResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, SupervisorMessage, - UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, UpdateConfigResponse, - UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, + ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, ListSandboxProvidersRequest, + ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, ListServicesRequest, + ListServicesResponse, ProviderProfileResponse, ProviderResponse, PushSandboxLogsRequest, + PushSandboxLogsResponse, RefreshSandboxTokenRequest, RefreshSandboxTokenResponse, + RejectDraftChunkRequest, RejectDraftChunkResponse, RelayFrame, ReportPolicyStatusRequest, + ReportPolicyStatusResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, + RotateProviderCredentialRequest, RotateProviderCredentialResponse, SandboxResponse, + SandboxStreamEvent, ServiceEndpointResponse, ServiceStatus, SubmitPolicyAnalysisRequest, + SubmitPolicyAnalysisResponse, SupervisorMessage, TcpForwardFrame, UndoDraftChunkRequest, + UndoDraftChunkResponse, UpdateConfigRequest, UpdateConfigResponse, UpdateProviderRequest, + WatchSandboxRequest, open_shell_server::OpenShell, }; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; @@ -59,6 +70,29 @@ pub fn clamp_limit(raw: u32, default: u32, max: u32) -> u32 { if raw == 0 { default } else { raw.min(max) } } +/// Map a `PersistenceError` to an appropriate gRPC `Status`. +/// +/// CAS conflicts (optimistic concurrency failures) are mapped to `ABORTED` +/// to signal that the client should retry with fresh data. Other persistence +/// errors are mapped to `INTERNAL`. +pub fn persistence_error_to_status( + err: crate::persistence::PersistenceError, + operation: &str, +) -> Status { + use crate::persistence::PersistenceError; + + match err { + PersistenceError::Conflict { + current_resource_version, + } => Status::aborted(format!( + "{} failed due to concurrent modification (current resource_version: {})", + operation, + current_resource_version.map_or_else(|| "unknown".to_string(), |v| v.to_string()) + )), + other => Status::internal(format!("{operation} failed: {other}")), + } +} + // --------------------------------------------------------------------------- // Field-level size limits (shared across submodules) // --------------------------------------------------------------------------- @@ -98,6 +132,10 @@ const MAX_PROVIDER_CONFIG_ENTRIES: usize = 64; struct StoredSettings { revision: u64, settings: BTreeMap, + /// Database `resource_version` for CAS. Not persisted in the JSON payload; + /// loaded from `ObjectRecord` and used for optimistic concurrency control. + #[serde(skip)] + resource_version: u64, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -114,9 +152,8 @@ enum StoredSettingValue { // Utility // --------------------------------------------------------------------------- -fn current_time_ms() -> Result { - let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?; - Ok(i64::try_from(now.as_millis()).unwrap_or(i64::MAX)) +fn current_time_ms() -> i64 { + openshell_core::time::now_ms() } /// Validate that object metadata is present and contains required fields. @@ -199,6 +236,27 @@ impl OpenShell for OpenShellService { sandbox::handle_list_sandboxes(&self.state, request).await } + async fn list_sandbox_providers( + &self, + request: Request, + ) -> Result, Status> { + sandbox::handle_list_sandbox_providers(&self.state, request).await + } + + async fn attach_sandbox_provider( + &self, + request: Request, + ) -> Result, Status> { + sandbox::handle_attach_sandbox_provider(&self.state, request).await + } + + async fn detach_sandbox_provider( + &self, + request: Request, + ) -> Result, Status> { + sandbox::handle_detach_sandbox_provider(&self.state, request).await + } + async fn delete_sandbox( &self, request: Request, @@ -217,6 +275,25 @@ impl OpenShell for OpenShellService { sandbox::handle_exec_sandbox(&self.state, request).await } + type ForwardTcpStream = + Pin> + Send + 'static>>; + + async fn forward_tcp( + &self, + request: Request>, + ) -> Result, Status> { + sandbox::handle_forward_tcp(&self.state, request).await + } + + type ExecSandboxInteractiveStream = ReceiverStream>; + + async fn exec_sandbox_interactive( + &self, + request: Request>, + ) -> Result, Status> { + sandbox::handle_exec_sandbox_interactive(&self.state, request).await + } + // --- SSH sessions --- async fn create_ssh_session( @@ -226,6 +303,34 @@ impl OpenShell for OpenShellService { sandbox::handle_create_ssh_session(&self.state, request).await } + async fn expose_service( + &self, + request: Request, + ) -> Result, Status> { + service::handle_expose_service(&self.state, request).await + } + + async fn get_service( + &self, + request: Request, + ) -> Result, Status> { + service::handle_get_service(&self.state, request).await + } + + async fn list_services( + &self, + request: Request, + ) -> Result, Status> { + service::handle_list_services(&self.state, request).await + } + + async fn delete_service( + &self, + request: Request, + ) -> Result, Status> { + service::handle_delete_service(&self.state, request).await + } + async fn revoke_ssh_session( &self, request: Request, @@ -291,6 +396,34 @@ impl OpenShell for OpenShellService { provider::handle_update_provider(&self.state, request).await } + async fn get_provider_refresh_status( + &self, + request: Request, + ) -> Result, Status> { + provider::handle_get_provider_refresh_status(&self.state, request).await + } + + async fn configure_provider_refresh( + &self, + request: Request, + ) -> Result, Status> { + provider::handle_configure_provider_refresh(&self.state, request).await + } + + async fn rotate_provider_credential( + &self, + request: Request, + ) -> Result, Status> { + provider::handle_rotate_provider_credential(&self.state, request).await + } + + async fn delete_provider_refresh( + &self, + request: Request, + ) -> Result, Status> { + provider::handle_delete_provider_refresh(&self.state, request).await + } + async fn delete_provider( &self, request: Request, @@ -437,6 +570,22 @@ impl OpenShell for OpenShellService { policy::handle_get_draft_history(&self.state, request).await } + // --- Sandbox identity --- + + async fn issue_sandbox_token( + &self, + request: Request, + ) -> Result, Status> { + auth_rpc::handle_issue_sandbox_token(&self.state, request).await + } + + async fn refresh_sandbox_token( + &self, + request: Request, + ) -> Result, Status> { + auth_rpc::handle_refresh_sandbox_token(&self.state, request).await + } + // --- Supervisor session --- type ConnectSupervisorStream = @@ -461,6 +610,45 @@ impl OpenShell for OpenShellService { } } +// --------------------------------------------------------------------------- +// Shared test support +// --------------------------------------------------------------------------- + +/// Shared test helpers for grpc submodule unit tests. +#[cfg(test)] +pub mod test_support { + use std::sync::Arc; + + use crate::ServerState; + use crate::compute::new_test_runtime; + use crate::persistence::Store; + use crate::sandbox_index::SandboxIndex; + use crate::sandbox_watch::SandboxWatchBus; + use crate::supervisor_session::SupervisorSessionRegistry; + use crate::tracing_bus::TracingLogBus; + use openshell_core::Config; + + /// Build an in-memory `ServerState` for unit tests. + pub async fn test_server_state() -> Arc { + let store = Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(), + ); + let compute = new_test_runtime(store.clone()).await; + Arc::new(ServerState::new( + Config::new(None).with_database_url("sqlite::memory:?cache=shared"), + store, + compute, + SandboxIndex::new(), + SandboxWatchBus::new(), + TracingLogBus::new(), + Arc::new(SupervisorSessionRegistry::new()), + None, + )) + } +} + // --------------------------------------------------------------------------- // Tests for mod-level utilities // --------------------------------------------------------------------------- diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 2c62c930a..bdc96d862 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -10,9 +10,10 @@ #![allow(clippy::cast_precision_loss)] // f64->f32 for confidence scores #![allow(clippy::items_after_statements)] // DB_PORTS const inside function -use crate::persistence::{DraftChunkRecord, ObjectId, ObjectName, PolicyRecord, Store}; +use crate::ServerState; +use crate::auth::principal::Principal; +use crate::persistence::{DraftChunkRecord, ObjectId, ObjectName, ObjectType, PolicyRecord, Store}; use crate::policy_store::PolicyStoreExt; -use crate::{ServerState, auth::oidc}; use openshell_core::proto::policy_merge_operation; use openshell_core::proto::setting_value; use openshell_core::proto::{ @@ -216,6 +217,12 @@ fn summarize_endpoint(endpoint: &NetworkEndpoint) -> String { if !endpoint.tls.is_empty() { parts.push(format!("tls={}", endpoint.tls)); } + if endpoint.websocket_credential_rewrite { + parts.push("websocket_credential_rewrite=true".to_string()); + } + if endpoint.request_body_credential_rewrite { + parts.push("request_body_credential_rewrite=true".to_string()); + } if !endpoint.allowed_ips.is_empty() { parts.push(format!("allowed_ips={}", endpoint.allowed_ips.len())); } @@ -308,36 +315,75 @@ fn truncate_for_log(input: &str, max_chars: usize) -> String { } } -fn is_sandbox_secret_authenticated(request: &Request) -> bool { - oidc::is_sandbox_secret_authenticated(request.metadata()) +#[cfg(test)] +fn is_sandbox_caller(request: &Request) -> bool { + matches!( + request.extensions().get::(), + Some(Principal::Sandbox(_)) + ) } -/// Sandbox-secret-authenticated callers may only perform sandbox-scoped policy -/// sync. They must not be able to mutate global config or sandbox settings. -fn validate_sandbox_secret_update(req: &UpdateConfigRequest) -> Result<(), Status> { +/// Sandbox-class callers may only perform sandbox-scoped policy sync. They +/// must not mutate global config or sandbox settings. +fn validate_sandbox_caller_update(req: &UpdateConfigRequest) -> Result<(), Status> { if req.global { return Err(Status::permission_denied( - "sandbox secret cannot mutate global config", + "sandbox callers cannot mutate global config", )); } if req.delete_setting { return Err(Status::permission_denied( - "sandbox secret cannot delete settings", + "sandbox callers cannot delete settings", )); } if req.name.trim().is_empty() { return Err(Status::permission_denied( - "sandbox secret may only perform sandbox policy sync", + "sandbox callers may only perform sandbox policy sync", )); } if req.policy.is_none() || !req.setting_key.trim().is_empty() { return Err(Status::permission_denied( - "sandbox secret may only perform sandbox policy sync", + "sandbox callers may only perform sandbox policy sync", )); } Ok(()) } +async fn resolve_sandbox_by_name_for_principal( + store: &Store, + principal: &Principal, + name: &str, +) -> Result { + let sandbox = store + .get_message_by_name::(name) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; + + match principal { + Principal::Sandbox(_) => { + let Some(sandbox) = sandbox else { + return Err(Status::permission_denied( + "sandbox not found or not owned by caller", + )); + }; + crate::auth::guard::ensure_sandbox_scope(principal, sandbox.object_id()).map_err( + |status| { + if status.code() == tonic::Code::PermissionDenied { + Status::permission_denied("sandbox not found or not owned by caller") + } else { + status + } + }, + )?; + Ok(sandbox) + } + Principal::User(_) => sandbox.ok_or_else(|| Status::not_found("sandbox not found")), + Principal::Anonymous => Err(Status::unauthenticated( + "sandbox-scoped methods require an authenticated caller", + )), + } +} + // --------------------------------------------------------------------------- // Config handlers // --------------------------------------------------------------------------- @@ -346,7 +392,9 @@ pub(super) async fn handle_get_sandbox_config( state: &Arc, request: Request, ) -> Result, Status> { - let sandbox_id = request.into_inner().sandbox_id; + let sandbox_id = request.get_ref().sandbox_id.clone(); + crate::auth::guard::enforce_sandbox_scope(&request, &sandbox_id)?; + drop(request); let sandbox = state .store @@ -472,6 +520,8 @@ pub(super) async fn handle_get_sandbox_config( let settings = merge_effective_settings(&global_settings, &sandbox_settings)?; let config_revision = compute_config_revision(policy.as_ref(), &settings, policy_source); + let provider_env_revision = + compute_provider_env_revision(state.store.as_ref(), &sandbox_provider_names).await?; Ok(Response::new(GetSandboxConfigResponse { policy, @@ -481,9 +531,58 @@ pub(super) async fn handle_get_sandbox_config( config_revision, policy_source: policy_source.into(), global_policy_version, + provider_env_revision, })) } +pub(super) async fn compute_provider_env_revision( + store: &Store, + provider_names: &[String], +) -> Result { + let mut hasher = Sha256::new(); + hasher.update(b"openshell-provider-env-revision-v1"); + + for provider_name in provider_names { + hasher.update(provider_name.as_bytes()); + match store + .get_by_name(Provider::object_type(), provider_name) + .await + .map_err(|e| { + Status::internal(format!("fetch provider '{provider_name}' failed: {e}")) + })? { + Some(record) => { + hasher.update(record.id.as_bytes()); + hasher.update(record.updated_at_ms.to_le_bytes()); + + let provider = Provider::decode(record.payload.as_slice()).map_err(|e| { + Status::internal(format!("decode provider '{provider_name}' failed: {e}")) + })?; + hasher.update(provider.r#type.as_bytes()); + + let mut credential_keys: Vec<_> = provider.credentials.keys().collect(); + credential_keys.sort(); + for key in credential_keys { + hasher.update(key.as_bytes()); + } + let mut expiry_keys: Vec<_> = provider.credential_expires_at_ms.keys().collect(); + expiry_keys.sort(); + for key in expiry_keys { + hasher.update(key.as_bytes()); + hasher.update(provider.credential_expires_at_ms[key].to_le_bytes()); + } + } + None => { + hasher.update(b"missing"); + } + } + } + + let digest = hasher.finalize(); + Ok(u64::from_le_bytes(digest[..8].try_into().map_err( + |_| Status::internal("provider env revision digest too short"), + )?)) +} + async fn profile_provider_policy_layers( store: &Store, provider_names: &[String], @@ -558,7 +657,9 @@ pub(super) async fn handle_get_sandbox_provider_environment( state: &Arc, request: Request, ) -> Result, Status> { - let sandbox_id = request.into_inner().sandbox_id; + let sandbox_id = request.get_ref().sandbox_id.clone(); + crate::auth::guard::enforce_sandbox_scope(&request, &sandbox_id)?; + drop(request); let sandbox = state .store @@ -571,19 +672,25 @@ pub(super) async fn handle_get_sandbox_provider_environment( .spec .ok_or_else(|| Status::internal("sandbox has no spec"))?; - let environment = - super::provider::resolve_provider_environment(state.store.as_ref(), &spec.providers) + let provider_names = spec.providers; + let provider_env_revision = + compute_provider_env_revision(state.store.as_ref(), &provider_names).await?; + let provider_environment = + super::provider::resolve_provider_environment(state.store.as_ref(), &provider_names) .await?; info!( sandbox_id = %sandbox_id, - provider_count = spec.providers.len(), - env_count = environment.len(), + provider_count = provider_names.len(), + env_count = provider_environment.environment.len(), + provider_env_revision, "GetSandboxProviderEnvironment request completed successfully" ); Ok(Response::new(GetSandboxProviderEnvironmentResponse { - environment, + environment: provider_environment.environment, + provider_env_revision, + credential_expires_at_ms: provider_environment.credential_expires_at_ms, })) } @@ -595,10 +702,19 @@ pub(super) async fn handle_update_config( state: &Arc, request: Request, ) -> Result, Status> { - let sandbox_secret_auth = is_sandbox_secret_authenticated(&request); + let principal = request.extensions().get::().cloned(); + let sandbox_caller = matches!(principal, Some(Principal::Sandbox(_))); let req = request.into_inner(); - if sandbox_secret_auth { - validate_sandbox_secret_update(&req)?; + if sandbox_caller { + validate_sandbox_caller_update(&req)?; + resolve_sandbox_by_name_for_principal( + state.store.as_ref(), + principal + .as_ref() + .expect("sandbox_caller implies principal"), + &req.name, + ) + .await?; } let key = req.setting_key.trim(); let has_policy = req.policy.is_some(); @@ -950,15 +1066,25 @@ pub(super) async fn handle_update_config( validate_static_fields_unchanged(baseline_policy, &new_policy)?; validate_policy_safety(&new_policy)?; } else { - let mut sandbox = sandbox; - if let Some(ref mut spec) = sandbox.spec { - spec.policy = Some(new_policy.clone()); - } + // Backfill spec.policy using CAS (first-time policy discovery) + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; + let sandbox_id = sandbox.object_id().to_string(); + let new_policy_clone = new_policy.clone(); state .store - .put_message(&sandbox) + .update_message_cas::( + &sandbox_id, + req.expected_resource_version, + |sandbox| { + if let Some(ref mut spec) = sandbox.spec + && spec.policy.is_none() + { + spec.policy = Some(new_policy_clone.clone()); + } + }, + ) .await - .map_err(|e| Status::internal(format!("backfill spec.policy failed: {e}")))?; + .map_err(|e| super::persistence_error_to_status(e, "backfill spec.policy"))?; info!( sandbox_id = %sandbox_id, "UpdateConfig: backfilled spec.policy from sandbox-discovered policy" @@ -1111,6 +1237,8 @@ pub(super) async fn handle_report_policy_status( state: &Arc, request: Request, ) -> Result, Status> { + let sandbox_id = request.get_ref().sandbox_id.clone(); + crate::auth::guard::enforce_sandbox_scope(&request, &sandbox_id)?; let req = request.into_inner(); if req.sandbox_id.is_empty() { return Err(Status::invalid_argument("sandbox_id is required")); @@ -1127,7 +1255,7 @@ pub(super) async fn handle_report_policy_status( }; let loaded_at_ms = if status_str == "loaded" { - Some(current_time_ms().map_err(|e| Status::internal(format!("timestamp error: {e}")))?) + Some(current_time_ms()) } else { None }; @@ -1159,10 +1287,19 @@ pub(super) async fn handle_report_policy_status( .store .supersede_older_policies(&req.sandbox_id, version) .await; - if let Ok(Some(mut sandbox)) = state.store.get_message::(&req.sandbox_id).await { - sandbox.current_policy_version = req.version; - let _ = state.store.put_message(&sandbox).await; - } + + // Update current_policy_version using CAS + // TODO: Accept expected_version from UpdateConfigRequest for proper client-driven CAS + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; + let version_to_set = req.version; + state + .store + .update_message_cas::(&req.sandbox_id, 0, |sandbox| { + sandbox.current_policy_version = version_to_set; + }) + .await + .map_err(|e| super::persistence_error_to_status(e, "update current_policy_version"))?; + state.sandbox_watch_bus.notify(&req.sandbox_id); } @@ -1224,8 +1361,13 @@ pub(super) async fn handle_push_sandbox_logs( state: &Arc, request: Request>, ) -> Result, Status> { + let principal = request + .extensions() + .get::() + .cloned() + .ok_or_else(|| Status::unauthenticated("missing principal"))?; let mut stream = request.into_inner(); - let mut validated = false; + let mut validated_sandbox_id = None; while let Some(batch) = stream .message() @@ -1236,15 +1378,13 @@ pub(super) async fn handle_push_sandbox_logs( continue; } - if !validated { - state - .store - .get_message::(&batch.sandbox_id) - .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? - .ok_or_else(|| Status::not_found("sandbox not found"))?; - validated = true; - } + ensure_log_stream_sandbox_scope( + state, + &principal, + &batch.sandbox_id, + &mut validated_sandbox_id, + ) + .await?; for log in batch.logs.into_iter().take(100) { let mut log = log; @@ -1257,6 +1397,32 @@ pub(super) async fn handle_push_sandbox_logs( Ok(Response::new(PushSandboxLogsResponse {})) } +async fn ensure_log_stream_sandbox_scope( + state: &Arc, + principal: &Principal, + sandbox_id: &str, + validated_sandbox_id: &mut Option, +) -> Result<(), Status> { + if let Some(validated) = validated_sandbox_id.as_deref() { + if sandbox_id != validated { + return Err(Status::permission_denied( + "log stream sandbox_id changed after validation", + )); + } + return Ok(()); + } + + crate::auth::guard::ensure_sandbox_scope(principal, sandbox_id)?; + state + .store + .get_message::(sandbox_id) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + *validated_sandbox_id = Some(sandbox_id.to_string()); + Ok(()) +} + // --------------------------------------------------------------------------- // Draft policy recommendation handlers // --------------------------------------------------------------------------- @@ -1265,17 +1431,18 @@ pub(super) async fn handle_submit_policy_analysis( state: &Arc, request: Request, ) -> Result, Status> { + let principal = request + .extensions() + .get::() + .cloned() + .ok_or_else(|| Status::unauthenticated("missing principal"))?; let req = request.into_inner(); if req.name.is_empty() { return Err(Status::invalid_argument("name is required")); } - let sandbox = state - .store - .get_message_by_name::(&req.name) - .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? - .ok_or_else(|| Status::not_found("sandbox not found"))?; + let sandbox = + resolve_sandbox_by_name_for_principal(state.store.as_ref(), &principal, &req.name).await?; let sandbox_id = sandbox.object_id().to_string(); let current_version = state @@ -1288,6 +1455,7 @@ pub(super) async fn handle_submit_policy_analysis( let mut accepted: u32 = 0; let mut rejected: u32 = 0; let mut rejection_reasons: Vec = Vec::new(); + let mut accepted_chunk_ids: Vec = Vec::new(); for chunk in &req.proposed_chunks { if chunk.rule_name.is_empty() { @@ -1301,9 +1469,7 @@ pub(super) async fn handle_submit_policy_analysis( continue; } - let chunk_id = uuid::Uuid::new_v4().to_string(); - let now_ms = - current_time_ms().map_err(|e| Status::internal(format!("timestamp error: {e}")))?; + let now_ms = current_time_ms(); let proposed_rule_bytes = chunk .proposed_rule .as_ref() @@ -1321,7 +1487,10 @@ pub(super) async fn handle_submit_policy_analysis( .unwrap_or_default(); let record = DraftChunkRecord { - id: chunk_id, + // The handler proposes an id; the store may swap it for an + // existing row's id on dedup. Always trust `effective_id` for + // anything user-facing. + id: uuid::Uuid::new_v4().to_string(), sandbox_id: sandbox_id.clone(), draft_version, status: "pending".to_string(), @@ -1349,13 +1518,23 @@ pub(super) async fn handle_submit_policy_analysis( } else { now_ms }, + validation_result: String::new(), + rejection_reason: String::new(), }; - state + // Mechanistic mode dedups N denials targeting the same endpoint + // into one chunk. All other modes (agent-authored proposals, future + // modes) submit each chunk as a distinct row — the redraft loop + // relies on it, and the conservative default for an unknown mode + // is to keep the proposal rather than silently fold it away. + let dedup_key = matches!(req.analysis_mode.as_str(), "mechanistic") + .then(|| crate::policy_store::observation_dedup_key(&record)); + let effective_id = state .store - .put_draft_chunk(&record) + .put_draft_chunk(&record, dedup_key.as_deref()) .await .map_err(|e| Status::internal(format!("persist draft chunk failed: {e}")))?; accepted += 1; + accepted_chunk_ids.push(effective_id); } state.sandbox_watch_bus.notify(&sandbox_id); @@ -1373,6 +1552,7 @@ pub(super) async fn handle_submit_policy_analysis( accepted_chunks: accepted, rejected_chunks: rejected, rejection_reasons, + accepted_chunk_ids, })) } @@ -1380,17 +1560,18 @@ pub(super) async fn handle_get_draft_policy( state: &Arc, request: Request, ) -> Result, Status> { + let principal = request + .extensions() + .get::() + .cloned() + .ok_or_else(|| Status::unauthenticated("missing principal"))?; let req = request.into_inner(); if req.name.is_empty() { return Err(Status::invalid_argument("name is required")); } - let sandbox = state - .store - .get_message_by_name::(&req.name) - .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? - .ok_or_else(|| Status::not_found("sandbox not found"))?; + let sandbox = + resolve_sandbox_by_name_for_principal(state.store.as_ref(), &principal, &req.name).await?; let sandbox_id = sandbox.object_id().to_string(); let status_filter = if req.status_filter.is_empty() { @@ -1485,11 +1666,10 @@ pub(super) async fn handle_approve_draft_chunk( merge_chunk_into_policy(state.store.as_ref(), &sandbox_id, &chunk).await?; let chunk_summary = summarize_draft_chunk_rule(&chunk)?; - let now_ms = - current_time_ms().map_err(|e| Status::internal(format!("timestamp error: {e}")))?; + let now_ms = current_time_ms(); state .store - .update_draft_chunk_status(&req.chunk_id, "approved", Some(now_ms)) + .update_draft_chunk_status(&req.chunk_id, "approved", Some(now_ms), None) .await .map_err(|e| Status::internal(format!("update chunk status failed: {e}")))?; @@ -1585,11 +1765,18 @@ pub(super) async fn handle_reject_draft_chunk( ); } - let now_ms = - current_time_ms().map_err(|e| Status::internal(format!("timestamp error: {e}")))?; + let now_ms = current_time_ms(); + // Persist the reviewer's free-form `reason` into the chunk's + // `rejection_reason` field so the in-sandbox agent can read it back via + // GetDraftPolicy / policy.local and revise the proposal. + let persisted_reason = if req.reason.is_empty() { + None + } else { + Some(req.reason.as_str()) + }; state .store - .update_draft_chunk_status(&req.chunk_id, "rejected", Some(now_ms)) + .update_draft_chunk_status(&req.chunk_id, "rejected", Some(now_ms), persisted_reason) .await .map_err(|e| Status::internal(format!("update chunk status failed: {e}")))?; @@ -1667,11 +1854,10 @@ pub(super) async fn handle_approve_all_draft_chunks( last_hash = hash; let chunk_summary = summarize_draft_chunk_rule(chunk)?; - let now_ms = - current_time_ms().map_err(|e| Status::internal(format!("timestamp error: {e}")))?; + let now_ms = current_time_ms(); state .store - .update_draft_chunk_status(&chunk.id, "approved", Some(now_ms)) + .update_draft_chunk_status(&chunk.id, "approved", Some(now_ms), None) .await .map_err(|e| Status::internal(format!("update chunk status failed: {e}")))?; @@ -1814,9 +2000,12 @@ pub(super) async fn handle_undo_draft_chunk( let (version, hash) = remove_chunk_from_policy(state, &sandbox_id, &chunk).await?; + // Clear any prior rejection_reason on the way back to "pending" so an + // agent reading the chunk via policy.local cannot see a stale guidance + // string left over from a previous reject → undo round. state .store - .update_draft_chunk_status(&req.chunk_id, "pending", None) + .update_draft_chunk_status(&req.chunk_id, "pending", None, Some("")) .await .map_err(|e| Status::internal(format!("update chunk status failed: {e}")))?; @@ -2038,6 +2227,8 @@ fn draft_chunk_record_to_proto(record: &DraftChunkRecord) -> Result String { || host.starts_with("192.168.") || host == "localhost" || host.starts_with("127.") + || host.starts_with("169.254.") { notes.push(format!( "Destination '{host}' appears to be an internal/private address." @@ -2573,8 +2765,11 @@ async fn load_settings_record( .await .map_err(|e| Status::internal(format!("fetch settings failed: {e}")))?; if let Some(record) = record { - serde_json::from_slice::(&record.payload) - .map_err(|e| Status::internal(format!("decode settings payload failed: {e}"))) + let mut settings = serde_json::from_slice::(&record.payload) + .map_err(|e| Status::internal(format!("decode settings payload failed: {e}")))?; + // Populate resource_version from database record for CAS + settings.resource_version = record.resource_version; + Ok(settings) } else { Ok(StoredSettings::default()) } @@ -2586,18 +2781,43 @@ async fn save_settings_record( name: &str, settings: &StoredSettings, ) -> Result<(), Status> { + use crate::persistence::WriteCondition; + let payload = serde_json::to_vec(settings) .map_err(|e| Status::internal(format!("encode settings payload failed: {e}")))?; - store - .put( - object_type, - &uuid::Uuid::new_v4().to_string(), - name, - &payload, - None, + + let (id, condition) = if settings.resource_version == 0 { + // Create new settings (resource_version 0 means never persisted) + (uuid::Uuid::new_v4().to_string(), WriteCondition::MustCreate) + } else { + // Update existing with CAS on the version from when it was loaded + // Fetch the record to get the stable ID + let existing = store + .get_by_name(object_type, name) + .await + .map_err(|e| Status::internal(format!("fetch settings for CAS failed: {e}")))? + .ok_or_else(|| Status::not_found("settings disappeared since load"))?; + + ( + existing.id, + WriteCondition::MatchResourceVersion(settings.resource_version), ) + }; + + // Single-attempt CAS write + store + .put_if(object_type, &id, name, &payload, None, condition) .await - .map_err(|e| Status::internal(format!("persist settings failed: {e}")))?; + .map_err(|e| match e { + crate::persistence::PersistenceError::Conflict { .. } => { + Status::aborted("settings were modified concurrently; please retry") + } + crate::persistence::PersistenceError::UniqueViolation { .. } => { + Status::aborted("settings were created concurrently; please retry") + } + other => super::persistence_error_to_status(other, "persist settings"), + })?; + Ok(()) } @@ -2694,75 +2914,169 @@ fn materialize_global_settings( #[cfg(test)] mod tests { use super::*; - use crate::ServerState; - use crate::compute::new_test_runtime; - use crate::persistence::Store; - use crate::sandbox_index::SandboxIndex; - use crate::sandbox_watch::SandboxWatchBus; - use crate::supervisor_session::SupervisorSessionRegistry; - use crate::tracing_bus::TracingLogBus; - use openshell_core::Config; + use crate::auth::identity::{Identity, IdentityProvider}; + use crate::auth::principal::{ + Principal, SandboxIdentitySource, SandboxPrincipal, UserPrincipal, + }; + use crate::grpc::test_support::test_server_state; use std::collections::HashMap; use std::sync::Arc; use tonic::Code; + async fn test_store() -> Store { + Store::connect("sqlite::memory:?cache=shared") + .await + .expect("in-memory SQLite store should connect") + } + + /// Wrap a request with a user `Principal` so handler scope guards treat + /// the test caller as a CLI user. Most handler tests exercise + /// user-facing behavior and should not trip sandbox equality checks. + fn with_user(mut request: Request) -> Request { + request + .extensions_mut() + .insert(Principal::User(UserPrincipal { + identity: Identity { + subject: "test-user".to_string(), + display_name: None, + roles: vec![], + scopes: vec![], + provider: IdentityProvider::Oidc, + }, + })); + request + } + + /// Wrap a request with a sandbox `Principal` bound to `sandbox_id`. + /// Use for tests that exercise sandbox-caller code paths. + #[allow(dead_code)] + fn with_sandbox(mut request: Request, sandbox_id: &str) -> Request { + request + .extensions_mut() + .insert(Principal::Sandbox(SandboxPrincipal { + sandbox_id: sandbox_id.to_string(), + source: SandboxIdentitySource::BootstrapJwt { + issuer: "openshell-gateway:test".to_string(), + }, + trust_domain: Some("openshell".to_string()), + })); + request + } + #[test] - fn sandbox_secret_update_validation_allows_sandbox_policy_sync() { + fn sandbox_caller_update_validation_allows_sandbox_policy_sync() { let req = UpdateConfigRequest { name: "sandbox-1".to_string(), policy: Some(ProtoSandboxPolicy::default()), ..Default::default() }; - assert!(validate_sandbox_secret_update(&req).is_ok()); + assert!(validate_sandbox_caller_update(&req).is_ok()); } #[test] - fn sandbox_secret_update_validation_rejects_global_mutation() { + fn sandbox_caller_update_validation_rejects_global_mutation() { let req = UpdateConfigRequest { global: true, policy: Some(ProtoSandboxPolicy::default()), ..Default::default() }; - let err = validate_sandbox_secret_update(&req).unwrap_err(); + let err = validate_sandbox_caller_update(&req).unwrap_err(); assert_eq!(err.code(), Code::PermissionDenied); } #[test] - fn sandbox_secret_update_validation_rejects_setting_mutation() { + fn sandbox_caller_update_validation_rejects_setting_mutation() { let req = UpdateConfigRequest { name: "sandbox-1".to_string(), setting_key: "inference.model".to_string(), setting_value: Some(SettingValue { value: None }), ..Default::default() }; - let err = validate_sandbox_secret_update(&req).unwrap_err(); + let err = validate_sandbox_caller_update(&req).unwrap_err(); assert_eq!(err.code(), Code::PermissionDenied); } #[test] - fn sandbox_secret_marker_detected_from_metadata() { + fn sandbox_caller_detected_from_principal_extension() { + use crate::auth::principal::{Principal, SandboxIdentitySource, SandboxPrincipal}; let mut req = Request::new(()); - req.metadata_mut().insert( - oidc::INTERNAL_AUTH_SOURCE_HEADER, - oidc::AUTH_SOURCE_SANDBOX_SECRET.parse().unwrap(), - ); - assert!(is_sandbox_secret_authenticated(&req)); + req.extensions_mut() + .insert(Principal::Sandbox(SandboxPrincipal { + sandbox_id: "test-sandbox".to_string(), + source: SandboxIdentitySource::BootstrapJwt { + issuer: "openshell-gateway:test".to_string(), + }, + trust_domain: None, + })); + assert!(is_sandbox_caller(&req)); } - // ---- Sandbox without policy ---- + #[test] + fn user_principal_not_treated_as_sandbox_caller() { + use crate::auth::identity::{Identity, IdentityProvider}; + use crate::auth::principal::{Principal, UserPrincipal}; + let mut req = Request::new(()); + req.extensions_mut().insert(Principal::User(UserPrincipal { + identity: Identity { + subject: "alice".to_string(), + display_name: None, + roles: vec![], + scopes: vec![], + provider: IdentityProvider::Oidc, + }, + })); + assert!(!is_sandbox_caller(&req)); + } + + // ---- Sandbox IDOR guard (issue #1354) ---- #[tokio::test] - async fn sandbox_without_policy_stores_successfully() { + async fn cross_sandbox_get_sandbox_config_denied() { use openshell_core::proto::{SandboxPhase, SandboxSpec}; + let state = test_server_state().await; + // Two sandboxes; the caller is principal of A, the request body + // references B. + for (id, name) in [("sb-a", "sandbox-a"), ("sb-b", "sandbox-b")] { + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: id.to_string(), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + } + let req = with_sandbox( + Request::new(GetSandboxConfigRequest { + sandbox_id: "sb-b".to_string(), + }), + "sb-a", + ); + let err = handle_get_sandbox_config(&state, req) + .await + .expect_err("cross-sandbox call must be denied"); + assert_eq!(err.code(), Code::PermissionDenied); + } - let store = Store::connect("sqlite::memory:").await.unwrap(); - + #[tokio::test] + async fn same_sandbox_get_sandbox_config_allowed() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + let state = test_server_state().await; let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: "sb-no-policy".to_string(), - name: "no-policy-sandbox".to_string(), + id: "sb-self".to_string(), + name: "self".to_string(), created_at_ms: 1_000_000, - labels: std::collections::HashMap::new(), + labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -2771,83 +3085,308 @@ mod tests { phase: SandboxPhase::Provisioning as i32, ..Default::default() }; - store.put_message(&sandbox).await.unwrap(); - - let loaded = store - .get_message::("sb-no-policy") + state.store.put_message(&sandbox).await.unwrap(); + let req = with_sandbox( + Request::new(GetSandboxConfigRequest { + sandbox_id: "sb-self".to_string(), + }), + "sb-self", + ); + handle_get_sandbox_config(&state, req) .await - .unwrap() - .unwrap(); - assert!(loaded.spec.unwrap().policy.is_none()); + .expect("matching principal must be allowed"); } - fn test_provider(name: &str, provider_type: &str) -> Provider { - Provider { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: format!("provider-{name}"), - name: name.to_string(), - created_at_ms: 1_000_000, - labels: HashMap::new(), - }), - r#type: provider_type.to_string(), - credentials: std::iter::once(("GITHUB_TOKEN".to_string(), "ghp-test".to_string())) - .collect(), - config: HashMap::new(), + #[tokio::test] + async fn cross_sandbox_submit_policy_analysis_denied() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + let state = test_server_state().await; + for (id, name) in [("sb-a", "sandbox-a"), ("sb-b", "sandbox-b")] { + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: id.to_string(), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); } + let req = with_sandbox( + Request::new(SubmitPolicyAnalysisRequest { + name: "sandbox-b".to_string(), + ..Default::default() + }), + "sb-a", + ); + let err = handle_submit_policy_analysis(&state, req) + .await + .expect_err("cross-sandbox submit must be denied"); + assert_eq!(err.code(), Code::PermissionDenied); } - fn test_policy_with_rule(rule_name: &str, host: &str) -> ProtoSandboxPolicy { - ProtoSandboxPolicy { - network_policies: std::iter::once(( - rule_name.to_string(), - NetworkPolicyRule { - name: rule_name.to_string(), - endpoints: vec![NetworkEndpoint { - host: host.to_string(), - port: 443, - ..Default::default() - }], + #[tokio::test] + async fn cross_sandbox_get_draft_policy_denied() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + let state = test_server_state().await; + for (id, name) in [("sb-a", "sandbox-a"), ("sb-b", "sandbox-b")] { + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: id.to_string(), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, ..Default::default() - }, - )) - .collect(), - ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); } + let req = with_sandbox( + Request::new(GetDraftPolicyRequest { + name: "sandbox-b".to_string(), + status_filter: String::new(), + }), + "sb-a", + ); + let err = handle_get_draft_policy(&state, req) + .await + .expect_err("cross-sandbox draft read must be denied"); + assert_eq!(err.code(), Code::PermissionDenied); } - fn test_sandbox( - id: &str, - name: &str, - policy: ProtoSandboxPolicy, - providers: Vec, - ) -> Sandbox { - use openshell_core::proto::{SandboxPhase, SandboxSpec}; - - Sandbox { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: id.to_string(), - name: name.to_string(), - created_at_ms: 1_000_000, - labels: HashMap::new(), - }), - spec: Some(SandboxSpec { - policy: Some(policy), - providers, + #[tokio::test] + async fn sandbox_update_config_missing_name_returns_permission_denied() { + let state = test_server_state().await; + let req = with_sandbox( + Request::new(UpdateConfigRequest { + name: "missing-sandbox".to_string(), + policy: Some(ProtoSandboxPolicy::default()), ..Default::default() }), - phase: SandboxPhase::Ready as i32, - ..Default::default() - } + "sb-a", + ); + + let err = handle_update_config(&state, req) + .await + .expect_err("missing name must not leak existence to sandbox callers"); + assert_eq!(err.code(), Code::PermissionDenied); } - async fn enable_providers_v2(state: &Arc) { - let global_settings = StoredSettings { - revision: 1, - settings: std::iter::once(( - settings::PROVIDERS_V2_ENABLED_KEY.to_string(), - StoredSettingValue::Bool(true), - )) + #[tokio::test] + async fn sandbox_submit_policy_analysis_missing_name_returns_permission_denied() { + let state = test_server_state().await; + let req = with_sandbox( + Request::new(SubmitPolicyAnalysisRequest { + name: "missing-sandbox".to_string(), + ..Default::default() + }), + "sb-a", + ); + + let err = handle_submit_policy_analysis(&state, req) + .await + .expect_err("missing name must not leak existence to sandbox callers"); + assert_eq!(err.code(), Code::PermissionDenied); + } + + #[tokio::test] + async fn sandbox_get_draft_policy_missing_name_returns_permission_denied() { + let state = test_server_state().await; + let req = with_sandbox( + Request::new(GetDraftPolicyRequest { + name: "missing-sandbox".to_string(), + status_filter: String::new(), + }), + "sb-a", + ); + + let err = handle_get_draft_policy(&state, req) + .await + .expect_err("missing name must not leak existence to sandbox callers"); + assert_eq!(err.code(), Code::PermissionDenied); + } + + #[tokio::test] + async fn user_principal_can_read_any_sandbox_config() { + // RBAC was the user gate; the IDOR guard must NOT trip for users. + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + let state = test_server_state().await; + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-x".to_string(), + name: "x".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + let req = with_user(Request::new(GetSandboxConfigRequest { + sandbox_id: "sb-x".to_string(), + })); + handle_get_sandbox_config(&state, req) + .await + .expect("user principal must succeed"); + } + + #[tokio::test] + async fn log_stream_scope_rejects_sandbox_id_change_after_validation() { + let state = test_server_state().await; + for id in ["sb-a", "sb-b"] { + let sandbox = test_sandbox(id, id, ProtoSandboxPolicy::default(), vec![]); + state.store.put_message(&sandbox).await.unwrap(); + } + let req = with_sandbox(Request::new(()), "sb-a"); + let principal = req.extensions().get::().unwrap().clone(); + let mut validated = None; + + ensure_log_stream_sandbox_scope(&state, &principal, "sb-a", &mut validated) + .await + .expect("first frame should validate"); + let err = ensure_log_stream_sandbox_scope(&state, &principal, "sb-b", &mut validated) + .await + .expect_err("later frame must not switch sandbox ids"); + + assert_eq!(err.code(), Code::PermissionDenied); + } + + #[tokio::test] + async fn log_stream_scope_rejects_missing_sandbox() { + let state = test_server_state().await; + let req = with_sandbox(Request::new(()), "sb-a"); + let principal = req.extensions().get::().unwrap().clone(); + let mut validated = None; + + let err = ensure_log_stream_sandbox_scope(&state, &principal, "sb-a", &mut validated) + .await + .expect_err("missing sandbox must not validate"); + + assert_eq!(err.code(), Code::NotFound); + } + + // ---- Sandbox without policy ---- + + #[tokio::test] + async fn sandbox_without_policy_stores_successfully() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + + let store = test_store().await; + + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-no-policy".to_string(), + name: "no-policy-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + store.put_message(&sandbox).await.unwrap(); + + let loaded = store + .get_message::("sb-no-policy") + .await + .unwrap() + .unwrap(); + assert!(loaded.spec.unwrap().policy.is_none()); + } + + fn test_provider(name: &str, provider_type: &str) -> Provider { + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: format!("provider-{name}"), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: provider_type.to_string(), + credentials: std::iter::once(("GITHUB_TOKEN".to_string(), "ghp-test".to_string())) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + } + } + + fn test_policy_with_rule(rule_name: &str, host: &str) -> ProtoSandboxPolicy { + ProtoSandboxPolicy { + network_policies: std::iter::once(( + rule_name.to_string(), + NetworkPolicyRule { + name: rule_name.to_string(), + endpoints: vec![NetworkEndpoint { + host: host.to_string(), + port: 443, + ..Default::default() + }], + ..Default::default() + }, + )) + .collect(), + ..Default::default() + } + } + + fn test_sandbox( + id: &str, + name: &str, + policy: ProtoSandboxPolicy, + providers: Vec, + ) -> Sandbox { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + + Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: id.to_string(), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: Some(policy), + providers, + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + ..Default::default() + } + } + + async fn enable_providers_v2(state: &Arc) { + let global_settings = StoredSettings { + revision: 1, + settings: std::iter::once(( + settings::PROVIDERS_V2_ENABLED_KEY.to_string(), + StoredSettingValue::Bool(true), + )) .collect(), + ..Default::default() }; save_global_settings(state.store.as_ref(), &global_settings) .await @@ -2857,9 +3396,9 @@ mod tests { async fn get_sandbox_policy(state: &Arc, sandbox_id: &str) -> ProtoSandboxPolicy { handle_get_sandbox_config( state, - Request::new(GetSandboxConfigRequest { + with_user(Request::new(GetSandboxConfigRequest { sandbox_id: sandbox_id.to_string(), - }), + })), ) .await .unwrap() @@ -2870,7 +3409,7 @@ mod tests { #[tokio::test] async fn provider_policy_layers_skip_unknown_provider_types() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; store .put_message(&test_provider("custom-provider", "custom")) .await @@ -2885,7 +3424,7 @@ mod tests { #[tokio::test] async fn provider_policy_layers_skip_custom_profile_for_legacy_provider_type() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; store .put_message(&test_provider("custom-provider", "generic")) .await @@ -2897,6 +3436,7 @@ mod tests { name: "generic".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), profile: Some(openshell_core::proto::ProviderProfile { id: "generic".to_string(), @@ -2911,6 +3451,7 @@ mod tests { }], binaries: Vec::new(), inference_capable: false, + discovery: None, }), }) .await @@ -2926,7 +3467,7 @@ mod tests { #[tokio::test] #[allow(deprecated)] async fn provider_policy_layers_include_custom_provider_profiles() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; store .put_message(&test_provider("work-custom", "custom-api")) .await @@ -2938,6 +3479,7 @@ mod tests { name: "custom-api".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), profile: Some(openshell_core::proto::ProviderProfile { id: "custom-api".to_string(), @@ -2966,6 +3508,7 @@ mod tests { harness: true, }], inference_capable: false, + discovery: None, }), }) .await @@ -2988,7 +3531,7 @@ mod tests { #[tokio::test] async fn provider_policy_layers_normalize_custom_provider_type_ids() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; store .put_message(&test_provider("work-custom", " Custom-API ")) .await @@ -3000,6 +3543,7 @@ mod tests { name: "custom-api".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), profile: Some(openshell_core::proto::ProviderProfile { id: "custom-api".to_string(), @@ -3014,6 +3558,7 @@ mod tests { }], binaries: Vec::new(), inference_capable: false, + discovery: None, }), }) .await @@ -3029,7 +3574,7 @@ mod tests { #[tokio::test] async fn provider_policy_layers_include_known_provider_profiles() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; store .put_message(&test_provider("work-github", "github")) .await @@ -3041,7 +3586,7 @@ mod tests { assert_eq!(layers.len(), 1); assert_eq!(layers[0].rule_name, "_provider_work_github"); - assert_eq!(layers[0].rule.endpoints.len(), 2); + assert_eq!(layers[0].rule.endpoints.len(), 3); assert!( layers[0] .rule @@ -3049,6 +3594,23 @@ mod tests { .iter() .any(|endpoint| endpoint.host == "api.github.com") ); + assert!( + layers[0].rule.endpoints.iter().any(|endpoint| { + endpoint.host == "api.github.com" + && endpoint.protocol == "graphql" + && endpoint.path == "/graphql" + && endpoint.access == "read-only" + }), + "github provider policy should include read-only GraphQL endpoint" + ); + assert!( + layers[0] + .rule + .endpoints + .iter() + .all(|endpoint| endpoint.access == "read-only"), + "github provider policy should be read-only by default" + ); } #[test] @@ -3297,9 +3859,9 @@ mod tests { let legacy_env = handle_get_sandbox_provider_environment( &state, - Request::new(GetSandboxProviderEnvironmentRequest { + with_user(Request::new(GetSandboxProviderEnvironmentRequest { sandbox_id: "sb-provider-env".to_string(), - }), + })), ) .await .unwrap() @@ -3309,9 +3871,9 @@ mod tests { enable_providers_v2(&state).await; let v2_env = handle_get_sandbox_provider_environment( &state, - Request::new(GetSandboxProviderEnvironmentRequest { + with_user(Request::new(GetSandboxProviderEnvironmentRequest { sandbox_id: "sb-provider-env".to_string(), - }), + })), ) .await .unwrap() @@ -3323,92 +3885,426 @@ mod tests { } #[tokio::test] - async fn global_policy_suppresses_provider_profile_layers_when_v2_enabled() { - use openshell_core::proto::{ - GetSandboxConfigRequest, NetworkEndpoint, NetworkPolicyRule, SandboxPhase, - SandboxPolicy, SandboxSpec, - }; + async fn provider_env_revision_changes_when_attached_provider_record_changes() { + use openshell_core::proto::GetSandboxProviderEnvironmentRequest; + use std::time::Duration; let state = test_server_state().await; + let mut provider = test_provider("work-github", "github"); + state.store.put_message(&provider).await.unwrap(); state .store - .put_message(&test_provider("work-github", "github")) + .put_message(&test_sandbox( + "sb-provider-revision", + "provider-revision", + test_policy_with_rule("sandbox_only", "sandbox.example.com"), + vec!["work-github".to_string()], + )) .await .unwrap(); - let sandbox_policy = SandboxPolicy { - network_policies: std::iter::once(( - "sandbox_only".to_string(), - NetworkPolicyRule { - name: "sandbox_only".to_string(), - endpoints: vec![NetworkEndpoint { - host: "sandbox.example.com".to_string(), - port: 443, - ..Default::default() - }], - ..Default::default() - }, - )) - .collect(), - ..Default::default() - }; - let sandbox = Sandbox { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: "sb-global-profile".to_string(), - name: "global-profile-sandbox".to_string(), - created_at_ms: 1_000_000, - labels: HashMap::new(), - }), - spec: Some(SandboxSpec { - policy: Some(sandbox_policy), - providers: vec!["work-github".to_string()], - ..Default::default() - }), - phase: SandboxPhase::Ready as i32, - ..Default::default() - }; - state.store.put_message(&sandbox).await.unwrap(); + let first = handle_get_sandbox_provider_environment( + &state, + with_user(Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-provider-revision".to_string(), + })), + ) + .await + .unwrap() + .into_inner(); - let global_policy = SandboxPolicy { - network_policies: std::iter::once(( - "global_only".to_string(), - NetworkPolicyRule { - name: "global_only".to_string(), - endpoints: vec![NetworkEndpoint { - host: "global.example.com".to_string(), - port: 443, - ..Default::default() - }], - ..Default::default() - }, - )) - .collect(), - ..Default::default() - }; - let global_settings = StoredSettings { - revision: 1, - settings: [ - ( - settings::PROVIDERS_V2_ENABLED_KEY.to_string(), - StoredSettingValue::Bool(true), - ), - ( - POLICY_SETTING_KEY.to_string(), - StoredSettingValue::Bytes(hex::encode(global_policy.encode_to_vec())), - ), - ] - .into_iter() - .collect(), - }; - save_global_settings(state.store.as_ref(), &global_settings) - .await - .unwrap(); + tokio::time::sleep(Duration::from_millis(2)).await; + provider + .credentials + .insert("GITHUB_TOKEN".to_string(), "rotated".to_string()); + state.store.put_message(&provider).await.unwrap(); - let response = handle_get_sandbox_config( + let second = handle_get_sandbox_provider_environment( &state, - Request::new(GetSandboxConfigRequest { - sandbox_id: "sb-global-profile".to_string(), - }), + with_user(Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-provider-revision".to_string(), + })), + ) + .await + .unwrap() + .into_inner(); + + assert_ne!( + first.provider_env_revision, second.provider_env_revision, + "provider object updates must trigger sandbox credential refresh" + ); + assert_eq!( + second.environment.get("GITHUB_TOKEN"), + Some(&"rotated".to_string()) + ); + } + + #[tokio::test] + async fn sandbox_config_and_provider_env_follow_attached_provider_lifecycle() { + use crate::grpc::sandbox::{ + handle_attach_sandbox_provider, handle_detach_sandbox_provider, + }; + use openshell_core::proto::{ + AttachSandboxProviderRequest, DetachSandboxProviderRequest, + GetSandboxProviderEnvironmentRequest, + }; + + let state = test_server_state().await; + enable_providers_v2(&state).await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox( + "sb-attach-lifecycle", + "attach-lifecycle", + test_policy_with_rule("sandbox_only", "sandbox.example.com"), + Vec::new(), + )) + .await + .unwrap(); + + let baseline_policy = get_sandbox_policy(&state, "sb-attach-lifecycle").await; + assert!( + !baseline_policy + .network_policies + .contains_key("_provider_work_github") + ); + let baseline_env = handle_get_sandbox_provider_environment( + &state, + with_user(Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-attach-lifecycle".to_string(), + })), + ) + .await + .unwrap() + .into_inner(); + + handle_attach_sandbox_provider( + &state, + with_user(Request::new(AttachSandboxProviderRequest { + sandbox_name: "attach-lifecycle".to_string(), + provider_name: "work-github".to_string(), + expected_resource_version: 0, + })), + ) + .await + .unwrap(); + + let attached_policy = get_sandbox_policy(&state, "sb-attach-lifecycle").await; + assert!( + attached_policy + .network_policies + .contains_key("_provider_work_github") + ); + + let attached_env = handle_get_sandbox_provider_environment( + &state, + with_user(Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-attach-lifecycle".to_string(), + })), + ) + .await + .unwrap() + .into_inner(); + assert_ne!( + baseline_env.provider_env_revision, + attached_env.provider_env_revision + ); + assert_eq!( + attached_env.environment.get("GITHUB_TOKEN"), + Some(&"ghp-test".to_string()) + ); + + handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "attach-lifecycle".to_string(), + provider_name: "work-github".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap(); + + let detached_policy = get_sandbox_policy(&state, "sb-attach-lifecycle").await; + assert!( + !detached_policy + .network_policies + .contains_key("_provider_work_github") + ); + + let detached_env = handle_get_sandbox_provider_environment( + &state, + with_user(Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-attach-lifecycle".to_string(), + })), + ) + .await + .unwrap() + .into_inner(); + assert_ne!( + attached_env.provider_env_revision, + detached_env.provider_env_revision + ); + assert!(!detached_env.environment.contains_key("GITHUB_TOKEN")); + } + + #[tokio::test] + #[allow(deprecated)] + async fn custom_imported_profile_policy_and_env_follow_attach_detach_lifecycle() { + use crate::grpc::provider::handle_import_provider_profiles; + use crate::grpc::sandbox::{ + handle_attach_sandbox_provider, handle_detach_sandbox_provider, + }; + use openshell_core::proto::{ + AttachSandboxProviderRequest, DetachSandboxProviderRequest, + GetSandboxProviderEnvironmentRequest, ImportProviderProfilesRequest, NetworkBinary, + ProviderProfile, ProviderProfileCategory, ProviderProfileCredential, + ProviderProfileImportItem, + }; + + let state = test_server_state().await; + enable_providers_v2(&state).await; + handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + source: "custom-api.yaml".to_string(), + profile: Some(ProviderProfile { + id: "custom-api".to_string(), + display_name: "Custom API".to_string(), + description: String::new(), + category: ProviderProfileCategory::Other as i32, + credentials: vec![ProviderProfileCredential { + name: "api_key".to_string(), + env_vars: vec!["CUSTOM_API_KEY".to_string()], + auth_style: "bearer".to_string(), + header_name: "authorization".to_string(), + required: true, + ..Default::default() + }], + endpoints: vec![NetworkEndpoint { + host: "api.custom.example".to_string(), + port: 443, + protocol: "rest".to_string(), + rules: vec![L7Rule { + allow: Some(openshell_core::proto::L7Allow { + method: "GET".to_string(), + path: "/v1/**".to_string(), + ..Default::default() + }), + }], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/custom".to_string(), + harness: true, + }], + inference_capable: false, + discovery: None, + }), + }], + }), + ) + .await + .unwrap(); + + let mut provider = test_provider("work-custom", "custom-api"); + provider.credentials = + std::iter::once(("CUSTOM_API_KEY".to_string(), "custom-secret".to_string())).collect(); + state.store.put_message(&provider).await.unwrap(); + state + .store + .put_message(&test_sandbox( + "sb-custom-attach-lifecycle", + "custom-attach-lifecycle", + test_policy_with_rule("sandbox_only", "sandbox.example.com"), + Vec::new(), + )) + .await + .unwrap(); + + let baseline_policy = get_sandbox_policy(&state, "sb-custom-attach-lifecycle").await; + assert!( + !baseline_policy + .network_policies + .contains_key("_provider_work_custom") + ); + let baseline_env = handle_get_sandbox_provider_environment( + &state, + with_user(Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-custom-attach-lifecycle".to_string(), + })), + ) + .await + .unwrap() + .into_inner(); + + handle_attach_sandbox_provider( + &state, + with_user(Request::new(AttachSandboxProviderRequest { + sandbox_name: "custom-attach-lifecycle".to_string(), + provider_name: "work-custom".to_string(), + expected_resource_version: 0, + })), + ) + .await + .unwrap(); + + let attached_policy = get_sandbox_policy(&state, "sb-custom-attach-lifecycle").await; + let custom_rule = attached_policy + .network_policies + .get("_provider_work_custom") + .expect("custom provider rule should be composed after attach"); + assert_eq!(custom_rule.endpoints[0].host, "api.custom.example"); + assert_eq!(custom_rule.endpoints[0].protocol, "rest"); + assert_eq!(custom_rule.endpoints[0].rules.len(), 1); + assert_eq!(custom_rule.binaries[0].path, "/usr/bin/custom"); + + let attached_env = handle_get_sandbox_provider_environment( + &state, + with_user(Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-custom-attach-lifecycle".to_string(), + })), + ) + .await + .unwrap() + .into_inner(); + assert_ne!( + baseline_env.provider_env_revision, + attached_env.provider_env_revision + ); + assert_eq!( + attached_env.environment.get("CUSTOM_API_KEY"), + Some(&"custom-secret".to_string()) + ); + + handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "custom-attach-lifecycle".to_string(), + provider_name: "work-custom".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap(); + + let detached_policy = get_sandbox_policy(&state, "sb-custom-attach-lifecycle").await; + assert!( + !detached_policy + .network_policies + .contains_key("_provider_work_custom") + ); + let detached_env = handle_get_sandbox_provider_environment( + &state, + with_user(Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-custom-attach-lifecycle".to_string(), + })), + ) + .await + .unwrap() + .into_inner(); + assert_ne!( + attached_env.provider_env_revision, + detached_env.provider_env_revision + ); + assert!(!detached_env.environment.contains_key("CUSTOM_API_KEY")); + } + + #[tokio::test] + async fn global_policy_suppresses_provider_profile_layers_when_v2_enabled() { + use openshell_core::proto::{ + GetSandboxConfigRequest, NetworkEndpoint, NetworkPolicyRule, SandboxPhase, + SandboxPolicy, SandboxSpec, + }; + + let state = test_server_state().await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + + let sandbox_policy = SandboxPolicy { + network_policies: std::iter::once(( + "sandbox_only".to_string(), + NetworkPolicyRule { + name: "sandbox_only".to_string(), + endpoints: vec![NetworkEndpoint { + host: "sandbox.example.com".to_string(), + port: 443, + ..Default::default() + }], + ..Default::default() + }, + )) + .collect(), + ..Default::default() + }; + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-global-profile".to_string(), + name: "global-profile-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: Some(sandbox_policy), + providers: vec!["work-github".to_string()], + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + let global_policy = SandboxPolicy { + network_policies: std::iter::once(( + "global_only".to_string(), + NetworkPolicyRule { + name: "global_only".to_string(), + endpoints: vec![NetworkEndpoint { + host: "global.example.com".to_string(), + port: 443, + ..Default::default() + }], + ..Default::default() + }, + )) + .collect(), + ..Default::default() + }; + let global_settings = StoredSettings { + revision: 1, + settings: [ + ( + settings::PROVIDERS_V2_ENABLED_KEY.to_string(), + StoredSettingValue::Bool(true), + ), + ( + POLICY_SETTING_KEY.to_string(), + StoredSettingValue::Bytes(hex::encode(global_policy.encode_to_vec())), + ), + ] + .into_iter() + .collect(), + ..Default::default() + }; + save_global_settings(state.store.as_ref(), &global_settings) + .await + .unwrap(); + + let response = handle_get_sandbox_config( + &state, + with_user(Request::new(GetSandboxConfigRequest { + sandbox_id: "sb-global-profile".to_string(), + })), ) .await .unwrap() @@ -3437,7 +4333,7 @@ mod tests { async fn sandbox_policy_backfill_on_update_when_no_baseline() { use openshell_core::proto::{FilesystemPolicy, LandlockPolicy, SandboxPhase, SandboxSpec}; - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { @@ -3445,6 +4341,7 @@ mod tests { name: "backfill-sandbox".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -3493,27 +4390,6 @@ mod tests { assert_eq!(policy.process.unwrap().run_as_user, "sandbox"); } - async fn test_server_state() -> Arc { - let store = Arc::new( - Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(), - ); - let compute = new_test_runtime(store.clone()).await; - Arc::new(ServerState::new( - Config::new(None) - .with_database_url("sqlite::memory:?cache=shared") - .with_ssh_handshake_secret("test-secret"), - store, - compute, - SandboxIndex::new(), - SandboxWatchBus::new(), - TracingLogBus::new(), - Arc::new(SupervisorSessionRegistry::new()), - None, - )) - } - #[tokio::test] async fn draft_chunk_handler_lifecycle_round_trip() { use openshell_core::proto::{ @@ -3527,6 +4403,7 @@ mod tests { name: "draft-flow".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -3553,7 +4430,7 @@ mod tests { let submit = handle_submit_policy_analysis( &state, - Request::new(SubmitPolicyAnalysisRequest { + with_user(Request::new(SubmitPolicyAnalysisRequest { name: sandbox_name.clone(), proposed_chunks: vec![PolicyChunk { rule_name: "allow_example".to_string(), @@ -3567,20 +4444,22 @@ mod tests { ..Default::default() }], ..Default::default() - }), + })), ) .await .unwrap() .into_inner(); assert_eq!(submit.accepted_chunks, 1); assert_eq!(submit.rejected_chunks, 0); + assert_eq!(submit.accepted_chunk_ids.len(), 1); + assert!(!submit.accepted_chunk_ids[0].is_empty()); let draft_policy = handle_get_draft_policy( &state, - Request::new(GetDraftPolicyRequest { + with_user(Request::new(GetDraftPolicyRequest { name: sandbox_name.clone(), status_filter: String::new(), - }), + })), ) .await .unwrap() @@ -3647,10 +4526,10 @@ mod tests { let draft_policy_after_undo = handle_get_draft_policy( &state, - Request::new(GetDraftPolicyRequest { + with_user(Request::new(GetDraftPolicyRequest { name: sandbox_name.clone(), status_filter: String::new(), - }), + })), ) .await .unwrap() @@ -3699,10 +4578,10 @@ mod tests { let draft_policy_after_clear = handle_get_draft_policy( &state, - Request::new(GetDraftPolicyRequest { + with_user(Request::new(GetDraftPolicyRequest { name: sandbox_name.clone(), status_filter: String::new(), - }), + })), ) .await .unwrap() @@ -3711,12 +4590,435 @@ mod tests { let history_after_clear = handle_get_draft_history( &state, - Request::new(GetDraftHistoryRequest { name: sandbox_name }), + Request::new(GetDraftHistoryRequest { name: sandbox_name }), + ) + .await + .unwrap() + .into_inner(); + assert!(history_after_clear.entries.is_empty()); + } + + /// A reviewer's free-form rejection reason must round-trip through + /// persistence and surface on the chunk via `GetDraftPolicy`, so the + /// in-sandbox agent can read the guidance and redraft. The MVP-v2 agent + /// feedback loop hangs off this guarantee. + #[tokio::test] + async fn reject_with_reason_persists_into_chunk_for_agent_readback() { + use openshell_core::proto::{NetworkBinary, NetworkEndpoint, SandboxPhase, SandboxSpec}; + + let state = test_server_state().await; + let sandbox_name = "agent-feedback-loop".to_string(); + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-feedback".to_string(), + name: sandbox_name.clone(), + created_at_ms: 1_000_000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + let proposed_rule = NetworkPolicyRule { + name: "allow_example".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + let submit = handle_submit_policy_analysis( + &state, + with_user(Request::new(SubmitPolicyAnalysisRequest { + name: sandbox_name.clone(), + proposed_chunks: vec![PolicyChunk { + rule_name: "allow_example".to_string(), + proposed_rule: Some(proposed_rule), + rationale: "agent intent".to_string(), + ..Default::default() + }], + ..Default::default() + })), + ) + .await + .unwrap() + .into_inner(); + let chunk_id = submit.accepted_chunk_ids[0].clone(); + + let guidance = "scope to docs/ paths only, not all repo contents"; + handle_reject_draft_chunk( + &state, + Request::new(RejectDraftChunkRequest { + name: sandbox_name.clone(), + chunk_id: chunk_id.clone(), + reason: guidance.to_string(), + }), + ) + .await + .unwrap(); + + let draft = handle_get_draft_policy( + &state, + with_user(Request::new(GetDraftPolicyRequest { + name: sandbox_name, + status_filter: String::new(), + })), + ) + .await + .unwrap() + .into_inner(); + let rejected = draft + .chunks + .iter() + .find(|c| c.id == chunk_id) + .expect("rejected chunk should still be visible"); + assert_eq!(rejected.status, "rejected"); + assert_eq!( + rejected.rejection_reason, guidance, + "reviewer's free-form reason must round-trip into the chunk for agent readback" + ); + // validation_result is unpopulated until the prover runs (#1097). + assert!(rejected.validation_result.is_empty()); + } + + /// Two agent-authored proposals targeting the same host/port/binary must + /// each persist as a distinct chunk. The mechanistic-mode dedup + /// (`host|port|binary`) is wrong for agent intent: the redraft loop + /// relies on the second submission landing as its own chunk so the + /// reviewer can decide on it independently. Regression test for the bug + /// where Flow B of `e2e/policy-advisor/wait-smoke.sh` saw a fresh + /// `chunk_id` returned from submit but `RejectDraftChunk` could not + /// find it because the SQL ON CONFLICT had silently kept the prior row. + #[tokio::test] + async fn agent_authored_submits_for_same_endpoint_do_not_dedup() { + use openshell_core::proto::{NetworkBinary, NetworkEndpoint, SandboxPhase, SandboxSpec}; + + let state = test_server_state().await; + let sandbox_name = "redraft-loop".to_string(); + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-redraft".to_string(), + name: sandbox_name.clone(), + created_at_ms: 1_000_000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + // Two proposals with the same host|port|binary (so the mechanistic + // dedup_key would collide) but distinct rule names and L7 paths — + // proves the gateway distinguishes them by intentional act and not + // by payload hash. If a future dedup-by-payload-hash regression + // landed, this test would still fail because the chunk_ids would + // still need to be distinct. + let make_rule = |rule_name: &str| NetworkPolicyRule { + name: rule_name.to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + let submit_one = |rule_name: &str, rule: NetworkPolicyRule| { + let state = state.clone(); + let sandbox_name = sandbox_name.clone(); + let rule_name = rule_name.to_string(); + async move { + handle_submit_policy_analysis( + &state, + with_user(Request::new(SubmitPolicyAnalysisRequest { + name: sandbox_name, + analysis_mode: "agent_authored".to_string(), + proposed_chunks: vec![PolicyChunk { + rule_name, + proposed_rule: Some(rule), + ..Default::default() + }], + ..Default::default() + })), + ) + .await + .unwrap() + .into_inner() + } + }; + + let first = submit_one("allow_first", make_rule("allow_first")).await; + let second = submit_one("allow_second", make_rule("allow_second")).await; + + assert_eq!(first.accepted_chunk_ids.len(), 1); + assert_eq!(second.accepted_chunk_ids.len(), 1); + assert_ne!( + first.accepted_chunk_ids[0], second.accepted_chunk_ids[0], + "second agent-authored proposal for the same endpoint must get its own chunk_id, not dedup" + ); + + let draft = handle_get_draft_policy( + &state, + with_user(Request::new(GetDraftPolicyRequest { + name: sandbox_name.clone(), + status_filter: String::new(), + })), + ) + .await + .unwrap() + .into_inner(); + let ids: Vec<_> = draft.chunks.iter().map(|c| c.id.as_str()).collect(); + assert!( + ids.contains(&first.accepted_chunk_ids[0].as_str()) + && ids.contains(&second.accepted_chunk_ids[0].as_str()), + "both reported chunk_ids must be persisted; got: {ids:?}" + ); + + // Reject the second by id to prove the gateway can actually find + // what the submit response claimed to have created — this is the + // exact path the smoke test exercises end-to-end. + handle_reject_draft_chunk( + &state, + Request::new(RejectDraftChunkRequest { + name: sandbox_name, + chunk_id: second.accepted_chunk_ids[0].clone(), + reason: "redraft test".to_string(), + }), + ) + .await + .expect("reject must find the chunk_id the submit response just promised"); + } + + /// Complement to the agent-authored test above: mechanistic-mode + /// submissions for the same endpoint must STILL dedup. The + /// observation-driven path relies on N denials folding into one chunk + /// instead of N near-identical chunks. Lock the behavior in so a future + /// change to the dedup branch doesn't accidentally also turn off + /// mechanistic dedup. + #[tokio::test] + async fn mechanistic_submits_for_same_endpoint_dedup_into_one_chunk() { + use openshell_core::proto::{NetworkBinary, NetworkEndpoint, SandboxPhase, SandboxSpec}; + + let state = test_server_state().await; + let sandbox_name = "mechanistic-dedup".to_string(); + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-mech-dedup".to_string(), + name: sandbox_name.clone(), + created_at_ms: 1_000_000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + let proposed_rule = NetworkPolicyRule { + name: "allow_example".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + let submit_one = || { + let state = state.clone(); + let sandbox_name = sandbox_name.clone(); + let rule = proposed_rule.clone(); + async move { + handle_submit_policy_analysis( + &state, + with_user(Request::new(SubmitPolicyAnalysisRequest { + name: sandbox_name, + analysis_mode: "mechanistic".to_string(), + proposed_chunks: vec![PolicyChunk { + rule_name: "allow_example".to_string(), + proposed_rule: Some(rule), + ..Default::default() + }], + ..Default::default() + })), + ) + .await + .unwrap() + .into_inner() + } + }; + let first = submit_one().await; + let second = submit_one().await; + assert_eq!(first.accepted_chunk_ids.len(), 1); + assert_eq!(second.accepted_chunk_ids.len(), 1); + + let draft = handle_get_draft_policy( + &state, + with_user(Request::new(GetDraftPolicyRequest { + name: sandbox_name, + status_filter: String::new(), + })), + ) + .await + .unwrap() + .into_inner(); + assert_eq!( + draft.chunks.len(), + 1, + "two mechanistic submits for the same host|port|binary must dedup; got {} chunks", + draft.chunks.len() + ); + // Both submits must report the same effective id — the id of the + // one row that actually exists in the DB. Before the dedup fix the + // second submit would return a freshly-generated UUID that was + // never persisted; this assertion locks the contract down. + let stored_id = &draft.chunks[0].id; + assert_eq!( + &first.accepted_chunk_ids[0], stored_id, + "first submit's reported id must match the stored chunk" + ); + assert_eq!( + &second.accepted_chunk_ids[0], stored_id, + "second submit must report the same id as the first (dedup fold-in), not a fresh UUID" + ); + } + + /// Undo of an approve must clear any `rejection_reason` left over from a + /// prior reject. Without this, the in-sandbox agent reading chunks via + /// `policy.local` cannot tell "pending and never rejected" from "pending + /// but previously rejected with this stale guidance." The only path that + /// lands a non-empty reason on a pending chunk is reject → re-approve → + /// undo, so the test walks that sequence. + #[tokio::test] + async fn undo_after_reject_clears_stale_rejection_reason() { + use openshell_core::proto::{NetworkBinary, NetworkEndpoint, SandboxPhase, SandboxSpec}; + + let state = test_server_state().await; + let sandbox_name = "undo-clears-reason".to_string(); + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-undo-clears".to_string(), + name: sandbox_name.clone(), + created_at_ms: 1_000_000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + let proposed_rule = NetworkPolicyRule { + name: "allow_example".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + let submit = handle_submit_policy_analysis( + &state, + with_user(Request::new(SubmitPolicyAnalysisRequest { + name: sandbox_name.clone(), + proposed_chunks: vec![PolicyChunk { + rule_name: "allow_example".to_string(), + proposed_rule: Some(proposed_rule), + ..Default::default() + }], + ..Default::default() + })), + ) + .await + .unwrap() + .into_inner(); + let chunk_id = submit.accepted_chunk_ids[0].clone(); + + handle_reject_draft_chunk( + &state, + Request::new(RejectDraftChunkRequest { + name: sandbox_name.clone(), + chunk_id: chunk_id.clone(), + reason: "scope too broad".to_string(), + }), + ) + .await + .unwrap(); + + handle_approve_draft_chunk( + &state, + Request::new(ApproveDraftChunkRequest { + name: sandbox_name.clone(), + chunk_id: chunk_id.clone(), + }), + ) + .await + .unwrap(); + + handle_undo_draft_chunk( + &state, + Request::new(UndoDraftChunkRequest { + name: sandbox_name.clone(), + chunk_id: chunk_id.clone(), + }), + ) + .await + .unwrap(); + + let draft = handle_get_draft_policy( + &state, + with_user(Request::new(GetDraftPolicyRequest { + name: sandbox_name, + status_filter: String::new(), + })), ) .await .unwrap() .into_inner(); - assert!(history_after_clear.entries.is_empty()); + let restored = draft + .chunks + .iter() + .find(|c| c.id == chunk_id) + .expect("chunk should still be present after undo"); + assert_eq!(restored.status, "pending"); + assert!( + restored.rejection_reason.is_empty(), + "undo must clear stale rejection_reason; got: {:?}", + restored.rejection_reason + ); } #[tokio::test] @@ -3730,6 +5032,7 @@ mod tests { name: "draft-owner".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -3744,6 +5047,7 @@ mod tests { name: "draft-other".to_string(), created_at_ms: 1_000_001, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -3770,7 +5074,7 @@ mod tests { handle_submit_policy_analysis( &state, - Request::new(SubmitPolicyAnalysisRequest { + with_user(Request::new(SubmitPolicyAnalysisRequest { name: sandbox_a.object_name().to_string(), proposed_chunks: vec![PolicyChunk { rule_name: "allow_example".to_string(), @@ -3784,17 +5088,17 @@ mod tests { ..Default::default() }], ..Default::default() - }), + })), ) .await .unwrap(); let draft_policy = handle_get_draft_policy( &state, - Request::new(GetDraftPolicyRequest { + with_user(Request::new(GetDraftPolicyRequest { name: sandbox_a.object_name().to_string(), status_filter: String::new(), - }), + })), ) .await .unwrap() @@ -3927,13 +5231,69 @@ mod tests { ); } + #[test] + fn summarize_cli_policy_merge_op_formats_websocket_credential_rewrite() { + let operation = PolicyMergeOp::AddRule { + rule_name: "realtime_api".to_string(), + rule: NetworkPolicyRule { + name: "realtime_api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + protocol: "websocket".to_string(), + access: "read-write".to_string(), + enforcement: "enforce".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + }; + + assert_eq!( + summarize_cli_policy_merge_op(&operation), + "add-endpoint realtime_api endpoints=[realtime.example.com:443 protocol=websocket access=read-write enforcement=enforce websocket_credential_rewrite=true] binaries=[/usr/bin/node]" + ); + } + + #[test] + fn summarize_cli_policy_merge_op_formats_request_body_credential_rewrite() { + let operation = PolicyMergeOp::AddRule { + rule_name: "slack_api".to_string(), + rule: NetworkPolicyRule { + name: "slack_api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + protocol: "rest".to_string(), + access: "read-write".to_string(), + enforcement: "enforce".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + }; + + assert_eq!( + summarize_cli_policy_merge_op(&operation), + "add-endpoint slack_api endpoints=[slack.com:443 protocol=rest access=read-write enforcement=enforce request_body_credential_rewrite=true] binaries=[/usr/bin/node]" + ); + } + // ---- merge_chunk_into_policy ---- #[tokio::test] async fn merge_chunk_into_policy_adds_first_network_rule_to_empty_policy() { use openshell_core::proto::{NetworkBinary, NetworkEndpoint, NetworkPolicyRule}; - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let rule = NetworkPolicyRule { name: "google".to_string(), endpoints: vec![NetworkEndpoint { @@ -3964,6 +5324,8 @@ mod tests { hit_count: 1, first_seen_ms: 0, last_seen_ms: 0, + validation_result: String::new(), + rejection_reason: String::new(), }; let (version, _) = merge_chunk_into_policy(&store, &chunk.sandbox_id, &chunk) @@ -3994,7 +5356,7 @@ mod tests { NetworkBinary, NetworkEndpoint, NetworkPolicyRule, SandboxPolicy, }; - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let sandbox_id = "sb-merge"; let initial_policy = SandboxPolicy { @@ -4058,6 +5420,8 @@ mod tests { hit_count: 1, first_seen_ms: 0, last_seen_ms: 0, + validation_result: String::new(), + rejection_reason: String::new(), }; let (version, _) = merge_chunk_into_policy(&store, sandbox_id, &chunk) @@ -4093,7 +5457,7 @@ mod tests { NetworkBinary, NetworkEndpoint, NetworkPolicyRule, SandboxPolicy, }; - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let sandbox_id = "sb-new"; let initial_policy = SandboxPolicy { @@ -4157,6 +5521,8 @@ mod tests { hit_count: 1, first_seen_ms: 0, last_seen_ms: 0, + validation_result: String::new(), + rejection_reason: String::new(), }; let (version, _) = merge_chunk_into_policy(&store, sandbox_id, &chunk) @@ -4178,7 +5544,7 @@ mod tests { L7Allow, L7DenyRule, L7Rule, NetworkEndpoint, NetworkPolicyRule, SandboxPolicy, }; - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let sandbox_id = "sb-concurrent-merge"; let initial_policy = SandboxPolicy { @@ -4417,6 +5783,7 @@ mod tests { revision: 1, settings: std::iter::once(("policy".to_string(), StoredSettingValue::Bytes(encoded))) .collect(), + ..Default::default() }; let decoded = decode_policy_from_global_settings(&global) @@ -4499,6 +5866,7 @@ mod tests { ] .into_iter() .collect(), + ..Default::default() }; let sandbox = StoredSettings { revision: 1, @@ -4511,6 +5879,7 @@ mod tests { ] .into_iter() .collect(), + ..Default::default() }; let merged = merge_effective_settings(&global, &sandbox).unwrap(); @@ -4540,6 +5909,7 @@ mod tests { )] .into_iter() .collect(), + ..Default::default() }; let merged = merge_effective_settings(&global, &sandbox).unwrap(); @@ -4569,6 +5939,7 @@ mod tests { StoredSettingValue::Bytes("deadbeef".to_string()), )) .collect(), + ..Default::default() }; let sandbox = StoredSettings { revision: 1, @@ -4577,6 +5948,7 @@ mod tests { StoredSettingValue::Bytes("cafebabe".to_string()), )) .collect(), + ..Default::default() }; let merged = merge_effective_settings(&global, &sandbox).unwrap(); @@ -4743,9 +6115,7 @@ mod tests { #[tokio::test] async fn global_settings_load_returns_default_when_empty() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let settings = load_global_settings(&store).await.unwrap(); assert!(settings.settings.is_empty()); assert_eq!(settings.revision, 0); @@ -4753,9 +6123,7 @@ mod tests { #[tokio::test] async fn sandbox_settings_load_returns_default_when_empty() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let settings = load_sandbox_settings(&store, "nonexistent").await.unwrap(); assert!(settings.settings.is_empty()); assert_eq!(settings.revision, 0); @@ -4763,9 +6131,7 @@ mod tests { #[tokio::test] async fn global_settings_save_and_load_round_trip() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let mut settings = StoredSettings::default(); settings.settings.insert( @@ -4792,9 +6158,7 @@ mod tests { #[tokio::test] async fn sandbox_settings_save_and_load_round_trip() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let sandbox_name = "my-sandbox"; let mut settings = StoredSettings::default(); @@ -4816,11 +6180,7 @@ mod tests { #[tokio::test] async fn concurrent_global_setting_mutations_are_serialized() { - let store = Arc::new( - Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(), - ); + let store = Arc::new(test_store().await); let mutex = Arc::new(tokio::sync::Mutex::new(())); let n = 50; @@ -4851,11 +6211,7 @@ mod tests { #[tokio::test] async fn concurrent_global_setting_mutations_without_lock_can_lose_writes() { - let store = Arc::new( - Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(), - ); + let store = Arc::new(test_store().await); let n = 50; let mut handles = Vec::with_capacity(n); @@ -4869,33 +6225,60 @@ mod tests { .settings .insert(format!("key_{i}"), StoredSettingValue::Int(i as i64)); settings.revision = settings.revision.wrapping_add(1); - save_global_settings(&store, &settings).await.unwrap(); + save_global_settings(&store, &settings).await })); } + let mut succeeded = 0; + let mut cas_conflicts = 0; for h in handles { - h.await.unwrap(); + match h.await.unwrap() { + Ok(()) => succeeded += 1, + Err(e) if e.code() == Code::Aborted => cas_conflicts += 1, + Err(e) => panic!("unexpected error: {e}"), + } } let final_settings = load_global_settings(&store).await.unwrap(); - let lost = (n as u64).saturating_sub(final_settings.revision); - if lost == 0 { - eprintln!( - "note: no lost writes detected in unlocked test (sequential scheduling); \ - the locked test is the authoritative correctness check" - ); - } else { - eprintln!("unlocked test: {lost} lost writes out of {n} (expected behavior)"); - } + + // With single-attempt CAS (no retry), concurrent modifications are properly detected: + // - All tasks read initial state (revision=0, resource_version=0) + // - First write succeeds with resource_version=1 + // - Subsequent writes fail with ABORTED (CAS conflict) because they all have stale resource_version=0 + // - Only the first write succeeds; all others are rejected + // + // This demonstrates that single-attempt CAS prevents lost writes by rejecting stale updates. + // The caller must retry from a fresh read to incorporate concurrent changes. + assert!( + cas_conflicts > 0, + "most concurrent writes should fail with CAS conflict (succeeded={succeeded}, conflicts={cas_conflicts})" + ); + assert!( + succeeded < n, + "not all writes should succeed due to conflicts (succeeded={succeeded}, total={n})" + ); + assert_eq!( + final_settings.revision as usize, succeeded, + "final revision should match number of successful writes" + ); + assert_eq!( + final_settings.settings.len(), + succeeded, + "final settings should contain exactly the keys from successful writes" + ); + + eprintln!( + "unlocked CAS test: {succeeded} succeeded, {cas_conflicts} CAS conflicts, \ + final revision={} (matches succeeded count, demonstrating proper conflict detection)", + final_settings.revision + ); } // ---- Conflict guard tests ---- #[tokio::test] async fn conflict_guard_sandbox_set_blocked_when_global_exists() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let mut global = StoredSettings::default(); global.settings.insert( @@ -4912,9 +6295,7 @@ mod tests { #[tokio::test] async fn conflict_guard_sandbox_delete_blocked_when_global_exists() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let mut global = StoredSettings::default(); global @@ -4929,10 +6310,9 @@ mod tests { #[tokio::test] async fn delete_unlock_sandbox_set_succeeds_after_global_delete() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; + // Create initial global settings let mut global = StoredSettings::default(); global.settings.insert( "log_level".to_string(), @@ -4944,6 +6324,8 @@ mod tests { let loaded = load_global_settings(&store).await.unwrap(); assert!(loaded.settings.contains_key("log_level")); + // Load fresh to get current resource_version before updating + let mut global = load_global_settings(&store).await.unwrap(); global.settings.remove("log_level"); global.revision = 2; save_global_settings(&store, &global).await.unwrap(); @@ -4987,4 +6369,330 @@ mod tests { assert_eq!(err.code(), Code::InvalidArgument); assert!(err.message().contains("unknown setting key")); } + + #[tokio::test] + async fn save_settings_detects_concurrent_modification() { + let store = test_store().await; + + // Create initial settings + let mut settings = StoredSettings { + revision: 1, + settings: std::iter::once(( + "initial_key".to_string(), + StoredSettingValue::String("initial_value".to_string()), + )) + .collect(), + ..Default::default() + }; + save_global_settings(&store, &settings).await.unwrap(); + + // Load settings (simulating first client read) + let loaded = load_global_settings(&store).await.unwrap(); + assert_eq!(loaded.revision, 1); + + // Simulate concurrent modification: another client updates the settings + let mut concurrent_update = loaded.clone(); + concurrent_update.settings.insert( + "concurrent_key".to_string(), + StoredSettingValue::String("concurrent_value".to_string()), + ); + concurrent_update.revision = 2; + save_global_settings(&store, &concurrent_update) + .await + .unwrap(); + + // Now attempt to save our original modification (which is based on stale revision 1) + settings.settings.insert( + "our_key".to_string(), + StoredSettingValue::String("our_value".to_string()), + ); + settings.revision = 2; // We think we're updating to revision 2 + + let result = save_global_settings(&store, &settings).await; + + // Should fail with ABORTED due to concurrent modification + assert!(result.is_err(), "save with stale revision should fail"); + let err = result.unwrap_err(); + assert_eq!( + err.code(), + Code::Aborted, + "should fail with ABORTED due to version mismatch" + ); + assert!( + err.message().contains("concurrently"), + "error should mention concurrent modification: {}", + err.message() + ); + + // Verify the database contains the concurrent update, not our stale update + let final_settings = load_global_settings(&store).await.unwrap(); + assert_eq!(final_settings.revision, 2); + assert!( + final_settings.settings.contains_key("concurrent_key"), + "concurrent update should be preserved" + ); + assert!( + !final_settings.settings.contains_key("our_key"), + "stale update should NOT be in database" + ); + } + + // ---- CAS (Client-driven optimistic concurrency) tests for UpdateConfig ---- + // These test the policy backfill path where spec.policy is None and UpdateConfig + // uses update_message_cas to atomically set it. + + #[tokio::test] + async fn update_config_policy_backfill_cas_succeeds_with_correct_version() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + + let state = test_server_state().await; + + // Create a sandbox WITHOUT a policy (spec.policy = None) + // This simulates a sandbox before the supervisor has discovered and synced a policy + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-1".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, // No policy yet - will be backfilled + providers: Vec::new(), + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + // Fetch the sandbox to get its current resource_version + let current = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + let current_version = current.metadata.as_ref().unwrap().resource_version; + + // Backfill the policy with correct expected_resource_version + let new_policy = ProtoSandboxPolicy::default(); + + let response = handle_update_config( + &state, + Request::new(UpdateConfigRequest { + name: "test-sandbox".to_string(), + policy: Some(new_policy), + setting_key: String::new(), + setting_value: None, + delete_setting: false, + global: false, + merge_operations: vec![], + expected_resource_version: current_version, + }), + ) + .await + .unwrap() + .into_inner(); + + // UpdateConfigResponse contains the policy version + assert_eq!(response.version, 1); + + // Verify the resource_version incremented and policy was backfilled + let updated_sandbox = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + assert_eq!( + updated_sandbox.metadata.as_ref().unwrap().resource_version, + current_version + 1, + "resource_version should increment during CAS backfill" + ); + assert!( + updated_sandbox.spec.as_ref().unwrap().policy.is_some(), + "policy should be backfilled" + ); + } + + #[tokio::test] + async fn update_config_policy_backfill_cas_rejects_stale_version() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + + let state = test_server_state().await; + + // Create a sandbox WITHOUT a policy + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-1".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + providers: Vec::new(), + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + // Get current version + let current = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + let current_version = current.metadata.as_ref().unwrap().resource_version; + + // Try to backfill with a stale version + let new_policy = ProtoSandboxPolicy::default(); + + let err = handle_update_config( + &state, + Request::new(UpdateConfigRequest { + name: "test-sandbox".to_string(), + policy: Some(new_policy), + setting_key: String::new(), + setting_value: None, + delete_setting: false, + global: false, + merge_operations: vec![], + expected_resource_version: 99, // stale version + }), + ) + .await + .unwrap_err(); + + // Should get ABORTED status for CAS conflict + assert_eq!(err.code(), Code::Aborted); + assert!( + err.message().contains("modified concurrently") + || err.message().contains("resource_version"), + "error message should mention concurrency conflict: {}", + err.message() + ); + + // Verify the sandbox was not modified (policy still None) + let unchanged = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + assert_eq!( + unchanged.metadata.as_ref().unwrap().resource_version, + current_version, + "resource_version should not change when CAS fails" + ); + assert!( + unchanged.spec.as_ref().unwrap().policy.is_none(), + "policy should still be None after failed backfill" + ); + } + + #[tokio::test] + async fn update_config_policy_backfill_concurrent_with_stale_versions() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + use std::sync::Arc; + + let state = Arc::new(test_server_state().await); + + // Create a sandbox WITHOUT a policy + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-1".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + providers: Vec::new(), + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + // All three clients fetch the sandbox and see the same version + let initial = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + let initial_version = initial.metadata.as_ref().unwrap().resource_version; + + // Launch 3 concurrent policy backfill attempts, all using the same initial version + let mut handles = vec![]; + for _i in 0..3 { + let state_clone = Arc::clone(&state); + let new_policy = ProtoSandboxPolicy::default(); + + let handle = tokio::spawn(async move { + handle_update_config( + &state_clone, + Request::new(UpdateConfigRequest { + name: "test-sandbox".to_string(), + policy: Some(new_policy), + setting_key: String::new(), + setting_value: None, + delete_setting: false, + global: false, + merge_operations: vec![], + expected_resource_version: initial_version, + }), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Only one should succeed; others should get ABORTED + let successes = results.iter().filter(|r| r.is_ok()).count(); + let aborted_conflicts = results + .iter() + .filter(|r| r.as_ref().err().is_some_and(|e| e.code() == Code::Aborted)) + .count(); + + assert_eq!( + successes, 1, + "exactly one backfill should succeed with client-driven CAS" + ); + assert_eq!( + aborted_conflicts, 2, + "two backfills should fail with ABORTED due to stale version" + ); + + // Final sandbox should have resource_version = initial_version + 1 and policy backfilled + let final_sandbox = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + assert_eq!( + final_sandbox.metadata.as_ref().unwrap().resource_version, + initial_version + 1 + ); + assert!( + final_sandbox.spec.as_ref().unwrap().policy.is_some(), + "policy should be backfilled after one success" + ); + } } diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 2f4893073..3ddaae037 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -5,14 +5,18 @@ #![allow(clippy::result_large_err)] // gRPC handlers return Result, Status> -use crate::persistence::{ObjectName, ObjectType, Store, generate_name}; -use openshell_core::proto::Provider; +use crate::persistence::{ + ObjectId, ObjectLabels, ObjectName, ObjectType, Store, WriteCondition, generate_name, +}; +use openshell_core::proto::{Provider, Sandbox}; use prost::Message; use tonic::Status; use tracing::warn; use super::validation::validate_provider_fields; -use super::{MAX_PAGE_SIZE, clamp_limit}; +use super::{ + MAX_MAP_KEY_LEN, MAX_MAP_VALUE_LEN, MAX_PAGE_SIZE, MAX_PROVIDER_CONFIG_ENTRIES, clamp_limit, +}; // --------------------------------------------------------------------------- // CRUD helpers @@ -29,6 +33,29 @@ fn redact_provider_credentials(mut provider: Provider) -> Provider { provider } +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(super) struct ProviderEnvironment { + pub environment: std::collections::HashMap, + pub credential_expires_at_ms: std::collections::HashMap, +} + +impl ProviderEnvironment { + #[cfg(test)] + fn is_empty(&self) -> bool { + self.environment.is_empty() + } + + #[cfg(test)] + fn get(&self, key: &str) -> Option<&String> { + self.environment.get(key) + } + + #[cfg(test)] + fn contains_key(&self, key: &str) -> bool { + self.environment.contains_key(key) + } +} + pub(super) async fn create_provider_record( store: &Store, mut provider: Provider, @@ -37,13 +64,13 @@ pub(super) async fn create_provider_record( // Initialize metadata if not present if provider.metadata.is_none() { - let now_ms = current_time_ms() - .map_err(|e| Status::internal(format!("failed to get current time: {e}")))?; + let now_ms = current_time_ms(); provider.metadata = Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: uuid::Uuid::new_v4().to_string(), name: generate_name(), created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }); } @@ -63,7 +90,9 @@ pub(super) async fn create_provider_record( if provider.r#type.trim().is_empty() { return Err(Status::invalid_argument("provider.type is required")); } - if provider.credentials.is_empty() { + if provider.credentials.is_empty() + && !provider_type_allows_empty_credentials_for_refresh(store, &provider.r#type).await? + { return Err(Status::invalid_argument( "provider.credentials must not be empty", )); @@ -72,19 +101,38 @@ pub(super) async fn create_provider_record( // Validate field sizes before any I/O. validate_provider_fields(&provider)?; - let existing = store - .get_message_by_name::(provider.object_name()) - .await - .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))?; - - if existing.is_some() { - return Err(Status::already_exists("provider already exists")); + // Generate UUID for database row and update metadata.id to match + let provider_id = uuid::Uuid::new_v4().to_string(); + let mut provider = provider; + if let Some(metadata) = provider.metadata.as_mut() { + metadata.id.clone_from(&provider_id); } - store - .put_message(&provider) + // Create with MustCreate condition to prevent duplicate creation race + let result = store + .put_if( + Provider::object_type(), + &provider_id, + provider.object_name(), + &provider.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) .await - .map_err(|e| Status::internal(format!("persist provider failed: {e}")))?; + .map_err(|e| { + if matches!( + e, + crate::persistence::PersistenceError::UniqueViolation { .. } + ) { + Status::already_exists("provider already exists") + } else { + Status::internal(format!("persist provider failed: {e}")) + } + })?; + + if let Some(metadata) = provider.metadata.as_mut() { + metadata.resource_version = result.resource_version; + } Ok(redact_provider_credentials(provider)) } @@ -107,31 +155,31 @@ pub(super) async fn list_provider_records( limit: u32, offset: u32, ) -> Result, Status> { - let records = store - .list(Provider::object_type(), limit, offset) + let providers: Vec = store + .list_messages(limit, offset) .await .map_err(|e| Status::internal(format!("list providers failed: {e}")))?; - let mut providers = Vec::with_capacity(records.len()); - for record in records { - let provider = Provider::decode(record.payload.as_slice()) - .map_err(|e| Status::internal(format!("decode provider failed: {e}")))?; - providers.push(redact_provider_credentials(provider)); - } - - Ok(providers) + Ok(providers + .into_iter() + .map(redact_provider_credentials) + .collect()) } pub(super) async fn update_provider_record( store: &Store, provider: Provider, ) -> Result { - use crate::persistence::ObjectName; + use crate::persistence::{ObjectId, ObjectName}; if provider.object_name().is_empty() { return Err(Status::invalid_argument("provider.name is required")); } + // Extract expected version from provider metadata + let expected_resource_version = provider.metadata.as_ref().map_or(0, |m| m.resource_version); + + // Resolve provider ID from name for CAS update let existing = store .get_message_by_name::(provider.object_name()) .await @@ -150,24 +198,75 @@ pub(super) async fn update_provider_record( )); } - let updated = Provider { - metadata: existing.metadata, - r#type: existing.r#type, - credentials: merge_map(existing.credentials, provider.credentials), - config: merge_map(existing.config, provider.config), + let current_version = existing.metadata.as_ref().map_or(0, |m| m.resource_version); + + let cas_version = if expected_resource_version == 0 { + current_version + } else { + expected_resource_version }; - // Ensure metadata is valid (defense in depth - existing.metadata should always be valid) - super::validation::validate_object_metadata(updated.metadata.as_ref(), "provider")?; + // Apply merge to create candidate + let mut candidate = existing.clone(); + candidate.credentials = merge_map(candidate.credentials, provider.credentials); + candidate.config = merge_map(candidate.config, provider.config); + candidate.credential_expires_at_ms = merge_i64_map( + candidate.credential_expires_at_ms, + provider.credential_expires_at_ms, + ); - validate_provider_fields(&updated)?; + // Validate BEFORE writing to prevent persisting invalid state + super::validation::validate_object_metadata(candidate.metadata.as_ref(), "provider")?; + validate_provider_fields(&candidate)?; + validate_provider_update_against_attached_sandboxes(store, &candidate).await?; + + // Serialize labels for storage + let labels_map = candidate.object_labels(); + let labels_json = if labels_map + .as_ref() + .is_none_or(std::collections::HashMap::is_empty) + { + None + } else { + Some( + serde_json::to_string(&labels_map) + .map_err(|e| Status::internal(format!("serialize labels failed: {e}")))?, + ) + }; - store - .put_message(&updated) + // Write validated candidate with CAS condition + let result = store + .put_if( + Provider::object_type(), + candidate.object_id(), + candidate.object_name(), + &candidate.encode_to_vec(), + labels_json.as_deref(), + WriteCondition::MatchResourceVersion(cas_version), + ) .await - .map_err(|e| Status::internal(format!("persist provider failed: {e}")))?; + .map_err(|e| { + if matches!(e, crate::persistence::PersistenceError::Conflict { .. }) { + Status::aborted(format!( + "provider was modified concurrently (current resource_version: {})", + match e { + crate::persistence::PersistenceError::Conflict { + current_resource_version, + } => current_resource_version.unwrap_or(0), + _ => 0, + } + )) + } else { + Status::internal(format!("update provider failed: {e}")) + } + })?; + + // Update resource_version from successful write + if let Some(metadata) = candidate.metadata.as_mut() { + metadata.resource_version = result.resource_version; + } - Ok(redact_provider_credentials(updated)) + Ok(redact_provider_credentials(candidate)) } pub(super) async fn delete_provider_record(store: &Store, name: &str) -> Result { @@ -175,12 +274,102 @@ pub(super) async fn delete_provider_record(store: &Store, name: &str) -> Result< return Err(Status::invalid_argument("name is required")); } + let Some(provider) = store + .get_message_by_name::(name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + else { + return Ok(false); + }; + + let blocking_sandboxes = sandboxes_using_provider(store, name).await?; + if !blocking_sandboxes.is_empty() { + return Err(Status::failed_precondition(format!( + "provider '{name}' is attached to sandbox(es): {}", + blocking_sandboxes.join(", ") + ))); + } + + crate::provider_refresh::delete_refresh_states_for_provider(store, provider.object_id()) + .await?; + store .delete_by_name(Provider::object_type(), name) .await .map_err(|e| Status::internal(format!("delete provider failed: {e}"))) } +/// Iterate over every `Sandbox` in the store and collect items produced by +/// `f`. `f` receives each decoded sandbox; returning `Some(T)` includes the +/// value in the output, `None` skips it. +/// +/// This is the shared pagination kernel used by all sandbox-scan helpers. +async fn scan_sandboxes(store: &Store, mut f: F) -> Result, Status> +where + F: FnMut(Sandbox) -> Option, +{ + let mut out = Vec::new(); + let mut offset = 0u32; + loop { + let records = store + .list(Sandbox::object_type(), 1000, offset) + .await + .map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?; + if records.is_empty() { + break; + } + offset = offset + .checked_add( + u32::try_from(records.len()) + .map_err(|_| Status::internal("sandbox page size exceeded u32"))?, + ) + .ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?; + for record in records { + let sandbox = Sandbox::decode(record.payload.as_slice()) + .map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?; + if let Some(item) = f(sandbox) { + out.push(item); + } + } + } + Ok(out) +} + +async fn sandboxes_using_provider( + store: &Store, + provider_name: &str, +) -> Result, Status> { + let provider_name = provider_name.to_string(); + let mut names = scan_sandboxes(store, |sandbox| { + let spec = sandbox.spec.as_ref()?; + if spec.providers.iter().any(|n| n == &provider_name) { + Some(sandbox.object_name().to_string()) + } else { + None + } + }) + .await?; + names.sort(); + names.dedup(); + Ok(names) +} + +async fn sandboxes_using_provider_records( + store: &Store, + provider_name: &str, +) -> Result, Status> { + let provider_name = provider_name.to_string(); + scan_sandboxes(store, |sandbox| { + let spec = sandbox.spec.as_ref()?; + if spec.providers.iter().any(|n| n == &provider_name) { + Some(sandbox) + } else { + None + } + }) + .await +} + /// Merge an incoming map into an existing map. /// /// - If `incoming` is empty, return `existing` unchanged (no-op). @@ -203,6 +392,23 @@ fn merge_map( existing } +fn merge_i64_map( + mut existing: std::collections::HashMap, + incoming: std::collections::HashMap, +) -> std::collections::HashMap { + if incoming.is_empty() { + return existing; + } + for (key, value) in incoming { + if value <= 0 { + existing.remove(&key); + } else { + existing.insert(key, value); + } + } + existing +} + // --------------------------------------------------------------------------- // Provider environment resolution // --------------------------------------------------------------------------- @@ -211,17 +417,20 @@ fn merge_map( /// /// For each provider name in the list, fetches the provider from the store and /// collects credential key-value pairs. Returns a map of environment variables -/// to inject into the sandbox. When duplicate keys appear across providers, the -/// first provider's value wins. +/// to inject into the sandbox. Credential keys must be unique across attached +/// providers so one provider cannot silently overwrite another provider's token. pub(super) async fn resolve_provider_environment( store: &Store, provider_names: &[String], -) -> Result, Status> { +) -> Result { if provider_names.is_empty() { - return Ok(std::collections::HashMap::new()); + return Ok(ProviderEnvironment::default()); } let mut env = std::collections::HashMap::new(); + let mut expires = std::collections::HashMap::new(); + let now_ms = crate::persistence::current_time_ms(); + validate_provider_environment_keys_unique_at(store, provider_names, None, now_ms).await?; for name in provider_names { let provider = store @@ -232,6 +441,23 @@ pub(super) async fn resolve_provider_environment( for (key, value) in &provider.credentials { if is_valid_env_key(key) { + let expires_at_ms = provider + .credential_expires_at_ms + .get(key) + .copied() + .unwrap_or_default(); + if expires_at_ms > 0 && expires_at_ms <= now_ms { + warn!( + provider_name = %name, + key = %key, + expires_at_ms, + "skipping expired provider credential" + ); + continue; + } + if expires_at_ms > 0 { + expires.entry(key.clone()).or_insert(expires_at_ms); + } env.entry(key.clone()).or_insert_with(|| value.clone()); } else { warn!( @@ -243,7 +469,133 @@ pub(super) async fn resolve_provider_environment( } } - Ok(env) + Ok(ProviderEnvironment { + environment: env, + credential_expires_at_ms: expires, + }) +} + +pub async fn validate_provider_environment_keys_unique( + store: &Store, + provider_names: &[String], +) -> Result<(), Status> { + validate_provider_environment_keys_unique_at( + store, + provider_names, + None, + crate::persistence::current_time_ms(), + ) + .await +} + +pub async fn validate_provider_credential_key_available_for_attached_sandboxes( + store: &Store, + provider: &Provider, + credential_key: &str, +) -> Result<(), Status> { + let mut candidate = provider.clone(); + candidate + .credentials + .entry(credential_key.to_string()) + .or_insert_with(|| "pending".to_string()); + candidate.credential_expires_at_ms.remove(credential_key); + validate_provider_update_against_attached_sandboxes(store, &candidate).await +} + +pub async fn validate_provider_update_against_attached_sandboxes( + store: &Store, + provider: &Provider, +) -> Result<(), Status> { + let provider_name = provider.object_name().to_string(); + for sandbox in sandboxes_using_provider_records(store, &provider_name).await? { + let sandbox_name = sandbox.object_name().to_string(); + let Some(spec) = sandbox.spec.as_ref() else { + continue; + }; + validate_provider_environment_keys_unique_at( + store, + &spec.providers, + Some(provider), + crate::persistence::current_time_ms(), + ) + .await + .map_err(|err| { + Status::failed_precondition(format!( + "provider update would create credential env key conflict on sandbox '{sandbox_name}': {}", + err.message() + )) + })?; + } + Ok(()) +} + +async fn validate_provider_environment_keys_unique_at( + store: &Store, + provider_names: &[String], + candidate_provider: Option<&Provider>, + now_ms: i64, +) -> Result<(), Status> { + let mut seen = std::collections::HashMap::::new(); + for name in provider_names { + let provider = match candidate_provider { + Some(candidate) if candidate.object_name() == name.as_str() => candidate.clone(), + _ => store + .get_message_by_name::(name) + .await + .map_err(|e| Status::internal(format!("failed to fetch provider '{name}': {e}")))? + .ok_or_else(|| { + Status::failed_precondition(format!("provider '{name}' not found")) + })?, + }; + let provider_name = provider.object_name().to_string(); + for key in active_provider_environment_keys(store, &provider, now_ms).await? { + if let Some(first_provider) = seen.get(&key) { + if first_provider != &provider_name { + return Err(Status::failed_precondition(format!( + "credential env key '{key}' is provided by both provider '{first_provider}' and provider '{provider_name}'; use provider-specific env names" + ))); + } + } else { + seen.insert(key, provider_name.clone()); + } + } + } + Ok(()) +} + +async fn active_provider_environment_keys( + store: &Store, + provider: &Provider, + now_ms: i64, +) -> Result, Status> { + let mut keys = active_provider_credential_keys(provider, now_ms); + if !provider.object_id().is_empty() { + keys.extend( + crate::provider_refresh::list_refresh_states_for_provider(store, provider.object_id()) + .await? + .into_iter() + .map(|state| state.credential_key) + .filter(|key| is_valid_env_key(key)), + ); + } + keys.sort(); + keys.dedup(); + Ok(keys) +} + +fn active_provider_credential_keys(provider: &Provider, now_ms: i64) -> Vec { + provider + .credentials + .keys() + .filter(|key| is_valid_env_key(key)) + .filter(|key| { + provider + .credential_expires_at_ms + .get(*key) + .is_none_or(|expires_at_ms| *expires_at_ms <= 0 || *expires_at_ms > now_ms) + }) + .cloned() + .collect() } pub(super) fn is_valid_env_key(key: &str) -> bool { @@ -273,17 +625,21 @@ impl ObjectType for Provider { use crate::ServerState; use openshell_core::proto::{ - CreateProviderRequest, DeleteProviderProfileRequest, DeleteProviderProfileResponse, - DeleteProviderRequest, DeleteProviderResponse, GetProviderProfileRequest, GetProviderRequest, - ImportProviderProfilesRequest, ImportProviderProfilesResponse, LintProviderProfilesRequest, - LintProviderProfilesResponse, ListProviderProfilesRequest, ListProviderProfilesResponse, - ListProvidersRequest, ListProvidersResponse, ProviderProfile, ProviderProfileDiagnostic, - ProviderProfileImportItem, ProviderProfileResponse, ProviderResponse, Sandbox, - StoredProviderProfile, UpdateProviderRequest, + ConfigureProviderRefreshRequest, ConfigureProviderRefreshResponse, CreateProviderRequest, + DeleteProviderProfileRequest, DeleteProviderProfileResponse, DeleteProviderRefreshRequest, + DeleteProviderRefreshResponse, DeleteProviderRequest, DeleteProviderResponse, + GetProviderProfileRequest, GetProviderRefreshStatusRequest, GetProviderRefreshStatusResponse, + GetProviderRequest, ImportProviderProfilesRequest, ImportProviderProfilesResponse, + LintProviderProfilesRequest, LintProviderProfilesResponse, ListProviderProfilesRequest, + ListProviderProfilesResponse, ListProvidersRequest, ListProvidersResponse, + ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic, + ProviderProfileImportItem, ProviderProfileResponse, ProviderResponse, + RotateProviderCredentialRequest, RotateProviderCredentialResponse, StoredProviderProfile, + UpdateProviderRequest, }; use openshell_providers::{ - ProfileValidationDiagnostic, ProviderTypeProfile, default_profiles, get_default_profile, - normalize_profile_id, normalize_provider_type, validate_profile_set, + CredentialRefreshProfile, ProfileValidationDiagnostic, ProviderTypeProfile, default_profiles, + get_default_profile, normalize_profile_id, validate_profile_set, }; use std::sync::Arc; use tonic::{Request, Response}; @@ -390,7 +746,14 @@ pub(super) async fn handle_import_provider_profiles( let stored = stored_provider_profile(profile.to_proto()); state .store - .put_message(&stored) + .put_if( + StoredProviderProfile::object_type(), + stored.object_id(), + stored.object_name(), + &stored.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) .await .map_err(|e| Status::internal(format!("persist provider profile failed: {e}")))?; imported.push(stored.profile.unwrap_or_default()); @@ -477,6 +840,72 @@ pub(super) async fn get_provider_type_profile( Ok(profile) } +async fn provider_refresh_defaults( + store: &Store, + provider: &Provider, + credential_key: &str, +) -> Result, Status> { + let Some(profile) = get_provider_type_profile(store, &provider.r#type).await? else { + return Ok(None); + }; + Ok(profile + .credentials + .iter() + .find(|credential| { + credential.name == credential_key + || credential + .env_vars + .iter() + .any(|env_var| env_var == credential_key) + }) + .and_then(|credential| credential.refresh.clone())) +} + +fn validate_refresh_material( + material: &std::collections::HashMap, + refresh_defaults: Option<&CredentialRefreshProfile>, +) -> Result<(), Status> { + let Some(refresh_defaults) = refresh_defaults else { + return Ok(()); + }; + for required in refresh_defaults + .material + .iter() + .filter(|item| item.required) + { + if material + .get(&required.name) + .is_none_or(|value| value.trim().is_empty()) + { + return Err(Status::invalid_argument(format!( + "{} material is required by the provider profile", + required.name + ))); + } + } + Ok(()) +} + +async fn provider_type_allows_empty_credentials_for_refresh( + store: &Store, + provider_type: &str, +) -> Result { + let Some(profile) = get_provider_type_profile(store, provider_type).await? else { + return Ok(false); + }; + let required_credentials = profile + .credentials + .iter() + .filter(|credential| credential.required) + .collect::>(); + Ok(!required_credentials.is_empty() + && required_credentials.iter().all(|credential| { + credential.refresh.as_ref().is_some_and(|refresh| { + crate::provider_refresh::is_gateway_mintable_strategy(refresh.strategy) + }) + })) +} + async fn merged_provider_profiles(store: &Store) -> Result, Status> { let mut profiles = default_profiles().to_vec(); profiles.extend( @@ -490,17 +919,10 @@ async fn merged_provider_profiles(store: &Store) -> Result Result, Status> { - let records = store - .list(StoredProviderProfile::object_type(), 10_000, 0) + let profiles: Vec = store + .list_messages(10_000, 0) .await .map_err(|e| Status::internal(format!("list provider profiles failed: {e}")))?; - - let mut profiles = Vec::with_capacity(records.len()); - for record in records { - let profile = StoredProviderProfile::decode(record.payload.as_slice()) - .map_err(|e| Status::internal(format!("decode provider profile failed: {e}")))?; - profiles.push(profile); - } Ok(profiles) } @@ -572,18 +994,6 @@ async fn profile_conflict_diagnostics( }); continue; } - if let Some(provider_type) = normalize_provider_type(&id) { - diagnostics.push(ProfileValidationDiagnostic { - source: source.clone(), - profile_id: id.clone(), - field: "id".to_string(), - message: format!( - "provider profile id '{id}' is reserved for legacy provider type '{provider_type}'" - ), - severity: "error".to_string(), - }); - continue; - } if store .get_message_by_name::(&id) .await @@ -604,13 +1014,14 @@ async fn profile_conflict_diagnostics( fn stored_provider_profile(profile: ProviderProfile) -> StoredProviderProfile { use crate::persistence::current_time_ms; - let now_ms = current_time_ms().unwrap_or_default(); + let now_ms = current_time_ms(); StoredProviderProfile { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: uuid::Uuid::new_v4().to_string(), name: profile.id.clone(), created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }), profile: Some(profile), } @@ -633,41 +1044,31 @@ fn has_errors(diagnostics: &[ProfileValidationDiagnostic]) -> bool { } async fn sandboxes_using_profile(store: &Store, profile_id: &str) -> Result, Status> { - let mut blocking = Vec::new(); - let mut offset = 0; - loop { - let records = store - .list(Sandbox::object_type(), 1000, offset) - .await - .map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?; - if records.is_empty() { - break; - } - offset = offset - .checked_add( - u32::try_from(records.len()) - .map_err(|_| Status::internal("sandbox page size exceeded u32"))?, - ) - .ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?; + // Collect all sandboxes that reference at least one provider — pagination + // is handled by `scan_sandboxes`; the async provider lookup happens below. + let candidates = scan_sandboxes(store, |sandbox| { + let has_providers = sandbox + .spec + .as_ref() + .is_some_and(|s| !s.providers.is_empty()); + has_providers.then_some(sandbox) + }) + .await?; - for record in records { - let sandbox = Sandbox::decode(record.payload.as_slice()) - .map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?; - let Some(spec) = sandbox.spec.as_ref() else { + let mut blocking = Vec::new(); + for sandbox in candidates { + let spec = sandbox.spec.as_ref().expect("filtered by scan_sandboxes"); + for provider_name in &spec.providers { + let Some(provider) = store + .get_message_by_name::(provider_name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + else { continue; }; - for provider_name in &spec.providers { - let Some(provider) = store - .get_message_by_name::(provider_name) - .await - .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? - else { - continue; - }; - if normalize_profile_id(&provider.r#type).as_deref() == Some(profile_id) { - blocking.push(sandbox.object_name().to_string()); - break; - } + if normalize_profile_id(&provider.r#type).as_deref() == Some(profile_id) { + blocking.push(sandbox.object_name().to_string()); + break; } } } @@ -681,9 +1082,12 @@ pub(super) async fn handle_update_provider( request: Request, ) -> Result, Status> { let req = request.into_inner(); - let provider = req + let mut provider = req .provider .ok_or_else(|| Status::invalid_argument("provider is required"))?; + provider + .credential_expires_at_ms + .extend(req.credential_expires_at_ms); let provider = update_provider_record(state.store.as_ref(), provider).await?; Ok(Response::new(ProviderResponse { @@ -691,66 +1095,398 @@ pub(super) async fn handle_update_provider( })) } -pub(super) async fn handle_delete_provider( +pub(super) async fn handle_get_provider_refresh_status( state: &Arc, - request: Request, -) -> Result, Status> { - let name = request.into_inner().name; - let deleted = delete_provider_record(state.store.as_ref(), &name).await?; - - Ok(Response::new(DeleteProviderResponse { deleted })) -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + if request.provider.trim().is_empty() { + return Err(Status::invalid_argument("provider is required")); + } + let provider = state + .store + .get_message_by_name::(&request.provider) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + .ok_or_else(|| Status::not_found("provider not found"))?; -#[cfg(test)] -mod tests { - use super::*; - use crate::ServerState; - use crate::compute::new_test_runtime; - use crate::grpc::MAX_MAP_KEY_LEN; - use crate::sandbox_index::SandboxIndex; - use crate::sandbox_watch::SandboxWatchBus; - use crate::supervisor_session::SupervisorSessionRegistry; - use crate::tracing_bus::TracingLogBus; - use openshell_core::Config; - use openshell_core::proto::{ - DeleteProviderProfileRequest, GetProviderProfileRequest, ImportProviderProfilesRequest, - L7Allow, L7Rule, LintProviderProfilesRequest, ListProviderProfilesRequest, NetworkBinary, - NetworkEndpoint, ProviderProfile, ProviderProfileCategory, ProviderProfileImportItem, - Sandbox, SandboxSpec, + let states = if request.credential_key.trim().is_empty() { + crate::provider_refresh::list_refresh_states_for_provider( + state.store.as_ref(), + provider.object_id(), + ) + .await? + } else { + crate::provider_refresh::get_refresh_state( + state.store.as_ref(), + provider.object_id(), + request.credential_key.trim(), + ) + .await? + .into_iter() + .collect() }; - use openshell_core::{ObjectId, ObjectName}; - use std::collections::HashMap; - use std::sync::Arc; - use tonic::{Code, Request}; - #[test] - fn env_key_validation_accepts_valid_keys() { - assert!(is_valid_env_key("PATH")); - assert!(is_valid_env_key("PYTHONPATH")); - assert!(is_valid_env_key("_OPENSHELL_VALUE_1")); - } + Ok(Response::new(GetProviderRefreshStatusResponse { + credentials: states + .iter() + .map(crate::provider_refresh::refresh_status_from_state) + .collect(), + })) +} - #[test] - fn env_key_validation_rejects_invalid_keys() { - assert!(!is_valid_env_key("")); - assert!(!is_valid_env_key("1PATH")); - assert!(!is_valid_env_key("BAD-KEY")); - assert!(!is_valid_env_key("BAD KEY")); - assert!(!is_valid_env_key("X=Y")); - assert!(!is_valid_env_key("X;rm -rf /")); +pub(super) async fn handle_configure_provider_refresh( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + let provider_name = request.provider.trim(); + let credential_key = request.credential_key.trim(); + if provider_name.is_empty() { + return Err(Status::invalid_argument("provider is required")); } - - fn provider_with_values(name: &str, provider_type: &str) -> Provider { - Provider { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: String::new(), - name: name.to_string(), + if credential_key.is_empty() { + return Err(Status::invalid_argument("credential_key is required")); + } + if !is_valid_env_key(credential_key) { + return Err(Status::invalid_argument( + "credential_key must be a valid environment variable name", + )); + } + let strategy = ProviderCredentialRefreshStrategy::try_from(request.strategy) + .unwrap_or(ProviderCredentialRefreshStrategy::Unspecified); + if strategy == ProviderCredentialRefreshStrategy::Unspecified { + return Err(Status::invalid_argument("refresh strategy is required")); + } + if !crate::provider_refresh::is_gateway_mintable_strategy(strategy) { + return Err(Status::invalid_argument(format!( + "refresh strategy '{}' is not gateway-mintable; update current credentials with provider update instead", + crate::provider_refresh::refresh_strategy_name(strategy as i32) + ))); + } + if request.material.len() > MAX_PROVIDER_CONFIG_ENTRIES { + return Err(Status::invalid_argument(format!( + "material exceeds maximum entries ({} > {MAX_PROVIDER_CONFIG_ENTRIES})", + request.material.len() + ))); + } + for (key, value) in &request.material { + if key.len() > MAX_MAP_KEY_LEN { + return Err(Status::invalid_argument(format!( + "material key exceeds maximum length ({} > {MAX_MAP_KEY_LEN})", + key.len() + ))); + } + if value.len() > MAX_MAP_VALUE_LEN { + return Err(Status::invalid_argument(format!( + "material value exceeds maximum length ({} > {MAX_MAP_VALUE_LEN})", + value.len() + ))); + } + } + if request.secret_material_keys.len() > MAX_PROVIDER_CONFIG_ENTRIES { + return Err(Status::invalid_argument(format!( + "secret_material_keys exceeds maximum entries ({} > {MAX_PROVIDER_CONFIG_ENTRIES})", + request.secret_material_keys.len() + ))); + } + for key in &request.secret_material_keys { + if key.len() > MAX_MAP_KEY_LEN { + return Err(Status::invalid_argument(format!( + "secret_material_keys entry exceeds maximum length ({} > {MAX_MAP_KEY_LEN})", + key.len() + ))); + } + } + if request + .material + .get("token_url") + .is_some_and(|value| !value.trim().is_empty()) + || request + .material + .get("token_uri") + .is_some_and(|value| !value.trim().is_empty()) + { + return Err(Status::invalid_argument( + "refresh token endpoints must be defined by the provider profile, not material", + )); + } + if request + .expires_at_ms + .is_some_and(|expires_at_ms| expires_at_ms < 0) + { + return Err(Status::invalid_argument( + "expires_at_ms must be greater than or equal to 0", + )); + } + + let provider = state + .store + .get_message_by_name::(provider_name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + .ok_or_else(|| Status::not_found("provider not found"))?; + validate_provider_credential_key_available_for_attached_sandboxes( + state.store.as_ref(), + &provider, + credential_key, + ) + .await?; + let refresh_defaults = + provider_refresh_defaults(state.store.as_ref(), &provider, credential_key).await?; + validate_refresh_material(&request.material, refresh_defaults.as_ref())?; + let material_scopes = crate::provider_refresh::material_scopes(&request.material); + let token_url = refresh_defaults + .as_ref() + .map(|refresh| refresh.token_url.clone()) + .unwrap_or_default(); + let scopes = if material_scopes.is_empty() { + refresh_defaults + .as_ref() + .map(|refresh| refresh.scopes.clone()) + .unwrap_or_default() + } else { + material_scopes + }; + let refresh_before_seconds = + crate::provider_refresh::parse_material_i64(&request.material, "refresh_before_seconds")? + .or_else(|| { + refresh_defaults + .as_ref() + .map(|refresh| refresh.refresh_before_seconds) + }) + .unwrap_or_default(); + let max_lifetime_seconds = + crate::provider_refresh::parse_material_i64(&request.material, "max_lifetime_seconds")? + .or_else(|| { + refresh_defaults + .as_ref() + .map(|refresh| refresh.max_lifetime_seconds) + }) + .unwrap_or_default(); + if refresh_before_seconds < 0 { + return Err(Status::invalid_argument( + "refresh_before_seconds material must be greater than or equal to 0", + )); + } + if max_lifetime_seconds < 0 { + return Err(Status::invalid_argument( + "max_lifetime_seconds material must be greater than or equal to 0", + )); + } + let existing_refresh_state = crate::provider_refresh::get_refresh_state( + state.store.as_ref(), + provider.object_id(), + credential_key, + ) + .await?; + let expires_at_ms = request.expires_at_ms.unwrap_or_else(|| { + existing_refresh_state + .as_ref() + .map(|state| state.expires_at_ms) + .unwrap_or_default() + }); + let mut state_record = crate::provider_refresh::new_refresh_state( + &provider, + credential_key, + crate::provider_refresh::NewRefreshStateConfig { + strategy, + material: request.material, + secret_material_keys: request.secret_material_keys, + expires_at_ms, + token_url, + scopes, + refresh_before_seconds, + max_lifetime_seconds, + }, + )?; + if let Some(existing) = existing_refresh_state { + state_record.metadata = existing.metadata; + state_record.last_refresh_at_ms = existing.last_refresh_at_ms; + } + crate::provider_refresh::put_refresh_state(state.store.as_ref(), &state_record).await?; + + if let Some(expires_at_ms) = request.expires_at_ms { + let updated = Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: provider_name.to_string(), + created_at_ms: 0, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + r#type: String::new(), + credentials: std::collections::HashMap::new(), + config: std::collections::HashMap::new(), + credential_expires_at_ms: std::collections::HashMap::from([( + credential_key.to_string(), + expires_at_ms, + )]), + }; + update_provider_record(state.store.as_ref(), updated).await?; + } + + Ok(Response::new(ConfigureProviderRefreshResponse { + status: Some(crate::provider_refresh::refresh_status_from_state( + &state_record, + )), + })) +} + +pub(super) async fn handle_rotate_provider_credential( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + let provider_name = request.provider.trim(); + let credential_key = request.credential_key.trim(); + if provider_name.is_empty() { + return Err(Status::invalid_argument("provider is required")); + } + if credential_key.is_empty() { + return Err(Status::invalid_argument("credential_key is required")); + } + let refresh_state = crate::provider_refresh::refresh_provider_credential( + state.store.as_ref(), + provider_name, + credential_key, + ) + .await?; + + Ok(Response::new(RotateProviderCredentialResponse { + status: Some(crate::provider_refresh::refresh_status_from_state( + &refresh_state, + )), + })) +} + +pub(super) async fn handle_delete_provider_refresh( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + let provider_name = request.provider.trim(); + let credential_key = request.credential_key.trim(); + if provider_name.is_empty() { + return Err(Status::invalid_argument("provider is required")); + } + if credential_key.is_empty() { + return Err(Status::invalid_argument("credential_key is required")); + } + let provider = state + .store + .get_message_by_name::(provider_name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + .ok_or_else(|| Status::not_found("provider not found"))?; + let existing_refresh_state = crate::provider_refresh::get_refresh_state( + state.store.as_ref(), + provider.object_id(), + credential_key, + ) + .await?; + let deleted_refresh_state = crate::provider_refresh::delete_refresh_state( + state.store.as_ref(), + provider.object_id(), + credential_key, + ) + .await?; + + let refresh_owned_expiry = existing_refresh_state + .as_ref() + .is_some_and(|refresh_state| { + refresh_state.expires_at_ms > 0 + && provider + .credential_expires_at_ms + .get(credential_key) + .is_some_and(|expires_at_ms| *expires_at_ms == refresh_state.expires_at_ms) + }); + if refresh_owned_expiry { + let updated = Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: provider_name.to_string(), + created_at_ms: 0, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + r#type: String::new(), + credentials: std::collections::HashMap::new(), + config: std::collections::HashMap::new(), + credential_expires_at_ms: std::collections::HashMap::from([( + credential_key.to_string(), + 0, + )]), + }; + update_provider_record(state.store.as_ref(), updated).await?; + } + + Ok(Response::new(DeleteProviderRefreshResponse { + deleted: deleted_refresh_state, + })) +} + +pub(super) async fn handle_delete_provider( + state: &Arc, + request: Request, +) -> Result, Status> { + let name = request.into_inner().name; + let deleted = delete_provider_record(state.store.as_ref(), &name).await?; + + Ok(Response::new(DeleteProviderResponse { deleted })) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::grpc::MAX_MAP_KEY_LEN; + use crate::grpc::test_support::test_server_state; + + async fn test_store() -> Store { + Store::connect("sqlite::memory:?cache=shared") + .await + .expect("in-memory SQLite store should connect") + } + use openshell_core::proto::{ + DeleteProviderProfileRequest, GetProviderProfileRequest, ImportProviderProfilesRequest, + L7Allow, L7Rule, LintProviderProfilesRequest, ListProviderProfilesRequest, NetworkBinary, + NetworkEndpoint, ProviderCredentialRefresh, ProviderCredentialRefreshMaterial, + ProviderProfile, ProviderProfileCategory, ProviderProfileCredential, + ProviderProfileImportItem, Sandbox, SandboxSpec, + }; + use openshell_core::{ObjectId, ObjectName}; + use std::collections::HashMap; + use tonic::{Code, Request}; + + #[test] + fn env_key_validation_accepts_valid_keys() { + assert!(is_valid_env_key("PATH")); + assert!(is_valid_env_key("PYTHONPATH")); + assert!(is_valid_env_key("_OPENSHELL_VALUE_1")); + } + + #[test] + fn env_key_validation_rejects_invalid_keys() { + assert!(!is_valid_env_key("")); + assert!(!is_valid_env_key("1PATH")); + assert!(!is_valid_env_key("BAD-KEY")); + assert!(!is_valid_env_key("BAD KEY")); + assert!(!is_valid_env_key("X=Y")); + assert!(!is_valid_env_key("X;rm -rf /")); + } + + fn provider_with_values(name: &str, provider_type: &str) -> Provider { + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: [ @@ -765,6 +1501,7 @@ mod tests { ] .into_iter() .collect(), + credential_expires_at_ms: HashMap::new(), } } @@ -778,6 +1515,7 @@ mod tests { endpoints: Vec::new(), binaries: Vec::new(), inference_capable: false, + discovery: None, } } @@ -791,25 +1529,73 @@ mod tests { profile } - async fn test_server_state() -> Arc { - let store = Arc::new( - Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(), - ); - let compute = new_test_runtime(store.clone()).await; - Arc::new(ServerState::new( - Config::new(None) - .with_database_url("sqlite::memory:?cache=shared") - .with_ssh_handshake_secret("test-secret"), - store, - compute, - SandboxIndex::new(), - SandboxWatchBus::new(), - TracingLogBus::new(), - Arc::new(SupervisorSessionRegistry::new()), - None, - )) + fn refreshable_credential(name: &str, env_var: &str) -> ProviderProfileCredential { + ProviderProfileCredential { + name: name.to_string(), + description: String::new(), + env_vars: vec![env_var.to_string()], + required: true, + auth_style: "bearer".to_string(), + header_name: "authorization".to_string(), + query_param: String::new(), + refresh: Some(ProviderCredentialRefresh { + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + token_url: "https://auth.example.com/token".to_string(), + scopes: Vec::new(), + refresh_before_seconds: 300, + max_lifetime_seconds: 3600, + material: vec![ + ProviderCredentialRefreshMaterial { + name: "client_id".to_string(), + description: String::new(), + required: true, + secret: false, + }, + ProviderCredentialRefreshMaterial { + name: "client_secret".to_string(), + description: String::new(), + required: true, + secret: true, + }, + ], + }), + } + } + + async fn import_test_refresh_profile(state: &Arc, id: &str, credential_key: &str) { + let mut profile = custom_profile(id); + profile.category = ProviderProfileCategory::Messaging as i32; + profile.credentials = vec![refreshable_credential("access_token", credential_key)]; + handle_import_provider_profiles( + state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(profile), + source: format!("{id}.yaml"), + }], + }), + ) + .await + .unwrap(); + } + + const TEST_GRAPH_PROVIDER_TYPE: &str = "test-msgraph"; + + async fn import_test_graph_refresh_profile(state: &Arc) { + import_test_refresh_profile(state, TEST_GRAPH_PROVIDER_TYPE, "MS_GRAPH_ACCESS_TOKEN").await; + } + + fn static_credential(name: &str, env_var: &str, required: bool) -> ProviderProfileCredential { + ProviderProfileCredential { + name: name.to_string(), + description: String::new(), + env_vars: vec![env_var.to_string()], + required, + auth_style: "bearer".to_string(), + header_name: "authorization".to_string(), + query_param: String::new(), + refresh: None, + } } #[tokio::test] @@ -826,6 +1612,13 @@ mod tests { .unwrap() .into_inner(); + let ids = response + .profiles + .iter() + .map(|profile| profile.id.as_str()) + .collect::>(); + assert_eq!(ids, vec!["claude-code", "github", "nvidia"]); + let github = response .profiles .iter() @@ -835,13 +1628,6 @@ mod tests { github.category, ProviderProfileCategory::SourceControl as i32 ); - assert!( - response - .profiles - .iter() - .all(|profile| profile.id != "generic"), - "generic remains a legacy provider type without a v2 profile" - ); } #[tokio::test] @@ -951,14 +1737,14 @@ mod tests { } #[tokio::test] - async fn import_provider_profile_rejects_legacy_provider_type_ids() { + async fn import_provider_profile_allows_legacy_provider_type_ids_without_built_in_profiles() { let state = test_server_state().await; let response = handle_import_provider_profiles( &state, Request::new(ImportProviderProfilesRequest { profiles: vec![ProviderProfileImportItem { - profile: Some(custom_profile("generic")), - source: "generic.yaml".to_string(), + profile: Some(custom_profile("codex")), + source: "codex.yaml".to_string(), }], }), ) @@ -966,23 +1752,21 @@ mod tests { .unwrap() .into_inner(); - assert!(!response.imported); - assert!( - response - .diagnostics - .iter() - .any(|diagnostic| diagnostic.message.contains("reserved")) - ); + assert!(response.imported); + assert!(response.diagnostics.is_empty()); - let missing = handle_get_provider_profile( + let imported = handle_get_provider_profile( &state, Request::new(GetProviderProfileRequest { - id: "generic".to_string(), + id: "codex".to_string(), }), ) .await - .unwrap_err(); - assert_eq!(missing.code(), Code::NotFound); + .unwrap() + .into_inner() + .profile + .expect("codex profile should be returned"); + assert_eq!(imported.id, "codex"); } #[tokio::test] @@ -1142,6 +1926,7 @@ mod tests { harness: true, }], inference_capable: false, + discovery: None, }), source: "advanced-api.yaml".to_string(), }], @@ -1263,6 +2048,7 @@ mod tests { name: "sandbox-using-custom".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { providers: vec!["custom-provider".to_string()], @@ -1286,75 +2072,584 @@ mod tests { } #[tokio::test] - async fn delete_provider_profile_removes_unused_custom_profile() { + async fn configure_provider_refresh_stores_scoped_status_and_provider_expiry() { let state = test_server_state().await; - handle_import_provider_profiles( - &state, - Request::new(ImportProviderProfilesRequest { - profiles: vec![ProviderProfileImportItem { - profile: Some(custom_profile("custom-api")), - source: "custom-api.yaml".to_string(), - }], - }), + import_test_graph_refresh_profile(&state).await; + create_provider_record( + state.store.as_ref(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "msgraph".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: TEST_GRAPH_PROVIDER_TYPE.to_string(), + credentials: std::iter::once(( + "MS_GRAPH_ACCESS_TOKEN".to_string(), + "token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, ) .await .unwrap(); - let deleted = handle_delete_provider_profile( + let expires_at_ms = crate::persistence::current_time_ms() + 60_000; + let response = handle_configure_provider_refresh( &state, - Request::new(DeleteProviderProfileRequest { - id: "custom-api".to_string(), + Request::new(ConfigureProviderRefreshRequest { + provider: "msgraph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + material: HashMap::from([ + ("tenant_id".to_string(), "tenant".to_string()), + ("client_id".to_string(), "client-id".to_string()), + ("client_secret".to_string(), "client-secret".to_string()), + ]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: Some(expires_at_ms), }), ) .await .unwrap() - .into_inner(); - assert!(deleted.deleted); + .into_inner() + .status + .expect("status"); + assert_eq!(response.credential_key, "MS_GRAPH_ACCESS_TOKEN"); - let missing = handle_get_provider_profile( + let status = handle_get_provider_refresh_status( &state, - Request::new(GetProviderProfileRequest { - id: "custom-api".to_string(), + Request::new(GetProviderRefreshStatusRequest { + provider: "msgraph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), }), ) .await - .unwrap_err(); - assert_eq!(missing.code(), Code::NotFound); - } - - #[tokio::test] - async fn provider_crud_round_trip_and_semantics() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + .unwrap() + .into_inner(); + assert_eq!(status.credentials.len(), 1); + assert_eq!(status.credentials[0].expires_at_ms, expires_at_ms); - let created = provider_with_values("gitlab-local", "gitlab"); - let persisted = create_provider_record(&store, created.clone()) + let provider = state + .store + .get_message_by_name::("msgraph") .await - .unwrap(); - assert_eq!(persisted.object_name(), "gitlab-local"); - assert_eq!(persisted.r#type, "gitlab"); - assert!(!persisted.object_id().is_empty()); - let provider_id = persisted.object_id().to_string(); + .unwrap() + .expect("provider"); + assert_eq!( + provider + .credential_expires_at_ms + .get("MS_GRAPH_ACCESS_TOKEN"), + Some(&expires_at_ms) + ); - let duplicate_err = create_provider_record(&store, created).await.unwrap_err(); - assert_eq!(duplicate_err.code(), Code::AlreadyExists); + let deleted = handle_delete_provider_refresh( + &state, + Request::new(DeleteProviderRefreshRequest { + provider: "msgraph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(deleted.deleted); - let loaded = get_provider_record(&store, "gitlab-local").await.unwrap(); - assert_eq!(loaded.object_id(), provider_id); + let status_after_delete = handle_get_provider_refresh_status( + &state, + Request::new(GetProviderRefreshStatusRequest { + provider: "msgraph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(status_after_delete.credentials.is_empty()); - let listed = list_provider_records(&store, 100, 0).await.unwrap(); - assert_eq!(listed.len(), 1); - assert_eq!(listed[0].object_name(), "gitlab-local"); + let provider_after_delete = state + .store + .get_message_by_name::("msgraph") + .await + .unwrap() + .expect("provider"); + assert!( + !provider_after_delete + .credential_expires_at_ms + .contains_key("MS_GRAPH_ACCESS_TOKEN") + ); + } - let updated = update_provider_record( - &store, + #[tokio::test] + async fn delete_provider_refresh_preserves_manually_updated_expiry() { + let state = test_server_state().await; + import_test_graph_refresh_profile(&state).await; + create_provider_record( + state.store.as_ref(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "msgraph".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: TEST_GRAPH_PROVIDER_TYPE.to_string(), + credentials: std::iter::once(( + "MS_GRAPH_ACCESS_TOKEN".to_string(), + "token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + + let refresh_expires_at_ms = crate::persistence::current_time_ms() + 60_000; + handle_configure_provider_refresh( + &state, + Request::new(ConfigureProviderRefreshRequest { + provider: "msgraph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + material: HashMap::from([ + ("tenant_id".to_string(), "tenant".to_string()), + ("client_id".to_string(), "client-id".to_string()), + ("client_secret".to_string(), "client-secret".to_string()), + ]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: Some(refresh_expires_at_ms), + }), + ) + .await + .unwrap(); + + let manual_expires_at_ms = refresh_expires_at_ms + 60_000; + update_provider_record( + state.store.as_ref(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "msgraph".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: String::new(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::from([( + "MS_GRAPH_ACCESS_TOKEN".to_string(), + manual_expires_at_ms, + )]), + }, + ) + .await + .unwrap(); + + let deleted = handle_delete_provider_refresh( + &state, + Request::new(DeleteProviderRefreshRequest { + provider: "msgraph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(deleted.deleted); + + let provider_after_delete = state + .store + .get_message_by_name::("msgraph") + .await + .unwrap() + .expect("provider"); + assert_eq!( + provider_after_delete + .credential_expires_at_ms + .get("MS_GRAPH_ACCESS_TOKEN"), + Some(&manual_expires_at_ms) + ); + } + + #[tokio::test] + async fn configure_provider_refresh_rejects_credential_key_collision_for_attached_sandbox() { + let state = test_server_state().await; + import_test_graph_refresh_profile(&state).await; + create_provider_record( + state.store.as_ref(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "existing-graph".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: TEST_GRAPH_PROVIDER_TYPE.to_string(), + credentials: std::iter::once(( + "MS_GRAPH_ACCESS_TOKEN".to_string(), + "existing-token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + create_provider_record( + state.store.as_ref(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "refreshing-graph".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: TEST_GRAPH_PROVIDER_TYPE.to_string(), + credentials: std::iter::once(("OTHER_TOKEN".to_string(), "other".to_string())) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + state + .store + .put_message(&Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sandbox-collision".to_string(), + name: "collision".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + providers: vec!["existing-graph".to_string(), "refreshing-graph".to_string()], + ..SandboxSpec::default() + }), + ..Default::default() + }) + .await + .unwrap(); + + let err = handle_configure_provider_refresh( + &state, + Request::new(ConfigureProviderRefreshRequest { + provider: "refreshing-graph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + material: HashMap::from([ + ("tenant_id".to_string(), "tenant".to_string()), + ("client_id".to_string(), "client-id".to_string()), + ("client_secret".to_string(), "client-secret".to_string()), + ]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: None, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("collision")); + assert!(err.message().contains("MS_GRAPH_ACCESS_TOKEN")); + let states = crate::provider_refresh::list_all_refresh_states(state.store.as_ref()) + .await + .unwrap(); + assert!(states.is_empty()); + } + + #[tokio::test] + async fn configure_provider_refresh_treats_existing_refresh_state_keys_as_reserved() { + let state = test_server_state().await; + import_test_graph_refresh_profile(&state).await; + for name in ["first-graph", "second-graph"] { + create_provider_record( + state.store.as_ref(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: name.to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: TEST_GRAPH_PROVIDER_TYPE.to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + } + state + .store + .put_message(&Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sandbox-refresh-collision".to_string(), + name: "refresh-collision".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + providers: vec!["first-graph".to_string(), "second-graph".to_string()], + ..SandboxSpec::default() + }), + ..Default::default() + }) + .await + .unwrap(); + + handle_configure_provider_refresh( + &state, + Request::new(ConfigureProviderRefreshRequest { + provider: "first-graph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + material: HashMap::from([ + ("tenant_id".to_string(), "tenant".to_string()), + ("client_id".to_string(), "client-id".to_string()), + ("client_secret".to_string(), "client-secret".to_string()), + ]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: None, + }), + ) + .await + .unwrap(); + + let err = handle_configure_provider_refresh( + &state, + Request::new(ConfigureProviderRefreshRequest { + provider: "second-graph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + material: HashMap::from([ + ("tenant_id".to_string(), "tenant".to_string()), + ("client_id".to_string(), "client-id".to_string()), + ("client_secret".to_string(), "client-secret".to_string()), + ]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: None, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("collision")); + assert!(err.message().contains("MS_GRAPH_ACCESS_TOKEN")); + assert!(err.message().contains("first-graph")); + assert!(err.message().contains("second-graph")); + } + + #[tokio::test] + async fn configure_provider_refresh_rejects_profile_endpoint_override_and_missing_material() { + let state = test_server_state().await; + import_test_graph_refresh_profile(&state).await; + create_provider_record( + state.store.as_ref(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "msgraph".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: TEST_GRAPH_PROVIDER_TYPE.to_string(), + credentials: std::iter::once(( + "MS_GRAPH_ACCESS_TOKEN".to_string(), + "token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + + let endpoint_override = handle_configure_provider_refresh( + &state, + Request::new(ConfigureProviderRefreshRequest { + provider: "msgraph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + material: HashMap::from([ + ("tenant_id".to_string(), "tenant".to_string()), + ("client_id".to_string(), "client-id".to_string()), + ("client_secret".to_string(), "client-secret".to_string()), + ( + "token_url".to_string(), + "https://attacker.example/token".to_string(), + ), + ]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: None, + }), + ) + .await + .unwrap_err(); + assert_eq!(endpoint_override.code(), Code::InvalidArgument); + assert!(endpoint_override.message().contains("provider profile")); + + let missing_material = handle_configure_provider_refresh( + &state, + Request::new(ConfigureProviderRefreshRequest { + provider: "msgraph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32, + material: HashMap::from([("tenant_id".to_string(), "tenant".to_string())]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: None, + }), + ) + .await + .unwrap_err(); + assert_eq!(missing_material.code(), Code::InvalidArgument); + assert!(missing_material.message().contains("client_id material")); + } + + #[tokio::test] + async fn configure_provider_refresh_rejects_non_gateway_mintable_strategies() { + let state = test_server_state().await; + create_provider_record( + state.store.as_ref(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "msgraph".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "outlook".to_string(), + credentials: std::iter::once(( + "MS_GRAPH_ACCESS_TOKEN".to_string(), + "token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + + for strategy in [ + ProviderCredentialRefreshStrategy::Static, + ProviderCredentialRefreshStrategy::External, + ] { + let err = handle_configure_provider_refresh( + &state, + Request::new(ConfigureProviderRefreshRequest { + provider: "msgraph".to_string(), + credential_key: "MS_GRAPH_ACCESS_TOKEN".to_string(), + strategy: strategy as i32, + material: HashMap::new(), + secret_material_keys: Vec::new(), + expires_at_ms: None, + }), + ) + .await + .unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!( + err.message().contains("not gateway-mintable"), + "unexpected error: {}", + err.message() + ); + } + + let refresh_states = crate::provider_refresh::list_all_refresh_states(state.store.as_ref()) + .await + .unwrap(); + assert!(refresh_states.is_empty()); + } + + #[tokio::test] + async fn delete_provider_profile_removes_unused_custom_profile() { + let state = test_server_state().await; + handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(custom_profile("custom-api")), + source: "custom-api.yaml".to_string(), + }], + }), + ) + .await + .unwrap(); + + let deleted = handle_delete_provider_profile( + &state, + Request::new(DeleteProviderProfileRequest { + id: "custom-api".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(deleted.deleted); + + let missing = handle_get_provider_profile( + &state, + Request::new(GetProviderProfileRequest { + id: "custom-api".to_string(), + }), + ) + .await + .unwrap_err(); + assert_eq!(missing.code(), Code::NotFound); + } + + #[tokio::test] + async fn provider_crud_round_trip_and_semantics() { + let store = test_store().await; + + let created = provider_with_values("gitlab-local", "gitlab"); + let persisted = create_provider_record(&store, created.clone()) + .await + .unwrap(); + assert_eq!(persisted.object_name(), "gitlab-local"); + assert_eq!(persisted.r#type, "gitlab"); + assert!(!persisted.object_id().is_empty()); + let provider_id = persisted.object_id().to_string(); + + let duplicate_err = create_provider_record(&store, created).await.unwrap_err(); + assert_eq!(duplicate_err.code(), Code::AlreadyExists); + + let loaded = get_provider_record(&store, "gitlab-local").await.unwrap(); + assert_eq!(loaded.object_id(), provider_id); + + let listed = list_provider_records(&store, 100, 0).await.unwrap(); + assert_eq!(listed.len(), 1); + assert_eq!(listed[0].object_name(), "gitlab-local"); + + let updated = update_provider_record( + &store, Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: String::new(), name: "gitlab-local".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: "gitlab".to_string(), credentials: std::iter::once(( @@ -1364,6 +2659,7 @@ mod tests { .collect(), config: std::iter::once(("endpoint".to_string(), "https://gitlab.com".to_string())) .collect(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1415,47 +2711,368 @@ mod tests { } #[tokio::test] - async fn provider_validation_errors() { - let store = Store::connect("sqlite::memory:?cache=shared") + async fn delete_provider_removes_scoped_refresh_states() { + let store = test_store().await; + + let provider = create_provider_record( + &store, + Provider { + credential_expires_at_ms: HashMap::from([("API_TOKEN".to_string(), 123_456)]), + ..provider_with_values("gitlab-local", "gitlab") + }, + ) + .await + .unwrap(); + let refresh_state = crate::provider_refresh::new_refresh_state( + &provider, + "API_TOKEN", + crate::provider_refresh::NewRefreshStateConfig { + strategy: ProviderCredentialRefreshStrategy::External, + material: HashMap::from([( + "endpoint".to_string(), + "https://refresh.example.com".to_string(), + )]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: 123_456, + token_url: "https://refresh.example.com/token".to_string(), + scopes: Vec::new(), + refresh_before_seconds: 300, + max_lifetime_seconds: 3600, + }, + ) + .unwrap(); + crate::provider_refresh::put_refresh_state(&store, &refresh_state) .await .unwrap(); - let create_missing_type = create_provider_record( + let deleted = delete_provider_record(&store, "gitlab-local") + .await + .unwrap(); + assert!(deleted); + + let refresh_states = + crate::provider_refresh::list_refresh_states_for_provider(&store, provider.object_id()) + .await + .unwrap(); + assert!(refresh_states.is_empty()); + } + + #[tokio::test] + async fn delete_provider_rejects_attached_provider() { + let store = test_store().await; + + create_provider_record(&store, provider_with_values("gitlab-local", "gitlab")) + .await + .unwrap(); + store + .put_message(&Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sandbox-id".to_string(), + name: "attached-sandbox".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + providers: vec!["gitlab-local".to_string()], + ..Default::default() + }), + ..Default::default() + }) + .await + .unwrap(); + + let err = delete_provider_record(&store, "gitlab-local") + .await + .unwrap_err(); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!( + err.message().contains("attached-sandbox"), + "error should identify blocking sandbox: {}", + err.message() + ); + } + + #[tokio::test] + async fn provider_create_and_update_return_correct_resource_version() { + let store = test_store().await; + + // Create provider and verify resource_version: 1 in response + let created = provider_with_values("test-provider", "openai"); + let persisted = create_provider_record(&store, created).await.unwrap(); + assert_eq!( + persisted.metadata.as_ref().unwrap().resource_version, + 1, + "create_provider_record should return resource_version: 1 after insert" + ); + + // Update provider and verify resource_version: 2 in response + let updated = update_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "test-provider".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "openai".to_string(), + credentials: std::iter::once(( + "OPENAI_API_KEY".to_string(), + "updated-key".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + assert_eq!( + updated.metadata.as_ref().unwrap().resource_version, + 2, + "update_provider_record should return resource_version: 2 after first update" + ); + + // Update again and verify resource_version: 3 + let updated_again = update_provider_record( &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "test-provider".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "openai".to_string(), + credentials: std::iter::once(( + "OPENAI_API_KEY".to_string(), + "third-key".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + assert_eq!( + updated_again.metadata.as_ref().unwrap().resource_version, + 3, + "update_provider_record should return resource_version: 3 after second update" + ); + } + + #[tokio::test] + async fn provider_validation_errors() { + let state = test_server_state().await; + let store = state.store.as_ref(); + + let create_missing_type = create_provider_record( + store, Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: String::new(), name: "bad-provider".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await .unwrap_err(); assert_eq!(create_missing_type.code(), Code::InvalidArgument); - let get_err = get_provider_record(&store, "").await.unwrap_err(); + let create_missing_credentials = create_provider_record( + store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "gitlab-no-creds".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "gitlab".to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap_err(); + assert_eq!(create_missing_credentials.code(), Code::InvalidArgument); + + handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(ProviderProfile { + id: "delegated-refresh-api".to_string(), + display_name: "Delegated Refresh API".to_string(), + description: String::new(), + category: ProviderProfileCategory::Messaging as i32, + credentials: vec![ProviderProfileCredential { + name: "access_token".to_string(), + description: String::new(), + env_vars: vec!["DELEGATED_ACCESS_TOKEN".to_string()], + required: true, + auth_style: "bearer".to_string(), + header_name: "authorization".to_string(), + query_param: String::new(), + refresh: Some(ProviderCredentialRefresh { + strategy: ProviderCredentialRefreshStrategy::Oauth2RefreshToken + as i32, + token_url: "https://login.example/token".to_string(), + scopes: vec!["https://example.test/.default".to_string()], + refresh_before_seconds: 300, + max_lifetime_seconds: 3600, + material: vec![ + ProviderCredentialRefreshMaterial { + name: "client_id".to_string(), + description: String::new(), + required: true, + secret: false, + }, + ProviderCredentialRefreshMaterial { + name: "refresh_token".to_string(), + description: String::new(), + required: true, + secret: true, + }, + ], + }), + }], + endpoints: vec![], + binaries: vec![], + inference_capable: false, + discovery: None, + }), + source: "delegated-refresh-api.yaml".to_string(), + }], + }), + ) + .await + .unwrap(); + let delegated_refresh_bootstrap_provider = create_provider_record( + store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "delegated-refresh-no-token-yet".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "delegated-refresh-api".to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + assert!(delegated_refresh_bootstrap_provider.credentials.is_empty()); + + let mut mixed_required_profile = custom_profile("mixed-required-api"); + mixed_required_profile.credentials = vec![ + refreshable_credential("access_token", "MIXED_ACCESS_TOKEN"), + static_credential("static_token", "MIXED_STATIC_TOKEN", true), + ]; + handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(mixed_required_profile), + source: "mixed-required-api.yaml".to_string(), + }], + }), + ) + .await + .unwrap(); + let mixed_required_empty = create_provider_record( + store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "mixed-required-no-token-yet".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "mixed-required-api".to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap_err(); + assert_eq!(mixed_required_empty.code(), Code::InvalidArgument); + + let mut optional_static_profile = custom_profile("optional-static-api"); + optional_static_profile.credentials = vec![ + refreshable_credential("access_token", "OPTIONAL_ACCESS_TOKEN"), + static_credential("static_token", "OPTIONAL_STATIC_TOKEN", false), + ]; + handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(optional_static_profile), + source: "optional-static-api.yaml".to_string(), + }], + }), + ) + .await + .unwrap(); + let optional_static_empty = create_provider_record( + store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "optional-static-no-token-yet".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "optional-static-api".to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + assert!(optional_static_empty.credentials.is_empty()); + + let get_err = get_provider_record(store, "").await.unwrap_err(); assert_eq!(get_err.code(), Code::InvalidArgument); - let delete_err = delete_provider_record(&store, "").await.unwrap_err(); + let delete_err = delete_provider_record(store, "").await.unwrap_err(); assert_eq!(delete_err.code(), Code::InvalidArgument); let update_missing_err = update_provider_record( - &store, + store, Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: String::new(), name: "missing".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1465,9 +3082,7 @@ mod tests { #[tokio::test] async fn update_provider_empty_maps_is_noop() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let created = provider_with_values("noop-test", "nvidia"); let persisted = create_provider_record(&store, created).await.unwrap(); @@ -1480,10 +3095,12 @@ mod tests { name: "noop-test".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1512,9 +3129,7 @@ mod tests { #[tokio::test] async fn update_provider_empty_value_deletes_key() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let created = provider_with_values("delete-key-test", "openai"); create_provider_record(&store, created).await.unwrap(); @@ -1527,10 +3142,12 @@ mod tests { name: "delete-key-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: std::iter::once(("SECONDARY".to_string(), String::new())).collect(), config: std::iter::once(("region".to_string(), String::new())).collect(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1563,9 +3180,7 @@ mod tests { #[tokio::test] async fn update_provider_empty_type_preserves_existing() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let created = provider_with_values("type-preserve-test", "anthropic"); create_provider_record(&store, created).await.unwrap(); @@ -1578,10 +3193,12 @@ mod tests { name: "type-preserve-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1592,9 +3209,7 @@ mod tests { #[tokio::test] async fn update_provider_rejects_type_change() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let created = provider_with_values("type-change-test", "nvidia"); create_provider_record(&store, created).await.unwrap(); @@ -1607,10 +3222,12 @@ mod tests { name: "type-change-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "openai".to_string(), credentials: HashMap::new(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1621,10 +3238,8 @@ mod tests { } #[tokio::test] - async fn update_provider_validates_merged_result() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + async fn update_provider_validates_merged_result() { + let store = test_store().await; let created = provider_with_values("validate-merge-test", "gitlab"); create_provider_record(&store, created).await.unwrap(); @@ -1638,10 +3253,12 @@ mod tests { name: "validate-merge-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: std::iter::once((oversized_key, "value".to_string())).collect(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1652,20 +3269,21 @@ mod tests { #[tokio::test] async fn resolve_provider_env_empty_list_returns_empty() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let result = resolve_provider_environment(&store, &[]).await.unwrap(); assert!(result.is_empty()); } #[tokio::test] async fn resolve_provider_env_injects_credentials() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let provider = Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: String::new(), name: "claude-local".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: [ @@ -1679,6 +3297,7 @@ mod tests { "https://api.anthropic.com".to_string(), )) .collect(), + credential_expires_at_ms: HashMap::new(), }; create_provider_record(&store, provider).await.unwrap(); @@ -1690,9 +3309,49 @@ mod tests { assert!(!result.contains_key("endpoint")); } + #[tokio::test] + async fn resolve_provider_env_skips_expired_credentials_and_returns_expiry_metadata() { + let store = test_store().await; + let now_ms = crate::persistence::current_time_ms(); + let provider = Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "expiring-provider".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "test".to_string(), + credentials: [ + ("FRESH_TOKEN".to_string(), "fresh".to_string()), + ("STALE_TOKEN".to_string(), "stale".to_string()), + ] + .into_iter() + .collect(), + config: HashMap::new(), + credential_expires_at_ms: [ + ("FRESH_TOKEN".to_string(), now_ms + 60_000), + ("STALE_TOKEN".to_string(), now_ms - 60_000), + ] + .into_iter() + .collect(), + }; + create_provider_record(&store, provider).await.unwrap(); + + let result = resolve_provider_environment(&store, &["expiring-provider".to_string()]) + .await + .unwrap(); + assert_eq!(result.get("FRESH_TOKEN"), Some(&"fresh".to_string())); + assert!(!result.contains_key("STALE_TOKEN")); + assert_eq!( + result.credential_expires_at_ms.get("FRESH_TOKEN"), + Some(&(now_ms + 60_000)) + ); + } + #[tokio::test] async fn resolve_provider_env_unknown_name_returns_error() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let err = resolve_provider_environment(&store, &["nonexistent".to_string()]) .await .unwrap_err(); @@ -1702,13 +3361,14 @@ mod tests { #[tokio::test] async fn resolve_provider_env_skips_invalid_credential_keys() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let provider = Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: String::new(), name: "test-provider".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "test".to_string(), credentials: [ @@ -1719,6 +3379,7 @@ mod tests { .into_iter() .collect(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }; create_provider_record(&store, provider).await.unwrap(); @@ -1732,7 +3393,7 @@ mod tests { #[tokio::test] async fn resolve_provider_env_multiple_providers_merge() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; create_provider_record( &store, Provider { @@ -1741,6 +3402,7 @@ mod tests { name: "claude-local".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: std::iter::once(( @@ -1749,6 +3411,7 @@ mod tests { )) .collect(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1761,11 +3424,13 @@ mod tests { name: "gitlab-local".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "gitlab".to_string(), credentials: std::iter::once(("GITLAB_TOKEN".to_string(), "glpat-xyz".to_string())) .collect(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1782,8 +3447,8 @@ mod tests { } #[tokio::test] - async fn resolve_provider_env_first_credential_wins_on_duplicate_key() { - let store = Store::connect("sqlite::memory:").await.unwrap(); + async fn resolve_provider_env_rejects_duplicate_credential_keys() { + let store = test_store().await; create_provider_record( &store, Provider { @@ -1792,11 +3457,13 @@ mod tests { name: "provider-a".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: std::iter::once(("SHARED_KEY".to_string(), "first-value".to_string())) .collect(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1809,6 +3476,7 @@ mod tests { name: "provider-b".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "gitlab".to_string(), credentials: std::iter::once(( @@ -1817,25 +3485,120 @@ mod tests { )) .collect(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await .unwrap(); - let result = resolve_provider_environment( + let err = resolve_provider_environment( &store, &["provider-a".to_string(), "provider-b".to_string()], ) .await + .unwrap_err(); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("SHARED_KEY")); + assert!(err.message().contains("provider-a")); + assert!(err.message().contains("provider-b")); + } + + #[tokio::test] + async fn update_provider_rejects_credential_key_collision_for_attached_sandbox() { + let store = test_store().await; + create_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "provider-a".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "outlook".to_string(), + credentials: std::iter::once(( + "MS_GRAPH_ACCESS_TOKEN".to_string(), + "graph-token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + create_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "provider-b".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "google-drive".to_string(), + credentials: std::iter::once(( + "GOOGLE_ACCESS_TOKEN".to_string(), + "google-token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await .unwrap(); - assert_eq!(result.get("SHARED_KEY"), Some(&"first-value".to_string())); + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sandbox-collision".to_string(), + name: "collision".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + providers: vec!["provider-a".to_string(), "provider-b".to_string()], + ..SandboxSpec::default() + }), + ..Default::default() + }; + store.put_message(&sandbox).await.unwrap(); + + let err = update_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "provider-b".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: String::new(), + credentials: std::iter::once(( + "MS_GRAPH_ACCESS_TOKEN".to_string(), + "wrong-token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("collision")); + assert!(err.message().contains("MS_GRAPH_ACCESS_TOKEN")); } #[tokio::test] async fn handler_flow_resolves_credentials_from_sandbox_providers() { use openshell_core::proto::{Sandbox, SandboxPhase, SandboxSpec}; - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; create_provider_record( &store, @@ -1845,6 +3608,7 @@ mod tests { name: "my-claude".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: std::iter::once(( @@ -1853,6 +3617,7 @@ mod tests { )) .collect(), config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), }, ) .await @@ -1864,6 +3629,7 @@ mod tests { name: "test-sandbox".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { providers: vec!["my-claude".to_string()], @@ -1892,7 +3658,7 @@ mod tests { async fn handler_flow_returns_empty_when_no_providers() { use openshell_core::proto::{Sandbox, SandboxPhase, SandboxSpec}; - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { @@ -1900,6 +3666,7 @@ mod tests { name: "empty-sandbox".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec::default()), status: None, @@ -1925,8 +3692,365 @@ mod tests { async fn handler_flow_returns_none_for_unknown_sandbox() { use openshell_core::proto::Sandbox; - let store = Store::connect("sqlite::memory:").await.unwrap(); + let store = test_store().await; let result = store.get_message::("nonexistent").await.unwrap(); assert!(result.is_none()); } + + #[tokio::test] + async fn update_provider_validates_before_write() { + let store = Arc::new(test_store().await); + + // Create a valid provider + let provider = provider_with_values("test-validate-provider", "test-type"); + let created = create_provider_record(&store, provider.clone()) + .await + .unwrap(); + + // Build update request with just the name and new credentials + let mut update_req = Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "test-validate-provider".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: String::new(), // Empty type is ignored in update + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }; + + // Attempt to update with an oversized credential key (exceeds MAX_MAP_KEY_LEN) + update_req.credentials.insert( + "k".repeat(MAX_MAP_KEY_LEN + 1), + "oversized-key-value".to_string(), + ); + + let result = update_provider_record(&store, update_req).await; + + // Update should fail with InvalidArgument due to oversized key + assert!(result.is_err(), "update with invalid data should fail"); + let err = result.unwrap_err(); + assert_eq!( + err.code(), + Code::InvalidArgument, + "should fail validation with InvalidArgument" + ); + assert!( + err.message().contains("key"), + "error message should mention key: {}", + err.message() + ); + + // Verify database still contains the ORIGINAL valid provider (not the invalid one) + let stored = store + .get_message_by_name::("test-validate-provider") + .await + .unwrap() + .expect("provider should still exist"); + + assert_eq!( + stored.object_id(), + created.object_id(), + "stored provider ID should match original" + ); + assert_eq!( + stored.credentials.len(), + created.credentials.len(), + "credentials count should not have changed" + ); + assert!( + !stored + .credentials + .contains_key(&"k".repeat(MAX_MAP_KEY_LEN + 1)), + "oversized key should NOT be in database" + ); + } + + #[tokio::test] + async fn concurrent_create_provider_rejects_duplicate() { + let store = Arc::new(test_store().await); + + let provider = provider_with_values("test-concurrent-provider", "test-type"); + + // Spawn two concurrent creation attempts for the same provider + let store1 = store.clone(); + let provider1 = provider.clone(); + let handle1 = tokio::spawn(async move { create_provider_record(&store1, provider1).await }); + + let store2 = store.clone(); + let provider2 = provider.clone(); + let handle2 = tokio::spawn(async move { create_provider_record(&store2, provider2).await }); + + // Wait for both to complete + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // Exactly one should succeed, one should fail with AlreadyExists + let success_count = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + let already_exists_count = [&result1, &result2] + .iter() + .filter(|r| { + r.as_ref() + .err() + .is_some_and(|e| e.code() == Code::AlreadyExists) + }) + .count(); + + assert_eq!( + success_count, 1, + "exactly one creation should succeed, got results: {result1:?} {result2:?}" + ); + assert_eq!( + already_exists_count, 1, + "exactly one creation should fail with AlreadyExists, got results: {result1:?} {result2:?}" + ); + + // Verify the successful provider can be retrieved by name + let created_provider = [result1, result2] + .into_iter() + .find_map(Result::ok) + .expect("should have one successful creation"); + let retrieved = store + .get_message_by_name::("test-concurrent-provider") + .await + .unwrap(); + assert!( + retrieved.is_some(), + "created provider should be retrievable by name" + ); + assert_eq!( + retrieved.unwrap().object_id(), + created_provider.object_id(), + "retrieved provider should match created provider" + ); + } + + // ---- CAS (Client-driven optimistic concurrency) tests for UpdateProvider ---- + + #[tokio::test] + async fn update_provider_client_driven_cas_succeeds_with_correct_version() { + let state = test_server_state().await; + + // Create a provider + let mut provider = provider_with_values("test-provider", "generic"); + provider.metadata.as_mut().unwrap().id = String::new(); + handle_create_provider( + &state, + Request::new(CreateProviderRequest { + provider: Some(provider.clone()), + }), + ) + .await + .unwrap(); + + // Fetch the provider to get its current resource_version + let current = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + let current_version = current.metadata.as_ref().unwrap().resource_version; + + // Prepare an update with the correct resource_version + let mut updated_provider = current.clone(); + updated_provider + .credentials + .insert("NEW_KEY".to_string(), "new-value".to_string()); + updated_provider.metadata.as_mut().unwrap().resource_version = current_version; + + // Update should succeed + let response = handle_update_provider( + &state, + Request::new(UpdateProviderRequest { + provider: Some(updated_provider.clone()), + credential_expires_at_ms: HashMap::new(), + }), + ) + .await + .unwrap() + .into_inner(); + + assert_eq!( + response.provider.as_ref().unwrap().object_name(), + "test-provider" + ); + assert_eq!( + response + .provider + .as_ref() + .unwrap() + .metadata + .as_ref() + .unwrap() + .resource_version, + current_version + 1 + ); + assert!( + response + .provider + .unwrap() + .credentials + .contains_key("NEW_KEY") + ); + } + + #[tokio::test] + async fn update_provider_client_driven_cas_rejects_stale_version() { + let state = test_server_state().await; + + // Create a provider + let mut provider = provider_with_values("test-provider", "generic"); + provider.metadata.as_mut().unwrap().id = String::new(); + handle_create_provider( + &state, + Request::new(CreateProviderRequest { + provider: Some(provider.clone()), + }), + ) + .await + .unwrap(); + + // Fetch the current state + let current = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + let current_version = current.metadata.as_ref().unwrap().resource_version; + + // Prepare an update with a stale resource_version + let mut stale_provider = current.clone(); + stale_provider + .credentials + .insert("NEW_KEY".to_string(), "new-value".to_string()); + stale_provider.metadata.as_mut().unwrap().resource_version = 99; // stale version + + // Update should fail with ABORTED + let err = handle_update_provider( + &state, + Request::new(UpdateProviderRequest { + provider: Some(stale_provider), + credential_expires_at_ms: HashMap::new(), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::Aborted); + assert!( + err.message().contains("modified concurrently") + || err.message().contains("resource_version"), + "error message should mention concurrency conflict: {}", + err.message() + ); + + // Verify the provider was not modified + let unchanged = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + assert_eq!( + unchanged.metadata.as_ref().unwrap().resource_version, + current_version + ); + assert!(!unchanged.credentials.contains_key("NEW_KEY")); + } + + #[tokio::test] + async fn update_provider_concurrent_updates_with_stale_versions() { + use std::sync::Arc; + + let state = Arc::new(test_server_state().await); + + // Create a provider + let mut provider = provider_with_values("test-provider", "generic"); + provider.metadata.as_mut().unwrap().id = String::new(); + handle_create_provider( + &state, + Request::new(CreateProviderRequest { + provider: Some(provider.clone()), + }), + ) + .await + .unwrap(); + + // All three clients fetch the provider and see the same version + let initial = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + let initial_version = initial.metadata.as_ref().unwrap().resource_version; + + // Launch 3 concurrent updates, all using the same initial version + let mut handles = vec![]; + for i in 0..3 { + let state_clone = Arc::clone(&state); + let mut updated = initial.clone(); + updated + .credentials + .insert(format!("KEY_{i}"), format!("value-{i}")); + updated.metadata.as_mut().unwrap().resource_version = initial_version; + + let handle = tokio::spawn(async move { + handle_update_provider( + &state_clone, + Request::new(UpdateProviderRequest { + provider: Some(updated), + credential_expires_at_ms: HashMap::new(), + }), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Only one should succeed; others should get ABORTED + let successes = results.iter().filter(|r| r.is_ok()).count(); + let aborted_conflicts = results + .iter() + .filter(|r| r.as_ref().err().is_some_and(|e| e.code() == Code::Aborted)) + .count(); + + assert_eq!( + successes, 1, + "exactly one update should succeed with client-driven CAS" + ); + assert_eq!( + aborted_conflicts, 2, + "two updates should fail with ABORTED due to stale version" + ); + + // Final provider should have exactly 1 new credential key and resource_version = initial_version + 1 + let final_provider = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + assert_eq!( + final_provider.metadata.as_ref().unwrap().resource_version, + initial_version + 1 + ); + + // Exactly one of KEY_0, KEY_1, or KEY_2 should be present + let new_keys_count = (0..3) + .filter(|i| final_provider.credentials.contains_key(&format!("KEY_{i}"))) + .count(); + assert_eq!(new_keys_count, 1); + } } diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index ed1b4cdfc..1855972d7 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -10,33 +10,44 @@ #![allow(clippy::cast_possible_wrap)] // Intentional u32->i32 conversions for proto compat use crate::ServerState; -use crate::persistence::{ObjectType, generate_name}; +use crate::persistence::{ObjectType, WriteCondition, generate_name}; use futures::future; use openshell_core::proto::{ - CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteSandboxRequest, - DeleteSandboxResponse, ExecSandboxEvent, ExecSandboxExit, ExecSandboxRequest, - ExecSandboxStderr, ExecSandboxStdout, GetSandboxRequest, ListSandboxesRequest, - ListSandboxesResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, WatchSandboxRequest, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateSandboxRequest, + CreateSshSessionRequest, CreateSshSessionResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, ExecSandboxExit, + ExecSandboxInput, ExecSandboxRequest, ExecSandboxStderr, ExecSandboxStdout, GetSandboxRequest, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, Provider, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, SshRelayTarget, TcpForwardFrame, TcpForwardInit, + TcpRelayTarget, WatchSandboxRequest, relay_open, tcp_forward_init, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; +use openshell_core::{ObjectId, ObjectName}; use prost::Message; +use std::net::IpAddr; +use std::pin::Pin; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; -use tracing::{info, warn}; +use tracing::{debug, info, warn}; use russh::ChannelMsg; use russh::client::AuthResult; -use super::provider::is_valid_env_key; +use super::provider::{ + get_provider_record, is_valid_env_key, validate_provider_environment_keys_unique, +}; use super::validation::{ level_matches, source_matches, validate_exec_request_fields, validate_policy_safety, validate_sandbox_spec, }; -use super::{MAX_PAGE_SIZE, clamp_limit, current_time_ms}; +use super::{MAX_PAGE_SIZE, MAX_PROVIDERS, clamp_limit, current_time_ms}; + +const TCP_FORWARD_CHUNK_SIZE: usize = 64 * 1024; // --------------------------------------------------------------------------- // Sandbox lifecycle handlers @@ -66,11 +77,12 @@ pub(super) async fn handle_create_sandbox( for name in &spec.providers { state .store - .get_message_by_name::(name) + .get_message_by_name::(name) .await .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? .ok_or_else(|| Status::failed_precondition(format!("provider '{name}' not found")))?; } + validate_provider_environment_keys_unique(state.store.as_ref(), &spec.providers).await?; // Ensure the template always carries the resolved image. let mut spec = spec; @@ -93,8 +105,7 @@ pub(super) async fn handle_create_sandbox( request.name.clone() }; - let now_ms = current_time_ms() - .map_err(|e| Status::internal(format!("failed to get current time: {e}")))?; + let now_ms = current_time_ms(); let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { @@ -102,6 +113,7 @@ pub(super) async fn handle_create_sandbox( name: name.clone(), created_at_ms: now_ms, labels: request.labels.clone(), + resource_version: 0, }), spec: Some(spec), status: None, @@ -121,7 +133,27 @@ pub(super) async fn handle_create_sandbox( status })?; - let sandbox = state.compute.create_sandbox(sandbox).await?; + // Mint the gateway JWT for singleplayer drivers. K8s sandboxes skip + // this mint and bootstrap via `IssueSandboxToken` at supervisor + // startup; identifying "is this K8s?" lives in the compute layer, so + // we mint unconditionally here when the issuer is configured and let + // the K8s driver simply ignore the field. + let sandbox_token = state.sandbox_jwt_issuer.as_ref().map(|issuer| { + issuer.mint(&id).map(|minted| { + tracing::info!( + sandbox_id = %id, + "minted sandbox JWT" + ); + minted.token + }) + }); + let sandbox_token = match sandbox_token { + Some(Ok(token)) => Some(token), + Some(Err(status)) => return Err(status), + None => None, + }; + + let sandbox = state.compute.create_sandbox(sandbox, sandbox_token).await?; info!( sandbox_id = %id, @@ -161,35 +193,220 @@ pub(super) async fn handle_list_sandboxes( let request = request.into_inner(); let limit = clamp_limit(request.limit, 100, MAX_PAGE_SIZE); - // If no label selector is provided, use the unfiltered list path - let records = if request.label_selector.is_empty() { + let sandboxes: Vec = if request.label_selector.is_empty() { state .store - .list(Sandbox::object_type(), limit, request.offset) + .list_messages(limit, request.offset) .await .map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))? } else { crate::grpc::validation::validate_label_selector(&request.label_selector)?; state .store - .list_with_selector( - Sandbox::object_type(), - &request.label_selector, - limit, - request.offset, - ) + .list_messages_with_selector(&request.label_selector, limit, request.offset) .await .map_err(|e| Status::internal(format!("list sandboxes with selector failed: {e}")))? }; - let mut sandboxes = Vec::with_capacity(records.len()); - for record in records { - let sandbox = Sandbox::decode(record.payload.as_slice()) - .map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?; - sandboxes.push(sandbox); + Ok(Response::new(ListSandboxesResponse { sandboxes })) +} + +pub(super) async fn handle_list_sandbox_providers( + state: &Arc, + request: Request, +) -> Result, Status> { + let sandbox = sandbox_by_name(state, &request.into_inner().sandbox_name).await?; + let providers = providers_for_sandbox(state, &sandbox).await?; + Ok(Response::new(ListSandboxProvidersResponse { providers })) +} + +pub(super) async fn handle_attach_sandbox_provider( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + if request.provider_name.is_empty() { + return Err(Status::invalid_argument("provider_name is required")); } - Ok(Response::new(ListSandboxesResponse { sandboxes })) + // Validate provider name would not violate sandbox spec constraints if added + // (pre-validation ensures CAS mutations preserve invariants) + if request.provider_name.len() > super::MAX_NAME_LEN { + return Err(Status::invalid_argument(format!( + "provider_name exceeds maximum length ({} > {})", + request.provider_name.len(), + super::MAX_NAME_LEN + ))); + } + + get_provider_record(state.store.as_ref(), &request.provider_name) + .await + .map_err(|err| { + if err.code() == tonic::Code::NotFound { + Status::failed_precondition(format!( + "provider '{}' not found", + request.provider_name + )) + } else { + err + } + })?; + + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; + let sandbox = sandbox_by_name(state, &request.sandbox_name).await?; + let sandbox_id = sandbox + .metadata + .as_ref() + .ok_or_else(|| Status::internal("sandbox metadata is missing"))? + .id + .clone(); + + // Pre-check: fail fast if sandbox spec is missing (invariant violation) + let spec = sandbox + .spec + .as_ref() + .ok_or_else(|| Status::internal("sandbox spec is missing"))?; + + // Pre-check: fail fast if already at MAX_PROVIDERS limit (avoid spurious CAS conflicts) + // Note: This is an optimization; the CAS closure rechecks after dedupe in case of races + if spec.providers.len() >= MAX_PROVIDERS + && !spec + .providers + .iter() + .any(|name| name == &request.provider_name) + { + return Err(Status::invalid_argument(format!( + "providers list exceeds maximum ({MAX_PROVIDERS})" + ))); + } + let mut candidate_spec = spec.clone(); + dedupe_provider_names(&mut candidate_spec.providers); + if !candidate_spec + .providers + .iter() + .any(|name| name == &request.provider_name) + { + candidate_spec.providers.push(request.provider_name.clone()); + } + validate_sandbox_spec(&request.sandbox_name, &candidate_spec)?; + validate_provider_environment_keys_unique(state.store.as_ref(), &candidate_spec.providers) + .await?; + + let provider_name = request.provider_name.clone(); + let attached = Arc::new(AtomicBool::new(false)); + let attached_clone = attached.clone(); + + let sandbox = state + .store + .update_message_cas::( + &sandbox_id, + request.expected_resource_version, + |sandbox| { + let Some(ref mut spec) = sandbox.spec else { + // Spec should always exist post-creation; if missing, fail CAS to surface error + return; + }; + + dedupe_provider_names(&mut spec.providers); + if !spec.providers.iter().any(|name| name == &provider_name) + && spec.providers.len() < MAX_PROVIDERS + { + spec.providers.push(provider_name.clone()); + attached_clone.store(true, Ordering::Relaxed); + } + }, + ) + .await + .map_err(|e| super::persistence_error_to_status(e, "attach sandbox provider"))?; + + let attached = attached.load(Ordering::Relaxed); + + info!( + sandbox_name = %request.sandbox_name, + provider_name = %request.provider_name, + attached, + "AttachSandboxProvider request completed successfully" + ); + + Ok(Response::new(AttachSandboxProviderResponse { + sandbox: Some(sandbox), + attached, + })) +} + +pub(super) async fn handle_detach_sandbox_provider( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + if request.provider_name.is_empty() { + return Err(Status::invalid_argument("provider_name is required")); + } + + // Validate provider name (pre-validation ensures CAS mutations preserve invariants) + if request.provider_name.len() > super::MAX_NAME_LEN { + return Err(Status::invalid_argument(format!( + "provider_name exceeds maximum length ({} > {})", + request.provider_name.len(), + super::MAX_NAME_LEN + ))); + } + + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; + let sandbox = sandbox_by_name(state, &request.sandbox_name).await?; + let sandbox_id = sandbox + .metadata + .as_ref() + .ok_or_else(|| Status::internal("sandbox metadata is missing"))? + .id + .clone(); + + // Pre-check: fail fast if sandbox spec is missing (invariant violation) + let _spec = sandbox + .spec + .as_ref() + .ok_or_else(|| Status::internal("sandbox spec is missing"))?; + + let provider_name = request.provider_name.clone(); + let detached = Arc::new(AtomicBool::new(false)); + let detached_clone = detached.clone(); + + let sandbox = state + .store + .update_message_cas::( + &sandbox_id, + request.expected_resource_version, + |sandbox| { + let Some(ref mut spec) = sandbox.spec else { + // Spec should always exist post-creation; if missing, fail CAS to surface error + return; + }; + + let before_len = spec.providers.len(); + spec.providers.retain(|name| name != &provider_name); + if spec.providers.len() != before_len { + detached_clone.store(true, Ordering::Relaxed); + // Only dedupe after making a change + dedupe_provider_names(&mut spec.providers); + } + }, + ) + .await + .map_err(|e| super::persistence_error_to_status(e, "detach sandbox provider"))?; + + let detached = detached.load(Ordering::Relaxed); + + info!( + sandbox_name = %request.sandbox_name, + provider_name = %request.provider_name, + detached, + "DetachSandboxProvider request completed successfully" + ); + + Ok(Response::new(DetachSandboxProviderResponse { + sandbox: Some(sandbox), + detached, + })) } pub(super) async fn handle_delete_sandbox( @@ -206,6 +423,56 @@ pub(super) async fn handle_delete_sandbox( Ok(Response::new(DeleteSandboxResponse { deleted })) } +async fn sandbox_by_name(state: &Arc, name: &str) -> Result { + if name.is_empty() { + return Err(Status::invalid_argument("sandbox_name is required")); + } + + state + .store + .get_message_by_name::(name) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found")) +} + +async fn providers_for_sandbox( + state: &Arc, + sandbox: &Sandbox, +) -> Result, Status> { + let provider_names = sandbox + .spec + .as_ref() + .map(|spec| spec.providers.as_slice()) + .ok_or_else(|| Status::failed_precondition("sandbox spec is missing"))?; + + let mut providers = Vec::with_capacity(provider_names.len()); + for name in provider_names { + let provider = get_provider_record(state.store.as_ref(), name) + .await + .map_err(|err| { + if err.code() == tonic::Code::NotFound { + Status::failed_precondition(format!("provider '{name}' not found")) + } else { + err + } + })?; + providers.push(provider); + } + Ok(providers) +} + +fn dedupe_provider_names(provider_names: &mut Vec) { + let mut index = 0; + while index < provider_names.len() { + if provider_names[..index].contains(&provider_names[index]) { + provider_names.remove(index); + } else { + index += 1; + } + } +} + // --------------------------------------------------------------------------- // Watch handler // --------------------------------------------------------------------------- @@ -233,6 +500,7 @@ pub(super) async fn handle_watch_sandbox( let log_since_ms = req.log_since_ms; let log_sources = req.log_sources; let log_min_level = req.log_min_level; + let event_tail = req.event_tail; let (tx, rx) = mpsc::channel::>(256); let state = state.clone(); @@ -337,7 +605,7 @@ pub(super) async fn handle_watch_sandbox( for evt in state .tracing_log_bus .platform_event_bus - .tail(&sandbox_id, 50) + .tail(&sandbox_id, event_tail as usize) { if tx.send(Ok(evt)).await.is_err() { return; @@ -470,9 +738,8 @@ pub(super) async fn handle_exec_sandbox( } // Open a relay channel through the supervisor session. Use a 15s - // session-wait timeout — enough to cover a transient supervisor - // reconnect, but shorter than `/connect/ssh` since `ExecSandbox` is - // typically called during normal operation (not right after create). + // session-wait timeout, enough to cover a transient supervisor reconnect + // while still failing quickly during normal operation. let (channel_id, relay_rx) = state .supervisor_sessions .open_relay(sandbox.object_id(), std::time::Duration::from_secs(15)) @@ -490,24 +757,10 @@ pub(super) async fn handle_exec_sandbox( let (tx, rx) = mpsc::channel::>(256); tokio::spawn(async move { // Wait for the supervisor's reverse CONNECT to deliver the relay stream. - let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) - .await - { - Ok(Ok(stream)) => stream, - Ok(Err(_)) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay channel dropped"); - let _ = tx - .send(Err(Status::unavailable("relay channel dropped"))) - .await; - return; - } - Err(_) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay open timed out"); - let _ = tx - .send(Err(Status::deadline_exceeded("relay open timed out"))) - .await; - return; - } + let Some(relay_stream) = + await_relay_stream(relay_rx, &tx, &sandbox_id, &channel_id, "ExecSandbox").await + else { + return; }; if let Err(err) = stream_exec_over_relay( @@ -530,22 +783,66 @@ pub(super) async fn handle_exec_sandbox( Ok(Response::new(ReceiverStream::new(rx))) } -// --------------------------------------------------------------------------- -// SSH session handlers -// --------------------------------------------------------------------------- +/// Wait for the supervisor's reverse CONNECT to deliver a relay stream. +/// +/// Returns `Some(stream)` on success. On any failure the error is sent on `tx` +/// and `None` is returned; the caller should then `return` immediately. +async fn await_relay_stream( + relay_rx: oneshot::Receiver>, + tx: &mpsc::Sender>, + sandbox_id: &str, + channel_id: &str, + context: &str, +) -> Option { + match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx).await { + Ok(Ok(Ok(stream))) => Some(stream), + Ok(Ok(Err(status))) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "{context}: relay target open failed"); + let _ = tx.send(Err(status)).await; + None + } + Ok(Err(_)) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "{context}: relay channel dropped"); + let _ = tx + .send(Err(Status::unavailable("relay channel dropped"))) + .await; + None + } + Err(_) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "{context}: relay open timed out"); + let _ = tx + .send(Err(Status::deadline_exceeded("relay open timed out"))) + .await; + None + } + } +} -pub(super) async fn handle_create_ssh_session( +pub(super) async fn handle_forward_tcp( state: &Arc, - request: Request, -) -> Result, Status> { - let req = request.into_inner(); - if req.sandbox_id.is_empty() { - return Err(Status::invalid_argument("sandbox_id is required")); - } + request: Request>, +) -> Result< + Response< + Pin> + Send + 'static>>, + >, + Status, +> { + let mut inbound = request.into_inner(); + let first = inbound + .message() + .await? + .ok_or_else(|| Status::invalid_argument("empty ForwardTcp stream"))?; + let Some(openshell_core::proto::tcp_forward_frame::Payload::Init(init)) = first.payload else { + return Err(Status::invalid_argument( + "first TcpForwardFrame must be init", + )); + }; + + let target = validate_tcp_forward_init(&init)?; let sandbox = state .store - .get_message::(&req.sandbox_id) + .get_message::(&init.sandbox_id) .await .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; @@ -554,116 +851,532 @@ pub(super) async fn handle_create_ssh_session( return Err(Status::failed_precondition("sandbox is not ready")); } - let token = uuid::Uuid::new_v4().to_string(); - let now_ms = current_time_ms() - .map_err(|e| Status::internal(format!("timestamp generation failed: {e}")))?; - let expires_at_ms = if state.config.ssh_session_ttl_secs > 0 { - now_ms + (state.config.ssh_session_ttl_secs as i64 * 1000) - } else { - 0 - }; - let session = SshSession { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: token.clone(), - name: generate_name(), - created_at_ms: now_ms, - labels: std::collections::HashMap::new(), - }), - sandbox_id: req.sandbox_id.clone(), - token: token.clone(), - revoked: false, - expires_at_ms, - }; + let connection_guard = acquire_forward_connection_guard(state, &init, &sandbox).await?; + let (channel_id, relay_rx) = state + .supervisor_sessions + .open_relay_with_target( + sandbox.object_id(), + target, + init.service_id.clone(), + std::time::Duration::from_secs(15), + ) + .await + .map_err(|e| Status::unavailable(format!("supervisor relay failed: {e}")))?; - // Ensure metadata is valid (defense in depth - should always be true for server-constructed metadata) - super::validation::validate_object_metadata(session.metadata.as_ref(), "ssh_session")?; + let sandbox_id = sandbox.object_id().to_string(); + let (tx, rx) = mpsc::channel::>(256); + tokio::spawn(async move { + let _connection_guard = connection_guard; + let Some(relay_stream) = + await_relay_stream(relay_rx, &tx, &sandbox_id, &channel_id, "ForwardTcp").await + else { + return; + }; - state - .store - .put_message(&session) - .await - .map_err(|e| Status::internal(format!("persist ssh session failed: {e}")))?; + bridge_forward_tcp_stream(inbound, relay_stream, tx, &sandbox_id, &channel_id).await; + }); - let (gateway_host, gateway_port) = resolve_gateway(&state.config); - let scheme = if state.config.tls.is_some() { - "https" - } else { - "http" - }; + let stream: Pin< + Box> + Send + 'static>, + > = Box::pin(ReceiverStream::new(rx)); + Ok(Response::new(stream)) +} - Ok(Response::new(CreateSshSessionResponse { - sandbox_id: req.sandbox_id, - token, - gateway_host, - gateway_port: gateway_port.into(), - gateway_scheme: scheme.to_string(), - connect_path: state.config.ssh_connect_path.clone(), - host_key_fingerprint: String::new(), - expires_at_ms, - })) +struct ForwardConnectionGuard { + state: Arc, + token: Option, + sandbox_id: String, } -pub(super) async fn handle_revoke_ssh_session( +impl Drop for ForwardConnectionGuard { + fn drop(&mut self) { + if let Some(token) = self.token.as_deref() { + decrement_ssh_connection_count(&self.state.ssh_connections_by_token, token); + decrement_ssh_connection_count( + &self.state.ssh_connections_by_sandbox, + &self.sandbox_id, + ); + } + } +} + +async fn acquire_forward_connection_guard( state: &Arc, - request: Request, -) -> Result, Status> { - let token = request.into_inner().token; + init: &TcpForwardInit, + sandbox: &Sandbox, +) -> Result { + let sandbox_id = sandbox.object_id().to_string(); + let token = init.authorization_token.trim(); if token.is_empty() { - return Err(Status::invalid_argument("token is required")); + return Err(Status::unauthenticated( + "authorization_token is required for ForwardTcp", + )); } + validate_ssh_forward_token(state, token, &sandbox_id).await?; + acquire_ssh_connection_slots( + &state.ssh_connections_by_token, + &state.ssh_connections_by_sandbox, + token, + &sandbox_id, + )?; + + Ok(ForwardConnectionGuard { + state: state.clone(), + token: Some(token.to_string()), + sandbox_id, + }) +} + +async fn validate_ssh_forward_token( + state: &Arc, + token: &str, + sandbox_id: &str, +) -> Result<(), Status> { let session = state .store - .get_message::(&token) + .get_message::(token) .await - .map_err(|e| Status::internal(format!("fetch ssh session failed: {e}")))?; + .map_err(|e| Status::internal(format!("fetch SSH session failed: {e}")))? + .ok_or_else(|| Status::unauthenticated("SSH session token not found"))?; - let Some(mut session) = session else { - return Ok(Response::new(RevokeSshSessionResponse { revoked: false })); - }; + if session.revoked || session.sandbox_id != sandbox_id { + return Err(Status::unauthenticated("SSH session token is not valid")); + } - session.revoked = true; - state - .store - .put_message(&session) - .await - .map_err(|e| Status::internal(format!("persist ssh session failed: {e}")))?; + if session.expires_at_ms > 0 { + let now_ms = current_time_ms(); + if now_ms > session.expires_at_ms { + return Err(Status::unauthenticated("SSH session token expired")); + } + } - Ok(Response::new(RevokeSshSessionResponse { revoked: true })) + Ok(()) } -// --------------------------------------------------------------------------- -// Exec transport helpers -// --------------------------------------------------------------------------- +fn acquire_ssh_connection_slots( + token_counts: &std::sync::Mutex>, + sandbox_counts: &std::sync::Mutex>, + token: &str, + sandbox_id: &str, +) -> Result<(), Status> { + const MAX_CONNECTIONS_PER_TOKEN: u32 = 3; + const MAX_CONNECTIONS_PER_SANDBOX: u32 = 20; -fn resolve_gateway(config: &openshell_core::Config) -> (String, u16) { - let host = if config.ssh_gateway_host.is_empty() { - config.bind_address.ip().to_string() - } else { - config.ssh_gateway_host.clone() - }; - let port = if config.ssh_gateway_port == 0 { - config.bind_address.port() - } else { - config.ssh_gateway_port - }; - (host, port) -} + { + let mut counts = token_counts.lock().unwrap(); + let count = counts.entry(token.to_string()).or_insert(0); + if *count >= MAX_CONNECTIONS_PER_TOKEN { + return Err(Status::resource_exhausted( + "SSH session connection limit reached", + )); + } + *count += 1; + } -/// Shell-escape a value for embedding in a POSIX shell command. -/// -/// Wraps unsafe values in single quotes with the standard `'\''` idiom for -/// embedded single-quote characters. Rejects null bytes which can truncate -/// shell parsing at the C level. -fn shell_escape(value: &str) -> Result { - if value.bytes().any(|b| b == 0) { - return Err("value contains null bytes".to_string()); + { + let mut counts = sandbox_counts.lock().unwrap(); + let count = counts.entry(sandbox_id.to_string()).or_insert(0); + if *count >= MAX_CONNECTIONS_PER_SANDBOX { + decrement_ssh_connection_count(token_counts, token); + return Err(Status::resource_exhausted( + "sandbox SSH connection limit reached", + )); + } + *count += 1; } - if value.bytes().any(|b| b == b'\n' || b == b'\r') { - return Err("value contains newline or carriage return".to_string()); + + Ok(()) +} + +fn decrement_ssh_connection_count( + counts: &std::sync::Mutex>, + key: &str, +) { + let mut counts = counts.lock().unwrap(); + if let Some(count) = counts.get_mut(key) { + *count = count.saturating_sub(1); + if *count == 0 { + counts.remove(key); + } } - if value.is_empty() { - return Ok("''".to_string()); +} + +fn validate_tcp_forward_init(init: &TcpForwardInit) -> Result { + if init.sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox_id is required")); + } + + if let Some(target) = init.target.as_ref() { + return match target { + tcp_forward_init::Target::Ssh(_) => { + Ok(relay_open::Target::Ssh(SshRelayTarget::default())) + } + tcp_forward_init::Target::Tcp(target) => Ok(relay_open::Target::Tcp( + validate_tcp_forward_target(target)?, + )), + }; + } + + Err(Status::invalid_argument("tcp forward target is required")) +} + +fn validate_tcp_forward_target(target: &TcpRelayTarget) -> Result { + if target.port == 0 || target.port > u32::from(u16::MAX) { + return Err(Status::invalid_argument( + "tcp target port must be between 1 and 65535", + )); + } + + validate_tcp_target_parts(target.host.trim(), target.port).map(|host| TcpRelayTarget { + host, + port: target.port, + }) +} + +fn validate_tcp_target_parts(host: &str, _port: u32) -> Result { + if host.is_empty() { + return Err(Status::invalid_argument("tcp target host is required")); + } + if host.eq_ignore_ascii_case("localhost") { + return Ok("127.0.0.1".to_string()); + } + + let ip: IpAddr = host + .parse() + .map_err(|_| Status::invalid_argument("tcp target host must be loopback"))?; + if ip.is_loopback() { + Ok(ip.to_string()) + } else { + Err(Status::invalid_argument("tcp target host must be loopback")) + } +} + +async fn bridge_forward_tcp_stream( + mut inbound: tonic::Streaming, + relay_stream: tokio::io::DuplexStream, + tx: mpsc::Sender>, + sandbox_id: &str, + channel_id: &str, +) { + let (mut relay_read, mut relay_write) = tokio::io::split(relay_stream); + + let sandbox_id_in = sandbox_id.to_string(); + let channel_id_in = channel_id.to_string(); + tokio::spawn(async move { + loop { + match inbound.message().await { + Ok(Some(frame)) => { + let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = + frame.payload + else { + warn!(sandbox_id = %sandbox_id_in, channel_id = %channel_id_in, "ForwardTcp: received non-data frame after init"); + break; + }; + if data.is_empty() { + continue; + } + if let Err(err) = + tokio::io::AsyncWriteExt::write_all(&mut relay_write, &data).await + { + warn!(sandbox_id = %sandbox_id_in, channel_id = %channel_id_in, error = %err, "ForwardTcp: write to relay failed"); + break; + } + } + Ok(None) => break, + Err(err) => { + debug!(sandbox_id = %sandbox_id_in, channel_id = %channel_id_in, error = %err, "ForwardTcp: inbound stream ended"); + break; + } + } + } + let _ = tokio::io::AsyncWriteExt::shutdown(&mut relay_write).await; + }); + + let mut buf = vec![0u8; TCP_FORWARD_CHUNK_SIZE]; + loop { + match tokio::io::AsyncReadExt::read(&mut relay_read, &mut buf).await { + Ok(0) => break, + Ok(n) => { + let frame = TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Data( + buf[..n].to_vec(), + )), + }; + if tx.send(Ok(frame)).await.is_err() { + break; + } + } + Err(err) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %err, "ForwardTcp: read from relay failed"); + let _ = tx + .send(Err(Status::unavailable(format!( + "relay read failed: {err}" + )))) + .await; + break; + } + } + } +} + +// --------------------------------------------------------------------------- +// Interactive exec handler (bidirectional stdin streaming) +// --------------------------------------------------------------------------- + +fn validate_interactive_exec_start( + msg: Option, +) -> Result { + use openshell_core::proto::exec_sandbox_input::Payload; + + let msg = + msg.ok_or_else(|| Status::invalid_argument("empty stream: expected start message"))?; + + let Some(Payload::Start(req)) = msg.payload else { + return Err(Status::invalid_argument( + "first message must be a start payload", + )); + }; + + if req.sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox_id is required")); + } + if req.command.is_empty() { + return Err(Status::invalid_argument("command is required")); + } + if req.environment.keys().any(|key| !is_valid_env_key(key)) { + return Err(Status::invalid_argument( + "environment keys must match ^[A-Za-z_][A-Za-z0-9_]*$", + )); + } + validate_exec_request_fields(&req)?; + + Ok(req) +} + +pub(super) async fn handle_exec_sandbox_interactive( + state: &Arc, + request: Request>, +) -> Result>>, Status> { + use openshell_core::ObjectId; + + let mut input_stream = request.into_inner(); + + let first_msg = input_stream + .message() + .await + .map_err(|e| Status::internal(format!("failed to read first message: {e}")))?; + + let req = validate_interactive_exec_start(first_msg)?; + + let sandbox = state + .store + .get_message::(&req.sandbox_id) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + + if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { + return Err(Status::failed_precondition("sandbox is not ready")); + } + + let (channel_id, relay_rx) = state + .supervisor_sessions + .open_relay(sandbox.object_id(), std::time::Duration::from_secs(15)) + .await + .map_err(|e| Status::unavailable(format!("supervisor relay failed: {e}")))?; + + let command_str = build_remote_exec_command(&req) + .map_err(|e| Status::invalid_argument(format!("command construction failed: {e}")))?; + let timeout_seconds = req.timeout_seconds; + let cols = if req.cols == 0 { 80 } else { req.cols }; + let rows = if req.rows == 0 { 24 } else { req.rows }; + + let sandbox_id = sandbox.object_id().to_string(); + + let (tx, rx) = mpsc::channel::>(256); + tokio::spawn(async move { + let Some(relay_stream) = await_relay_stream( + relay_rx, + &tx, + &sandbox_id, + &channel_id, + "ExecSandboxInteractive", + ) + .await + else { + return; + }; + + if let Err(err) = stream_interactive_exec_over_relay( + tx.clone(), + &sandbox_id, + &channel_id, + relay_stream, + &command_str, + input_stream, + timeout_seconds, + cols, + rows, + ) + .await + { + warn!(sandbox_id = %sandbox_id, error = %err, "ExecSandboxInteractive failed"); + let _ = tx.send(Err(err)).await; + } + }); + + Ok(Response::new(ReceiverStream::new(rx))) +} + +// --------------------------------------------------------------------------- +// SSH session handlers +// --------------------------------------------------------------------------- + +pub(super) async fn handle_create_ssh_session( + state: &Arc, + request: Request, +) -> Result, Status> { + let req = request.into_inner(); + if req.sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox_id is required")); + } + + let sandbox = state + .store + .get_message::(&req.sandbox_id) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + + if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { + return Err(Status::failed_precondition("sandbox is not ready")); + } + + let token = uuid::Uuid::new_v4().to_string(); + let now_ms = current_time_ms(); + let expires_at_ms = if state.config.ssh_session_ttl_secs > 0 { + now_ms + (state.config.ssh_session_ttl_secs as i64 * 1000) + } else { + 0 + }; + let session = SshSession { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: token.clone(), + name: generate_name(), + created_at_ms: now_ms, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + sandbox_id: req.sandbox_id.clone(), + token: token.clone(), + revoked: false, + expires_at_ms, + }; + + // Ensure metadata is valid (defense in depth - should always be true for server-constructed metadata) + super::validation::validate_object_metadata(session.metadata.as_ref(), "ssh_session")?; + + // Use MustCreate to atomically ensure the session token is unique + state + .store + .put_if( + SshSession::object_type(), + &token, + session.object_name(), + &session.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) + .await + .map_err(|e| Status::internal(format!("persist ssh session failed: {e}")))?; + + let (gateway_host, gateway_port) = resolve_gateway(&state.config); + let scheme = if state.config.tls.is_some() { + "https" + } else { + "http" + }; + + Ok(Response::new(CreateSshSessionResponse { + sandbox_id: req.sandbox_id, + token, + gateway_host, + gateway_port: gateway_port.into(), + gateway_scheme: scheme.to_string(), + host_key_fingerprint: String::new(), + expires_at_ms, + })) +} + +pub(super) async fn handle_revoke_ssh_session( + state: &Arc, + request: Request, +) -> Result, Status> { + let token = request.into_inner().token; + if token.is_empty() { + return Err(Status::invalid_argument("token is required")); + } + + let session = state + .store + .get_message::(&token) + .await + .map_err(|e| Status::internal(format!("fetch ssh session failed: {e}")))?; + + let Some(mut session) = session else { + return Ok(Response::new(RevokeSshSessionResponse { revoked: false })); + }; + + let resource_version = session + .metadata + .as_ref() + .map_or(0, |metadata| metadata.resource_version); + + session.revoked = true; + + // Use CAS to prevent lost updates from concurrent revocations + state + .store + .put_if( + SshSession::object_type(), + session.object_id(), + session.object_name(), + &session.encode_to_vec(), + None, + WriteCondition::MatchResourceVersion(resource_version), + ) + .await + .map_err(|e| super::persistence_error_to_status(e, "revoke ssh session"))?; + + Ok(Response::new(RevokeSshSessionResponse { revoked: true })) +} + +// --------------------------------------------------------------------------- +// Exec transport helpers +// --------------------------------------------------------------------------- + +fn resolve_gateway(config: &openshell_core::Config) -> (String, u16) { + ( + config.bind_address.ip().to_string(), + config.bind_address.port(), + ) +} + +/// Shell-escape a value for embedding in a POSIX shell command. +/// +/// Wraps unsafe values in single quotes with the standard `'\''` idiom for +/// embedded single-quote characters. Rejects null bytes which can truncate +/// shell parsing at the C level. +fn shell_escape(value: &str) -> Result { + if value.bytes().any(|b| b == 0) { + return Err("value contains null bytes".to_string()); + } + if value.bytes().any(|b| b == b'\n' || b == b'\r') { + return Err("value contains newline or carriage return".to_string()); + } + if value.is_empty() { + return Ok("''".to_string()); } let safe = value .bytes() @@ -706,8 +1419,7 @@ fn build_remote_exec_command(req: &ExecSandboxRequest) -> Result /// /// This is the relay equivalent of `stream_exec_over_ssh`. Instead of dialing a /// sandbox endpoint directly, the SSH transport runs over a `DuplexStream` that -/// is bridged to the supervisor's local SSH daemon via a reverse HTTP CONNECT -/// tunnel. +/// is bridged to the supervisor's local SSH daemon via `RelayStream`. #[allow(clippy::too_many_arguments)] async fn stream_exec_over_relay( tx: mpsc::Sender>, @@ -783,14 +1495,216 @@ async fn stream_exec_over_relay( Ok(()) } -/// Create a localhost SSH proxy that bridges to a relay `DuplexStream`. -/// -/// The proxy forwards raw SSH bytes between the `russh` client and the relay. -/// The supervisor bridges the relay to its Unix-socket SSH daemon; filesystem -/// permissions on that socket are the only access-control boundary. -async fn start_single_use_ssh_proxy_over_relay( - mut relay_stream: tokio::io::DuplexStream, -) -> Result<(u16, tokio::task::JoinHandle<()>), Box> { +#[allow(clippy::too_many_arguments)] +async fn stream_interactive_exec_over_relay( + tx: mpsc::Sender>, + sandbox_id: &str, + channel_id: &str, + relay_stream: tokio::io::DuplexStream, + command: &str, + input_stream: tonic::Streaming, + timeout_seconds: u32, + cols: u32, + rows: u32, +) -> Result<(), Status> { + let command_preview: String = command.chars().take(120).collect(); + info!( + sandbox_id = %sandbox_id, + channel_id = %channel_id, + command_len = command.len(), + command_preview = %command_preview, + "ExecSandboxInteractive (relay): command started" + ); + + let (local_proxy_port, proxy_task) = start_single_use_ssh_proxy_over_relay(relay_stream) + .await + .map_err(|e| Status::internal(format!("failed to start relay proxy: {e}")))?; + + let exec = run_interactive_exec_with_russh( + local_proxy_port, + command, + input_stream, + cols, + rows, + tx.clone(), + ); + + let exec_result = if timeout_seconds == 0 { + exec.await + } else if let Ok(r) = tokio::time::timeout( + std::time::Duration::from_secs(u64::from(timeout_seconds)), + exec, + ) + .await + { + r + } else { + let _ = tx + .send(Ok(ExecSandboxEvent { + payload: Some(openshell_core::proto::exec_sandbox_event::Payload::Exit( + ExecSandboxExit { exit_code: 124 }, + )), + })) + .await; + let _ = proxy_task.await; + return Ok(()); + }; + + let exit_code = match exec_result { + Ok(code) => code, + Err(status) => { + let _ = proxy_task.await; + return Err(status); + } + }; + + let _ = proxy_task.await; + + let _ = tx + .send(Ok(ExecSandboxEvent { + payload: Some(openshell_core::proto::exec_sandbox_event::Payload::Exit( + ExecSandboxExit { exit_code }, + )), + })) + .await; + + Ok(()) +} + +async fn run_interactive_exec_with_russh( + local_proxy_port: u16, + command: &str, + mut input_stream: tonic::Streaming, + cols: u32, + rows: u32, + tx: mpsc::Sender>, +) -> Result { + use openshell_core::proto::exec_sandbox_input::Payload; + use russh::ChannelMsg; + + if command.as_bytes().contains(&0) { + return Err(Status::invalid_argument( + "command contains null bytes at transport boundary", + )); + } + if command.len() > MAX_COMMAND_STRING_LEN { + return Err(Status::invalid_argument(format!( + "command exceeds {MAX_COMMAND_STRING_LEN} byte limit at transport boundary" + ))); + } + + let stream = TcpStream::connect(("127.0.0.1", local_proxy_port)) + .await + .map_err(|e| Status::internal(format!("failed to connect to ssh proxy: {e}")))?; + + let config = Arc::new(russh::client::Config::default()); + let mut client = russh::client::connect_stream(config, stream, SandboxSshClientHandler) + .await + .map_err(|e| Status::internal(format!("failed to establish ssh transport: {e}")))?; + + match client + .authenticate_none("sandbox") + .await + .map_err(|e| Status::internal(format!("failed to authenticate ssh session: {e}")))? + { + AuthResult::Success => {} + AuthResult::Failure { .. } => { + return Err(Status::permission_denied( + "ssh authentication rejected by sandbox", + )); + } + } + + let channel = client + .channel_open_session() + .await + .map_err(|e| Status::internal(format!("failed to open ssh channel: {e}")))?; + + channel + .request_pty(false, "xterm-256color", cols, rows, 0, 0, &[]) + .await + .map_err(|e| Status::internal(format!("failed to allocate PTY: {e}")))?; + + channel + .exec(true, command.as_bytes()) + .await + .map_err(|e| Status::internal(format!("failed to execute command over ssh: {e}")))?; + + let (mut read_half, write_half) = channel.split(); + + let stdin_task = tokio::spawn(async move { + while let Ok(Some(msg)) = input_stream.message().await { + match msg.payload { + Some(Payload::Stdin(data)) => { + if write_half.data(std::io::Cursor::new(data)).await.is_err() { + break; + } + } + Some(Payload::Resize(resize)) => { + let _ = write_half + .window_change(resize.cols, resize.rows, 0, 0) + .await; + } + Some(Payload::Start(_)) | None => {} + } + } + let _ = write_half.eof().await; + let _ = write_half.close().await; + }); + + let mut exit_code: Option = None; + while let Some(msg) = read_half.wait().await { + match msg { + ChannelMsg::Data { data } => { + let event = Ok(ExecSandboxEvent { + payload: Some(openshell_core::proto::exec_sandbox_event::Payload::Stdout( + ExecSandboxStdout { + data: data.to_vec(), + }, + )), + }); + if tx.send(event).await.is_err() { + break; + } + } + ChannelMsg::ExtendedData { data, .. } => { + let event = Ok(ExecSandboxEvent { + payload: Some(openshell_core::proto::exec_sandbox_event::Payload::Stderr( + ExecSandboxStderr { + data: data.to_vec(), + }, + )), + }); + if tx.send(event).await.is_err() { + break; + } + } + ChannelMsg::ExitStatus { exit_status } => { + let converted = i32::try_from(exit_status).unwrap_or(i32::MAX); + exit_code = Some(converted); + } + ChannelMsg::Close => break, + _ => {} + } + } + + stdin_task.abort(); + + let _ = client + .disconnect(russh::Disconnect::ByApplication, "exec complete", "en") + .await; + + Ok(exit_code.unwrap_or(1)) +} + +/// Create a localhost SSH proxy that bridges to a relay `DuplexStream`. +/// +/// The proxy forwards raw SSH bytes between the `russh` client and the relay. +/// The supervisor bridges the relay to its Unix-socket SSH daemon; filesystem +/// permissions on that socket are the only access-control boundary. +async fn start_single_use_ssh_proxy_over_relay( + mut relay_stream: tokio::io::DuplexStream, +) -> Result<(u16, tokio::task::JoinHandle<()>), Box> { let listener = TcpListener::bind(("127.0.0.1", 0)).await?; let port = listener.local_addr()?.port(); @@ -938,6 +1852,9 @@ async fn run_exec_with_russh( #[cfg(test)] mod tests { use super::*; + use crate::grpc::test_support::test_server_state; + use openshell_core::proto::datamodel::v1::ObjectMeta; + use std::collections::HashMap; // ---- shell_escape ---- @@ -1034,6 +1951,87 @@ mod tests { assert!(build_remote_exec_command(&req).is_err()); } + #[test] + fn tcp_forward_init_allows_loopback_targets() { + for host in ["127.0.0.1", "::1", "localhost"] { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + service_id: String::new(), + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: host.to_string(), + port: 8080, + })), + authorization_token: String::new(), + }; + validate_tcp_forward_init(&init).expect("loopback target should pass"); + } + } + + #[test] + fn tcp_forward_init_allows_ssh_target() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + target: Some(tcp_forward_init::Target::Ssh(SshRelayTarget::default())), + ..Default::default() + }; + match validate_tcp_forward_init(&init).expect("ssh target should pass") { + relay_open::Target::Ssh(_) => {} + other @ relay_open::Target::Tcp(_) => panic!("expected SSH target, got {other:?}"), + } + } + + #[test] + fn tcp_forward_init_rejects_non_loopback_targets() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + service_id: String::new(), + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: "example.com".to_string(), + port: 8080, + })), + authorization_token: String::new(), + }; + assert_eq!( + validate_tcp_forward_init(&init) + .expect_err("hostname rejected") + .message(), + "tcp target host must be loopback" + ); + } + + #[test] + fn tcp_forward_init_rejects_invalid_port() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + service_id: String::new(), + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: "127.0.0.1".to_string(), + port: 0, + })), + authorization_token: String::new(), + }; + assert_eq!( + validate_tcp_forward_init(&init) + .expect_err("zero port rejected") + .message(), + "tcp target port must be between 1 and 65535" + ); + } + + #[test] + fn tcp_forward_init_requires_target() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + ..Default::default() + }; + assert_eq!( + validate_tcp_forward_init(&init) + .expect_err("missing target rejected") + .message(), + "tcp forward target is required" + ); + } + // ---- petname / generate_name ---- #[test] @@ -1066,4 +2064,1069 @@ mod tests { ); } } + + fn test_provider(name: &str, provider_type: &str) -> Provider { + test_provider_with_credential_key(name, provider_type, "TOKEN") + } + + fn test_provider_with_credential_key( + name: &str, + provider_type: &str, + credential_key: &str, + ) -> Provider { + Provider { + metadata: Some(ObjectMeta { + id: format!("provider-{name}"), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: provider_type.to_string(), + credentials: std::iter::once((credential_key.to_string(), "secret".to_string())) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + } + } + + fn test_sandbox(name: &str, providers: Vec) -> Sandbox { + Sandbox { + metadata: Some(ObjectMeta { + id: format!("sandbox-{name}"), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: std::iter::once(("team".to_string(), "agents".to_string())).collect(), + resource_version: 0, + }), + spec: Some(openshell_core::proto::SandboxSpec { + log_level: "debug".to_string(), + policy: Some(openshell_core::proto::SandboxPolicy::default()), + providers, + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + current_policy_version: 7, + ..Default::default() + } + } + + #[tokio::test] + async fn attach_sandbox_provider_persists_current_provider_list() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + let response = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "work-github".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.attached); + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let spec = sandbox.spec.unwrap(); + assert_eq!(spec.providers, vec!["work-github"]); + assert_eq!(spec.log_level, "debug"); + assert_eq!(sandbox.phase, SandboxPhase::Ready as i32); + assert_eq!(sandbox.current_policy_version, 7); + } + + #[tokio::test] + async fn attach_sandbox_provider_is_idempotent_and_avoids_duplicates() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox( + "work", + vec!["work-github".to_string(), "work-github".to_string()], + )) + .await + .unwrap(); + + let response = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "work-github".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(!response.attached); + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers, vec!["work-github"]); + } + + #[tokio::test] + async fn detach_sandbox_provider_is_idempotent_and_removes_all_matches() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox( + "work", + vec![ + "work-github".to_string(), + "other".to_string(), + "work-github".to_string(), + ], + )) + .await + .unwrap(); + + let response = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "work-github".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.detached); + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers, vec!["other"]); + + let response = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "work-github".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap() + .into_inner(); + assert!(!response.detached); + } + + #[tokio::test] + async fn list_sandbox_providers_returns_attached_provider_records() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", vec!["work-github".to_string()])) + .await + .unwrap(); + + let response = handle_list_sandbox_providers( + &state, + Request::new(ListSandboxProvidersRequest { + sandbox_name: "work".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + assert_eq!(response.providers.len(), 1); + assert_eq!(response.providers[0].r#type, "github"); + assert_eq!( + response.providers[0].credentials.get("TOKEN"), + Some(&"REDACTED".to_string()) + ); + } + + #[tokio::test] + async fn attach_sandbox_provider_validates_provider_exists() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "missing".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + } + + // ---- validate_interactive_exec_start ---- + + #[test] + fn interactive_exec_rejects_empty_stream() { + let err = validate_interactive_exec_start(None).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("expected start message")); + } + + #[test] + fn interactive_exec_rejects_stdin_as_first_message() { + use openshell_core::proto::exec_sandbox_input; + let msg = ExecSandboxInput { + payload: Some(exec_sandbox_input::Payload::Stdin(b"hello".to_vec())), + }; + let err = validate_interactive_exec_start(Some(msg)).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("start payload")); + } + + #[test] + fn interactive_exec_rejects_resize_as_first_message() { + use openshell_core::proto::{ExecSandboxWindowResize, exec_sandbox_input}; + let msg = ExecSandboxInput { + payload: Some(exec_sandbox_input::Payload::Resize( + ExecSandboxWindowResize { cols: 80, rows: 24 }, + )), + }; + let err = validate_interactive_exec_start(Some(msg)).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("start payload")); + } + + #[test] + fn interactive_exec_rejects_none_payload() { + let msg = ExecSandboxInput { payload: None }; + let err = validate_interactive_exec_start(Some(msg)).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + } + + #[test] + fn interactive_exec_rejects_missing_sandbox_id() { + use openshell_core::proto::exec_sandbox_input; + let msg = ExecSandboxInput { + payload: Some(exec_sandbox_input::Payload::Start(ExecSandboxRequest { + command: vec!["bash".to_string()], + ..Default::default() + })), + }; + let err = validate_interactive_exec_start(Some(msg)).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("sandbox_id")); + } + + #[test] + fn interactive_exec_rejects_missing_command() { + use openshell_core::proto::exec_sandbox_input; + let msg = ExecSandboxInput { + payload: Some(exec_sandbox_input::Payload::Start(ExecSandboxRequest { + sandbox_id: "test-id".to_string(), + ..Default::default() + })), + }; + let err = validate_interactive_exec_start(Some(msg)).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("command")); + } + + #[test] + fn interactive_exec_rejects_invalid_env_key() { + use openshell_core::proto::exec_sandbox_input; + let msg = ExecSandboxInput { + payload: Some(exec_sandbox_input::Payload::Start(ExecSandboxRequest { + sandbox_id: "test-id".to_string(), + command: vec!["bash".to_string()], + environment: std::iter::once(("bad key!".to_string(), "val".to_string())).collect(), + ..Default::default() + })), + }; + let err = validate_interactive_exec_start(Some(msg)).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("environment")); + } + + #[test] + fn interactive_exec_accepts_valid_start() { + use openshell_core::proto::exec_sandbox_input; + let msg = ExecSandboxInput { + payload: Some(exec_sandbox_input::Payload::Start(ExecSandboxRequest { + sandbox_id: "test-id".to_string(), + command: vec!["bash".to_string()], + tty: true, + cols: 120, + rows: 40, + ..Default::default() + })), + }; + let req = validate_interactive_exec_start(Some(msg)).unwrap(); + assert_eq!(req.sandbox_id, "test-id"); + assert_eq!(req.command, vec!["bash"]); + assert!(req.tty); + assert_eq!(req.cols, 120); + assert_eq!(req.rows, 40); + } + + #[tokio::test] + async fn interactive_exec_rejects_sandbox_not_found() { + let state = test_server_state().await; + + let req = ExecSandboxRequest { + sandbox_id: "nonexistent".to_string(), + command: vec!["bash".to_string()], + tty: true, + ..Default::default() + }; + let sandbox_result = state + .store + .get_message::(&req.sandbox_id) + .await + .unwrap(); + assert!(sandbox_result.is_none()); + } + + #[tokio::test] + async fn interactive_exec_rejects_sandbox_not_ready() { + let state = test_server_state().await; + let mut sandbox = test_sandbox("not-ready", Vec::new()); + sandbox.phase = SandboxPhase::Provisioning as i32; + state.store.put_message(&sandbox).await.unwrap(); + + let stored = state + .store + .get_message::("sandbox-not-ready") + .await + .unwrap() + .unwrap(); + assert_ne!( + SandboxPhase::try_from(stored.phase).ok(), + Some(SandboxPhase::Ready) + ); + } + + #[tokio::test] + async fn create_sandbox_rejects_provider_credential_key_collisions() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("provider-a", "outlook")) + .await + .unwrap(); + state + .store + .put_message(&test_provider("provider-b", "google-drive")) + .await + .unwrap(); + + let err = handle_create_sandbox( + &state, + Request::new(CreateSandboxRequest { + name: "collision".to_string(), + spec: Some(openshell_core::proto::SandboxSpec { + providers: vec!["provider-a".to_string(), "provider-b".to_string()], + ..Default::default() + }), + labels: HashMap::new(), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + assert!(err.message().contains("TOKEN")); + assert!(err.message().contains("provider-a")); + assert!(err.message().contains("provider-b")); + } + + #[tokio::test] + async fn attach_sandbox_provider_rejects_credential_key_collisions() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("provider-a", "outlook")) + .await + .unwrap(); + state + .store + .put_message(&test_provider("provider-b", "google-drive")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", vec!["provider-a".to_string()])) + .await + .unwrap(); + + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "provider-b".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + assert!(err.message().contains("TOKEN")); + assert!(err.message().contains("provider-a")); + assert!(err.message().contains("provider-b")); + } + + #[tokio::test] + async fn attach_sandbox_provider_accepts_at_max_providers_limit() { + let state = test_server_state().await; + + // Create MAX_PROVIDERS (32) providers + for i in 0..MAX_PROVIDERS { + state + .store + .put_message(&test_provider_with_credential_key( + &format!("provider-{i}"), + "generic", + &format!("TOKEN_{i}"), + )) + .await + .unwrap(); + } + + // Create sandbox with 31 providers already attached + let mut existing_providers = Vec::new(); + for i in 0..(MAX_PROVIDERS - 1) { + existing_providers.push(format!("provider-{i}")); + } + state + .store + .put_message(&test_sandbox("work", existing_providers)) + .await + .unwrap(); + + // Attaching the 32nd provider should succeed + let response = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "provider-31".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.attached); + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers.len(), MAX_PROVIDERS); + } + + #[tokio::test] + async fn attach_sandbox_provider_rejects_beyond_max_providers_limit() { + let state = test_server_state().await; + + // Create MAX_PROVIDERS + 1 providers + for i in 0..=MAX_PROVIDERS { + state + .store + .put_message(&test_provider_with_credential_key( + &format!("provider-{i}"), + "generic", + &format!("TOKEN_{i}"), + )) + .await + .unwrap(); + } + + // Create sandbox with MAX_PROVIDERS already attached + let mut existing_providers = Vec::new(); + for i in 0..MAX_PROVIDERS { + existing_providers.push(format!("provider-{i}")); + } + state + .store + .put_message(&test_sandbox("work", existing_providers)) + .await + .unwrap(); + + // Attempting to attach the 33rd provider should fail + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "provider-32".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("exceeds maximum")); + + // Verify sandbox was not modified + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers.len(), MAX_PROVIDERS); + } + + #[tokio::test] + async fn attach_sandbox_provider_pre_validation_fails_fast() { + let state = test_server_state().await; + + // Provider name that exceeds validation limits + let long_name = "a".repeat(1000); + state + .store + .put_message(&test_provider(&long_name, "generic")) + .await + .unwrap(); + + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Should fail validation before attempting CAS + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: long_name, + expected_resource_version: 0, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + } + + #[tokio::test] + async fn detach_sandbox_provider_pre_validation_rejects_invalid_names() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox("work", vec!["valid".to_string()])) + .await + .unwrap(); + + // Provider name that exceeds validation limits + let long_name = "a".repeat(1000); + + let err = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: long_name, + expected_resource_version: 0, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + } + + #[tokio::test] + async fn concurrent_create_ssh_session_prevents_duplicate_tokens() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Both requests try to create sessions for the same sandbox + // The token generation is random, so we can't force a collision, + // but we can verify that both succeed with different tokens + let state1 = state.clone(); + let handle1 = tokio::spawn(async move { + handle_create_ssh_session( + &state1, + Request::new(CreateSshSessionRequest { + sandbox_id: "sandbox-work".to_string(), + }), + ) + .await + }); + + let state2 = state.clone(); + let handle2 = tokio::spawn(async move { + handle_create_ssh_session( + &state2, + Request::new(CreateSshSessionRequest { + sandbox_id: "sandbox-work".to_string(), + }), + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // Both should succeed (tokens are random UUIDs, collision is astronomically unlikely) + assert!(result1.is_ok(), "first create should succeed"); + assert!(result2.is_ok(), "second create should succeed"); + + let token1 = result1.unwrap().into_inner().token; + let token2 = result2.unwrap().into_inner().token; + + // Tokens must be different + assert_ne!(token1, token2, "tokens should be unique"); + + // Both sessions should be in the database + let session1 = state + .store + .get_message::(&token1) + .await + .unwrap(); + let session2 = state + .store + .get_message::(&token2) + .await + .unwrap(); + assert!(session1.is_some()); + assert!(session2.is_some()); + } + + #[tokio::test] + async fn concurrent_revoke_ssh_session_handles_cas_properly() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Create a session first + let response = handle_create_ssh_session( + &state, + Request::new(CreateSshSessionRequest { + sandbox_id: "sandbox-work".to_string(), + }), + ) + .await + .unwrap(); + let token = response.into_inner().token; + + // Spawn two concurrent revocation attempts + let state1 = state.clone(); + let token1 = token.clone(); + let handle1 = tokio::spawn(async move { + handle_revoke_ssh_session( + &state1, + Request::new(RevokeSshSessionRequest { token: token1 }), + ) + .await + }); + + let state2 = state.clone(); + let token2 = token.clone(); + let handle2 = tokio::spawn(async move { + handle_revoke_ssh_session( + &state2, + Request::new(RevokeSshSessionRequest { token: token2 }), + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // One should succeed, one may fail with ABORTED due to CAS conflict + let successes = [&result1, &result2] + .iter() + .filter(|r| r.is_ok() && r.as_ref().unwrap().get_ref().revoked) + .count(); + + // At least one should succeed in revoking + assert!( + successes >= 1, + "at least one revocation should succeed, got: {result1:?}, {result2:?}" + ); + + // The session should be revoked in the database + let session = state.store.get_message::(&token).await.unwrap(); + assert!(session.is_some()); + assert!(session.unwrap().revoked, "session should be revoked"); + } + + // ---- CAS (Client-driven optimistic concurrency) tests ---- + + #[tokio::test] + async fn attach_sandbox_provider_client_driven_cas_succeeds_with_correct_version() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Fetch the sandbox to get its current resource_version + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let current_version = sandbox.metadata.as_ref().unwrap().resource_version; + + // Attach with correct expected_resource_version + let response = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "github".to_string(), + expected_resource_version: current_version, + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.attached); + + // Verify the resource_version incremented + let updated_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!( + updated_sandbox.metadata.as_ref().unwrap().resource_version, + current_version + 1 + ); + } + + #[tokio::test] + async fn attach_sandbox_provider_client_driven_cas_rejects_stale_version() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Get current version + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let current_version = sandbox.metadata.as_ref().unwrap().resource_version; + + // Try to attach with a stale version (current_version - 1 would be 0, use 99 instead) + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "github".to_string(), + expected_resource_version: 99, + }), + ) + .await + .unwrap_err(); + + // Should get ABORTED status for CAS conflict + assert_eq!(err.code(), tonic::Code::Aborted); + assert!( + err.message().contains("modified concurrently") + || err.message().contains("resource_version"), + "error message should mention concurrency conflict: {}", + err.message() + ); + + // Verify the sandbox was not modified + let unchanged_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!( + unchanged_sandbox + .metadata + .as_ref() + .unwrap() + .resource_version, + current_version + ); + assert!(unchanged_sandbox.spec.unwrap().providers.is_empty()); + } + + #[tokio::test] + async fn detach_sandbox_provider_client_driven_cas_succeeds_with_correct_version() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", vec!["github".to_string()])) + .await + .unwrap(); + + // Fetch the sandbox to get its current resource_version + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let current_version = sandbox.metadata.as_ref().unwrap().resource_version; + + // Detach with correct expected_resource_version + let response = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "github".to_string(), + expected_resource_version: current_version, + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.detached); + + // Verify the resource_version incremented + let updated_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!( + updated_sandbox.metadata.as_ref().unwrap().resource_version, + current_version + 1 + ); + } + + #[tokio::test] + async fn detach_sandbox_provider_client_driven_cas_rejects_stale_version() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", vec!["github".to_string()])) + .await + .unwrap(); + + // Get current version + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let current_version = sandbox.metadata.as_ref().unwrap().resource_version; + + // Try to detach with a stale version + let err = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "github".to_string(), + expected_resource_version: 99, + }), + ) + .await + .unwrap_err(); + + // Should get ABORTED status for CAS conflict + assert_eq!(err.code(), tonic::Code::Aborted); + assert!( + err.message().contains("modified concurrently") + || err.message().contains("resource_version"), + "error message should mention concurrency conflict: {}", + err.message() + ); + + // Verify the sandbox was not modified + let unchanged_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!( + unchanged_sandbox + .metadata + .as_ref() + .unwrap() + .resource_version, + current_version + ); + assert_eq!(unchanged_sandbox.spec.unwrap().providers, vec!["github"]); + } + + #[tokio::test] + async fn attach_sandbox_provider_concurrent_with_stale_versions() { + use std::sync::Arc; + + let state = Arc::new(test_server_state().await); + + // Create multiple providers + for i in 0..3 { + state + .store + .put_message(&test_provider_with_credential_key( + &format!("provider-{i}"), + "generic", + &format!("TOKEN_{i}"), + )) + .await + .unwrap(); + } + + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // All three clients fetch the sandbox and see version 1 + let initial_version = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .metadata + .as_ref() + .unwrap() + .resource_version; + + // Launch 3 concurrent attach operations, all using the same initial version + let mut handles = vec![]; + for i in 0..3 { + let state_clone = Arc::clone(&state); + let handle = tokio::spawn(async move { + handle_attach_sandbox_provider( + &state_clone, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: format!("provider-{i}"), + expected_resource_version: initial_version, + }), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Only one should succeed; others should get ABORTED + let successes = results.iter().filter(|r| r.is_ok()).count(); + let aborted_conflicts = results + .iter() + .filter(|r| { + r.as_ref() + .err() + .is_some_and(|e| e.code() == tonic::Code::Aborted) + }) + .count(); + + assert_eq!( + successes, 1, + "exactly one attach should succeed with client-driven CAS" + ); + assert_eq!( + aborted_conflicts, 2, + "two attaches should fail with ABORTED due to stale version" + ); + + // Final sandbox should have exactly 1 provider and resource_version = initial_version + 1 + let final_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!(final_sandbox.spec.as_ref().unwrap().providers.len(), 1); + assert_eq!( + final_sandbox.metadata.as_ref().unwrap().resource_version, + initial_version + 1 + ); + } } diff --git a/crates/openshell-server/src/grpc/service.rs b/crates/openshell-server/src/grpc/service.rs new file mode 100644 index 000000000..4d73f2279 --- /dev/null +++ b/crates/openshell-server/src/grpc/service.rs @@ -0,0 +1,547 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; +use std::sync::Arc; + +use openshell_core::ObjectId; +use openshell_core::proto::datamodel::v1::ObjectMeta; +use openshell_core::proto::{ + DeleteServiceRequest, DeleteServiceResponse, ExposeServiceRequest, GetServiceRequest, + ListServicesRequest, ListServicesResponse, Sandbox, ServiceEndpoint, ServiceEndpointResponse, +}; +use prost::Message as _; +use tonic::{Request, Response, Status}; +use uuid::Uuid; + +use crate::ServerState; +use crate::persistence::{ObjectType, WriteCondition}; +use crate::service_routing; + +const MAX_SERVICE_NAME_LEN: usize = 28; +const MAX_SANDBOX_NAME_LEN: usize = 28; + +pub(super) async fn handle_expose_service( + state: &Arc, + request: Request, +) -> Result, Status> { + let req = request.into_inner(); + validate_endpoint_name("sandbox", &req.sandbox, MAX_SANDBOX_NAME_LEN)?; + validate_optional_endpoint_name("service", &req.service, MAX_SERVICE_NAME_LEN)?; + if req.target_port == 0 || req.target_port > u32::from(u16::MAX) { + return Err(Status::invalid_argument("target_port must be in 1..=65535")); + } + + let sandbox = state + .store + .get_message_by_name::(&req.sandbox) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + + let now = super::current_time_ms(); + let key = service_routing::endpoint_key(&req.sandbox, &req.service); + + // Fetch existing endpoint to determine create vs. update path + let existing = state + .store + .get_message_by_name::(&key) + .await + .map_err(|e| Status::internal(format!("fetch endpoint failed: {e}")))?; + + let (id, created_at_ms, condition, created) = if let Some(existing) = existing { + // Update path: preserve id and created_at, use CAS to prevent conflicts + let resource_version = existing + .metadata + .as_ref() + .map_or(0, |metadata| metadata.resource_version); + ( + existing.object_id().to_string(), + existing + .metadata + .as_ref() + .map_or(now, |metadata| metadata.created_at_ms), + WriteCondition::MatchResourceVersion(resource_version), + false, + ) + } else { + // Create path: new id and created_at, use MustCreate to prevent races + ( + Uuid::new_v4().to_string(), + now, + WriteCondition::MustCreate, + true, + ) + }; + + let labels_json = serde_json::to_string(&HashMap::from([( + "sandbox".to_string(), + req.sandbox.clone(), + )])) + .map_err(|e| Status::internal(format!("serialize labels failed: {e}")))?; + + let endpoint = ServiceEndpoint { + metadata: Some(ObjectMeta { + id: id.clone(), + name: key.clone(), + created_at_ms, + labels: HashMap::from([("sandbox".to_string(), req.sandbox.clone())]), + resource_version: 0, + }), + sandbox_id: sandbox.object_id().to_string(), + sandbox_name: req.sandbox.clone(), + service_name: req.service.clone(), + target_port: req.target_port, + domain: true, + }; + + // Single-attempt CAS write: fails with ABORTED on concurrent modification + let result = state + .store + .put_if( + ServiceEndpoint::object_type(), + &id, + &key, + &endpoint.encode_to_vec(), + Some(&labels_json), + condition, + ) + .await + .map_err(|e| super::persistence_error_to_status(e, "expose service"))?; + + let mut endpoint = endpoint; + if let Some(ref mut meta) = endpoint.metadata { + meta.resource_version = result.resource_version; + } + + let url = service_routing::endpoint_url(&state.config, &req.sandbox, &req.service) + .unwrap_or_default(); + service_routing::emit_service_endpoint_config_event(&endpoint, &url, created); + + Ok(Response::new(ServiceEndpointResponse { + endpoint: Some(endpoint), + url, + })) +} + +pub(super) async fn handle_get_service( + state: &Arc, + request: Request, +) -> Result, Status> { + let req = request.into_inner(); + validate_endpoint_name("sandbox", &req.sandbox, MAX_SANDBOX_NAME_LEN)?; + validate_optional_endpoint_name("service", &req.service, MAX_SERVICE_NAME_LEN)?; + + let endpoint = get_service_endpoint(state, &req.sandbox, &req.service) + .await? + .ok_or_else(|| Status::not_found("service endpoint not found"))?; + + Ok(Response::new(service_endpoint_response(state, endpoint))) +} + +pub(super) async fn handle_list_services( + state: &Arc, + request: Request, +) -> Result, Status> { + let req = request.into_inner(); + if !req.sandbox.is_empty() { + validate_endpoint_name("sandbox", &req.sandbox, MAX_SANDBOX_NAME_LEN)?; + } + + let limit = super::clamp_limit(req.limit, 100, super::MAX_PAGE_SIZE); + let endpoints: Vec = if req.sandbox.is_empty() { + state.store.list_messages(limit, req.offset).await + } else { + state + .store + .list_messages_with_selector(&format!("sandbox={}", req.sandbox), limit, req.offset) + .await + } + .map_err(|e| Status::internal(format!("list endpoints failed: {e}")))?; + + let services = endpoints + .into_iter() + .map(|ep| service_endpoint_response(state, ep)) + .collect(); + + Ok(Response::new(ListServicesResponse { services })) +} + +pub(super) async fn handle_delete_service( + state: &Arc, + request: Request, +) -> Result, Status> { + let req = request.into_inner(); + validate_endpoint_name("sandbox", &req.sandbox, MAX_SANDBOX_NAME_LEN)?; + validate_optional_endpoint_name("service", &req.service, MAX_SERVICE_NAME_LEN)?; + + let endpoint = get_service_endpoint(state, &req.sandbox, &req.service).await?; + let Some(endpoint) = endpoint else { + return Ok(Response::new(DeleteServiceResponse { deleted: false })); + }; + + let key = service_routing::endpoint_key(&req.sandbox, &req.service); + let deleted = state + .store + .delete_by_name(ServiceEndpoint::object_type(), &key) + .await + .map_err(|e| Status::internal(format!("delete endpoint failed: {e}")))?; + + if deleted { + service_routing::emit_service_endpoint_delete_event(&endpoint); + } + + Ok(Response::new(DeleteServiceResponse { deleted })) +} + +async fn get_service_endpoint( + state: &Arc, + sandbox: &str, + service: &str, +) -> Result, Status> { + let key = service_routing::endpoint_key(sandbox, service); + state + .store + .get_message_by_name::(&key) + .await + .map_err(|e| Status::internal(format!("fetch endpoint failed: {e}"))) +} + +fn service_endpoint_response( + state: &Arc, + endpoint: ServiceEndpoint, +) -> ServiceEndpointResponse { + let url = service_routing::endpoint_url( + &state.config, + &endpoint.sandbox_name, + &endpoint.service_name, + ) + .unwrap_or_default(); + ServiceEndpointResponse { + endpoint: Some(endpoint), + url, + } +} + +#[allow(clippy::result_large_err)] +fn validate_endpoint_name(field: &str, value: &str, max_len: usize) -> Result<(), Status> { + if value.is_empty() { + return Err(Status::invalid_argument(format!("{field} is required"))); + } + validate_non_empty_endpoint_name(field, value, max_len) +} + +#[allow(clippy::result_large_err)] +fn validate_optional_endpoint_name(field: &str, value: &str, max_len: usize) -> Result<(), Status> { + if value.is_empty() { + return Ok(()); + } + validate_non_empty_endpoint_name(field, value, max_len) +} + +#[allow(clippy::result_large_err)] +fn validate_non_empty_endpoint_name( + field: &str, + value: &str, + max_len: usize, +) -> Result<(), Status> { + if value.len() > max_len { + return Err(Status::invalid_argument(format!( + "{field} must be at most {max_len} characters for sandbox service routing" + ))); + } + if value.contains("--") { + return Err(Status::invalid_argument(format!( + "{field} must not contain '--'" + ))); + } + if !is_dns_label(value) { + return Err(Status::invalid_argument(format!( + "{field} must be a lowercase DNS label" + ))); + } + Ok(()) +} + +fn is_dns_label(value: &str) -> bool { + if value.starts_with('-') || value.ends_with('-') { + return false; + } + value + .bytes() + .all(|b| b.is_ascii_lowercase() || b.is_ascii_digit() || b == b'-') +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::grpc::test_support::test_server_state; + use openshell_core::proto::SandboxPhase; + + async fn seed_sandbox(state: &Arc, name: &str) { + state + .store + .put_message(&Sandbox { + metadata: Some(ObjectMeta { + id: format!("sandbox-{name}"), + name: name.to_string(), + created_at_ms: 1_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(openshell_core::proto::SandboxSpec::default()), + phase: SandboxPhase::Ready as i32, + ..Default::default() + }) + .await + .unwrap(); + } + + #[test] + fn validates_good_endpoint_name() { + validate_endpoint_name("service", "web-api", 28).unwrap(); + } + + #[test] + fn validates_empty_optional_service_name() { + validate_optional_endpoint_name("service", "", 28).unwrap(); + } + + #[test] + fn rejects_separator_in_endpoint_name() { + assert!(validate_endpoint_name("service", "web--api", 28).is_err()); + } + + #[test] + fn rejects_empty_required_endpoint_name() { + assert!(validate_endpoint_name("sandbox", "", 28).is_err()); + } + + #[test] + fn rejects_uppercase_endpoint_name() { + assert!(validate_endpoint_name("service", "Web", 28).is_err()); + } + + #[tokio::test] + async fn endpoint_lifecycle_round_trip() { + let state = test_server_state().await; + seed_sandbox(&state, "my-sandbox").await; + + let exposed = handle_expose_service( + &state, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 8080, + domain: true, + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(exposed.endpoint.as_ref().unwrap().target_port, 8080); + + let listed = handle_list_services( + &state, + Request::new(ListServicesRequest { + sandbox: "my-sandbox".to_string(), + limit: 0, + offset: 0, + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(listed.services.len(), 1); + assert_eq!( + listed.services[0].endpoint.as_ref().unwrap().service_name, + "web" + ); + + let fetched = handle_get_service( + &state, + Request::new(GetServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(fetched.endpoint.as_ref().unwrap().target_port, 8080); + + let deleted = handle_delete_service( + &state, + Request::new(DeleteServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(deleted.deleted); + + let err = handle_get_service( + &state, + Request::new(GetServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + }), + ) + .await + .unwrap_err(); + assert_eq!(err.code(), tonic::Code::NotFound); + + let listed = handle_list_services( + &state, + Request::new(ListServicesRequest { + sandbox: "my-sandbox".to_string(), + limit: 0, + offset: 0, + }), + ) + .await + .unwrap() + .into_inner(); + assert!(listed.services.is_empty()); + } + + #[tokio::test] + async fn concurrent_expose_service_handles_cas_properly() { + let state = test_server_state().await; + seed_sandbox(&state, "my-sandbox").await; + + // Spawn two concurrent expose_service calls for the same endpoint + let state1 = state.clone(); + let handle1 = tokio::spawn(async move { + handle_expose_service( + &state1, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 8080, + domain: true, + }), + ) + .await + }); + + let state2 = state.clone(); + let handle2 = tokio::spawn(async move { + handle_expose_service( + &state2, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 9090, + domain: true, + }), + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // One should succeed with MustCreate, the other may fail with ABORTED or succeed with update + let successes = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + + // At least one should succeed + assert!( + successes >= 1, + "at least one expose should succeed, got: {result1:?}, {result2:?}" + ); + + // Only one endpoint should exist + let listed = handle_list_services( + &state, + Request::new(ListServicesRequest { + sandbox: "my-sandbox".to_string(), + limit: 0, + offset: 0, + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(listed.services.len(), 1); + } + + #[tokio::test] + async fn concurrent_expose_service_update_uses_cas() { + let state = test_server_state().await; + seed_sandbox(&state, "my-sandbox").await; + + // Create an initial endpoint + handle_expose_service( + &state, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 7070, + domain: true, + }), + ) + .await + .unwrap(); + + // Spawn two concurrent updates + let state1 = state.clone(); + let handle1 = tokio::spawn(async move { + handle_expose_service( + &state1, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 8080, + domain: true, + }), + ) + .await + }); + + let state2 = state.clone(); + let handle2 = tokio::spawn(async move { + handle_expose_service( + &state2, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 9090, + domain: true, + }), + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // One should succeed, one may fail with ABORTED due to CAS conflict + let successes = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + + assert!( + successes >= 1, + "at least one update should succeed, got: {result1:?}, {result2:?}" + ); + + // The endpoint should have one of the new port values + let fetched = handle_get_service( + &state, + Request::new(GetServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + let port = fetched.endpoint.as_ref().unwrap().target_port; + assert!( + port == 8080 || port == 9090, + "port should be one of the updated values, got {port}" + ); + assert_ne!(port, 7070, "port should not be the original value"); + } +} diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 160b7e031..53f292053 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -267,6 +267,25 @@ pub(super) fn validate_provider_fields(provider: &Provider) -> Result<(), Status MAX_MAP_VALUE_LEN, "provider.config", )?; + if provider.credential_expires_at_ms.len() > MAX_PROVIDER_CREDENTIALS_ENTRIES { + return Err(Status::invalid_argument(format!( + "provider.credential_expires_at_ms exceeds maximum entries ({} > {MAX_PROVIDER_CREDENTIALS_ENTRIES})", + provider.credential_expires_at_ms.len() + ))); + } + for (key, value) in &provider.credential_expires_at_ms { + if key.len() > MAX_MAP_KEY_LEN { + return Err(Status::invalid_argument(format!( + "provider.credential_expires_at_ms key exceeds maximum length ({} > {MAX_MAP_KEY_LEN})", + key.len() + ))); + } + if *value < 0 { + return Err(Status::invalid_argument( + "provider.credential_expires_at_ms value must be greater than or equal to 0", + )); + } + } Ok(()) } @@ -874,10 +893,12 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials, config, + credential_expires_at_ms: HashMap::new(), } } diff --git a/crates/openshell-server/src/http.rs b/crates/openshell-server/src/http.rs index 7650c2339..40f0b39d0 100644 --- a/crates/openshell-server/src/http.rs +++ b/crates/openshell-server/src/http.rs @@ -3,7 +3,14 @@ //! HTTP health endpoints using Axum. -use axum::{Json, Router, extract::State, http::StatusCode, response::IntoResponse, routing::get}; +use axum::{ + Json, Router, + extract::{Request, State}, + http::{HeaderMap, StatusCode, header}, + middleware::{self, Next}, + response::IntoResponse, + routing::get, +}; use metrics_exporter_prometheus::PrometheusHandle; use serde::Serialize; use std::sync::Arc; @@ -59,7 +66,242 @@ async fn render_metrics(State(handle): State) -> impl IntoResp /// Create the HTTP router. pub fn http_router(state: Arc) -> Router { - crate::ssh_tunnel::router(state.clone()) - .merge(crate::ws_tunnel::router(state.clone())) - .merge(crate::auth::router(state)) + crate::ws_tunnel::router(state.clone()) + .merge(crate::auth::router(state.clone())) + .layer(middleware::from_fn_with_state( + state, + sandbox_service_routing_first, + )) +} + +/// Create the plaintext loopback-only router for browser service endpoints. +/// +/// This router intentionally exposes only sandbox service routing. It does not +/// include gRPC, auth, health, metrics, or WebSocket tunnel routes. +pub fn service_http_router(state: Arc) -> Router { + Router::new() + .fallback(sandbox_service_routing_only) + .with_state(state) +} + +async fn sandbox_service_routing_first( + State(state): State>, + req: Request, + next: Next, +) -> impl IntoResponse { + if crate::service_routing::is_sandbox_service_request(&req, &state.config.service_routing) { + return crate::service_routing::proxy_sandbox_service_request(state, req) + .await + .into_response(); + } + next.run(req).await.into_response() +} + +async fn sandbox_service_routing_only( + State(state): State>, + req: Request, +) -> impl IntoResponse { + if !crate::service_routing::is_sandbox_service_request(&req, &state.config.service_routing) { + return StatusCode::NOT_FOUND.into_response(); + } + if !browser_context_allows_plaintext_service_request(&req) { + crate::service_routing::emit_cross_origin_service_http_rejection(&state, &req); + return crate::service_routing::service_error_response( + StatusCode::FORBIDDEN, + "Cross-origin service request rejected", + ); + } + crate::service_routing::proxy_sandbox_service_request(state, req) + .await + .into_response() +} + +fn browser_context_allows_plaintext_service_request(req: &Request) -> bool { + if let Some(fetch_site) = header_str(req.headers(), "sec-fetch-site") + && !matches!( + fetch_site.to_ascii_lowercase().as_str(), + "same-origin" | "none" + ) + { + return false; + } + + if let Some(origin) = header_str(req.headers(), header::ORIGIN.as_str()) { + let Some(request_origin) = request_origin(req) else { + return false; + }; + return parse_origin(origin).is_some_and(|origin| origin == request_origin); + } + + if let Some(referer) = header_str(req.headers(), header::REFERER.as_str()) { + let Some(request_origin) = request_origin(req) else { + return false; + }; + return parse_origin(referer).is_some_and(|origin| origin == request_origin); + } + + true +} + +fn header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { + headers.get(name)?.to_str().ok() +} + +#[derive(Debug, Eq, PartialEq)] +struct Origin { + scheme: String, + host: String, + port: u16, +} + +fn request_origin(req: &Request) -> Option { + let host = crate::service_routing::request_host(req)?; + parse_origin_authority("http", host) +} + +fn parse_origin(value: &str) -> Option { + if value.eq_ignore_ascii_case("null") { + return None; + } + let (scheme, rest) = value.split_once("://")?; + let authority_end = rest.find(['/', '?', '#']).unwrap_or(rest.len()); + parse_origin_authority(scheme, &rest[..authority_end]) +} + +fn parse_origin_authority(scheme: &str, authority: &str) -> Option { + let scheme = scheme.to_ascii_lowercase(); + let default_port = match scheme.as_str() { + "http" => 80, + "https" => 443, + _ => return None, + }; + let authority = authority.trim(); + if authority.is_empty() || authority.contains('@') { + return None; + } + + let (host, port) = split_host_port(authority)?; + let host = normalize_host(host)?; + Some(Origin { + scheme, + host, + port: port.unwrap_or(default_port), + }) +} + +fn split_host_port(authority: &str) -> Option<(&str, Option)> { + if let Some(rest) = authority.strip_prefix('[') { + let (host, rest) = rest.split_once(']')?; + let port = if rest.is_empty() { + None + } else { + Some(rest.strip_prefix(':')?.parse().ok()?) + }; + return Some((host, port)); + } + + match authority.rsplit_once(':') { + Some((host, port)) if !port.is_empty() && port.chars().all(|ch| ch.is_ascii_digit()) => { + Some((host, Some(port.parse().ok()?))) + } + Some(_) if authority.matches(':').count() == 1 => None, + _ => Some((authority, None)), + } +} + +fn normalize_host(host: &str) -> Option { + let host = host.trim().trim_end_matches('.').to_ascii_lowercase(); + (!host.is_empty()).then_some(host) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn service_request(headers: &[(&str, &str)]) -> Request { + let mut builder = Request::builder() + .uri("/some/path") + .header(header::HOST, "sandbox--web.dev.openshell.localhost:8080"); + for (name, value) in headers { + builder = builder.header(*name, *value); + } + builder.body(axum::body::Body::empty()).unwrap() + } + + #[test] + fn plaintext_service_browser_context_allows_direct_tools() { + let req = service_request(&[]); + + assert!(browser_context_allows_plaintext_service_request(&req)); + } + + #[test] + fn plaintext_service_browser_context_allows_same_origin_fetch_metadata() { + let req = service_request(&[("sec-fetch-site", "same-origin")]); + + assert!(browser_context_allows_plaintext_service_request(&req)); + } + + #[test] + fn plaintext_service_browser_context_allows_direct_navigation_fetch_metadata() { + let req = service_request(&[("sec-fetch-site", "none")]); + + assert!(browser_context_allows_plaintext_service_request(&req)); + } + + #[test] + fn plaintext_service_browser_context_rejects_cross_site_fetch_metadata() { + let req = service_request(&[("sec-fetch-site", "cross-site")]); + + assert!(!browser_context_allows_plaintext_service_request(&req)); + } + + #[test] + fn plaintext_service_browser_context_rejects_same_site_sibling_requests() { + let req = service_request(&[("sec-fetch-site", "same-site")]); + + assert!(!browser_context_allows_plaintext_service_request(&req)); + } + + #[test] + fn plaintext_service_browser_context_requires_matching_origin() { + let req = + service_request(&[("origin", "http://sandbox--web.dev.openshell.localhost:8080")]); + + assert!(browser_context_allows_plaintext_service_request(&req)); + + let req = service_request(&[( + "origin", + "http://sandbox--other.dev.openshell.localhost:8080", + )]); + + assert!(!browser_context_allows_plaintext_service_request(&req)); + } + + #[test] + fn plaintext_service_browser_context_requires_matching_referer() { + let req = service_request(&[( + "referer", + "http://sandbox--web.dev.openshell.localhost:8080/page", + )]); + + assert!(browser_context_allows_plaintext_service_request(&req)); + + let req = service_request(&[( + "referer", + "http://sandbox--other.dev.openshell.localhost:8080/page", + )]); + + assert!(!browser_context_allows_plaintext_service_request(&req)); + } + + #[test] + fn plaintext_service_browser_context_rejects_mismatched_origin_scheme() { + let req = service_request(&[( + "origin", + "https://sandbox--web.dev.openshell.localhost:8080", + )]); + + assert!(!browser_context_allows_plaintext_service_request(&req)); + } } diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index b52700f0d..183a80e74 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -3,6 +3,7 @@ #![allow(clippy::result_large_err)] // gRPC handlers return Result, Status> +use openshell_core::ObjectId; use openshell_core::proto::{ ClusterInferenceConfig, GetClusterInferenceRequest, GetClusterInferenceResponse, GetInferenceBundleRequest, GetInferenceBundleResponse, InferenceRoute, Provider, ResolvedRoute, @@ -11,13 +12,14 @@ use openshell_core::proto::{ }; use openshell_router::config::ResolvedRoute as RouterResolvedRoute; use openshell_router::{ValidationFailureKind, verify_backend_endpoint}; +use prost::Message as _; use std::sync::Arc; use std::time::Duration; use tonic::{Request, Response, Status}; use crate::{ ServerState, - persistence::{ObjectName, ObjectType, Store, current_time_ms}, + persistence::{ObjectName, ObjectType, Store, WriteCondition, current_time_ms}, }; #[derive(Debug)] @@ -57,8 +59,13 @@ impl ObjectType for InferenceRoute { impl Inference for InferenceService { async fn get_inference_bundle( &self, - _request: Request, + request: Request, ) -> Result, Status> { + authorize_inference_bundle( + request + .extensions() + .get::(), + )?; resolve_inference_bundle(self.state.store.as_ref()) .await .map(Response::new) @@ -169,40 +176,57 @@ async fn upsert_cluster_inference_route( let config = build_cluster_inference_config(&provider, model_id, timeout_secs); + // Fetch existing route to determine create vs. update path let existing = store .get_message_by_name::(route_name) .await .map_err(|e| Status::internal(format!("fetch route failed: {e}")))?; - let now_ms = - current_time_ms().map_err(|e| Status::internal(format!("get current time: {e}")))?; + let now_ms = current_time_ms(); - let route = if let Some(existing) = existing { - InferenceRoute { - metadata: existing.metadata.clone(), - config: Some(config), - version: existing.version.saturating_add(1), - } + let (id, metadata, new_version, condition) = if let Some(existing) = existing { + // Update path: preserve metadata, increment version, use CAS + let resource_version = existing.metadata.as_ref().map_or(0, |m| m.resource_version); + ( + existing.object_id().to_string(), + existing.metadata.clone(), + existing.version.saturating_add(1), + WriteCondition::MatchResourceVersion(resource_version), + ) } else { - InferenceRoute { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: uuid::Uuid::new_v4().to_string(), - name: route_name.to_string(), - created_at_ms: now_ms, - labels: std::collections::HashMap::new(), - }), - config: Some(config), - version: 1, - } + // Create path: new metadata, version 1, use MustCreate + let new_id = uuid::Uuid::new_v4().to_string(); + let new_metadata = Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: new_id.clone(), + name: route_name.to_string(), + created_at_ms: now_ms, + labels: std::collections::HashMap::new(), + resource_version: 0, + }); + (new_id, new_metadata, 1, WriteCondition::MustCreate) + }; + + let route = InferenceRoute { + metadata, + config: Some(config), + version: new_version, }; // Ensure metadata is valid (defense in depth - should always be true for server-constructed metadata) crate::grpc::validate_object_metadata(route.metadata.as_ref(), "inference_route")?; + // Single-attempt CAS write: fails with ABORTED on concurrent modification store - .put_message(&route) + .put_if( + InferenceRoute::object_type(), + &id, + route_name, + &route.encode_to_vec(), + None, + condition, + ) .await - .map_err(|e| Status::internal(format!("persist route failed: {e}")))?; + .map_err(|e| crate::grpc::persistence_error_to_status(e, "upsert inference route"))?; Ok(UpsertedInferenceRoute { route, validation }) } @@ -382,6 +406,20 @@ fn find_provider_config_value(provider: &Provider, preferred_keys: &[&str]) -> O None } +fn authorize_inference_bundle( + principal: Option<&crate::auth::principal::Principal>, +) -> Result<(), Status> { + match principal { + Some(crate::auth::principal::Principal::Sandbox(_)) => Ok(()), + Some(crate::auth::principal::Principal::User(_)) => Err(Status::permission_denied( + "GetInferenceBundle requires a sandbox principal", + )), + Some(crate::auth::principal::Principal::Anonymous) | None => Err(Status::unauthenticated( + "GetInferenceBundle requires an authenticated sandbox principal", + )), + } +} + /// Resolve the inference bundle (all managed routes + revision hash). async fn resolve_inference_bundle(store: &Store) -> Result { let mut routes = Vec::new(); @@ -479,10 +517,42 @@ async fn resolve_route_by_name( #[cfg(test)] mod tests { use super::*; + use crate::auth::identity::{Identity, IdentityProvider}; + use crate::auth::principal::{ + Principal, SandboxIdentitySource, SandboxPrincipal, UserPrincipal, + }; use openshell_core::ObjectId; use wiremock::matchers::{body_partial_json, header, method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; + async fn test_store() -> Store { + Store::connect("sqlite::memory:?cache=shared") + .await + .expect("in-memory SQLite store should connect") + } + + fn test_user_principal() -> Principal { + Principal::User(UserPrincipal { + identity: Identity { + subject: "user-a".to_string(), + display_name: None, + roles: vec!["openshell-user".to_string()], + scopes: vec![], + provider: IdentityProvider::Oidc, + }, + }) + } + + fn test_sandbox_principal() -> Principal { + Principal::Sandbox(SandboxPrincipal { + sandbox_id: "sandbox-a".to_string(), + source: SandboxIdentitySource::BootstrapJwt { + issuer: "openshell-gateway:test".to_string(), + }, + trust_domain: Some("openshell".to_string()), + }) + } + fn make_route(name: &str, provider_name: &str, model_id: &str) -> InferenceRoute { InferenceRoute { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { @@ -490,6 +560,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), config: Some(ClusterInferenceConfig { provider_name: provider_name.to_string(), @@ -507,10 +578,12 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: std::iter::once((key_name.to_string(), key_value.to_string())).collect(), config: std::collections::HashMap::new(), + credential_expires_at_ms: std::collections::HashMap::new(), } } @@ -528,11 +601,22 @@ mod tests { } } + #[test] + fn inference_bundle_requires_sandbox_principal() { + let sandbox = test_sandbox_principal(); + assert!(authorize_inference_bundle(Some(&sandbox)).is_ok()); + + let user = test_user_principal(); + let err = authorize_inference_bundle(Some(&user)).expect_err("users cannot fetch bundle"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + + let err = authorize_inference_bundle(None).expect_err("missing principal rejected"); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + } + #[tokio::test] async fn upsert_cluster_route_creates_and_increments_version() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store should connect"); + let store = test_store().await; let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-test"); store @@ -571,9 +655,7 @@ mod tests { #[tokio::test] async fn resolve_managed_route_returns_none_when_missing() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store should connect"); + let store = test_store().await; let route = resolve_route_by_name(&store, CLUSTER_INFERENCE_ROUTE_NAME) .await @@ -583,9 +665,7 @@ mod tests { #[tokio::test] async fn bundle_happy_path_returns_managed_route() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-test"); store @@ -612,9 +692,7 @@ mod tests { #[tokio::test] async fn bundle_without_cluster_route_returns_empty_routes() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let resp = resolve_inference_bundle(&store) .await @@ -624,9 +702,7 @@ mod tests { #[tokio::test] async fn bundle_revision_is_stable_for_same_route() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-test"); store @@ -656,9 +732,7 @@ mod tests { #[tokio::test] async fn resolve_managed_route_derives_from_provider() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store should connect"); + let store = test_store().await; let provider = Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { @@ -666,6 +740,7 @@ mod tests { name: "openai-dev".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), r#type: "openai".to_string(), credentials: std::iter::once(("OPENAI_API_KEY".to_string(), "sk-test".to_string())) @@ -675,6 +750,7 @@ mod tests { "https://station.example.com/v1".to_string(), )) .collect(), + credential_expires_at_ms: std::collections::HashMap::new(), }; store .put_message(&provider) @@ -687,6 +763,7 @@ mod tests { name: CLUSTER_INFERENCE_ROUTE_NAME.to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), config: Some(ClusterInferenceConfig { provider_name: "openai-dev".to_string(), @@ -721,9 +798,7 @@ mod tests { #[tokio::test] async fn resolve_managed_route_reflects_provider_key_rotation() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store should connect"); + let store = test_store().await; let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-initial"); store @@ -749,6 +824,7 @@ mod tests { credentials: std::iter::once(("OPENAI_API_KEY".to_string(), "sk-rotated".to_string())) .collect(), config: provider.config.clone(), + credential_expires_at_ms: provider.credential_expires_at_ms.clone(), }; store .put_message(&rotated_provider) @@ -764,9 +840,7 @@ mod tests { #[tokio::test] async fn upsert_system_route_creates_with_correct_name() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let provider = make_provider("anthropic-dev", "anthropic", "ANTHROPIC_API_KEY", "sk-ant"); store.put_message(&provider).await.expect("persist"); @@ -790,9 +864,7 @@ mod tests { #[tokio::test] async fn bundle_includes_both_user_and_system_routes() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let openai = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-oai"); store.put_message(&openai).await.expect("persist openai"); @@ -830,9 +902,7 @@ mod tests { #[tokio::test] async fn bundle_with_only_system_route() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-test"); store.put_message(&provider).await.expect("persist"); @@ -850,9 +920,7 @@ mod tests { #[tokio::test] async fn get_returns_system_route_when_requested() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-test"); store.put_message(&provider).await.expect("persist"); @@ -881,9 +949,7 @@ mod tests { #[tokio::test] async fn upsert_cluster_route_verifies_endpoint_when_requested() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let mock_server = MockServer::start().await; Mock::given(method("POST")) @@ -934,9 +1000,7 @@ mod tests { #[tokio::test] async fn upsert_cluster_route_rejects_failed_validation() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let mock_server = MockServer::start().await; Mock::given(method("POST")) @@ -987,9 +1051,7 @@ mod tests { #[tokio::test] async fn upsert_cluster_route_skips_validation_by_default() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .expect("store"); + let store = test_store().await; let provider = make_provider_with_base_url( "openai-dev", "openai", @@ -1047,4 +1109,173 @@ mod tests { let err = effective_route_name("unknown-route").unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); } + + #[tokio::test] + async fn concurrent_upsert_route_create_uses_must_create() { + let store = test_store().await; + + let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-test"); + store.put_message(&provider).await.expect("persist"); + + // Spawn two concurrent upsert calls for the same route (create path) + let store1 = store.clone(); + let handle1 = tokio::spawn(async move { + upsert_cluster_inference_route( + &store1, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4o", + 0, + false, + ) + .await + }); + + let store2 = store.clone(); + let handle2 = tokio::spawn(async move { + upsert_cluster_inference_route( + &store2, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4.1", + 0, + false, + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // If both tasks observe a missing route before either insert commits, MustCreate + // should let exactly one win. If the scheduler serializes them, the second call + // may legitimately observe the new route and take the update path. + let successes = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + let failures = [&result1, &result2] + .iter() + .filter(|r| { + r.as_ref().is_err_and(|e| { + // Accept either ABORTED (from CAS) or Internal (from DB unique constraint) + e.code() == tonic::Code::Aborted + || (e.code() == tonic::Code::Internal + && e.message().contains("unique violation")) + }) + }) + .count(); + + assert!( + successes == 1 || successes == 2, + "one racing create should succeed, or both serialized upserts should succeed, got: {result1:?}, {result2:?}" + ); + if successes == 1 { + assert_eq!( + failures, 1, + "the losing racing create should fail, got: {result1:?}, {result2:?}" + ); + } else { + assert_eq!( + failures, 0, + "serialized upserts should not fail, got: {result1:?}, {result2:?}" + ); + let mut versions = [&result1, &result2] + .into_iter() + .map(|result| result.as_ref().expect("success").route.version) + .collect::>(); + versions.sort_unstable(); + assert_eq!( + versions, + vec![1, 2], + "serialized create-then-update should return versions 1 and 2" + ); + } + + // Only one route should exist. + let route = store + .get_message_by_name::(CLUSTER_INFERENCE_ROUTE_NAME) + .await + .expect("fetch") + .expect("route should exist"); + let expected_version = if successes == 1 { 1 } else { 2 }; + assert_eq!(route.version, expected_version); + } + + #[tokio::test] + async fn concurrent_upsert_route_update_uses_cas() { + let store = test_store().await; + + let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-test"); + store.put_message(&provider).await.expect("persist"); + + // Create initial route + upsert_cluster_inference_route( + &store, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-3.5", + 0, + false, + ) + .await + .expect("initial create should succeed"); + + // Spawn two concurrent updates + let store1 = store.clone(); + let handle1 = tokio::spawn(async move { + upsert_cluster_inference_route( + &store1, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4o", + 0, + false, + ) + .await + }); + + let store2 = store.clone(); + let handle2 = tokio::spawn(async move { + upsert_cluster_inference_route( + &store2, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4.1", + 0, + false, + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // One should succeed, one may fail with ABORTED due to CAS conflict + let successes = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + + assert!( + successes >= 1, + "at least one update should succeed, got: {result1:?}, {result2:?}" + ); + + // The route should have one of the new model values and version 2 + let route = store + .get_message_by_name::(CLUSTER_INFERENCE_ROUTE_NAME) + .await + .expect("fetch") + .expect("route should exist"); + let config = route.config.expect("config"); + assert!( + config.model_id == "gpt-4o" || config.model_id == "gpt-4.1", + "model should be one of the updated values, got {}", + config.model_id + ); + assert_ne!( + config.model_id, "gpt-3.5", + "model should not be the original value" + ); + assert!( + route.version >= 2 && route.version <= 3, + "version should be 2 (one update won, one conflicted) or 3 (both succeeded sequentially), got {}", + route.version + ); + } } diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index a80301c12..b7e145bde 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -20,17 +20,22 @@ //! [`compute::vm`]; keep this file driver-agnostic going forward. mod auth; +pub mod certgen; pub mod cli; mod compute; +pub mod config_file; +mod defaults; mod grpc; mod http; mod inference; mod multiplex; mod persistence; pub(crate) mod policy_store; +mod provider_refresh; mod sandbox_index; mod sandbox_watch; -mod ssh_tunnel; +mod service_routing; +mod ssh_sessions; pub mod supervisor_session; mod tls; pub mod tracing_bus; @@ -41,15 +46,21 @@ use openshell_core::{ComputeDriverKind, Config, Error, Result}; use std::collections::HashMap; use std::io::ErrorKind; use std::net::SocketAddr; +use std::path::PathBuf; +#[cfg(test)] +use std::sync::LazyLock; use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::watch; use tracing::{debug, error, info, warn}; +#[cfg(test)] +pub(crate) static TEST_ENV_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + use compute::{ComputeRuntime, DockerComputeConfig, VmComputeConfig}; pub use grpc::OpenShellService; -pub use http::{health_router, http_router, metrics_router}; +pub use http::{health_router, http_router, metrics_router, service_http_router}; pub use multiplex::{MultiplexService, MultiplexedService}; use openshell_driver_kubernetes::KubernetesComputeConfig; use persistence::Store; @@ -100,6 +111,22 @@ pub struct ServerState { /// OIDC JWKS cache for JWT validation. `None` when OIDC is not configured. pub oidc_cache: Option>, + + /// Gateway-minted sandbox JWT issuer. `None` when `config.gateway_jwt` + /// is not configured; in that mode `IssueSandboxToken` returns + /// `Status::unavailable`. Populated at startup from the on-disk key + /// material that `certgen` writes. + pub sandbox_jwt_issuer: Option>, + + /// Authenticator that validates gateway-minted sandbox JWTs on every + /// inbound request. Always set when `sandbox_jwt_issuer` is, so callers + /// presenting a freshly minted token are recognized. + pub sandbox_jwt_authenticator: Option>, + + /// Optional K8s `ServiceAccount` authenticator that backs the + /// `IssueSandboxToken` bootstrap path. Only present when the gateway + /// runs in-cluster. + pub k8s_sa_authenticator: Option>, } fn is_benign_tls_handshake_failure(error: &std::io::Error) -> bool { @@ -109,6 +136,15 @@ fn is_benign_tls_handshake_failure(error: &std::io::Error) -> bool { ) } +fn is_benign_connection_close(error: &(dyn std::error::Error + 'static)) -> bool { + let msg = error.to_string(); + msg.contains("connection closed") + || msg.contains("connection reset") + || msg.contains("connection error") + || msg.contains("error reading a body from connection") + || msg.contains("broken pipe") +} + impl ServerState { /// Create new server state. #[must_use] @@ -135,6 +171,9 @@ impl ServerState { settings_mutex: tokio::sync::Mutex::new(()), supervisor_sessions, oidc_cache, + sandbox_jwt_issuer: None, + sandbox_jwt_authenticator: None, + k8s_sa_authenticator: None, } } } @@ -150,21 +189,13 @@ pub async fn run_server( config: Config, vm_config: VmComputeConfig, docker_config: DockerComputeConfig, + config_file: Option, tracing_log_bus: TracingLogBus, ) -> Result<()> { let database_url = config.database_url.trim(); if database_url.is_empty() { return Err(Error::config("database_url is required")); } - let driver = configured_compute_driver(&config)?; - if config.ssh_handshake_secret.is_empty() - && !matches!(driver, ComputeDriverKind::Docker | ComputeDriverKind::Vm) - { - return Err(Error::config( - "ssh_handshake_secret is required. Set --ssh-handshake-secret or OPENSHELL_SSH_HANDSHAKE_SECRET", - )); - } - let store = Arc::new(Store::connect(database_url).await?); let oidc_cache = if let Some(ref oidc) = config.oidc { @@ -192,6 +223,7 @@ pub async fn run_server( &config, &vm_config, &docker_config, + config_file.as_ref(), store.clone(), sandbox_index.clone(), sandbox_watch_bus.clone(), @@ -199,7 +231,7 @@ pub async fn run_server( supervisor_sessions.clone(), ) .await?; - let state = Arc::new(ServerState::new( + let mut state = ServerState::new( config.clone(), store.clone(), compute, @@ -208,7 +240,95 @@ pub async fn run_server( tracing_log_bus, supervisor_sessions, oidc_cache, - )); + ); + + // Load the gateway-minted sandbox JWT signing key when configured. + // Optional so single-driver dev deployments without certgen continue + // to start. The helm-deployed gateway and the RPM init script populate + // `gateway_jwt` once `certgen` has produced the on-disk material. + if let Some(ref jwt) = config.gateway_jwt { + let signing_pem = std::fs::read(&jwt.signing_key_path).map_err(|e| { + Error::config(format!( + "failed to read sandbox JWT signing key from {}: {e}", + jwt.signing_key_path.display() + )) + })?; + let public_pem = std::fs::read(&jwt.public_key_path).map_err(|e| { + Error::config(format!( + "failed to read sandbox JWT public key from {}: {e}", + jwt.public_key_path.display() + )) + })?; + let kid = std::fs::read_to_string(&jwt.kid_path) + .map_err(|e| { + Error::config(format!( + "failed to read sandbox JWT kid from {}: {e}", + jwt.kid_path.display() + )) + })? + .trim() + .to_string(); + if kid.is_empty() { + return Err(Error::config(format!( + "sandbox JWT kid file {} is empty", + jwt.kid_path.display() + ))); + } + let issuer = auth::sandbox_jwt::SandboxJwtIssuer::from_pem( + &signing_pem, + kid.clone(), + &jwt.gateway_id, + Duration::from_secs(jwt.ttl_secs), + ) + .map_err(Error::config)?; + let authenticator = + auth::sandbox_jwt::SandboxJwtAuthenticator::from_pem(&public_pem, kid, &jwt.gateway_id) + .map_err(Error::config)?; + info!( + gateway_id = %jwt.gateway_id, + ttl_secs = jwt.ttl_secs, + "gateway-minted sandbox JWT enabled" + ); + state.sandbox_jwt_issuer = Some(Arc::new(issuer)); + state.sandbox_jwt_authenticator = Some(Arc::new(authenticator)); + } + + // K8s ServiceAccount bootstrap authenticator. Only constructed when + // the gateway is running in-cluster (kubelet provides the API host + // env var) and has a sandbox JWT issuer to mint replacements against; + // outside the cluster we can't call the apiserver's TokenReview API, + // and without the issuer there's nothing to exchange the SA token for. + if state.sandbox_jwt_issuer.is_some() && std::env::var_os("KUBERNETES_SERVICE_HOST").is_some() { + // Pod lookups and TokenReview identity checks must match the sandbox + // namespace and service account used by the Kubernetes driver. + let kubernetes_config = kubernetes_config_for_k8s_sa_bootstrap(config_file.as_ref())?; + let sandbox_namespace = kubernetes_config.namespace; + let sandbox_service_account = kubernetes_config.service_account_name; + match kube::Client::try_default().await { + Ok(client) => { + let resolver = Arc::new(auth::k8s_sa::LiveK8sResolver::new( + client, + &sandbox_namespace, + "openshell-gateway".to_string(), + sandbox_service_account.clone(), + )); + let authenticator = auth::k8s_sa::K8sServiceAccountAuthenticator::new(resolver); + state.k8s_sa_authenticator = Some(Arc::new(authenticator)); + info!( + namespace = %sandbox_namespace, + service_account = %sandbox_service_account, + "K8s ServiceAccount bootstrap authenticator enabled" + ); + } + Err(e) => warn!( + error = %e, + "in-cluster K8s client construction failed; \ + K8s ServiceAccount bootstrap is disabled" + ), + } + } + + let state = Arc::new(state); // Resume sandboxes that were stopped during the previous gateway // shutdown so the running compute state matches the persisted store. @@ -219,8 +339,9 @@ pub async fn run_server( } state.compute.spawn_watchers(); - ssh_tunnel::spawn_session_reaper(store.clone(), Duration::from_secs(3600)); + ssh_sessions::spawn_session_reaper(store.clone(), Duration::from_secs(3600)); supervisor_session::spawn_relay_reaper(state.clone(), Duration::from_secs(30)); + provider_refresh::spawn_refresh_worker(state.clone(), Duration::from_secs(60)); // Create the multiplexed service let service = MultiplexService::new(state.clone()); @@ -287,8 +408,8 @@ pub async fn run_server( Some(TlsAcceptor::from_files( &tls.cert_path, &tls.key_path, - &tls.client_ca_path, - tls.allow_unauthenticated, + tls.client_ca_path.as_deref(), + tls.require_client_auth, )?) } else { info!("TLS disabled — accepting plaintext connections"); @@ -297,12 +418,14 @@ pub async fn run_server( let (shutdown_tx, shutdown_rx) = watch::channel(false); let mut listener_tasks = Vec::with_capacity(gateway_listeners.len()); + let enable_loopback_service_http = config.service_routing.enable_loopback_service_http; for (listener, listen_addr) in gateway_listeners { listener_tasks.push(tokio::spawn(serve_gateway_listener( listener, listen_addr, service.clone(), tls_acceptor.clone(), + enable_loopback_service_http, shutdown_rx.clone(), ))); } @@ -362,6 +485,7 @@ async fn serve_gateway_listener( listen_addr: SocketAddr, service: MultiplexService, tls_acceptor: Option, + enable_loopback_service_http: bool, mut shutdown: watch::Receiver, ) { loop { @@ -383,37 +507,140 @@ async fn serve_gateway_listener( } }; - spawn_gateway_connection(stream, addr, service.clone(), tls_acceptor.clone()); + spawn_gateway_connection( + stream, + addr, + listen_addr, + service.clone(), + tls_acceptor.clone(), + enable_loopback_service_http, + ); + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum ConnectionProtocol { + Tls, + PlainHttp, + Unknown, +} + +async fn classify_connection_protocol(stream: &TcpStream) -> std::io::Result { + let mut prefix = [0_u8; 8]; + let read = stream.peek(&mut prefix).await?; + Ok(classify_initial_bytes(&prefix[..read])) +} + +fn classify_initial_bytes(prefix: &[u8]) -> ConnectionProtocol { + if looks_like_tls(prefix) { + ConnectionProtocol::Tls + } else if looks_like_http(prefix) { + ConnectionProtocol::PlainHttp + } else { + ConnectionProtocol::Unknown + } +} + +fn looks_like_tls(prefix: &[u8]) -> bool { + prefix.len() >= 3 && prefix[0] == 0x16 && prefix[1] == 0x03 +} + +fn looks_like_http(prefix: &[u8]) -> bool { + const METHODS: [&[u8]; 10] = [ + b"GET ", + b"POST ", + b"PUT ", + b"PATCH ", + b"DELETE ", + b"HEAD ", + b"OPTIONS ", + b"TRACE ", + b"CONNECT ", + b"PRI ", + ]; + + if prefix.is_empty() { + return false; } + METHODS + .iter() + .any(|method| method.starts_with(prefix) || prefix.starts_with(method)) +} + +fn allow_plaintext_service_http( + enabled: bool, + listen_addr: SocketAddr, + peer_addr: SocketAddr, +) -> bool { + enabled && listen_addr.ip().is_loopback() && peer_addr.ip().is_loopback() } fn spawn_gateway_connection( stream: TcpStream, addr: SocketAddr, + listen_addr: SocketAddr, service: MultiplexService, tls_acceptor: Option, + enable_loopback_service_http: bool, ) { if let Some(acceptor) = tls_acceptor { tokio::spawn(async move { - match acceptor.inner().accept(stream).await { - Ok(tls_stream) => { - if let Err(e) = service.serve(tls_stream).await { - error!(error = %e, client = %addr, "Connection error"); + match classify_connection_protocol(&stream).await { + Ok(ConnectionProtocol::PlainHttp) + if allow_plaintext_service_http( + enable_loopback_service_http, + listen_addr, + addr, + ) => + { + if let Err(e) = service.serve_service_http(stream).await { + if is_benign_connection_close(e.as_ref()) { + debug!(error = %e, client = %addr, listen = %listen_addr, "Plaintext service HTTP connection closed"); + } else { + error!(error = %e, client = %addr, listen = %listen_addr, "Plaintext service HTTP connection error"); + } } } - Err(e) => { - if is_benign_tls_handshake_failure(&e) { - debug!(error = %e, client = %addr, "TLS handshake closed early"); - } else { - error!(error = %e, client = %addr, "TLS handshake failed"); + Ok(ConnectionProtocol::PlainHttp) => { + warn!(client = %addr, listen = %listen_addr, "Rejected plaintext HTTP on non-loopback gateway listener"); + } + Ok(ConnectionProtocol::Tls | ConnectionProtocol::Unknown) => { + match acceptor.inner().accept(stream).await { + Ok(tls_stream) => { + let peer_identity = multiplex::extract_peer_identity(&tls_stream); + if let Err(e) = service + .serve_with_peer_identity(tls_stream, peer_identity) + .await + { + if is_benign_connection_close(e.as_ref()) { + debug!(error = %e, client = %addr, "Connection closed"); + } else { + error!(error = %e, client = %addr, "Connection error"); + } + } + } + Err(e) => { + if is_benign_tls_handshake_failure(&e) { + debug!(error = %e, client = %addr, "TLS handshake closed early"); + } else { + error!(error = %e, client = %addr, "TLS handshake failed"); + } + } } } + Err(e) => { + debug!(error = %e, client = %addr, "Failed to inspect connection preface"); + } } }); } else { tokio::spawn(async move { if let Err(e) = service.serve(stream).await { - error!(error = %e, client = %addr, "Connection error"); + if is_benign_connection_close(e.as_ref()) { + debug!(error = %e, client = %addr, "Connection closed"); + } else { + error!(error = %e, client = %addr, "Connection error"); + } } }); } @@ -459,6 +686,7 @@ async fn build_compute_runtime( config: &Config, vm_config: &VmComputeConfig, docker_config: &DockerComputeConfig, + file: Option<&config_file::ConfigFile>, store: Arc, sandbox_index: SandboxIndex, sandbox_watch_bus: SandboxWatchBus, @@ -470,35 +698,12 @@ async fn build_compute_runtime( match driver { ComputeDriverKind::Kubernetes => { - let supervisor_image = std::env::var("OPENSHELL_SUPERVISOR_IMAGE") - .ok() - .filter(|s| !s.is_empty()) - .unwrap_or_else(|| openshell_core::config::DEFAULT_SUPERVISOR_IMAGE.to_string()); - let supervisor_image_pull_policy = - std::env::var("OPENSHELL_SUPERVISOR_IMAGE_PULL_POLICY") - .ok() - .filter(|s| !s.is_empty()) - .unwrap_or_default(); + let mut k8s = kubernetes_config_from_file(file)?; + if let Ok(size) = std::env::var("OPENSHELL_K8S_WORKSPACE_DEFAULT_STORAGE_SIZE") { + k8s.workspace_default_storage_size = size; + } ComputeRuntime::new_kubernetes( - KubernetesComputeConfig { - namespace: config.sandbox_namespace.clone(), - default_image: config.sandbox_image.clone(), - image_pull_policy: config.sandbox_image_pull_policy.clone(), - supervisor_image, - supervisor_image_pull_policy, - grpc_endpoint: config.grpc_endpoint.clone(), - // Filesystem path to the supervisor's Unix-socket SSH daemon. - // The path lives in a root-only directory so only the - // supervisor can connect; the gateway reaches it through the - // RelayStream bridge, not directly. Override via - // `sandbox_ssh_socket_path` in the config for deployments - // where multiple supervisors share a filesystem. - ssh_socket_path: config.sandbox_ssh_socket_path.clone(), - ssh_handshake_secret: config.ssh_handshake_secret.clone(), - ssh_handshake_skew_secs: config.ssh_handshake_skew_secs, - client_tls_secret_name: config.client_tls_secret_name.clone(), - host_gateway_ip: config.host_gateway_ip.clone(), - }, + k8s, store, sandbox_index, sandbox_watch_bus, @@ -534,63 +739,15 @@ async fn build_compute_runtime( .map_err(|e| Error::execution(format!("failed to create compute runtime: {e}"))) } ComputeDriverKind::Podman => { - let socket_path = std::env::var("OPENSHELL_PODMAN_SOCKET") - .ok() - .filter(|s| !s.is_empty()) - .map_or_else( - openshell_driver_podman::PodmanComputeConfig::default_socket_path, - std::path::PathBuf::from, - ); - - let network_name = std::env::var("OPENSHELL_NETWORK_NAME") - .ok() - .filter(|s| !s.is_empty()) - .unwrap_or_else(|| openshell_core::config::DEFAULT_NETWORK_NAME.to_string()); - - let stop_timeout_secs: u32 = std::env::var("OPENSHELL_STOP_TIMEOUT") - .ok() - .and_then(|s| s.parse().ok()) - .unwrap_or(openshell_core::config::DEFAULT_STOP_TIMEOUT_SECS); - - let supervisor_image = std::env::var("OPENSHELL_SUPERVISOR_IMAGE") - .ok() - .filter(|s| !s.is_empty()) - .unwrap_or_else(|| openshell_core::config::DEFAULT_SUPERVISOR_IMAGE.to_string()); - - // TLS client cert paths for sandbox mTLS. When all three are - // set, the Podman driver bind-mounts them into sandbox - // containers and switches the endpoint to https://. - let podman_tls_ca = std::env::var("OPENSHELL_PODMAN_TLS_CA") - .ok() - .filter(|s| !s.is_empty()) - .map(std::path::PathBuf::from); - let podman_tls_cert = std::env::var("OPENSHELL_PODMAN_TLS_CERT") - .ok() - .filter(|s| !s.is_empty()) - .map(std::path::PathBuf::from); - let podman_tls_key = std::env::var("OPENSHELL_PODMAN_TLS_KEY") - .ok() - .filter(|s| !s.is_empty()) - .map(std::path::PathBuf::from); + let mut podman = podman_config_from_file(file)?; + podman.gateway_port = config.bind_address.port(); + if let Ok(p) = std::env::var("OPENSHELL_PODMAN_SOCKET") { + podman.socket_path = PathBuf::from(p); + } + apply_podman_local_tls_defaults(config, &mut podman)?; ComputeRuntime::new_podman( - openshell_driver_podman::PodmanComputeConfig { - socket_path, - default_image: config.sandbox_image.clone(), - image_pull_policy: config.sandbox_image_pull_policy.parse().unwrap_or_default(), - grpc_endpoint: config.grpc_endpoint.clone(), - gateway_port: config.bind_address.port(), - sandbox_ssh_socket_path: config.sandbox_ssh_socket_path.clone(), - network_name, - ssh_port: config.sandbox_ssh_port, - ssh_handshake_secret: config.ssh_handshake_secret.clone(), - ssh_handshake_skew_secs: config.ssh_handshake_skew_secs, - stop_timeout_secs, - supervisor_image, - guest_tls_ca: podman_tls_ca, - guest_tls_cert: podman_tls_cert, - guest_tls_key: podman_tls_key, - }, + podman, store, sandbox_index, sandbox_watch_bus, @@ -603,14 +760,94 @@ async fn build_compute_runtime( } } +/// Build a [`KubernetesComputeConfig`] from the file's +/// `[openshell.drivers.kubernetes]` table merged with inheritable +/// `[openshell.gateway]` defaults. Falls back to the driver's `Default` +/// when no file is present. +fn kubernetes_config_from_file( + file: Option<&config_file::ConfigFile>, +) -> Result { + let Some(file) = file else { + return Ok(KubernetesComputeConfig::default()); + }; + let merged = config_file::driver_table( + ComputeDriverKind::Kubernetes, + &file.openshell.gateway, + file.openshell.drivers.get("kubernetes"), + ); + merged + .try_into() + .map_err(|e| Error::config(format!("invalid [openshell.drivers.kubernetes] table: {e}"))) +} + +fn kubernetes_config_for_k8s_sa_bootstrap( + file: Option<&config_file::ConfigFile>, +) -> Result { + let Some(file) = file else { + return Err(Error::config( + "K8s ServiceAccount bootstrap requires [openshell.drivers.kubernetes] when sandbox JWT issuing is enabled in-cluster", + )); + }; + if !file.openshell.drivers.contains_key("kubernetes") { + return Err(Error::config( + "K8s ServiceAccount bootstrap requires [openshell.drivers.kubernetes] when sandbox JWT issuing is enabled in-cluster", + )); + } + kubernetes_config_from_file(Some(file)) +} + +/// Same pattern as [`kubernetes_config_from_file`] but for Podman. +fn podman_config_from_file( + file: Option<&config_file::ConfigFile>, +) -> Result { + let Some(file) = file else { + return Ok(openshell_driver_podman::PodmanComputeConfig::default()); + }; + let merged = config_file::driver_table( + ComputeDriverKind::Podman, + &file.openshell.gateway, + file.openshell.drivers.get("podman"), + ); + merged + .try_into() + .map_err(|e| Error::config(format!("invalid [openshell.drivers.podman] table: {e}"))) +} + +fn apply_podman_local_tls_defaults( + config: &Config, + podman: &mut openshell_driver_podman::PodmanComputeConfig, +) -> Result<()> { + if config.tls.is_none() + || podman.guest_tls_ca.is_some() + || podman.guest_tls_cert.is_some() + || podman.guest_tls_key.is_some() + { + return Ok(()); + } + + let Some(paths) = defaults::complete_local_tls_paths() + .map_err(|e| Error::config(format!("failed to resolve local TLS defaults: {e}")))? + else { + return Ok(()); + }; + podman.guest_tls_ca = Some(paths.ca); + podman.guest_tls_cert = Some(paths.client_cert); + podman.guest_tls_key = Some(paths.client_key); + Ok(()) +} + fn configured_compute_driver(config: &Config) -> Result { match config.compute_drivers.as_slice() { - [] => openshell_core::config::detect_driver().ok_or_else(|| { - Error::config( + [] => match openshell_core::config::detect_driver() { + Some(ComputeDriverKind::Vm) => Err(Error::config( + "vm compute driver is opt-in only; set --drivers vm or OPENSHELL_DRIVERS=vm", + )), + Some(driver) => Ok(driver), + None => Err(Error::config( "no compute driver configured and auto-detection found no suitable driver; \ set --drivers or OPENSHELL_DRIVERS to kubernetes, podman, docker, or vm", - ) - }), + )), + }, [ driver @ (ComputeDriverKind::Kubernetes | ComputeDriverKind::Vm @@ -631,11 +868,170 @@ fn configured_compute_driver(config: &Config) -> Result { #[cfg(test)] mod tests { use super::{ - configured_compute_driver, gateway_listener_addresses, is_benign_tls_handshake_failure, + ConnectionProtocol, MultiplexService, ServerState, TlsAcceptor, + allow_plaintext_service_http, classify_initial_bytes, configured_compute_driver, + gateway_listener_addresses, is_benign_tls_handshake_failure, + kubernetes_config_for_k8s_sa_bootstrap, serve_gateway_listener, + }; + use openshell_core::{ + ComputeDriverKind, Config, + proto::{HealthRequest, open_shell_client::OpenShellClient}, }; - use openshell_core::{ComputeDriverKind, Config}; - use std::io::{Error, ErrorKind}; + use rcgen::{CertificateParams, IsCa, KeyPair}; + use std::io::{Error, ErrorKind, Write}; use std::net::SocketAddr; + use std::sync::Arc; + use std::time::Duration; + use tempfile::{TempDir, tempdir}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::{TcpListener, TcpStream}; + use tokio::sync::watch; + + fn install_rustls_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); + } + + fn test_tls_acceptor() -> (TempDir, TlsAcceptor) { + install_rustls_provider(); + + let mut ca_params = + CertificateParams::new(Vec::::new()).expect("failed to create CA params"); + ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + ca_params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-ca"); + let ca_key = KeyPair::generate().expect("failed to generate CA key"); + let ca_cert = ca_params + .self_signed(&ca_key) + .expect("failed to sign CA cert"); + + let server_params = CertificateParams::new(vec!["localhost".to_string()]) + .expect("failed to create server params"); + let server_key = KeyPair::generate().expect("failed to generate server key"); + let server_cert = server_params + .signed_by(&server_key, &ca_cert, &ca_key) + .expect("failed to sign server cert"); + + let dir = tempdir().expect("failed to create tempdir"); + let write_file = |name: &str, data: &[u8]| { + let path = dir.path().join(name); + std::fs::File::create(&path) + .and_then(|mut file| file.write_all(data)) + .expect("failed to write tls test file"); + }; + write_file("ca.pem", ca_cert.pem().as_bytes()); + write_file("server-cert.pem", server_cert.pem().as_bytes()); + write_file("server-key.pem", server_key.serialize_pem().as_bytes()); + + let acceptor = TlsAcceptor::from_files( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + Some(&dir.path().join("ca.pem")), + false, + ) + .expect("failed to build tls acceptor"); + + (dir, acceptor) + } + + async fn test_state( + bind_addr: SocketAddr, + enable_loopback_service_http: bool, + ) -> Arc { + let store = Arc::new( + crate::persistence::Store::connect("sqlite::memory:?cache=shared") + .await + .expect("failed to create test store"), + ); + let compute = crate::compute::new_test_runtime(store.clone()).await; + Arc::new(ServerState::new( + Config::new(None) + .with_database_url("sqlite::memory:?cache=shared") + .with_bind_address(bind_addr) + .with_server_sans(["*.dev.openshell.localhost"]) + .with_loopback_service_http(enable_loopback_service_http), + store, + compute, + crate::sandbox_index::SandboxIndex::new(), + crate::sandbox_watch::SandboxWatchBus::new(), + crate::tracing_bus::TracingLogBus::new(), + Arc::new(crate::supervisor_session::SupervisorSessionRegistry::new()), + None, + )) + } + + async fn start_tls_gateway_listener( + bind_addr: &str, + enable_loopback_service_http: bool, + ) -> ( + SocketAddr, + watch::Sender, + tokio::task::JoinHandle<()>, + TempDir, + ) { + let listener = TcpListener::bind(bind_addr) + .await + .expect("failed to bind test listener"); + let listen_addr = listener.local_addr().expect("failed to read local addr"); + let state = test_state(listen_addr, enable_loopback_service_http).await; + let service = MultiplexService::new(state); + let (tls_dir, tls_acceptor) = test_tls_acceptor(); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = tokio::spawn(serve_gateway_listener( + listener, + listen_addr, + service, + Some(tls_acceptor), + enable_loopback_service_http, + shutdown_rx, + )); + (listen_addr, shutdown_tx, handle, tls_dir) + } + + async fn send_plain_http(addr: SocketAddr, request: String) -> String { + let connect_addr: SocketAddr = format!("127.0.0.1:{}", addr.port()) + .parse() + .expect("failed to build loopback connect addr"); + let mut stream = TcpStream::connect(connect_addr) + .await + .expect("failed to connect to test listener"); + stream + .write_all(request.as_bytes()) + .await + .expect("failed to write request"); + + let mut response = Vec::new(); + let read_result = + tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut response)) + .await + .expect("timed out reading response"); + if let Err(err) = read_result + && err.kind() != ErrorKind::ConnectionReset + { + panic!("failed to read response: {err}"); + } + String::from_utf8_lossy(&response).into_owned() + } + + fn service_request(addr: SocketAddr, extra_headers: &[(&str, &str)]) -> String { + let mut request = format!( + "GET / HTTP/1.1\r\nHost: my-sandbox--web.dev.openshell.localhost:{}\r\nConnection: close\r\n", + addr.port() + ); + for (name, value) in extra_headers { + request.push_str(name); + request.push_str(": "); + request.push_str(value); + request.push_str("\r\n"); + } + request.push_str("\r\n"); + request + } + + async fn stop_listener(shutdown: watch::Sender, handle: tokio::task::JoinHandle<()>) { + let _ = shutdown.send(true); + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + } #[test] fn classifies_probe_style_tls_disconnects_as_benign() { @@ -657,6 +1053,159 @@ mod tests { } } + #[test] + fn classifies_tls_and_plain_http_prefaces() { + assert_eq!( + classify_initial_bytes(&[0x16, 0x03, 0x01, 0x00]), + ConnectionProtocol::Tls + ); + assert_eq!( + classify_initial_bytes(b"GET / HTTP/1.1\r\n"), + ConnectionProtocol::PlainHttp + ); + assert_eq!(classify_initial_bytes(b"G"), ConnectionProtocol::PlainHttp); + assert_eq!( + classify_initial_bytes(b"\x00\x01\x02"), + ConnectionProtocol::Unknown + ); + } + + #[test] + fn plaintext_service_http_requires_loopback_listener_and_peer() { + let loopback: SocketAddr = "127.0.0.1:8080".parse().unwrap(); + let peer: SocketAddr = "127.0.0.1:54000".parse().unwrap(); + let wildcard: SocketAddr = "0.0.0.0:8080".parse().unwrap(); + let remote_peer: SocketAddr = "192.0.2.10:54000".parse().unwrap(); + + assert!(allow_plaintext_service_http(true, loopback, peer)); + assert!(!allow_plaintext_service_http(false, loopback, peer)); + assert!(!allow_plaintext_service_http(true, wildcard, peer)); + assert!(!allow_plaintext_service_http(true, loopback, remote_peer)); + } + + #[tokio::test] + async fn plaintext_service_http_listener_rejects_non_loopback_bind() { + let (addr, shutdown, handle, _tls_dir) = + start_tls_gateway_listener("0.0.0.0:0", true).await; + + let response = send_plain_http(addr, service_request(addr, &[])).await; + + assert!( + response.is_empty(), + "non-loopback gateway listener should drop plaintext service HTTP, got: {response:?}" + ); + stop_listener(shutdown, handle).await; + } + + #[tokio::test] + async fn plaintext_service_http_rejects_cross_origin_browser_contexts() { + let (addr, shutdown, handle, _tls_dir) = + start_tls_gateway_listener("127.0.0.1:0", true).await; + let cases = [ + ( + "cross-site fetch metadata", + vec![("Sec-Fetch-Site", "cross-site")], + ), + ( + "same-site sibling fetch metadata", + vec![("Sec-Fetch-Site", "same-site")], + ), + ( + "mismatched origin", + vec![( + "Origin", + "http://other-sandbox--web.dev.openshell.localhost:8080", + )], + ), + ( + "mismatched referer", + vec![( + "Referer", + "http://other-sandbox--web.dev.openshell.localhost:8080/page", + )], + ), + ]; + + for (name, headers) in cases { + let response = send_plain_http(addr, service_request(addr, &headers)).await; + + assert!( + response.starts_with("HTTP/1.1 403 Forbidden"), + "{name} should be rejected before service lookup, got: {response:?}" + ); + assert!( + response.contains("Cross-origin service request rejected"), + "{name} should explain the service rejection, got: {response:?}" + ); + } + stop_listener(shutdown, handle).await; + } + + #[tokio::test] + async fn plaintext_service_http_allows_same_origin_browser_context_to_reach_service_lookup() { + let (addr, shutdown, handle, _tls_dir) = + start_tls_gateway_listener("127.0.0.1:0", true).await; + let origin = format!( + "http://my-sandbox--web.dev.openshell.localhost:{}", + addr.port() + ); + let response = send_plain_http( + addr, + service_request( + addr, + &[("Sec-Fetch-Site", "same-origin"), ("Origin", &origin)], + ), + ) + .await; + + assert!( + response.starts_with("HTTP/1.1 404 Not Found"), + "same-origin browser context should pass CSRF guard and miss only because no endpoint exists, got: {response:?}" + ); + assert!( + !response.contains("Cross-origin service request rejected"), + "same-origin browser context should not be rejected as cross-origin, got: {response:?}" + ); + stop_listener(shutdown, handle).await; + } + + #[tokio::test] + async fn plaintext_service_http_does_not_expose_grpc_gateway() { + let (addr, shutdown, handle, _tls_dir) = + start_tls_gateway_listener("127.0.0.1:0", true).await; + let grpc_endpoint = format!("http://127.0.0.1:{}", addr.port()); + let grpc_succeeded = tokio::time::timeout(Duration::from_secs(2), async { + match OpenShellClient::connect(grpc_endpoint).await { + Ok(mut client) => client.health(HealthRequest {}).await.is_ok(), + Err(_) => false, + } + }) + .await + .expect("timed out checking plaintext gRPC exposure"); + + assert!( + !grpc_succeeded, + "plaintext service HTTP must not expose successful gateway gRPC" + ); + + let request = format!( + "POST /openshell.v1.OpenShell/Health HTTP/1.1\r\nHost: 127.0.0.1:{}\r\nContent-Type: application/grpc\r\nTE: trailers\r\nContent-Length: 0\r\nConnection: close\r\n\r\n", + addr.port() + ); + + let response = send_plain_http(addr, request).await; + + assert!( + response.starts_with("HTTP/1.1 404 Not Found"), + "plaintext service HTTP router should not serve gateway gRPC, got: {response:?}" + ); + assert!( + !response.contains("grpc-status: 0"), + "plaintext service HTTP must not return a successful gRPC response: {response:?}" + ); + stop_listener(shutdown, handle).await; + } + #[test] fn configured_compute_driver_triggers_auto_detection_when_empty() { let config = Config::new(None).with_compute_drivers([]); @@ -664,7 +1213,7 @@ mod tests { // depending on the environment. This test verifies the auto-detection path // is taken rather than immediately returning an error. let result = configured_compute_driver(&config); - // Either we get a detected driver or an error about none being detected + // Either we get a detected driver or an error about none being detected. match result { Ok(driver) => { assert!( @@ -680,7 +1229,7 @@ mod tests { Err(e) => { assert!( e.to_string() - .contains("no compute driver configured and none detected"), + .contains("auto-detection found no suitable driver"), "unexpected error: {e}" ); } @@ -726,6 +1275,35 @@ mod tests { ); } + #[test] + fn k8s_sa_bootstrap_rejects_missing_kubernetes_driver_config() { + let err = kubernetes_config_for_k8s_sa_bootstrap(None).unwrap_err(); + assert!(err.to_string().contains("[openshell.drivers.kubernetes]")); + + let file: crate::config_file::ConfigFile = + toml::from_str("[openshell.gateway]\n").expect("valid config"); + let err = kubernetes_config_for_k8s_sa_bootstrap(Some(&file)).unwrap_err(); + assert!(err.to_string().contains("[openshell.drivers.kubernetes]")); + } + + #[test] + fn k8s_sa_bootstrap_uses_configured_namespace_and_service_account() { + let file: crate::config_file::ConfigFile = toml::from_str( + r#" +[openshell.gateway] + +[openshell.drivers.kubernetes] +namespace = "sandboxes" +service_account_name = "sandbox-sa" +"#, + ) + .expect("valid config"); + + let cfg = kubernetes_config_for_k8s_sa_bootstrap(Some(&file)).unwrap(); + assert_eq!(cfg.namespace, "sandboxes"); + assert_eq!(cfg.service_account_name, "sandbox-sa"); + } + #[test] fn gateway_listener_addresses_skip_driver_address_covered_by_wildcard() { let primary: SocketAddr = "0.0.0.0:8080".parse().unwrap(); diff --git a/crates/openshell-server/src/multiplex.rs b/crates/openshell-server/src/multiplex.rs index 93e58d202..4fcb3993a 100644 --- a/crates/openshell-server/src/multiplex.rs +++ b/crates/openshell-server/src/multiplex.rs @@ -14,6 +14,7 @@ use hyper::body::Incoming; use hyper_util::{ rt::{TokioExecutor, TokioIo}, server::conn::auto::Builder, + service::TowerToHyperService, }; use metrics::{counter, histogram}; use openshell_core::proto::{ @@ -30,8 +31,15 @@ use tower_http::request_id::{MakeRequestId, RequestId}; use tracing::Span; use crate::{ - OpenShellService, ServerState, auth::authz::AuthzPolicy, auth::oidc, http_router, + OpenShellService, ServerState, + auth::authenticator::AuthenticatorChain, + auth::authz::AuthzPolicy, + auth::identity::Identity, + auth::oidc::{self, OidcAuthenticator}, + auth::principal::{Principal, UserPrincipal}, + http_router, inference::InferenceService, + service_http_router, }; /// Request-ID generator that produces a UUID v4 for each inbound request. @@ -128,6 +136,18 @@ impl MultiplexService { /// Serve a connection, routing to gRPC or HTTP based on content-type. pub async fn serve(&self, stream: S) -> Result<(), Box> + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + self.serve_with_peer_identity(stream, None).await + } + + /// Serve a TLS connection with an optional mTLS peer identity. + pub async fn serve_with_peer_identity( + &self, + stream: S, + peer_identity: Option, + ) -> Result<(), Box> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -140,11 +160,19 @@ impl MultiplexService { user_role: oidc.user_role.clone(), scopes_enabled: !oidc.scopes_claim.is_empty(), }); - let grpc_service = AuthGrpcRouter::new( + let authenticator_chain = build_authenticator_chain(&self.state); + let grpc_service = AuthGrpcRouter::with_peer_identity( GrpcRouter::new(openshell, inference), - self.state.oidc_cache.clone(), + authenticator_chain, authz_policy, - self.state.config.ssh_handshake_secret.clone(), + self.state + .config + .mtls_auth + .enabled + .then_some(peer_identity) + .flatten(), + self.state.config.mtls_auth.enabled, + self.state.config.auth.allow_unauthenticated_users, ); let http_service = http_router(self.state.clone()); @@ -153,13 +181,6 @@ impl MultiplexService { let service = MultiplexedService::new(grpc_service, http_service); - // HTTP/2 adaptive flow control. Default windows (64 KiB / 64 KiB) - // throttle the RelayStream data plane to ~500 Mbps on LAN. Instead - // of committing to a fixed large window (which worst-case pins - // `max_concurrent_streams × stream_window` bytes per connection), - // we let hyper/h2 auto-size based on the measured bandwidth-delay - // product. Idle streams stay tiny; busy bulk streams grow as - // needed. Overrides any fixed initial_*_window_size settings. let mut builder = Builder::new(TokioExecutor::new()); builder.http2().adaptive_window(true); @@ -169,6 +190,25 @@ impl MultiplexService { Ok(()) } + + /// Serve a plaintext HTTP connection for sandbox service endpoints only. + pub async fn serve_service_http( + &self, + stream: S, + ) -> Result<(), Box> + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let http_service = TowerToHyperService::new(request_id_middleware!(service_http_router( + self.state.clone() + ))); + + Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(TokioIo::new(stream), http_service) + .await?; + + Ok(()) + } } /// Combined gRPC service that routes between `OpenShell` and Inference services @@ -224,40 +264,117 @@ where } } -/// gRPC router wrapper that authenticates and authorizes requests. +/// Assemble the authenticator chain for the gateway. +/// +/// Chain order (first-match-wins): +/// 1. `K8sServiceAccountAuthenticator` (path-scoped to `IssueSandboxToken`) +/// — exchanges a projected SA token for a `Principal::Sandbox` so the +/// `IssueSandboxToken` handler can mint a gateway JWT. No-op on every +/// other path; only present when the gateway runs in-cluster. +/// 2. `SandboxJwtAuthenticator` — validates gateway-minted JWTs. Recognized +/// via a distinctive `kid` so non-matching Bearer tokens fall through. +/// 3. `OidcAuthenticator` — validates user Bearer tokens against the +/// configured OIDC issuer. Returns `Unauthenticated` for missing +/// Bearer headers so non-OIDC clients can't sneak through. /// -/// When `oidc_cache` is `Some`, extracts the `authorization: Bearer ` -/// header, validates the JWT (authentication), then checks RBAC roles -/// (authorization) before forwarding to the inner gRPC router. +/// Once sandbox authentication is configured, callers must present an +/// explicit credential for authenticated gRPC methods. Missing bearer auth +/// is promoted to an mTLS user only when `mtls_auth.enabled` is configured +/// for local single-user gateways, or to an unsafe local developer user when +/// `auth.allow_unauthenticated_users` is explicitly enabled. /// -/// Authentication is provider-specific (currently OIDC via `oidc.rs`). -/// Authorization is provider-agnostic (via `authz.rs`). This separation -/// aligns with RFC 0001's control-plane identity design. +/// When neither OIDC nor gateway-minted JWTs are configured (a barebones +/// dev gateway), the chain is left as `None` so the router short-circuits +/// to pass-through. +fn build_authenticator_chain(state: &ServerState) -> Option { + let mut authenticators: Vec> = Vec::new(); + if let Some(k8s) = state.k8s_sa_authenticator.clone() { + authenticators.push(k8s); + } + if let Some(jwt) = state.sandbox_jwt_authenticator.clone() { + authenticators.push(jwt); + } + if let Some(cache) = state.oidc_cache.clone() { + authenticators.push(Arc::new(OidcAuthenticator::new(cache))); + } + if authenticators.is_empty() { + return None; + } + Some(AuthenticatorChain::new(authenticators)) +} + +/// gRPC router wrapper that runs the [`AuthenticatorChain`] and inserts the +/// resulting [`Principal`] into the request's extensions. +/// +/// Behavior: +/// - Strip any external `x-openshell-auth-source` marker first (so callers +/// cannot spoof a sandbox identity). +/// - Health probes / reflection bypass the chain entirely. +/// - When no chain is configured (OIDC not configured), forward without +/// authentication — preserves today's pass-through behavior. +/// - Otherwise, run the chain. The first match produces a `Principal`. +/// `Principal::User` is gated by the RBAC `AuthzPolicy`. +/// `Principal::Sandbox` is gated by a supervisor-method allowlist, then +/// handlers enforce same-sandbox scope on request bodies. #[derive(Clone)] pub struct AuthGrpcRouter { inner: S, - oidc_cache: Option>, + authenticator_chain: Option, authz_policy: Option, - /// SSH handshake secret used to validate sandbox-to-server RPCs. - sandbox_secret: String, + /// mTLS peer identity extracted from the TLS handshake. + peer_identity: Option, + mtls_auth_enabled: bool, + allow_unauthenticated_users: bool, } impl AuthGrpcRouter { + #[cfg(test)] fn new( inner: S, - oidc_cache: Option>, + authenticator_chain: Option, authz_policy: Option, - sandbox_secret: String, + ) -> Self { + Self::with_peer_identity(inner, authenticator_chain, authz_policy, None, false, false) + } + + fn with_peer_identity( + inner: S, + authenticator_chain: Option, + authz_policy: Option, + peer_identity: Option, + mtls_auth_enabled: bool, + allow_unauthenticated_users: bool, ) -> Self { Self { inner, - oidc_cache, + authenticator_chain, authz_policy, - sandbox_secret, + peer_identity, + mtls_auth_enabled, + allow_unauthenticated_users, } } } +fn unauthenticated_dev_user_principal() -> Principal { + Principal::User(UserPrincipal { + identity: Identity { + subject: "unauthenticated-local-dev".to_string(), + display_name: Some("Unauthenticated Local Dev".to_string()), + roles: vec!["openshell-user".to_string(), "openshell-admin".to_string()], + scopes: vec!["openshell:all".to_string()], + provider: crate::auth::identity::IdentityProvider::LocalDev, + }, + }) +} + +fn status_response(status: tonic::Status) -> Response { + let response = status.into_http(); + let (parts, body) = response.into_parts(); + let body = tonic::body::BoxBody::new(body); + Response::from_parts(parts, body) +} + impl tower::Service> for AuthGrpcRouter where S: tower::Service, Response = Response> @@ -277,19 +394,15 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let oidc_cache = self.oidc_cache.clone(); + let chain = self.authenticator_chain.clone(); let authz_policy = self.authz_policy.clone(); - let sandbox_secret = self.sandbox_secret.clone(); + let peer_identity = self.peer_identity.clone(); + let mtls_auth_enabled = self.mtls_auth_enabled; + let allow_unauthenticated_users = self.allow_unauthenticated_users; let mut inner = self.inner.clone(); Box::pin(async move { let mut req = req; - oidc::clear_internal_auth_markers(req.headers_mut()); - - // If OIDC is not configured, pass through directly. - let Some(cache) = oidc_cache else { - return inner.ready().await?.call(req).await; - }; let path = req.uri().path().to_string(); @@ -298,66 +411,58 @@ where return inner.ready().await?.call(req).await; } - // Sandbox-to-server RPCs — authenticated via shared secret, - // not OIDC Bearer tokens. - if oidc::is_sandbox_secret_method(&path) { - if let Err(status) = oidc::validate_sandbox_secret(req.headers(), &sandbox_secret) { - let response = status.into_http(); - let (parts, body) = response.into_parts(); - let body = tonic::body::BoxBody::new(body); - return Ok(Response::from_parts(parts, body)); + let principal = if let Some(chain) = chain { + match chain.authenticate(req.headers(), &path).await { + Ok(Some(p)) => p, + Ok(None) => match (mtls_auth_enabled, peer_identity) { + (true, Some(identity)) => Principal::User(UserPrincipal { identity }), + _ if allow_unauthenticated_users => unauthenticated_dev_user_principal(), + _ => { + return Ok(status_response(tonic::Status::unauthenticated( + "missing authorization header", + ))); + } + }, + Err(status) => return Ok(status_response(status)), } - oidc::mark_sandbox_secret_authenticated(req.headers_mut()); - return inner.ready().await?.call(req).await; - } - - // Dual-auth methods (e.g. UpdateConfig) — accept either a - // Bearer token (CLI users) or sandbox secret (supervisor). - if oidc::is_dual_auth_method(&path) - && oidc::validate_sandbox_secret(req.headers(), &sandbox_secret).is_ok() - { - oidc::mark_sandbox_secret_authenticated(req.headers_mut()); + } else if mtls_auth_enabled { + let Some(identity) = peer_identity else { + return Ok(status_response(tonic::Status::unauthenticated( + "missing client certificate", + ))); + }; + Principal::User(UserPrincipal { identity }) + } else if allow_unauthenticated_users { + unauthenticated_dev_user_principal() + } else { + // No auth configured — pass through for dev / + // fronting-proxy deployments. return inner.ready().await?.call(req).await; - } - // Fall through to Bearer token validation below. - - // Extract Bearer token from the authorization header. - let token = req - .headers() - .get("authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.strip_prefix("Bearer ")); - - let Some(token) = token else { - let status = tonic::Status::unauthenticated("missing authorization header"); - let response = status.into_http(); - // Convert the response body type. - let (parts, body) = response.into_parts(); - let body = tonic::body::BoxBody::new(body); - return Ok(Response::from_parts(parts, body)); }; - // Authenticate: validate the JWT and produce an Identity. - let identity = match cache.validate_token(token).await { - Ok(id) => id, - Err(status) => { - let response = status.into_http(); - let (parts, body) = response.into_parts(); - let body = tonic::body::BoxBody::new(body); - return Ok(Response::from_parts(parts, body)); + match principal { + Principal::User(ref user) => { + if let Some(ref policy) = authz_policy + && let Err(status) = policy.check(&user.identity, &path) + { + return Ok(status_response(status)); + } + } + Principal::Sandbox(_) => { + if !crate::auth::sandbox_methods::is_sandbox_callable(&path) { + return Ok(status_response(tonic::Status::permission_denied( + "sandbox principals may not call this method", + ))); + } + } + Principal::Anonymous => { + return Ok(status_response(tonic::Status::unauthenticated( + "anonymous callers may not call authenticated methods", + ))); } - }; - - // Authorize: check RBAC roles against the method. - if let Some(ref policy) = authz_policy - && let Err(status) = policy.check(&identity, &path) - { - let response = status.into_http(); - let (parts, body) = response.into_parts(); - let body = tonic::body::BoxBody::new(body); - return Ok(Response::from_parts(parts, body)); } + req.extensions_mut().insert(principal); inner.ready().await?.call(req).await }) } @@ -470,13 +575,49 @@ fn grpc_status_from_response(res: &Response) -> String { fn normalize_http_path(path: &str) -> &'static str { match path { - p if p.starts_with("/connect/ssh") => "/connect/ssh", p if p.starts_with("/_ws_tunnel") => "/_ws_tunnel", p if p.starts_with("/auth/") => "/auth", _ => "unknown", } } +/// Extract an [`Identity`] from the peer certificates presented during a TLS +/// handshake. Returns `None` if no client certificate was presented. +pub fn extract_peer_identity(tls_stream: &tokio_rustls::server::TlsStream) -> Option +where + S: AsyncRead + AsyncWrite + Unpin, +{ + use crate::auth::identity::IdentityProvider; + use x509_parser::prelude::*; + + let (_, server_conn) = tls_stream.get_ref(); + let certs = server_conn.peer_certificates()?; + let first = certs.first()?; + + let (_, cert) = X509Certificate::from_der(first.as_ref()).ok()?; + let subject = cert.subject(); + + let cn = subject + .iter_common_name() + .next() + .and_then(|attr| attr.as_str().ok()) + .unwrap_or("unknown") + .to_string(); + + let roles: Vec = subject + .iter_organizational_unit() + .filter_map(|attr| attr.as_str().ok().map(String::from)) + .collect(); + + Some(Identity { + subject: cn.clone(), + display_name: Some(cn), + roles, + scopes: Vec::new(), + provider: IdentityProvider::Mtls, + }) +} + /// Boxed body type for uniform handling. pub struct BoxBody( http_body_util::combinators::UnsyncBoxBody>, @@ -724,19 +865,6 @@ mod tests { assert_eq!(grpc_method_from_path(""), ""); } - #[test] - fn normalize_ssh_path() { - assert_eq!(normalize_http_path("/connect/ssh"), "/connect/ssh"); - } - - #[test] - fn normalize_ssh_path_with_trailing_segments() { - assert_eq!( - normalize_http_path("/connect/ssh?token=abc"), - "/connect/ssh" - ); - } - #[test] fn normalize_ws_tunnel() { assert_eq!(normalize_http_path("/_ws_tunnel"), "/_ws_tunnel"); @@ -774,4 +902,366 @@ mod tests { fn normalize_root_path() { assert_eq!(normalize_http_path("/"), "unknown"); } + + mod auth_router { + use super::*; + use crate::auth::authenticator::test_support::MockAuthenticator; + use crate::auth::identity::{Identity, IdentityProvider}; + use crate::auth::principal::{ + Principal, SandboxIdentitySource, SandboxPrincipal, UserPrincipal, + }; + use http_body_util::Full; + use std::sync::Arc; + use std::sync::Mutex; + use tower::Service; + + type RecordedPrincipal = Arc>>; + + /// Service that snapshots the `Principal` from request extensions + /// and returns 200 OK. Used by router-level tests to assert the + /// chain's effect on the downstream service. + #[derive(Clone)] + struct PrincipalRecorder { + recorded: RecordedPrincipal, + } + + impl PrincipalRecorder { + fn new() -> (Self, RecordedPrincipal) { + let recorded = Arc::new(Mutex::new(None)); + ( + Self { + recorded: recorded.clone(), + }, + recorded, + ) + } + } + + impl Service> for PrincipalRecorder { + type Response = Response; + type Error = std::convert::Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let principal = req.extensions().get::().cloned(); + *self.recorded.lock().unwrap() = principal; + Box::pin(async move { + let body = tonic::body::BoxBody::new( + Full::new(Bytes::new()) + .map_err(|never| match never {}) + .boxed_unsync(), + ); + Ok(Response::new(body)) + }) + } + } + + fn empty_request(path: &str) -> Request> { + Request::builder() + .uri(path) + .body(Full::new(Bytes::new())) + .unwrap() + } + + fn grpc_status(res: &Response) -> Option { + res.headers() + .get("grpc-status") + .map(|v| v.to_str().unwrap().to_string()) + } + + fn user_principal(subject: &str) -> Principal { + Principal::User(UserPrincipal { + identity: Identity { + subject: subject.to_string(), + display_name: None, + roles: vec![], + scopes: vec![], + provider: IdentityProvider::Oidc, + }, + }) + } + + fn mtls_identity(subject: &str) -> Identity { + Identity { + subject: subject.to_string(), + display_name: Some(subject.to_string()), + roles: vec!["openshell-user".to_string()], + scopes: vec![], + provider: IdentityProvider::Mtls, + } + } + + fn sandbox_principal() -> Principal { + Principal::Sandbox(SandboxPrincipal { + sandbox_id: "sandbox-a".to_string(), + source: SandboxIdentitySource::BootstrapJwt { + issuer: "openshell-gateway:test".to_string(), + }, + trust_domain: Some("openshell".to_string()), + }) + } + + #[tokio::test] + async fn mtls_peer_identity_fills_missing_principal_when_enabled() { + let mock = Arc::new(MockAuthenticator::returning(Ok(None))); + let chain = AuthenticatorChain::new(vec![mock]); + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::with_peer_identity( + recorder, + Some(chain), + None, + Some(mtls_identity("openshell-client")), + true, + false, + ); + + let res = router + .call(empty_request("/openshell.v1.OpenShell/ListSandboxes")) + .await + .unwrap(); + + assert_eq!(res.status(), 200); + let principal = seen.lock().unwrap().clone().expect("principal"); + match principal { + Principal::User(u) => { + assert_eq!(u.identity.subject, "openshell-client"); + assert_eq!(u.identity.provider, IdentityProvider::Mtls); + } + other => panic!("expected mTLS user principal, got {other:?}"), + } + } + + #[tokio::test] + async fn mtls_peer_identity_authenticates_without_chain_when_enabled() { + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::with_peer_identity( + recorder, + None, + None, + Some(mtls_identity("openshell-client")), + true, + false, + ); + + let res = router + .call(empty_request("/openshell.v1.OpenShell/ListSandboxes")) + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert!(matches!( + seen.lock().unwrap().as_ref(), + Some(Principal::User(_)) + )); + } + + #[tokio::test] + async fn mtls_auth_enabled_requires_peer_identity() { + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = + AuthGrpcRouter::with_peer_identity(recorder, None, None, None, true, false); + + let res = router + .call(empty_request("/openshell.v1.OpenShell/ListSandboxes")) + .await + .unwrap(); + + assert!(seen.lock().unwrap().is_none()); + assert_eq!(grpc_status(&res).as_deref(), Some("16")); + } + + #[tokio::test] + async fn unauthenticated_dev_user_fills_missing_principal_when_enabled() { + let mock = Arc::new(MockAuthenticator::returning(Ok(None))); + let chain = AuthenticatorChain::new(vec![mock]); + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = + AuthGrpcRouter::with_peer_identity(recorder, Some(chain), None, None, false, true); + + let res = router + .call(empty_request("/openshell.v1.OpenShell/ListSandboxes")) + .await + .unwrap(); + + assert_eq!(res.status(), 200); + let principal = seen.lock().unwrap().clone().expect("principal"); + match principal { + Principal::User(u) => { + assert_eq!(u.identity.subject, "unauthenticated-local-dev"); + assert_eq!(u.identity.provider, IdentityProvider::LocalDev); + } + other => panic!("expected dev user principal, got {other:?}"), + } + } + + #[tokio::test] + async fn unauthenticated_dev_user_authenticates_without_chain_when_enabled() { + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = + AuthGrpcRouter::with_peer_identity(recorder, None, None, None, false, true); + + let res = router + .call(empty_request("/openshell.v1.OpenShell/ListSandboxes")) + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert!(matches!( + seen.lock().unwrap().as_ref(), + Some(Principal::User(user)) + if user.identity.subject == "unauthenticated-local-dev" + )); + } + + #[tokio::test] + async fn user_principal_lands_in_request_extensions() { + let mock = Arc::new(MockAuthenticator::returning(Ok(Some(user_principal( + "alice", + ))))); + let chain = AuthenticatorChain::new(vec![mock]); + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::new(recorder, Some(chain), None); + let _ = router + .call(empty_request("/openshell.v1.OpenShell/ListSandboxes")) + .await + .unwrap(); + let principal = seen.lock().unwrap().clone().expect("principal"); + match principal { + Principal::User(u) => assert_eq!(u.identity.subject, "alice"), + _ => panic!("expected user principal"), + } + } + + #[tokio::test] + async fn sandbox_principal_lands_in_request_extensions() { + let mock = Arc::new(MockAuthenticator::returning(Ok(Some(sandbox_principal())))); + let chain = AuthenticatorChain::new(vec![mock]); + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::new(recorder, Some(chain), None); + let _ = router + .call(empty_request("/openshell.v1.OpenShell/ReportPolicyStatus")) + .await + .unwrap(); + let captured = seen.lock().unwrap().clone(); + match captured { + Some(Principal::Sandbox(p)) => assert_eq!(p.sandbox_id, "sandbox-a"), + other => panic!("expected sandbox principal, got {other:?}"), + } + } + + #[tokio::test] + async fn sandbox_principal_can_call_allowlisted_method() { + let mock = Arc::new(MockAuthenticator::returning(Ok(Some(sandbox_principal())))); + let chain = AuthenticatorChain::new(vec![mock]); + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::new(recorder, Some(chain), None); + + let res = router + .call(empty_request("/openshell.v1.OpenShell/GetSandboxConfig")) + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert!(matches!( + seen.lock().unwrap().as_ref(), + Some(Principal::Sandbox(_)) + )); + } + + #[tokio::test] + async fn sandbox_principal_can_fetch_inference_bundle() { + let mock = Arc::new(MockAuthenticator::returning(Ok(Some(sandbox_principal())))); + let chain = AuthenticatorChain::new(vec![mock]); + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::new(recorder, Some(chain), None); + + let res = router + .call(empty_request( + "/openshell.inference.v1.Inference/GetInferenceBundle", + )) + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert!(matches!( + seen.lock().unwrap().as_ref(), + Some(Principal::Sandbox(_)) + )); + } + + #[tokio::test] + async fn sandbox_principal_is_denied_on_user_and_admin_methods() { + for path in [ + "/openshell.v1.OpenShell/ListSandboxes", + "/openshell.v1.OpenShell/DeleteSandbox", + "/openshell.v1.OpenShell/CreateProvider", + "/openshell.v1.OpenShell/ApproveDraftChunk", + "/openshell.inference.v1.Inference/GetClusterInference", + "/openshell.inference.v1.Inference/SetClusterInference", + ] { + let mock = Arc::new(MockAuthenticator::returning(Ok(Some(sandbox_principal())))); + let chain = AuthenticatorChain::new(vec![mock]); + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::new(recorder, Some(chain), None); + + let res = router.call(empty_request(path)).await.unwrap(); + + assert!(seen.lock().unwrap().is_none(), "{path} reached handler"); + assert_eq!(grpc_status(&res).as_deref(), Some("7"), "{path}"); + } + } + + #[tokio::test] + async fn missing_principal_returns_unauthenticated() { + let mock = Arc::new(MockAuthenticator::returning(Ok(None))); + let chain = AuthenticatorChain::new(vec![mock]); + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::new(recorder, Some(chain), None); + let res = router + .call(empty_request("/openshell.v1.OpenShell/ListSandboxes")) + .await + .unwrap(); + assert!(seen.lock().unwrap().is_none()); + // tonic sets grpc-status=16 (UNAUTHENTICATED) in trailers. + assert_eq!(grpc_status(&res).as_deref(), Some("16")); + } + + #[tokio::test] + async fn authenticator_error_short_circuits() { + let mock = Arc::new(MockAuthenticator::returning(Err( + tonic::Status::unauthenticated("forged"), + ))); + let chain = AuthenticatorChain::new(vec![mock]); + let (recorder, seen) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::new(recorder, Some(chain), None); + let res = router + .call(empty_request("/openshell.v1.OpenShell/ListSandboxes")) + .await + .unwrap(); + assert!(seen.lock().unwrap().is_none()); + assert_eq!(grpc_status(&res).as_deref(), Some("16")); + } + + #[tokio::test] + async fn health_methods_bypass_chain() { + // Authenticator is wired to fail-closed; the request still gets + // through because the path is exempt. + let mock = Arc::new(MockAuthenticator::returning(Err( + tonic::Status::unauthenticated("would reject"), + ))); + let chain = AuthenticatorChain::new(vec![mock.clone()]); + let (recorder, _) = PrincipalRecorder::new(); + let mut router = AuthGrpcRouter::new(recorder, Some(chain), None); + let res = router + .call(empty_request("/openshell.v1.OpenShell/Health")) + .await + .unwrap(); + assert_eq!(res.status(), 200); + assert_eq!(mock.call_count(), 0, "health must not consult the chain"); + } + } } diff --git a/crates/openshell-server/src/persistence/mod.rs b/crates/openshell-server/src/persistence/mod.rs index 1c926bd4a..32875a9f9 100644 --- a/crates/openshell-server/src/persistence/mod.rs +++ b/crates/openshell-server/src/persistence/mod.rs @@ -14,12 +14,16 @@ use openshell_core::{Error as CoreError, Result as CoreResult}; use prost::Message; use rand::Rng; use std::collections::HashMap; -use std::time::{SystemTime, UNIX_EPOCH}; use thiserror::Error; pub use postgres::PostgresStore; pub use sqlite::SqliteStore; +/// Object type string for sandbox policy records. +pub const POLICY_OBJECT_TYPE: &str = "sandbox_policy"; +/// Object type string for draft policy chunk records. +pub const DRAFT_CHUNK_OBJECT_TYPE: &str = "draft_policy_chunk"; + pub type PersistenceResult = Result; /// Persistence-layer error type. @@ -41,6 +45,10 @@ pub enum PersistenceError { detail: Option, constraint_msg: String, }, + #[error("resource version conflict: expected version does not match current")] + Conflict { + current_resource_version: Option, + }, } impl PersistenceError { @@ -78,6 +86,28 @@ pub struct ObjectRecord { pub updated_at_ms: i64, /// JSON-serialized labels (key-value pairs). pub labels: Option, + /// Optimistic concurrency control version. + /// Incremented on each update for compare-and-swap operations. + pub resource_version: u64, +} + +/// Write condition for compare-and-swap operations. +#[derive(Debug, Clone, Copy)] +pub enum WriteCondition { + /// Object must not exist (insert only). + MustCreate, + /// Object must exist with the specified resource version (update only). + MatchResourceVersion(u64), + /// Unconditional write (insert or update). + Unconditional, +} + +/// Result of a successful write operation. +#[derive(Debug, Clone)] +pub struct WriteResult { + pub resource_version: u64, + pub created_at_ms: i64, + pub updated_at_ms: i64, } /// Persistence store implementations. @@ -94,7 +124,9 @@ pub trait ObjectType { // Import object metadata accessor traits from openshell-core // (implementations for all proto types are in openshell-core::metadata) -pub use openshell_core::{ObjectId, ObjectLabels, ObjectName}; +pub use openshell_core::{ + GetResourceVersion, ObjectId, ObjectLabels, ObjectName, SetResourceVersion, +}; /// Generate a random 6-character lowercase alphabetic name. pub fn generate_name() -> String { @@ -132,18 +164,95 @@ impl Store { } } - /// Insert or update a generic named object. - pub async fn put( + /// Insert or update a generic object with compare-and-swap support. + /// + /// # Arguments + /// * `object_type` - Type discriminator for the object + /// * `id` - Stable object identifier + /// * `name` - Human-readable object name + /// * `payload` - Serialized object data + /// * `labels` - Optional JSON-serialized labels + /// * `condition` - Write precondition (`MustCreate`, `MatchResourceVersion`, or `Unconditional`) + /// + /// # Returns + /// * `Ok(WriteResult)` - Write succeeded with new `resource_version` and timestamps + /// * `Err(Conflict)` - Resource version mismatch (for `MatchResourceVersion`) + /// * `Err(UniqueViolation)` - Object already exists (for `MustCreate`) or name conflict + pub async fn put_if( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + labels: Option<&str>, + condition: WriteCondition, + ) -> PersistenceResult { + match self { + Self::Postgres(store) => { + store + .put_if(object_type, id, name, payload, labels, condition) + .await + } + Self::Sqlite(store) => { + store + .put_if(object_type, id, name, payload, labels, condition) + .await + } + } + } + + /// Delete an object by id with compare-and-swap support. + /// + /// # Arguments + /// * `object_type` - Type discriminator for the object + /// * `id` - Stable object identifier + /// * `expected_resource_version` - Required resource version for the delete to proceed + /// + /// # Returns + /// * `Ok(true)` - Object was deleted + /// * `Ok(false)` - Object not found + /// * `Err(Conflict)` - Resource version mismatch + pub async fn delete_if( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + ) -> PersistenceResult { + match self { + Self::Postgres(store) => { + store + .delete_if(object_type, id, expected_resource_version) + .await + } + Self::Sqlite(store) => { + store + .delete_if(object_type, id, expected_resource_version) + .await + } + } + } + + /// Insert or update a generic named object with an application-owned scope. + pub async fn put_scoped( &self, object_type: &str, id: &str, name: &str, + scope: &str, payload: &[u8], labels: Option<&str>, ) -> PersistenceResult<()> { match self { - Self::Postgres(store) => store.put(object_type, id, name, payload, labels).await, - Self::Sqlite(store) => store.put(object_type, id, name, payload, labels).await, + Self::Postgres(store) => { + store + .put_scoped(object_type, id, name, scope, payload, labels) + .await + } + Self::Sqlite(store) => { + store + .put_scoped(object_type, id, name, scope, payload, labels) + .await + } } } @@ -200,6 +309,20 @@ impl Store { } } + /// List objects by type and application-owned scope. + pub async fn list_by_scope( + &self, + object_type: &str, + scope: &str, + limit: u32, + offset: u32, + ) -> PersistenceResult> { + match self { + Self::Postgres(store) => store.list_by_scope(object_type, scope, limit, offset).await, + Self::Sqlite(store) => store.list_by_scope(object_type, scope, limit, offset).await, + } + } + /// List objects by type with label selector filtering. /// Label selector format: "key1=value1,key2=value2" (comma-separated equality matches). pub async fn list_with_selector( @@ -227,12 +350,14 @@ impl Store { // Generic protobuf message helpers // ----------------------------------------------------------------------- - /// Insert or update a protobuf message using its inferred object type, id, and name. - pub async fn put_message( + /// Insert or update a protobuf message under an application-owned scope. + pub async fn put_scoped_message< + T: Message + ObjectType + ObjectId + ObjectName + ObjectLabels, + >( &self, message: &T, + scope: &str, ) -> PersistenceResult<()> { - // Serialize labels to JSON let labels_map = message.object_labels(); let labels_json = if labels_map.as_ref().is_none_or(HashMap::is_empty) { None @@ -242,10 +367,11 @@ impl Store { })?) }; - self.put( + self.put_scoped( T::object_type(), message.object_id(), message.object_name(), + scope, &message.encode_to_vec(), labels_json.as_deref(), ) @@ -253,7 +379,7 @@ impl Store { } /// Fetch and decode a protobuf message by id. - pub async fn get_message( + pub async fn get_message( &self, id: &str, ) -> PersistenceResult> { @@ -262,13 +388,17 @@ impl Store { return Ok(None); }; - T::decode(record.payload.as_slice()) - .map(Some) - .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}"))) + let mut message = T::decode(record.payload.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}")))?; + + // Hydrate resource_version from DB row (authoritative source) + message.set_resource_version(record.resource_version); + + Ok(Some(message)) } /// Fetch and decode a protobuf message by name. - pub async fn get_message_by_name( + pub async fn get_message_by_name( &self, name: &str, ) -> PersistenceResult> { @@ -277,18 +407,147 @@ impl Store { return Ok(None); }; - T::decode(record.payload.as_slice()) - .map(Some) - .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}"))) + let mut message = T::decode(record.payload.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}")))?; + + // Hydrate resource_version from DB row (authoritative source) + message.set_resource_version(record.resource_version); + + Ok(Some(message)) + } + + /// List and decode protobuf messages, hydrating `resource_version` from + /// the authoritative DB row (mirrors `get_message`). + pub async fn list_messages( + &self, + limit: u32, + offset: u32, + ) -> PersistenceResult> { + let records = self.list(T::object_type(), limit, offset).await?; + let mut messages = Vec::with_capacity(records.len()); + for record in records { + let mut message = T::decode(record.payload.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}")))?; + message.set_resource_version(record.resource_version); + messages.push(message); + } + Ok(messages) + } + + /// List and decode protobuf messages with label selector filtering, + /// hydrating `resource_version` from the authoritative DB row. + pub async fn list_messages_with_selector< + T: Message + Default + ObjectType + SetResourceVersion, + >( + &self, + label_selector: &str, + limit: u32, + offset: u32, + ) -> PersistenceResult> { + let records = self + .list_with_selector(T::object_type(), label_selector, limit, offset) + .await?; + let mut messages = Vec::with_capacity(records.len()); + for record in records { + let mut message = T::decode(record.payload.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}")))?; + message.set_resource_version(record.resource_version); + messages.push(message); + } + Ok(messages) + } + + /// Update a protobuf message using CAS (compare-and-swap). + /// + /// Fetches the current object, validates the expected version, applies the + /// mutation function, and attempts a single CAS write. Returns Conflict on + /// version mismatch for caller-driven retry. + /// + /// # Arguments + /// * `id` - Object ID to update + /// * `expected_version` - Required resource version for the update to proceed. + /// Pass 0 to use the current version (internal operations only). + /// For client-facing operations, pass the client-provided expected version. + /// * `mutate` - Function that modifies the object in place + /// + /// # Returns + /// * `Ok(T)` - Successfully updated object with new `resource_version` + /// * `Err(Conflict)` - Version mismatch; caller should retry + /// * `Err(Database)` - Object not found or other DB error + pub async fn update_message_cas( + &self, + id: &str, + expected_version: u64, + mut mutate: F, + ) -> PersistenceResult + where + T: Message + + Default + + ObjectType + + ObjectId + + ObjectName + + ObjectLabels + + SetResourceVersion + + GetResourceVersion + + Clone, + F: FnMut(&mut T), + { + // Fetch current object with authoritative resource_version + let current = self + .get_message::(id) + .await? + .ok_or_else(|| PersistenceError::Database(format!("object {id} not found")))?; + + let current_version = current.get_resource_version(); + + // Determine the version to use for CAS: + // - If expected_version is 0, use current version (internal operations) + // - Otherwise, validate that expected matches current (client-facing operations) + let cas_version = if expected_version == 0 { + current_version + } else { + if expected_version != current_version { + return Err(PersistenceError::Conflict { + current_resource_version: Some(current_version), + }); + } + expected_version + }; + + // Apply mutation + let mut updated = current.clone(); + mutate(&mut updated); + + // Serialize labels + let labels_map = updated.object_labels(); + let labels_json = if labels_map.as_ref().is_none_or(HashMap::is_empty) { + None + } else { + Some(serde_json::to_string(&labels_map).map_err(|e| { + PersistenceError::Encode(format!("failed to serialize labels: {e}")) + })?) + }; + + // Single-attempt CAS write - fails with Conflict on version mismatch + let result = self + .put_if( + T::object_type(), + updated.object_id(), + updated.object_name(), + &updated.encode_to_vec(), + labels_json.as_deref(), + WriteCondition::MatchResourceVersion(cas_version), + ) + .await?; + + // Success - hydrate the new resource_version and return + updated.set_resource_version(result.resource_version); + Ok(updated) } } -pub fn current_time_ms() -> PersistenceResult { - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|e| PersistenceError::Database(format!("time error: {e}")))?; - i64::try_from(now.as_millis()) - .map_err(|e| PersistenceError::Database(format!("time conversion error: {e}"))) +pub fn current_time_ms() -> i64 { + openshell_core::time::now_ms() } fn map_db_error(error: &sqlx::Error) -> PersistenceError { @@ -363,5 +622,48 @@ pub fn parse_label_selector(selector: &str) -> PersistenceResult, + ) -> PersistenceResult<()> { + match self { + Self::Postgres(store) => store.put(object_type, id, name, payload, labels).await, + Self::Sqlite(store) => store.put(object_type, id, name, payload, labels).await, + } + } + + pub async fn put_message( + &self, + message: &T, + ) -> PersistenceResult<()> { + let labels_map = message.object_labels(); + let labels_json = if labels_map.as_ref().is_none_or(HashMap::is_empty) { + None + } else { + Some(serde_json::to_string(&labels_map).map_err(|e| { + PersistenceError::Encode(format!("failed to serialize labels: {e}")) + })?) + }; + self.put( + T::object_type(), + message.object_id(), + message.object_name(), + &message.encode_to_vec(), + labels_json.as_deref(), + ) + .await + } +} + #[cfg(test)] mod tests; diff --git a/crates/openshell-server/src/persistence/postgres.rs b/crates/openshell-server/src/persistence/postgres.rs index 2cd6a046f..8399fd734 100644 --- a/crates/openshell-server/src/persistence/postgres.rs +++ b/crates/openshell-server/src/persistence/postgres.rs @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use super::{ - DraftChunkRecord, ObjectRecord, PersistenceResult, PolicyRecord, current_time_ms, map_db_error, - map_migrate_error, + DraftChunkRecord, ObjectRecord, PersistenceError, PersistenceResult, PolicyRecord, + WriteCondition, WriteResult, current_time_ms, map_db_error, map_migrate_error, }; use crate::policy_store::{ draft_chunk_payload_from_record, draft_chunk_record_from_parts, policy_payload_from_record, @@ -14,8 +14,7 @@ use sqlx::{PgPool, Row}; static POSTGRES_MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/postgres"); -const POLICY_OBJECT_TYPE: &str = "sandbox_policy"; -const DRAFT_CHUNK_OBJECT_TYPE: &str = "draft_policy_chunk"; +use super::{DRAFT_CHUNK_OBJECT_TYPE, POLICY_OBJECT_TYPE}; #[derive(Debug, Clone)] pub struct PostgresStore { @@ -48,11 +47,11 @@ impl PostgresStore { payload: &[u8], labels: Option<&str>, ) -> PersistenceResult<()> { - let now_ms = current_time_ms()?; + let now_ms = current_time_ms(); let labels_jsonb: Option = labels .map(serde_json::from_str) .transpose() - .map_err(|e| super::PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; + .map_err(|e| PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; sqlx::query( r" @@ -76,6 +75,197 @@ ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET Ok(()) } + pub async fn put_if( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + labels: Option<&str>, + condition: WriteCondition, + ) -> PersistenceResult { + let now_ms = current_time_ms(); + let labels_jsonb: Option = labels + .map(serde_json::from_str) + .transpose() + .map_err(|e| PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; + + match condition { + WriteCondition::MustCreate => { + // Insert only - fail if object exists + let row = sqlx::query( + r" +INSERT INTO objects (object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version) +VALUES ($1, $2, $3, $4, $5, $5, COALESCE($6, '{}'::jsonb), 1) +RETURNING resource_version, created_at_ms, updated_at_ms +", + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels_jsonb) + .fetch_one(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); + Ok(WriteResult { + resource_version: resource_version_i64.max(1).cast_unsigned(), + created_at_ms: row.get("created_at_ms"), + updated_at_ms: row.get("updated_at_ms"), + }) + } + WriteCondition::MatchResourceVersion(expected_version) => { + // Update with version check using RETURNING + let row_result = sqlx::query( + r" +UPDATE objects +SET payload = $4, labels = COALESCE($5, '{}'::jsonb), updated_at_ms = $6, resource_version = resource_version + 1 +WHERE object_type = $1 AND id = $2 AND resource_version = $3 +RETURNING resource_version, created_at_ms, updated_at_ms +", + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_version).unwrap_or(i64::MAX)) + .bind(payload) + .bind(labels_jsonb) + .bind(now_ms) + .fetch_optional(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if let Some(row) = row_result { + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); + Ok(WriteResult { + resource_version: resource_version_i64.max(1).cast_unsigned(), + created_at_ms: row.get("created_at_ms"), + updated_at_ms: row.get("updated_at_ms"), + }) + } else { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }) + } else { + Err(PersistenceError::Database(format!( + "object not found: {object_type}/{id}" + ))) + } + } + } + WriteCondition::Unconditional => { + // Unconditional upsert by name + let row = sqlx::query( + r" +INSERT INTO objects (object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version) +VALUES ($1, $2, $3, $4, $5, $5, COALESCE($6, '{}'::jsonb), 1) +ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET + payload = EXCLUDED.payload, + updated_at_ms = EXCLUDED.updated_at_ms, + labels = EXCLUDED.labels, + resource_version = objects.resource_version + 1 +RETURNING resource_version, created_at_ms, updated_at_ms +", + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels_jsonb) + .fetch_one(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); + Ok(WriteResult { + resource_version: resource_version_i64.max(1).cast_unsigned(), + created_at_ms: row.get("created_at_ms"), + updated_at_ms: row.get("updated_at_ms"), + }) + } + } + } + + pub async fn delete_if( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + ) -> PersistenceResult { + let result = sqlx::query( + r" +DELETE FROM objects +WHERE object_type = $1 AND id = $2 AND resource_version = $3 +", + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_resource_version).unwrap_or(i64::MAX)) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if result.rows_affected() > 0 { + Ok(true) + } else { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }) + } else { + Ok(false) + } + } + } + + pub async fn put_scoped( + &self, + object_type: &str, + id: &str, + name: &str, + scope: &str, + payload: &[u8], + labels: Option<&str>, + ) -> PersistenceResult<()> { + let now_ms = current_time_ms(); + let labels_jsonb: Option = labels + .map(serde_json::from_str) + .transpose() + .map_err(|e| PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; + + sqlx::query( + r" +INSERT INTO objects (object_type, id, name, scope, payload, created_at_ms, updated_at_ms, labels, resource_version) +VALUES ($1, $2, $3, $4, $5, $6, $6, COALESCE($7, '{}'::jsonb), 1) +ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET + scope = EXCLUDED.scope, + payload = EXCLUDED.payload, + updated_at_ms = EXCLUDED.updated_at_ms, + labels = EXCLUDED.labels, + resource_version = objects.resource_version + 1 +", + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(scope) + .bind(payload) + .bind(now_ms) + .bind(labels_jsonb) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + Ok(()) + } + pub async fn get( &self, object_type: &str, @@ -83,7 +273,7 @@ ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET ) -> PersistenceResult> { let row = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 AND id = $2 ", @@ -104,7 +294,7 @@ WHERE object_type = $1 AND id = $2 ) -> PersistenceResult> { let row = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 AND name = $2 ", @@ -146,7 +336,7 @@ WHERE object_type = $1 AND name = $2 ) -> PersistenceResult> { let rows = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 ORDER BY created_at_ms ASC, name ASC @@ -163,6 +353,33 @@ LIMIT $2 OFFSET $3 Ok(rows.into_iter().map(row_to_object_record).collect()) } + pub async fn list_by_scope( + &self, + object_type: &str, + scope: &str, + limit: u32, + offset: u32, + ) -> PersistenceResult> { + let rows = sqlx::query( + r" +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +FROM objects +WHERE object_type = $1 AND scope = $2 +ORDER BY created_at_ms ASC, name ASC +LIMIT $3 OFFSET $4 +", + ) + .bind(object_type) + .bind(scope) + .bind(i64::from(limit)) + .bind(i64::from(offset)) + .fetch_all(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + Ok(rows.into_iter().map(row_to_object_record).collect()) + } + pub async fn list_with_selector( &self, object_type: &str, @@ -173,13 +390,12 @@ LIMIT $2 OFFSET $3 use super::parse_label_selector; let required_labels = parse_label_selector(label_selector)?; - let labels_jsonb = serde_json::to_value(&required_labels).map_err(|e| { - super::PersistenceError::Encode(format!("failed to serialize labels: {e}")) - })?; + let labels_jsonb = serde_json::to_value(&required_labels) + .map_err(|e| PersistenceError::Encode(format!("failed to serialize labels: {e}")))?; let rows = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 AND labels @> $2 ORDER BY created_at_ms ASC, name ASC @@ -205,7 +421,7 @@ LIMIT $3 OFFSET $4 payload: &[u8], hash: &str, ) -> PersistenceResult<()> { - let now_ms = current_time_ms()?; + let now_ms = current_time_ms(); let record = PolicyRecord { id: id.to_string(), sandbox_id: sandbox_id.to_string(), @@ -348,7 +564,7 @@ LIMIT $3 OFFSET $4 record.load_error = load_error.map(ToOwned::to_owned); record.loaded_at_ms = loaded_at_ms; let payload = policy_payload_from_record(&record)?; - let now_ms = current_time_ms()?; + let now_ms = current_time_ms(); let result = sqlx::query( r" @@ -374,7 +590,7 @@ WHERE object_type = $1 AND scope = $2 AND version = $3 sandbox_id: &str, before_version: i64, ) -> PersistenceResult { - let now_ms = current_time_ms()?; + let now_ms = current_time_ms(); let result = sqlx::query( r" UPDATE objects @@ -395,9 +611,16 @@ WHERE object_type = $1 Ok(result.rows_affected()) } - pub async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> PersistenceResult<()> { + pub async fn put_draft_chunk( + &self, + chunk: &DraftChunkRecord, + dedup_key: Option<&str>, + ) -> PersistenceResult { let payload = draft_chunk_payload_from_record(chunk)?; - sqlx::query( + // RETURNING id gives the row's effective id whether INSERT inserted + // a fresh row or ON CONFLICT updated an existing one. See the + // matching sqlite path for the rationale. + let row = sqlx::query( r" INSERT INTO objects ( object_type, id, scope, status, dedup_key, hit_count, payload, created_at_ms, updated_at_ms @@ -406,21 +629,22 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT (object_type, scope, dedup_key) WHERE dedup_key IS NOT NULL DO UPDATE SET hit_count = objects.hit_count + EXCLUDED.hit_count, updated_at_ms = EXCLUDED.updated_at_ms +RETURNING id ", ) .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(&chunk.id) .bind(&chunk.sandbox_id) .bind(&chunk.status) - .bind(draft_chunk_dedup_key(chunk)) + .bind(dedup_key) .bind(i64::from(chunk.hit_count)) .bind(payload) .bind(chunk.first_seen_ms) .bind(chunk.last_seen_ms) - .execute(&self.pool) + .fetch_one(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(()) + Ok(row.get::("id")) } pub async fn get_draft_chunk(&self, id: &str) -> PersistenceResult> { @@ -483,6 +707,7 @@ ORDER BY created_at_ms DESC id: &str, status: &str, decided_at_ms: Option, + rejection_reason: Option<&str>, ) -> PersistenceResult { let Some(mut record) = self.get_draft_chunk(id).await? else { return Ok(false); @@ -490,7 +715,10 @@ ORDER BY created_at_ms DESC record.status = status.to_string(); record.decided_at_ms = decided_at_ms; - record.last_seen_ms = current_time_ms()?; + record.last_seen_ms = current_time_ms(); + if let Some(reason) = rejection_reason { + record.rejection_reason = reason.to_string(); + } let payload = draft_chunk_payload_from_record(&record)?; let result = sqlx::query( @@ -525,7 +753,7 @@ WHERE object_type = $1 AND id = $2 } record.proposed_rule = proposed_rule.to_vec(); - record.last_seen_ms = current_time_ms()?; + record.last_seen_ms = current_time_ms(); let payload = draft_chunk_payload_from_record(&record)?; let result = sqlx::query( @@ -597,12 +825,9 @@ WHERE object_type = $1 AND scope = $2 } } -fn draft_chunk_dedup_key(chunk: &DraftChunkRecord) -> String { - format!("{}|{}|{}", chunk.host, chunk.port, chunk.binary) -} - fn row_to_object_record(row: sqlx::postgres::PgRow) -> ObjectRecord { let labels_jsonb: Option = row.get("labels"); + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); ObjectRecord { object_type: row.get("object_type"), id: row.get("id"), @@ -611,6 +836,7 @@ fn row_to_object_record(row: sqlx::postgres::PgRow) -> ObjectRecord { created_at_ms: row.get("created_at_ms"), updated_at_ms: row.get("updated_at_ms"), labels: labels_jsonb.map(|value| value.to_string()), + resource_version: resource_version_i64.max(1).cast_unsigned(), } } diff --git a/crates/openshell-server/src/persistence/sqlite.rs b/crates/openshell-server/src/persistence/sqlite.rs index fafb07597..bdfadc8b0 100644 --- a/crates/openshell-server/src/persistence/sqlite.rs +++ b/crates/openshell-server/src/persistence/sqlite.rs @@ -2,21 +2,22 @@ // SPDX-License-Identifier: Apache-2.0 use super::{ - DraftChunkRecord, ObjectRecord, PersistenceResult, PolicyRecord, current_time_ms, map_db_error, - map_migrate_error, + DraftChunkRecord, ObjectRecord, PersistenceError, PersistenceResult, PolicyRecord, + WriteCondition, WriteResult, current_time_ms, map_db_error, map_migrate_error, }; use crate::policy_store::{ draft_chunk_payload_from_record, draft_chunk_record_from_parts, policy_payload_from_record, policy_record_from_parts, }; +use openshell_core::paths::set_file_owner_only; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use sqlx::{Row, SqlitePool}; +use std::path::{Path, PathBuf}; use std::str::FromStr; static SQLITE_MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/sqlite"); -const POLICY_OBJECT_TYPE: &str = "sandbox_policy"; -const DRAFT_CHUNK_OBJECT_TYPE: &str = "draft_policy_chunk"; +use super::{DRAFT_CHUNK_OBJECT_TYPE, POLICY_OBJECT_TYPE}; #[derive(Debug, Clone)] pub struct SqliteStore { @@ -40,11 +41,20 @@ impl SqliteStore { pool_options = pool_options.idle_timeout(None).max_lifetime(None); } + // Capture the on-disk path before `connect_with` consumes the options + // so we can restrict the permissions after the database is connected. + let db_path = (!is_in_memory).then(|| options.get_filename().to_path_buf()); + let pool = pool_options .connect_with(options) .await .map_err(|e| map_db_error(&e))?; + // Tighten the permissions of the database file to owner-only access (0o600). + if let Some(path) = db_path { + restrict_db_file_permissions(&path)?; + } + Ok(Self { pool }) } @@ -63,7 +73,7 @@ impl SqliteStore { payload: &[u8], labels: Option<&str>, ) -> PersistenceResult<()> { - let now_ms = current_time_ms()?; + let now_ms = current_time_ms(); sqlx::query( r#" @@ -87,6 +97,191 @@ ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET Ok(()) } + pub async fn put_if( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + labels: Option<&str>, + condition: WriteCondition, + ) -> PersistenceResult { + let now_ms = current_time_ms(); + + match condition { + WriteCondition::MustCreate => { + // Insert only - fail if object exists + sqlx::query( + r#" +INSERT INTO "objects" ("object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version") +VALUES (?1, ?2, ?3, ?4, ?5, ?5, ?6, 1) +"#, + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels.unwrap_or("{}")) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + Ok(WriteResult { + resource_version: 1, + created_at_ms: now_ms, + updated_at_ms: now_ms, + }) + } + WriteCondition::MatchResourceVersion(expected_version) => { + // Update with version check + let result = sqlx::query( + r#" +UPDATE "objects" +SET "payload" = ?4, "labels" = ?5, "updated_at_ms" = ?6, "resource_version" = "resource_version" + 1 +WHERE "object_type" = ?1 AND "id" = ?2 AND "resource_version" = ?3 +"#, + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_version).unwrap_or(i64::MAX)) + .bind(payload) + .bind(labels.unwrap_or("{}")) + .bind(now_ms) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if result.rows_affected() == 0 { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + return Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }); + } + return Err(PersistenceError::Database(format!( + "object not found: {object_type}/{id}" + ))); + } + + // Fetch the updated record to get the new resource_version + let updated = self.get(object_type, id).await?.ok_or_else(|| { + PersistenceError::Database("object disappeared after update".to_string()) + })?; + + Ok(WriteResult { + resource_version: updated.resource_version, + created_at_ms: updated.created_at_ms, + updated_at_ms: updated.updated_at_ms, + }) + } + WriteCondition::Unconditional => { + // Unconditional upsert by name + sqlx::query( + r#" +INSERT INTO "objects" ("object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version") +VALUES (?1, ?2, ?3, ?4, ?5, ?5, ?6, 1) +ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET + "payload" = excluded."payload", + "updated_at_ms" = excluded."updated_at_ms", + "labels" = excluded."labels", + "resource_version" = "objects"."resource_version" + 1 +"#, + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels.unwrap_or("{}")) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + // Fetch the result to get the resource_version + let record = self.get_by_name(object_type, name).await?.ok_or_else(|| { + PersistenceError::Database("object disappeared after upsert".to_string()) + })?; + + Ok(WriteResult { + resource_version: record.resource_version, + created_at_ms: record.created_at_ms, + updated_at_ms: record.updated_at_ms, + }) + } + } + } + + pub async fn delete_if( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + ) -> PersistenceResult { + let result = sqlx::query( + r#" +DELETE FROM "objects" +WHERE "object_type" = ?1 AND "id" = ?2 AND "resource_version" = ?3 +"#, + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_resource_version).unwrap_or(i64::MAX)) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if result.rows_affected() > 0 { + Ok(true) + } else { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + return Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }); + } + Ok(false) + } + } + + pub async fn put_scoped( + &self, + object_type: &str, + id: &str, + name: &str, + scope: &str, + payload: &[u8], + labels: Option<&str>, + ) -> PersistenceResult<()> { + let now_ms = current_time_ms(); + + sqlx::query( + r#" +INSERT INTO "objects" ("object_type", "id", "name", "scope", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version") +VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?6, ?7, 1) +ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET + "scope" = excluded."scope", + "payload" = excluded."payload", + "updated_at_ms" = excluded."updated_at_ms", + "labels" = excluded."labels", + "resource_version" = "objects"."resource_version" + 1 +"#, + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(scope) + .bind(payload) + .bind(now_ms) + .bind(labels.unwrap_or("{}")) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + Ok(()) + } + pub async fn get( &self, object_type: &str, @@ -94,7 +289,7 @@ ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET ) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version" FROM "objects" WHERE "object_type" = ?1 AND "id" = ?2 "#, @@ -115,7 +310,7 @@ WHERE "object_type" = ?1 AND "id" = ?2 ) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version" FROM "objects" WHERE "object_type" = ?1 AND "name" = ?2 "#, @@ -167,7 +362,7 @@ WHERE "object_type" = ?1 AND "name" = ?2 ) -> PersistenceResult> { let rows = sqlx::query( r#" -SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version" FROM "objects" WHERE "object_type" = ?1 ORDER BY "created_at_ms" ASC, "name" ASC @@ -183,6 +378,33 @@ LIMIT ?2 OFFSET ?3 Ok(rows.into_iter().map(row_to_object_record).collect()) } + + pub async fn list_by_scope( + &self, + object_type: &str, + scope: &str, + limit: u32, + offset: u32, + ) -> PersistenceResult> { + let rows = sqlx::query( + r#" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +FROM "objects" +WHERE "object_type" = ?1 AND "scope" = ?2 +ORDER BY "created_at_ms" ASC, "name" ASC +LIMIT ?3 OFFSET ?4 +"#, + ) + .bind(object_type) + .bind(scope) + .bind(i64::from(limit)) + .bind(i64::from(offset)) + .fetch_all(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + Ok(rows.into_iter().map(row_to_object_record).collect()) + } pub async fn list_with_selector( &self, object_type: &str, @@ -220,7 +442,7 @@ LIMIT ?2 OFFSET ?3 payload: &[u8], hash: &str, ) -> PersistenceResult<()> { - let now_ms = current_time_ms()?; + let now_ms = current_time_ms(); let record = PolicyRecord { id: id.to_string(), sandbox_id: sandbox_id.to_string(), @@ -363,7 +585,7 @@ LIMIT ?3 OFFSET ?4 record.load_error = load_error.map(ToOwned::to_owned); record.loaded_at_ms = loaded_at_ms; let payload = policy_payload_from_record(&record)?; - let now_ms = current_time_ms()?; + let now_ms = current_time_ms(); let result = sqlx::query( r#" @@ -389,7 +611,7 @@ WHERE "object_type" = ?1 AND "scope" = ?2 AND "version" = ?3 sandbox_id: &str, before_version: i64, ) -> PersistenceResult { - let now_ms = current_time_ms()?; + let now_ms = current_time_ms(); let result = sqlx::query( r#" UPDATE "objects" @@ -410,9 +632,17 @@ WHERE "object_type" = ?1 Ok(result.rows_affected()) } - pub async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> PersistenceResult<()> { + pub async fn put_draft_chunk( + &self, + chunk: &DraftChunkRecord, + dedup_key: Option<&str>, + ) -> PersistenceResult { let payload = draft_chunk_payload_from_record(chunk)?; - sqlx::query( + // RETURNING "id" gives us the row's effective id regardless of + // whether INSERT inserted a fresh row or ON CONFLICT updated an + // existing one. Callers report this id to clients so the response + // can never advertise a chunk_id that isn't actually persisted. + let row = sqlx::query( r#" INSERT INTO "objects" ( "object_type", "id", "scope", "status", "dedup_key", "hit_count", "payload", "created_at_ms", "updated_at_ms" @@ -421,21 +651,22 @@ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) ON CONFLICT ("object_type", "scope", "dedup_key") WHERE "dedup_key" IS NOT NULL DO UPDATE SET "hit_count" = "objects"."hit_count" + excluded."hit_count", "updated_at_ms" = excluded."updated_at_ms" +RETURNING "id" "#, ) .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(&chunk.id) .bind(&chunk.sandbox_id) .bind(&chunk.status) - .bind(draft_chunk_dedup_key(chunk)) + .bind(dedup_key) .bind(i64::from(chunk.hit_count)) .bind(payload) .bind(chunk.first_seen_ms) .bind(chunk.last_seen_ms) - .execute(&self.pool) + .fetch_one(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(()) + Ok(row.get::("id")) } pub async fn get_draft_chunk(&self, id: &str) -> PersistenceResult> { @@ -498,6 +729,7 @@ ORDER BY "created_at_ms" DESC id: &str, status: &str, decided_at_ms: Option, + rejection_reason: Option<&str>, ) -> PersistenceResult { let Some(mut record) = self.get_draft_chunk(id).await? else { return Ok(false); @@ -505,7 +737,10 @@ ORDER BY "created_at_ms" DESC record.status = status.to_string(); record.decided_at_ms = decided_at_ms; - record.last_seen_ms = current_time_ms()?; + record.last_seen_ms = current_time_ms(); + if let Some(reason) = rejection_reason { + record.rejection_reason = reason.to_string(); + } let payload = draft_chunk_payload_from_record(&record)?; let result = sqlx::query( @@ -540,7 +775,7 @@ WHERE "object_type" = ?1 AND "id" = ?2 } record.proposed_rule = proposed_rule.to_vec(); - record.last_seen_ms = current_time_ms()?; + record.last_seen_ms = current_time_ms(); let payload = draft_chunk_payload_from_record(&record)?; let result = sqlx::query( @@ -612,11 +847,40 @@ WHERE "object_type" = ?1 AND "scope" = ?2 } } -fn draft_chunk_dedup_key(chunk: &DraftChunkRecord) -> String { - format!("{}|{}|{}", chunk.host, chunk.port, chunk.binary) +/// Restrict the on-disk `SQLite` database file (and its WAL/SHM sidecars, +/// when present) to owner-only read/write (`0o600`). +/// +/// In WAL mode, `SQLite` keeps two sidecars next to +/// the main database file: `-wal` (uncommitted page log) +/// and `-shm` (shared memory index). They mirror the same sensitive data +/// as the main file, so they get the same `0o600` treatment whenever they exist on disk. +/// +/// Delegates to `set_file_owner_only`, which is a no-op on non-Unix platforms. +pub(super) fn restrict_db_file_permissions(path: &Path) -> PersistenceResult<()> { + set_file_owner_only(path).map_err(|err| PersistenceError::Database(err.to_string()))?; + + for sidecar in sqlite_sidecar_paths(path) { + if sidecar.exists() { + set_file_owner_only(&sidecar) + .map_err(|err| PersistenceError::Database(err.to_string()))?; + } + } + Ok(()) +} + +/// Compute the WAL/SHM sidecar paths `SQLite` derives from a main database file +/// (e.g. `foo.db` -> [`foo.db-wal`, `foo.db-shm`]). +pub(super) fn sqlite_sidecar_paths(path: &Path) -> [PathBuf; 2] { + let with_suffix = |suffix: &str| -> PathBuf { + let mut buf = path.as_os_str().to_os_string(); + buf.push(suffix); + PathBuf::from(buf) + }; + [with_suffix("-wal"), with_suffix("-shm")] } fn row_to_object_record(row: sqlx::sqlite::SqliteRow) -> ObjectRecord { + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); ObjectRecord { object_type: row.get("object_type"), id: row.get("id"), @@ -625,6 +889,7 @@ fn row_to_object_record(row: sqlx::sqlite::SqliteRow) -> ObjectRecord { created_at_ms: row.get("created_at_ms"), updated_at_ms: row.get("updated_at_ms"), labels: row.get("labels"), + resource_version: resource_version_i64.max(1).cast_unsigned(), } } diff --git a/crates/openshell-server/src/persistence/tests.rs b/crates/openshell-server/src/persistence/tests.rs index bef95d4b6..123feb862 100644 --- a/crates/openshell-server/src/persistence/tests.rs +++ b/crates/openshell-server/src/persistence/tests.rs @@ -1,16 +1,20 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use super::{ObjectType, Store, generate_name}; +use super::{ObjectType, PersistenceError, Store, generate_name}; use crate::policy_store::PolicyStoreExt; use openshell_core::proto::{ObjectForTest, SandboxPolicy}; use prost::Message; +async fn test_store() -> Store { + Store::connect("sqlite::memory:?cache=shared") + .await + .expect("in-memory SQLite store should connect") +} + #[tokio::test] async fn sqlite_put_get_round_trip() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "abc", "my-sandbox", b"payload", None) @@ -26,19 +30,192 @@ async fn sqlite_put_get_round_trip() { #[tokio::test] async fn sqlite_connect_runs_embedded_migrations() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let records = store.list("sandbox", 10, 0).await.unwrap(); assert!(records.is_empty()); } +#[cfg(unix)] +#[tokio::test] +async fn sqlite_connect_restricts_db_file_permissions() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::tempdir().expect("tempdir"); + let db_path = tmp.path().join("openshell.db"); + let url = format!("sqlite:{}?mode=rwc", db_path.display()); + + let _store = Store::connect(&url).await.expect("connect to sqlite"); + + let mode = std::fs::metadata(&db_path) + .expect("db file exists") + .permissions() + .mode() + & 0o777; + assert_eq!(mode, 0o600, "expected 0600, got {mode:04o}"); +} + +#[cfg(unix)] +#[tokio::test] +async fn sqlite_connect_tightens_existing_db_file_permissions() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::tempdir().expect("tempdir"); + let db_path = tmp.path().join("openshell.db"); + let url = format!("sqlite:{}?mode=rwc", db_path.display()); + + // First connect creates the file; close the pool by dropping the store. + { + let _store = Store::connect(&url).await.expect("initial connect"); + } + + // Simulate a pre-existing database left with permissive permissions + // (e.g., from an older gateway version that lacked this hardening). + std::fs::set_permissions(&db_path, std::fs::Permissions::from_mode(0o644)) + .expect("loosen permissions"); + + let _store = Store::connect(&url).await.expect("reconnect to sqlite"); + + let mode = std::fs::metadata(&db_path) + .expect("db file exists") + .permissions() + .mode() + & 0o777; + assert_eq!(mode, 0o600, "expected 0600, got {mode:04o}"); +} + +// The next three tests cover `restrict_db_file_permissions` against the +// WAL/SHM sidecars at increasing levels of fidelity: +// +// 1. `_tightens_main_and_wal_and_shm_files`: synthetic empty files, proves +// the chmod loop walks all three paths. +// 2. `_skips_missing_sidecars`: proves the `exists()` guard, which is the +// actual production path today (sqlx 0.8 doesn't default to WAL and +// doesn't accept `journal_mode` as a URL parameter). +// 3. `_handles_real_sqlite_wal_files`: opens a real sqlx pool with +// `SqliteJournalMode::Wal` via the builder API so SQLite materializes +// real `-wal` and `-shm` files, then checks the helper tightens them. + +#[cfg(unix)] +#[test] +fn restrict_db_file_permissions_tightens_main_and_wal_and_shm_files() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::tempdir().expect("tempdir"); + let db_path = tmp.path().join("openshell.db"); + let [wal_path, shm_path] = super::sqlite::sqlite_sidecar_paths(&db_path); + + // Simulate a SQLite database in WAL mode whose three files were left + // world-readable (older gateway version, or non-zero umask at creation). + for path in [&db_path, &wal_path, &shm_path] { + std::fs::write(path, b"").expect("create file"); + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o644)).expect("set 0o644"); + } + + super::sqlite::restrict_db_file_permissions(&db_path).expect("restrict permissions"); + + for path in [&db_path, &wal_path, &shm_path] { + let mode = std::fs::metadata(path) + .expect("file exists") + .permissions() + .mode() + & 0o777; + assert_eq!( + mode, + 0o600, + "expected 0600 on {}, got {mode:04o}", + path.display() + ); + } +} + +#[cfg(unix)] +#[test] +fn restrict_db_file_permissions_skips_missing_sidecars() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::tempdir().expect("tempdir"); + let db_path = tmp.path().join("openshell.db"); + let [wal_path, shm_path] = super::sqlite::sqlite_sidecar_paths(&db_path); + + // Only the main DB file exists (non-WAL journal mode, or pre-write WAL). + std::fs::write(&db_path, b"").expect("create file"); + std::fs::set_permissions(&db_path, std::fs::Permissions::from_mode(0o644)).expect("set 0o644"); + + super::sqlite::restrict_db_file_permissions(&db_path).expect("restrict permissions"); + + assert!(!wal_path.exists(), "WAL sidecar should not be created"); + assert!(!shm_path.exists(), "SHM sidecar should not be created"); + + let mode = std::fs::metadata(&db_path) + .expect("db file exists") + .permissions() + .mode() + & 0o777; + assert_eq!(mode, 0o600, "expected 0600, got {mode:04o}"); +} + +#[cfg(unix)] +#[tokio::test] +async fn restrict_db_file_permissions_handles_real_sqlite_wal_files() { + use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions}; + use std::os::unix::fs::PermissionsExt; + use std::str::FromStr; + + let tmp = tempfile::tempdir().expect("tempdir"); + let db_path = tmp.path().join("openshell.db"); + let url = format!("sqlite:{}", db_path.display()); + + // sqlx does not parse `journal_mode` from the connection URL — callers + // must opt into WAL via the builder API. + let options = SqliteConnectOptions::from_str(&url) + .expect("parse url") + .create_if_missing(true) + .journal_mode(SqliteJournalMode::Wal); + + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect_with(options) + .await + .expect("connect with WAL"); + + // Force a write so SQLite definitely materializes a non-empty WAL on disk. + sqlx::query("CREATE TABLE _hardening_probe (x INTEGER)") + .execute(&pool) + .await + .expect("write"); + + let [wal_path, shm_path] = super::sqlite::sqlite_sidecar_paths(&db_path); + assert!(wal_path.exists(), "WAL should exist after write"); + assert!(shm_path.exists(), "SHM should exist after WAL write"); + + // Loosen permissions on every file to simulate what an older gateway + // version (or a non-zero default umask) would have left behind. + for path in [&db_path, &wal_path, &shm_path] { + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o644)) + .expect("loosen permissions"); + } + + super::sqlite::restrict_db_file_permissions(&db_path).expect("restrict permissions"); + + for path in [&db_path, &wal_path, &shm_path] { + let mode = std::fs::metadata(path) + .expect("metadata") + .permissions() + .mode() + & 0o777; + assert_eq!( + mode, + 0o600, + "expected 0600 on {}, got {mode:04o}", + path.display() + ); + } +} + #[tokio::test] async fn sqlite_updates_timestamp() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "abc", "my-sandbox", b"payload", None) @@ -59,9 +236,7 @@ async fn sqlite_updates_timestamp() { #[tokio::test] async fn sqlite_list_paging() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; for idx in 0..5 { let id = format!("id-{idx}"); @@ -81,9 +256,7 @@ async fn sqlite_list_paging() { #[tokio::test] async fn sqlite_delete_behavior() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "abc", "my-sandbox", b"payload", None) @@ -99,9 +272,7 @@ async fn sqlite_delete_behavior() { #[tokio::test] async fn sqlite_protobuf_round_trip() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let object = ObjectForTest { id: "abc".to_string(), @@ -124,9 +295,7 @@ async fn sqlite_protobuf_round_trip() { #[tokio::test] async fn sqlite_get_by_name() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "id-1", "my-sandbox", b"payload", None) @@ -148,9 +317,7 @@ async fn sqlite_get_by_name() { #[tokio::test] async fn sqlite_get_message_by_name() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let object = ObjectForTest { id: "uid-1".to_string(), @@ -178,9 +345,7 @@ async fn sqlite_get_message_by_name() { #[tokio::test] async fn sqlite_delete_by_name() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "id-1", "my-sandbox", b"payload", None) @@ -199,9 +364,7 @@ async fn sqlite_delete_by_name() { #[tokio::test] async fn sqlite_name_unique_per_object_type() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "id-1", "shared-name", b"payload1", None) @@ -231,9 +394,7 @@ async fn sqlite_name_unique_per_object_type() { #[tokio::test] async fn sqlite_id_globally_unique() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "same-id", "name-a", b"payload1", None) @@ -281,9 +442,7 @@ impl ObjectType for ObjectForTest { #[tokio::test] async fn labels_round_trip() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let labels = r#"{"env":"production","team":"platform"}"#; store @@ -303,9 +462,7 @@ async fn labels_round_trip() { #[tokio::test] async fn label_selector_single_match() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "id-1", "s1", b"p1", Some(r#"{"env":"prod"}"#)) @@ -339,9 +496,7 @@ async fn label_selector_single_match() { #[tokio::test] async fn label_selector_multiple_labels() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put( @@ -385,9 +540,7 @@ async fn label_selector_multiple_labels() { #[tokio::test] async fn label_selector_no_match() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "id-1", "s1", b"p1", Some(r#"{"env":"prod"}"#)) @@ -404,9 +557,7 @@ async fn label_selector_no_match() { #[tokio::test] async fn label_selector_respects_paging() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; for idx in 0..5 { let id = format!("id-{idx}"); @@ -438,9 +589,7 @@ async fn label_selector_respects_paging() { #[tokio::test] async fn empty_labels_not_matched_by_selector() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; store .put("sandbox", "id-1", "s1", b"p1", None) @@ -466,9 +615,7 @@ async fn empty_labels_not_matched_by_selector() { #[tokio::test] async fn policy_put_and_get_latest() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let policy_v1 = SandboxPolicy::default().encode_to_vec(); store @@ -500,9 +647,7 @@ async fn policy_put_and_get_latest() { #[tokio::test] async fn policy_get_by_version() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let policy_v1 = SandboxPolicy::default().encode_to_vec(); let policy_v2 = SandboxPolicy { @@ -541,9 +686,7 @@ async fn policy_get_by_version() { #[tokio::test] async fn policy_update_status_and_get_loaded() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let payload = SandboxPolicy::default().encode_to_vec(); store @@ -574,9 +717,7 @@ async fn policy_update_status_and_get_loaded() { #[tokio::test] async fn policy_status_failed_with_error() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let payload = SandboxPolicy::default().encode_to_vec(); store @@ -600,9 +741,7 @@ async fn policy_status_failed_with_error() { #[tokio::test] async fn policy_supersede_older() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let payload = SandboxPolicy::default().encode_to_vec(); store @@ -655,9 +794,7 @@ async fn policy_supersede_older() { #[tokio::test] async fn policy_list_ordered_by_version_desc() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let payload = SandboxPolicy::default().encode_to_vec(); store @@ -688,9 +825,7 @@ async fn policy_list_ordered_by_version_desc() { #[tokio::test] async fn policy_isolation_between_sandboxes() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); + let store = test_store().await; let policy_s1 = SandboxPolicy::default().encode_to_vec(); let policy_s2 = SandboxPolicy { @@ -785,3 +920,428 @@ fn parse_label_selector_handles_whitespace() { assert_eq!(result.get("env"), Some(&"prod".to_string())); assert_eq!(result.get("tier"), Some(&"frontend".to_string())); } + +// --------------------------------------------------------------------------- +// CAS (compare-and-swap) tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn cas_put_if_must_create_succeeds() { + use super::WriteCondition; + + let store = test_store().await; + + let result = store + .put_if( + "sandbox", + "id-1", + "new-sandbox", + b"payload", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + assert_eq!(result.resource_version, 1); + + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 1); + assert_eq!(record.payload, b"payload"); +} + +#[tokio::test] +async fn cas_put_if_must_create_fails_on_duplicate() { + use super::{PersistenceError, WriteCondition}; + + let store = test_store().await; + + // First insert succeeds + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"payload1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Second insert with same ID fails + let result = store + .put_if( + "sandbox", + "id-1", + "sandbox-2", + b"payload2", + None, + WriteCondition::MustCreate, + ) + .await; + + assert!(matches!( + result, + Err(PersistenceError::UniqueViolation { .. }) + )); +} + +#[tokio::test] +async fn cas_put_if_match_version_succeeds() { + use super::WriteCondition; + + let store = test_store().await; + + // Create initial object + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Update with correct version + let result = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v2", + None, + WriteCondition::MatchResourceVersion(1), + ) + .await + .unwrap(); + + assert_eq!(result.resource_version, 2); + + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 2); + assert_eq!(record.payload, b"v2"); +} + +#[tokio::test] +async fn cas_put_if_match_version_fails_on_mismatch() { + use super::{PersistenceError, WriteCondition}; + + let store = test_store().await; + + // Create initial object + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Update with wrong version + let result = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v2", + None, + WriteCondition::MatchResourceVersion(99), + ) + .await; + + assert!(matches!( + result, + Err(PersistenceError::Conflict { + current_resource_version: Some(1) + }) + )); + + // Original payload unchanged + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 1); + assert_eq!(record.payload, b"v1"); +} + +#[tokio::test] +async fn cas_delete_if_succeeds_with_correct_version() { + use super::WriteCondition; + + let store = test_store().await; + + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"payload", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + let deleted = store.delete_if("sandbox", "id-1", 1).await.unwrap(); + assert!(deleted); + + let record = store.get("sandbox", "id-1").await.unwrap(); + assert!(record.is_none()); +} + +#[tokio::test] +async fn cas_delete_if_fails_with_wrong_version() { + use super::{PersistenceError, WriteCondition}; + + let store = test_store().await; + + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"payload", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + let result = store.delete_if("sandbox", "id-1", 99).await; + assert!(matches!( + result, + Err(PersistenceError::Conflict { + current_resource_version: Some(1) + }) + )); + + // Object still exists + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 1); +} + +#[tokio::test] +async fn cas_resource_version_increments() { + use super::WriteCondition; + + let store = test_store().await; + + // Create + let r1 = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + assert_eq!(r1.resource_version, 1); + + // Update 1 + let r2 = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v2", + None, + WriteCondition::MatchResourceVersion(1), + ) + .await + .unwrap(); + assert_eq!(r2.resource_version, 2); + + // Update 2 + let r3 = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v3", + None, + WriteCondition::MatchResourceVersion(2), + ) + .await + .unwrap(); + assert_eq!(r3.resource_version, 3); + + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 3); +} + +#[tokio::test] +async fn cas_concurrent_updates_one_succeeds() { + use super::WriteCondition; + use std::sync::Arc; + + let store = Arc::new(test_store().await); + + // Create initial object + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"initial", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Spawn 10 concurrent updates trying to update from version 1 + let mut handles = vec![]; + for i in 0..10 { + let store = Arc::clone(&store); + let handle = tokio::spawn(async move { + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + format!("update-{i}").as_bytes(), + None, + WriteCondition::MatchResourceVersion(1), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Exactly one should succeed, rest should conflict + let successes = results.iter().filter(|r| r.is_ok()).count(); + let conflicts = results.iter().filter(|r| r.is_err()).count(); + + assert_eq!(successes, 1); + assert_eq!(conflicts, 9); + + // Final version should be 2 + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 2); +} + +#[tokio::test] +async fn cas_update_message_cas_succeeds() { + use openshell_core::proto::Sandbox; + + let store = test_store().await; + + // Create a sandbox + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "test-id".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: None, + status: None, + phase: 0, + current_policy_version: 0, + }; + + store.put_message(&sandbox).await.unwrap(); + + // Update using CAS with expected_version = 0 (use current version) + let updated = store + .update_message_cas::("test-id", 0, |s| { + s.phase = 2; // Set to Ready + s.current_policy_version = 42; + }) + .await + .unwrap(); + + assert_eq!(updated.phase, 2); + assert_eq!(updated.current_policy_version, 42); + assert_eq!( + updated.metadata.as_ref().map_or(0, |m| m.resource_version), + 2 + ); +} + +#[tokio::test] +async fn cas_update_message_cas_conflicts_on_concurrent_updates() { + use openshell_core::proto::Sandbox; + use std::sync::Arc; + + let store = Arc::new(test_store().await); + + // Create a sandbox + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "test-id".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: None, + status: None, + phase: 0, + current_policy_version: 0, + }; + + store.put_message(&sandbox).await.unwrap(); + + // Spawn 5 concurrent CAS updates using the same observed version. Passing an + // explicit version makes this deterministic: later tasks cannot re-read the + // latest committed version and legitimately succeed. + let mut handles = vec![]; + for i in 0..5 { + let store = Arc::clone(&store); + let handle = tokio::spawn(async move { + store + .update_message_cas::("test-id", 1, |s| { + s.current_policy_version = i; + }) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Only one should succeed; others fail with Conflict due to single-attempt CAS. + let successes = results.iter().filter(|r| r.is_ok()).count(); + let conflicts = results + .iter() + .filter(|r| matches!(r, Err(PersistenceError::Conflict { .. }))) + .count(); + assert_eq!(successes, 1, "exactly one concurrent update should succeed"); + assert_eq!(conflicts, 4, "four updates should fail with Conflict"); + + // Final version should be 2 (initial 1 + 1 successful update) + let final_sandbox = store + .get_message::("test-id") + .await + .unwrap() + .unwrap(); + assert_eq!( + final_sandbox + .metadata + .as_ref() + .map_or(0, |m| m.resource_version), + 2, + "resource_version should be 2 (initial 1 + 1 successful update)" + ); +} diff --git a/crates/openshell-server/src/policy_store.rs b/crates/openshell-server/src/policy_store.rs index f0a43698e..9a6333543 100644 --- a/crates/openshell-server/src/policy_store.rs +++ b/crates/openshell-server/src/policy_store.rs @@ -54,7 +54,25 @@ pub trait PolicyStoreExt { before_version: i64, ) -> PersistenceResult; - async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> PersistenceResult<()>; + /// Persist a draft chunk. When `dedup_key` is `Some`, duplicate inserts + /// for the same `(sandbox, dedup_key)` fold into the existing row's + /// `hit_count` instead of creating a second chunk — appropriate for + /// observation-driven proposals from the mechanistic mapper. When + /// `None`, the chunk is inserted unconditionally — appropriate for + /// agent-authored proposals where each submission is an intentional + /// act and the redraft loop relies on every proposal getting its own + /// `chunk_id` even when the target endpoint is unchanged. + /// + /// Returns the **effective** row id. On a fresh insert that equals + /// `chunk.id`; on dedup fold-in it is the existing row's id. Callers + /// must use the returned id (not `chunk.id`) when reporting the chunk + /// to clients — otherwise the response advertises an id that was never + /// persisted. + async fn put_draft_chunk( + &self, + chunk: &DraftChunkRecord, + dedup_key: Option<&str>, + ) -> PersistenceResult; async fn get_draft_chunk(&self, id: &str) -> PersistenceResult>; @@ -64,11 +82,16 @@ pub trait PolicyStoreExt { status_filter: Option<&str>, ) -> PersistenceResult>; + /// Update a draft chunk's status, optionally recording a free-form + /// `rejection_reason` for the reviewer's note. Pass `Some` only on the + /// reject path; other status transitions pass `None` to leave any prior + /// reason intact. async fn update_draft_chunk_status( &self, id: &str, status: &str, decided_at_ms: Option, + rejection_reason: Option<&str>, ) -> PersistenceResult; async fn update_draft_chunk_rule( @@ -186,10 +209,14 @@ impl PolicyStoreExt for Store { } } - async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> PersistenceResult<()> { + async fn put_draft_chunk( + &self, + chunk: &DraftChunkRecord, + dedup_key: Option<&str>, + ) -> PersistenceResult { match self { - Self::Postgres(store) => store.put_draft_chunk(chunk).await, - Self::Sqlite(store) => store.put_draft_chunk(chunk).await, + Self::Postgres(store) => store.put_draft_chunk(chunk, dedup_key).await, + Self::Sqlite(store) => store.put_draft_chunk(chunk, dedup_key).await, } } @@ -216,16 +243,17 @@ impl PolicyStoreExt for Store { id: &str, status: &str, decided_at_ms: Option, + rejection_reason: Option<&str>, ) -> PersistenceResult { match self { Self::Postgres(store) => { store - .update_draft_chunk_status(id, status, decided_at_ms) + .update_draft_chunk_status(id, status, decided_at_ms, rejection_reason) .await } Self::Sqlite(store) => { store - .update_draft_chunk_status(id, status, decided_at_ms) + .update_draft_chunk_status(id, status, decided_at_ms, rejection_reason) .await } } @@ -301,6 +329,16 @@ pub fn policy_record_from_parts( }) } +/// Observation-mode dedup key: `host|port|binary`. Used by the mechanistic +/// mapper path where N denials targeting the same endpoint should fold into +/// one chunk instead of N near-identical chunks. Agent-authored proposals +/// pass `None` for the `dedup_key` argument to `put_draft_chunk` so each +/// proposal lands as its own chunk regardless of target — the redraft loop +/// depends on this. +pub fn observation_dedup_key(chunk: &DraftChunkRecord) -> String { + format!("{}|{}|{}", chunk.host, chunk.port, chunk.binary) +} + pub fn draft_chunk_payload_from_record(chunk: &DraftChunkRecord) -> PersistenceResult> { let proposed_rule = if chunk.proposed_rule.is_empty() { None @@ -325,6 +363,8 @@ pub fn draft_chunk_payload_from_record(chunk: &DraftChunkRecord) -> PersistenceR port: chunk.port, binary: chunk.binary.clone(), draft_version: chunk.draft_version, + validation_result: chunk.validation_result.clone(), + rejection_reason: chunk.rejection_reason.clone(), } .encode_to_vec()) } @@ -365,5 +405,7 @@ pub fn draft_chunk_record_from_parts( hit_count: i32::try_from(hit_count).unwrap_or(i32::MAX), first_seen_ms: created_at_ms, last_seen_ms: updated_at_ms, + validation_result: wrapper.validation_result, + rejection_reason: wrapper.rejection_reason, }) } diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs new file mode 100644 index 000000000..161daeb7f --- /dev/null +++ b/crates/openshell-server/src/provider_refresh.rs @@ -0,0 +1,1217 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Provider credential refresh state. + +#![allow(clippy::result_large_err)] + +use crate::persistence::{ObjectType, Store, current_time_ms}; +use openshell_core::proto::{ + Provider, ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, + StoredProviderCredentialRefreshState, +}; +use prost::Message; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; +use tonic::Status; +use tracing::{info, warn}; + +const DEFAULT_REFRESH_BEFORE_SECONDS: i64 = 300; +const DEFAULT_MAX_LIFETIME_SECONDS: i64 = 3600; +const REFRESH_ERROR_RETRY_SECONDS: i64 = 60; +const REFRESH_WORKER_PAGE_SIZE: u32 = 1000; + +impl ObjectType for StoredProviderCredentialRefreshState { + fn object_type() -> &'static str { + "provider_credential_refresh_state" + } +} + +pub fn refresh_state_name(provider_id: &str, credential_key: &str) -> String { + let mut key = String::with_capacity(credential_key.len() * 2); + for byte in credential_key.as_bytes() { + use std::fmt::Write as _; + write!(&mut key, "{byte:02x}").expect("writing to String cannot fail"); + } + format!("provider-refresh-{provider_id}-{key}") +} + +pub async fn put_refresh_state( + store: &Store, + state: &StoredProviderCredentialRefreshState, +) -> Result<(), Status> { + store + .put_scoped_message(state, &state.provider_id) + .await + .map_err(|e| Status::internal(format!("persist provider refresh state failed: {e}"))) +} + +pub async fn list_refresh_states_for_provider( + store: &Store, + provider_id: &str, +) -> Result, Status> { + let records = store + .list_by_scope( + StoredProviderCredentialRefreshState::object_type(), + provider_id, + 1000, + 0, + ) + .await + .map_err(|e| Status::internal(format!("list provider refresh states failed: {e}")))?; + + let mut states = Vec::with_capacity(records.len()); + for record in records { + states.push( + StoredProviderCredentialRefreshState::decode(record.payload.as_slice()).map_err( + |e| Status::internal(format!("decode provider refresh state failed: {e}")), + )?, + ); + } + Ok(states) +} + +pub async fn list_all_refresh_states( + store: &Store, +) -> Result, Status> { + let mut states = Vec::new(); + let mut offset = 0; + loop { + let records = store + .list( + StoredProviderCredentialRefreshState::object_type(), + REFRESH_WORKER_PAGE_SIZE, + offset, + ) + .await + .map_err(|e| Status::internal(format!("list provider refresh states failed: {e}")))?; + if records.is_empty() { + break; + } + offset = offset + .checked_add( + u32::try_from(records.len()) + .map_err(|_| Status::internal("provider refresh page size exceeded u32"))?, + ) + .ok_or_else(|| Status::internal("provider refresh pagination offset overflow"))?; + for record in records { + states.push( + StoredProviderCredentialRefreshState::decode(record.payload.as_slice()).map_err( + |e| Status::internal(format!("decode provider refresh state failed: {e}")), + )?, + ); + } + } + Ok(states) +} + +pub async fn get_refresh_state( + store: &Store, + provider_id: &str, + credential_key: &str, +) -> Result, Status> { + let name = refresh_state_name(provider_id, credential_key); + store + .get_message_by_name::(&name) + .await + .map_err(|e| Status::internal(format!("fetch provider refresh state failed: {e}"))) +} + +pub async fn delete_refresh_state( + store: &Store, + provider_id: &str, + credential_key: &str, +) -> Result { + let name = refresh_state_name(provider_id, credential_key); + store + .delete_by_name(StoredProviderCredentialRefreshState::object_type(), &name) + .await + .map_err(|e| Status::internal(format!("delete provider refresh state failed: {e}"))) +} + +pub async fn delete_refresh_states_for_provider( + store: &Store, + provider_id: &str, +) -> Result { + let states = list_refresh_states_for_provider(store, provider_id).await?; + let mut deleted = 0; + for state in states { + if store + .delete_by_name( + StoredProviderCredentialRefreshState::object_type(), + state.object_name(), + ) + .await + .map_err(|e| Status::internal(format!("delete provider refresh state failed: {e}")))? + { + deleted += 1; + } + } + Ok(deleted) +} + +pub fn refresh_status_from_state( + state: &StoredProviderCredentialRefreshState, +) -> ProviderCredentialRefreshStatus { + ProviderCredentialRefreshStatus { + provider_name: state.provider_name.clone(), + provider_id: state.provider_id.clone(), + credential_key: state.credential_key.clone(), + strategy: state.strategy, + status: state.status.clone(), + expires_at_ms: state.expires_at_ms, + next_refresh_at_ms: state.next_refresh_at_ms, + last_refresh_at_ms: state.last_refresh_at_ms, + last_error: state.last_error.clone(), + } +} + +pub struct NewRefreshStateConfig { + pub strategy: ProviderCredentialRefreshStrategy, + pub material: HashMap, + pub secret_material_keys: Vec, + pub expires_at_ms: i64, + pub token_url: String, + pub scopes: Vec, + pub refresh_before_seconds: i64, + pub max_lifetime_seconds: i64, +} + +#[allow(clippy::unnecessary_wraps)] +pub fn new_refresh_state( + provider: &Provider, + credential_key: &str, + config: NewRefreshStateConfig, +) -> Result { + let provider_id = provider.object_id().to_string(); + let provider_name = provider.object_name().to_string(); + let now_ms = current_time_ms(); + let next_refresh_at_ms = next_refresh_at_ms( + config.expires_at_ms, + config.refresh_before_seconds, + config.max_lifetime_seconds, + now_ms, + ); + Ok(StoredProviderCredentialRefreshState { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: uuid::Uuid::new_v4().to_string(), + name: refresh_state_name(&provider_id, credential_key), + created_at_ms: now_ms, + labels: HashMap::new(), + resource_version: 0, + }), + provider_id, + provider_name, + credential_key: credential_key.to_string(), + strategy: config.strategy as i32, + material: config.material, + secret_material_keys: config.secret_material_keys, + expires_at_ms: config.expires_at_ms, + next_refresh_at_ms, + last_refresh_at_ms: 0, + status: "configured".to_string(), + last_error: String::new(), + token_url: config.token_url, + scopes: config.scopes, + refresh_before_seconds: config.refresh_before_seconds, + max_lifetime_seconds: config.max_lifetime_seconds, + }) +} + +use openshell_core::{ObjectId, ObjectName}; + +#[derive(Debug)] +struct MintedCredential { + access_token: String, + expires_at_ms: i64, + refresh_token: Option, +} + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + expires_in: Option, + refresh_token: Option, +} + +#[derive(Debug, Serialize)] +struct GoogleServiceAccountClaims<'a> { + iss: &'a str, + scope: String, + aud: &'a str, + iat: i64, + exp: i64, + #[serde(skip_serializing_if = "Option::is_none")] + sub: Option<&'a str>, +} + +pub fn next_refresh_at_ms( + expires_at_ms: i64, + refresh_before_seconds: i64, + _max_lifetime_seconds: i64, + _now_ms: i64, +) -> i64 { + let refresh_before_seconds = if refresh_before_seconds > 0 { + refresh_before_seconds + } else { + DEFAULT_REFRESH_BEFORE_SECONDS + }; + if expires_at_ms > 0 { + return expires_at_ms.saturating_sub(refresh_before_seconds.saturating_mul(1000)); + } + 0 +} + +fn seconds_until_ms(now_ms: i64, target_ms: i64) -> i64 { + if target_ms <= 0 { + return 0; + } + target_ms.saturating_sub(now_ms).max(0) / 1000 +} + +pub fn refresh_strategy_name(strategy: i32) -> &'static str { + match ProviderCredentialRefreshStrategy::try_from(strategy) + .unwrap_or(ProviderCredentialRefreshStrategy::Unspecified) + { + ProviderCredentialRefreshStrategy::Static => "static", + ProviderCredentialRefreshStrategy::External => "external", + ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token", + ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials", + ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt", + ProviderCredentialRefreshStrategy::Unspecified => "unspecified", + } +} + +pub fn is_gateway_mintable_strategy(strategy: ProviderCredentialRefreshStrategy) -> bool { + matches!( + strategy, + ProviderCredentialRefreshStrategy::Oauth2RefreshToken + | ProviderCredentialRefreshStrategy::Oauth2ClientCredentials + | ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt + ) +} + +pub async fn refresh_provider_credential( + store: &Store, + provider_name: &str, + credential_key: &str, +) -> Result { + let provider = store + .get_message_by_name::(provider_name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + .ok_or_else(|| Status::not_found("provider not found"))?; + let Some(mut state) = get_refresh_state(store, provider.object_id(), credential_key).await? + else { + return Err(Status::not_found("provider refresh state not found")); + }; + + info!( + provider = %state.provider_name, + credential_key = %state.credential_key, + strategy = %refresh_strategy_name(state.strategy), + status = %state.status, + expires_at_ms = state.expires_at_ms, + next_refresh_at_ms = state.next_refresh_at_ms, + "provider credential refresh started" + ); + + match mint_credential(&state).await { + Ok(minted) => { + let now_ms = current_time_ms(); + if let Err(err) = + apply_minted_credential(store, &provider, credential_key, &minted).await + { + state.status = "error".to_string(); + state.last_error = err.message().to_string(); + state.next_refresh_at_ms = + now_ms.saturating_add(REFRESH_ERROR_RETRY_SECONDS.saturating_mul(1000)); + put_refresh_state(store, &state).await?; + warn!( + provider = %state.provider_name, + credential_key = %state.credential_key, + strategy = %refresh_strategy_name(state.strategy), + status = %state.status, + next_refresh_at_ms = state.next_refresh_at_ms, + seconds_until_refresh = seconds_until_ms(now_ms, state.next_refresh_at_ms), + error = %err, + "provider credential refresh errored" + ); + return Err(err); + } + if let Some(refresh_token) = minted.refresh_token { + state + .material + .insert("refresh_token".to_string(), refresh_token); + if !state + .secret_material_keys + .iter() + .any(|key| key == "refresh_token") + { + state.secret_material_keys.push("refresh_token".to_string()); + } + } + state.expires_at_ms = minted.expires_at_ms; + state.next_refresh_at_ms = next_refresh_at_ms( + minted.expires_at_ms, + state.refresh_before_seconds, + state.max_lifetime_seconds, + now_ms, + ); + state.last_refresh_at_ms = now_ms; + state.status = "refreshed".to_string(); + state.last_error.clear(); + put_refresh_state(store, &state).await?; + info!( + provider = %state.provider_name, + credential_key = %state.credential_key, + strategy = %refresh_strategy_name(state.strategy), + status = %state.status, + expires_at_ms = state.expires_at_ms, + next_refresh_at_ms = state.next_refresh_at_ms, + seconds_until_refresh = seconds_until_ms(now_ms, state.next_refresh_at_ms), + "provider credential refresh completed" + ); + Ok(state) + } + Err(err) => { + let now_ms = current_time_ms(); + state.status = "error".to_string(); + state.last_error = err.message().to_string(); + state.next_refresh_at_ms = + now_ms.saturating_add(REFRESH_ERROR_RETRY_SECONDS.saturating_mul(1000)); + put_refresh_state(store, &state).await?; + warn!( + provider = %state.provider_name, + credential_key = %state.credential_key, + strategy = %refresh_strategy_name(state.strategy), + status = %state.status, + next_refresh_at_ms = state.next_refresh_at_ms, + seconds_until_refresh = seconds_until_ms(now_ms, state.next_refresh_at_ms), + error = %err, + "provider credential refresh errored" + ); + Err(err) + } + } +} + +async fn apply_minted_credential( + store: &Store, + provider: &Provider, + credential_key: &str, + minted: &MintedCredential, +) -> Result<(), Status> { + let mut updated = provider.clone(); + updated + .credentials + .insert(credential_key.to_string(), minted.access_token.clone()); + if minted.expires_at_ms > 0 { + updated + .credential_expires_at_ms + .insert(credential_key.to_string(), minted.expires_at_ms); + } else { + updated.credential_expires_at_ms.remove(credential_key); + } + crate::grpc::provider::validate_provider_update_against_attached_sandboxes(store, &updated) + .await?; + store + .update_message_cas::(provider.object_id(), 0, |current| { + current + .credentials + .insert(credential_key.to_string(), minted.access_token.clone()); + if minted.expires_at_ms > 0 { + current + .credential_expires_at_ms + .insert(credential_key.to_string(), minted.expires_at_ms); + } else { + current.credential_expires_at_ms.remove(credential_key); + } + }) + .await + .map(|_| ()) + .map_err(|e| Status::internal(format!("persist refreshed provider credential failed: {e}"))) +} + +async fn mint_credential( + state: &StoredProviderCredentialRefreshState, +) -> Result { + let strategy = ProviderCredentialRefreshStrategy::try_from(state.strategy) + .unwrap_or(ProviderCredentialRefreshStrategy::Unspecified); + match strategy { + ProviderCredentialRefreshStrategy::Oauth2RefreshToken => { + mint_oauth2_refresh_token(state).await + } + ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => { + mint_oauth2_client_credentials(state).await + } + ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => { + mint_google_service_account_jwt(state).await + } + ProviderCredentialRefreshStrategy::External + | ProviderCredentialRefreshStrategy::Static + | ProviderCredentialRefreshStrategy::Unspecified => Err(Status::failed_precondition( + format!("refresh strategy '{strategy:?}' cannot be minted by the gateway"), + )), + } +} + +async fn mint_oauth2_refresh_token( + state: &StoredProviderCredentialRefreshState, +) -> Result { + let token_url = oauth2_token_url(state)?; + let client_id = required_material(&state.material, "client_id")?; + let refresh_token = required_material(&state.material, "refresh_token")?; + let mut form = vec![ + ("grant_type".to_string(), "refresh_token".to_string()), + ("client_id".to_string(), client_id), + ("refresh_token".to_string(), refresh_token), + ]; + if let Some(client_secret) = material_value(&state.material, &["client_secret"]) { + form.push(("client_secret".to_string(), client_secret)); + } + let scope = refresh_scopes(state).join(" "); + if !scope.is_empty() { + form.push(("scope".to_string(), scope)); + } + + request_token(&token_url, &form, state.max_lifetime_seconds).await +} + +async fn mint_oauth2_client_credentials( + state: &StoredProviderCredentialRefreshState, +) -> Result { + let token_url = oauth2_token_url(state)?; + let client_id = required_material(&state.material, "client_id")?; + let client_secret = required_material(&state.material, "client_secret")?; + let mut form = vec![ + ("grant_type".to_string(), "client_credentials".to_string()), + ("client_id".to_string(), client_id), + ("client_secret".to_string(), client_secret), + ]; + let scope = refresh_scopes(state).join(" "); + if !scope.is_empty() { + form.push(("scope".to_string(), scope)); + } + + request_token(&token_url, &form, state.max_lifetime_seconds).await +} + +async fn mint_google_service_account_jwt( + state: &StoredProviderCredentialRefreshState, +) -> Result { + let token_url = google_token_url(state); + let client_email = required_material(&state.material, "client_email")?; + let private_key = required_material(&state.material, "private_key")?; + let scopes = refresh_scopes(state); + if scopes.is_empty() { + return Err(Status::invalid_argument( + "google_service_account_jwt requires at least one scope", + )); + } + let now_ms = current_time_ms(); + let now_secs = now_ms / 1000; + let lifetime_secs = if state.max_lifetime_seconds > 0 { + state.max_lifetime_seconds.min(DEFAULT_MAX_LIFETIME_SECONDS) + } else { + DEFAULT_MAX_LIFETIME_SECONDS + }; + let subject = material_value(&state.material, &["subject", "sub"]); + let claims = GoogleServiceAccountClaims { + iss: &client_email, + scope: scopes.join(" "), + aud: &token_url, + iat: now_secs, + exp: now_secs.saturating_add(lifetime_secs), + sub: subject.as_deref(), + }; + let assertion = jsonwebtoken::encode( + &jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256), + &claims, + &jsonwebtoken::EncodingKey::from_rsa_pem(private_key.as_bytes()).map_err(|_| { + Status::invalid_argument("google_service_account_jwt private_key must be RSA PEM") + })?, + ) + .map_err(|_| Status::internal("sign google service account jwt failed"))?; + let form = vec![ + ( + "grant_type".to_string(), + "urn:ietf:params:oauth:grant-type:jwt-bearer".to_string(), + ), + ("assertion".to_string(), assertion), + ]; + request_token(&token_url, &form, lifetime_secs).await +} + +async fn request_token( + token_url: &str, + form: &[(String, String)], + max_lifetime_seconds: i64, +) -> Result { + let parsed = reqwest::Url::parse(token_url) + .map_err(|_| Status::invalid_argument("token_url must be an absolute URL"))?; + match parsed.scheme() { + "https" => {} + "http" if parsed.host_str().is_some_and(is_loopback_host) => {} + _ => { + return Err(Status::invalid_argument( + "token_url must use https, except loopback http for local tests", + )); + } + } + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .map_err(|e| Status::internal(format!("build refresh HTTP client failed: {e}")))?; + let response = client + .post(parsed) + .form(form) + .send() + .await + .map_err(|e| Status::unavailable(format!("token endpoint request failed: {e}")))?; + let status = response.status(); + if !status.is_success() { + return Err(Status::failed_precondition(format!( + "token endpoint returned HTTP {status}" + ))); + } + let token = response + .json::() + .await + .map_err(|_| Status::failed_precondition("token endpoint returned invalid JSON"))?; + if token.access_token.trim().is_empty() { + return Err(Status::failed_precondition( + "token endpoint returned empty access_token", + )); + } + let now_ms = current_time_ms(); + let lifetime_cap_seconds = if max_lifetime_seconds > 0 { + max_lifetime_seconds + } else { + DEFAULT_MAX_LIFETIME_SECONDS + }; + let lifetime_seconds = token + .expires_in + .filter(|value| *value > 0) + .unwrap_or(lifetime_cap_seconds); + let lifetime_seconds = lifetime_seconds.min(lifetime_cap_seconds); + Ok(MintedCredential { + access_token: token.access_token, + expires_at_ms: now_ms.saturating_add(lifetime_seconds.saturating_mul(1000)), + refresh_token: token + .refresh_token + .filter(|refresh_token| !refresh_token.trim().is_empty()), + }) +} + +pub fn refresh_scopes(state: &StoredProviderCredentialRefreshState) -> Vec { + if !state.scopes.is_empty() { + return state.scopes.clone(); + } + material_scopes(&state.material) +} + +pub fn material_scopes(material: &HashMap) -> Vec { + material_value(material, &["scope", "scopes"]) + .map(|raw| { + raw.split(|ch: char| ch == ',' || ch.is_ascii_whitespace()) + .map(str::trim) + .filter(|scope| !scope.is_empty()) + .map(ToString::to_string) + .collect() + }) + .unwrap_or_default() +} + +pub fn parse_material_i64( + material: &HashMap, + key: &str, +) -> Result, Status> { + let Some(value) = material_value(material, &[key]) else { + return Ok(None); + }; + value + .parse::() + .map(Some) + .map_err(|_| Status::invalid_argument(format!("{key} material must be a signed integer"))) +} + +fn oauth2_token_url(state: &StoredProviderCredentialRefreshState) -> Result { + if let Some(tenant_id) = material_value(&state.material, &["tenant_id"]) { + return Ok(format!( + "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token" + )); + } + if !state.token_url.trim().is_empty() { + return Ok(state.token_url.clone()); + } + Err(Status::invalid_argument( + "oauth2_client_credentials requires token_url or tenant_id material", + )) +} + +fn google_token_url(state: &StoredProviderCredentialRefreshState) -> String { + if state.token_url.trim().is_empty() { + "https://oauth2.googleapis.com/token".to_string() + } else { + state.token_url.clone() + } +} + +fn required_material(material: &HashMap, key: &str) -> Result { + material_value(material, &[key]) + .ok_or_else(|| Status::invalid_argument(format!("{key} material is required"))) +} + +fn material_value(material: &HashMap, keys: &[&str]) -> Option { + for key in keys { + if let Some(value) = material.get(*key).map(|value| value.trim()) + && !value.is_empty() + { + return Some(value.to_string()); + } + } + None +} + +fn is_loopback_host(host: &str) -> bool { + matches!(host, "localhost" | "127.0.0.1" | "::1") +} + +pub fn spawn_refresh_worker(state: std::sync::Arc, interval: Duration) { + info!( + interval_seconds = interval.as_secs(), + "provider credential refresh worker started" + ); + tokio::spawn(async move { + let mut ticker = tokio::time::interval(interval); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + loop { + ticker.tick().await; + if let Err(err) = run_refresh_worker_tick(state.store.as_ref()).await { + warn!(error = %err, "provider credential refresh worker tick failed"); + } + } + }); +} + +async fn run_refresh_worker_tick(store: &Store) -> Result<(), Status> { + let now_ms = current_time_ms(); + let states = list_all_refresh_states(store).await?; + let watched_count = states.len(); + let due_count = states + .iter() + .filter(|state| state.next_refresh_at_ms <= 0 || state.next_refresh_at_ms <= now_ms) + .count(); + let rotation_requested_count = states + .iter() + .filter(|state| state.status == "rotation_requested") + .count(); + info!( + watched_count, + due_count, rotation_requested_count, "provider credential refresh worker sweep" + ); + for state in states { + let strategy = ProviderCredentialRefreshStrategy::try_from(state.strategy) + .unwrap_or(ProviderCredentialRefreshStrategy::Unspecified); + let due = state.next_refresh_at_ms <= 0 || state.next_refresh_at_ms <= now_ms; + let rotation_requested = state.status == "rotation_requested"; + info!( + provider = %state.provider_name, + credential_key = %state.credential_key, + strategy = %refresh_strategy_name(state.strategy), + status = %state.status, + expires_at_ms = state.expires_at_ms, + seconds_until_expiry = seconds_until_ms(now_ms, state.expires_at_ms), + next_refresh_at_ms = state.next_refresh_at_ms, + last_refresh_at_ms = state.last_refresh_at_ms, + seconds_until_refresh = seconds_until_ms(now_ms, state.next_refresh_at_ms), + due, + rotation_requested, + "provider credential refresh watch" + ); + if !due && !rotation_requested { + continue; + } + if !is_gateway_mintable_strategy(strategy) { + warn!( + provider = %state.provider_name, + credential_key = %state.credential_key, + strategy = %refresh_strategy_name(state.strategy), + status = %state.status, + "skipping non-gateway-mintable provider credential refresh state" + ); + continue; + } + info!( + provider = %state.provider_name, + credential_key = %state.credential_key, + strategy = %refresh_strategy_name(state.strategy), + status = %state.status, + "refreshing provider credential" + ); + if let Err(err) = + refresh_provider_credential(store, &state.provider_name, &state.credential_key).await + { + warn!( + provider = %state.provider_name, + credential_key = %state.credential_key, + strategy = %refresh_strategy_name(state.strategy), + status = %state.status, + next_refresh_at_ms = state.next_refresh_at_ms, + error = %err, + "provider credential refresh failed" + ); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::{ + NewRefreshStateConfig, get_refresh_state, new_refresh_state, put_refresh_state, + refresh_provider_credential, refresh_state_name, refresh_strategy_name, + run_refresh_worker_tick, seconds_until_ms, + }; + use crate::persistence::Store; + use openshell_core::ObjectId; + use openshell_core::proto::datamodel::v1::ObjectMeta; + use openshell_core::proto::{ + Provider, ProviderCredentialRefreshStrategy, Sandbox, SandboxSpec, + }; + use std::collections::HashMap; + use wiremock::matchers::{body_string_contains, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + async fn test_store() -> Store { + Store::connect("sqlite::memory:?cache=shared") + .await + .expect("in-memory SQLite store should connect") + } + + #[test] + fn refresh_state_name_preserves_distinct_credential_keys() { + let provider_id = "provider-id"; + + assert_ne!( + refresh_state_name(provider_id, "API_KEY"), + refresh_state_name(provider_id, "api_key") + ); + assert_ne!( + refresh_state_name(provider_id, " alex-api "), + refresh_state_name(provider_id, " alex_api") + ); + assert_ne!( + refresh_state_name(provider_id, "Alex-API"), + refresh_state_name(provider_id, "alex-api") + ); + } + + #[test] + fn refresh_log_helpers_format_safe_operational_fields() { + assert_eq!(seconds_until_ms(1_000, 61_000), 60); + assert_eq!(seconds_until_ms(61_000, 1_000), 0); + assert_eq!(seconds_until_ms(1_000, 0), 0); + assert_eq!( + refresh_strategy_name(ProviderCredentialRefreshStrategy::Oauth2RefreshToken as i32), + "oauth2_refresh_token" + ); + assert_eq!( + refresh_strategy_name( + ProviderCredentialRefreshStrategy::Oauth2ClientCredentials as i32 + ), + "oauth2_client_credentials" + ); + assert_eq!( + refresh_strategy_name( + ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt as i32 + ), + "google_service_account_jwt" + ); + assert_eq!(refresh_strategy_name(i32::MAX), "unspecified"); + } + + #[tokio::test] + async fn oauth2_client_credentials_refresh_mints_and_persists_access_token() { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/token")) + .and(body_string_contains("grant_type=client_credentials")) + .and(body_string_contains("client_id=client-id")) + .and(body_string_contains( + "scope=https%3A%2F%2Fgraph.microsoft.com%2F.default", + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "access_token": "minted-graph-token", + "expires_in": 3600, + "token_type": "Bearer" + }))) + .mount(&mock_server) + .await; + + let store = test_store().await; + let provider = provider("my-graph", "outlook"); + store.put_message(&provider).await.unwrap(); + let before_refresh_ms = crate::persistence::current_time_ms(); + let state = new_refresh_state( + &provider, + "MS_GRAPH_ACCESS_TOKEN", + NewRefreshStateConfig { + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials, + material: HashMap::from([ + ("client_id".to_string(), "client-id".to_string()), + ("client_secret".to_string(), "client-secret".to_string()), + ]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: 0, + token_url: format!("{}/token", mock_server.uri()), + scopes: vec!["https://graph.microsoft.com/.default".to_string()], + refresh_before_seconds: 30, + max_lifetime_seconds: 60, + }, + ) + .unwrap(); + put_refresh_state(&store, &state).await.unwrap(); + + let refreshed = refresh_provider_credential(&store, "my-graph", "MS_GRAPH_ACCESS_TOKEN") + .await + .unwrap(); + assert_eq!(refreshed.status, "refreshed"); + assert!(refreshed.expires_at_ms > 0); + assert!(refreshed.next_refresh_at_ms > 0); + assert!(refreshed.expires_at_ms <= before_refresh_ms + 120_000); + assert!(refreshed.last_error.is_empty()); + + let stored = store + .get_message_by_name::("my-graph") + .await + .unwrap() + .unwrap(); + assert_eq!( + stored.credentials.get("MS_GRAPH_ACCESS_TOKEN"), + Some(&"minted-graph-token".to_string()) + ); + assert_eq!( + stored.credential_expires_at_ms.get("MS_GRAPH_ACCESS_TOKEN"), + Some(&refreshed.expires_at_ms) + ); + } + + #[tokio::test] + async fn refresh_rejects_minted_credential_key_collision_for_attached_sandbox() { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "access_token": "minted-graph-token", + "expires_in": 3600, + "token_type": "Bearer" + }))) + .mount(&mock_server) + .await; + + let store = test_store().await; + let mut provider_a = provider("existing-graph", "outlook"); + provider_a.credentials.insert( + "MS_GRAPH_ACCESS_TOKEN".to_string(), + "existing-token".to_string(), + ); + store.put_message(&provider_a).await.unwrap(); + let provider_b = provider("refreshing-graph", "outlook"); + store.put_message(&provider_b).await.unwrap(); + store + .put_message(&Sandbox { + metadata: Some(ObjectMeta { + id: "sandbox-collision".to_string(), + name: "collision".to_string(), + created_at_ms: 1, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + providers: vec!["existing-graph".to_string(), "refreshing-graph".to_string()], + ..SandboxSpec::default() + }), + ..Default::default() + }) + .await + .unwrap(); + let state = new_refresh_state( + &provider_b, + "MS_GRAPH_ACCESS_TOKEN", + NewRefreshStateConfig { + strategy: ProviderCredentialRefreshStrategy::Oauth2ClientCredentials, + material: HashMap::from([ + ("client_id".to_string(), "client-id".to_string()), + ("client_secret".to_string(), "client-secret".to_string()), + ]), + secret_material_keys: vec!["client_secret".to_string()], + expires_at_ms: 0, + token_url: format!("{}/token", mock_server.uri()), + scopes: Vec::new(), + refresh_before_seconds: 30, + max_lifetime_seconds: 60, + }, + ) + .unwrap(); + put_refresh_state(&store, &state).await.unwrap(); + + let err = refresh_provider_credential(&store, "refreshing-graph", "MS_GRAPH_ACCESS_TOKEN") + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + assert!(err.message().contains("MS_GRAPH_ACCESS_TOKEN")); + let stored_state = + get_refresh_state(&store, provider_b.object_id(), "MS_GRAPH_ACCESS_TOKEN") + .await + .unwrap() + .unwrap(); + assert_eq!(stored_state.status, "error"); + assert!(stored_state.last_error.contains("MS_GRAPH_ACCESS_TOKEN")); + let stored_provider = store + .get_message_by_name::("refreshing-graph") + .await + .unwrap() + .unwrap(); + assert!( + !stored_provider + .credentials + .contains_key("MS_GRAPH_ACCESS_TOKEN") + ); + } + + #[tokio::test] + async fn oauth2_refresh_token_refresh_mints_access_token_and_persists_rotated_refresh_token() { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/token")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains("client_id=client-id")) + .and(body_string_contains("refresh_token=old-refresh-token")) + .and(body_string_contains( + "scope=https%3A%2F%2Fgraph.microsoft.com%2F.default", + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "access_token": "delegated-graph-token", + "refresh_token": "rotated-refresh-token", + "expires_in": 3600, + "token_type": "Bearer" + }))) + .mount(&mock_server) + .await; + + let store = test_store().await; + let provider = provider("my-delegated-graph", "outlook"); + store.put_message(&provider).await.unwrap(); + let state = new_refresh_state( + &provider, + "MS_GRAPH_ACCESS_TOKEN", + NewRefreshStateConfig { + strategy: ProviderCredentialRefreshStrategy::Oauth2RefreshToken, + material: HashMap::from([ + ("client_id".to_string(), "client-id".to_string()), + ("refresh_token".to_string(), "old-refresh-token".to_string()), + ]), + secret_material_keys: vec!["refresh_token".to_string()], + expires_at_ms: 0, + token_url: format!("{}/token", mock_server.uri()), + scopes: vec!["https://graph.microsoft.com/.default".to_string()], + refresh_before_seconds: 30, + max_lifetime_seconds: 60, + }, + ) + .unwrap(); + put_refresh_state(&store, &state).await.unwrap(); + + let refreshed = + refresh_provider_credential(&store, "my-delegated-graph", "MS_GRAPH_ACCESS_TOKEN") + .await + .unwrap(); + assert_eq!(refreshed.status, "refreshed"); + assert!(refreshed.expires_at_ms > 0); + + let stored_provider = store + .get_message_by_name::("my-delegated-graph") + .await + .unwrap() + .unwrap(); + assert_eq!( + stored_provider.credentials.get("MS_GRAPH_ACCESS_TOKEN"), + Some(&"delegated-graph-token".to_string()) + ); + assert_eq!( + stored_provider + .credential_expires_at_ms + .get("MS_GRAPH_ACCESS_TOKEN"), + Some(&refreshed.expires_at_ms) + ); + + let stored_state = get_refresh_state(&store, provider.object_id(), "MS_GRAPH_ACCESS_TOKEN") + .await + .unwrap() + .unwrap(); + assert_eq!( + stored_state.material.get("refresh_token"), + Some(&"rotated-refresh-token".to_string()) + ); + assert!( + stored_state + .secret_material_keys + .iter() + .any(|key| key == "refresh_token") + ); + } + + #[tokio::test] + async fn google_service_account_refresh_mints_and_persists_access_token() { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/token")) + .and(body_string_contains( + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer", + )) + .and(body_string_contains("assertion=")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "access_token": "minted-drive-token", + "expires_in": 1800, + "token_type": "Bearer" + }))) + .mount(&mock_server) + .await; + + let store = test_store().await; + let provider = provider("my-drive", "google-drive"); + store.put_message(&provider).await.unwrap(); + let state = new_refresh_state( + &provider, + "GOOGLE_DRIVE_ACCESS_TOKEN", + NewRefreshStateConfig { + strategy: ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt, + material: HashMap::from([ + ( + "client_email".to_string(), + "svc@example.iam.gserviceaccount.com".to_string(), + ), + ("private_key".to_string(), TEST_RSA_PRIVATE_KEY.to_string()), + ]), + secret_material_keys: vec!["private_key".to_string()], + expires_at_ms: 0, + token_url: format!("{}/token", mock_server.uri()), + scopes: vec!["https://www.googleapis.com/auth/drive.readonly".to_string()], + refresh_before_seconds: 300, + max_lifetime_seconds: 3600, + }, + ) + .unwrap(); + put_refresh_state(&store, &state).await.unwrap(); + + let refreshed = + refresh_provider_credential(&store, "my-drive", "GOOGLE_DRIVE_ACCESS_TOKEN") + .await + .unwrap(); + assert_eq!(refreshed.status, "refreshed"); + assert!(refreshed.expires_at_ms > 0); + + let stored = store + .get_message_by_name::("my-drive") + .await + .unwrap() + .unwrap(); + assert_eq!( + stored.credentials.get("GOOGLE_DRIVE_ACCESS_TOKEN"), + Some(&"minted-drive-token".to_string()) + ); + } + + #[tokio::test] + async fn refresh_worker_skips_non_gateway_mintable_strategies() { + let store = test_store().await; + let provider = provider("my-external", "outlook"); + store.put_message(&provider).await.unwrap(); + let state = new_refresh_state( + &provider, + "MS_GRAPH_ACCESS_TOKEN", + NewRefreshStateConfig { + strategy: ProviderCredentialRefreshStrategy::External, + material: HashMap::new(), + secret_material_keys: Vec::new(), + expires_at_ms: 0, + token_url: String::new(), + scopes: Vec::new(), + refresh_before_seconds: 0, + max_lifetime_seconds: 0, + }, + ) + .unwrap(); + put_refresh_state(&store, &state).await.unwrap(); + + run_refresh_worker_tick(&store).await.unwrap(); + + let stored_state = get_refresh_state(&store, provider.object_id(), "MS_GRAPH_ACCESS_TOKEN") + .await + .unwrap() + .unwrap(); + assert_ne!(stored_state.status, "error"); + assert!(stored_state.last_error.is_empty()); + + let stored_provider = store + .get_message_by_name::("my-external") + .await + .unwrap() + .unwrap(); + assert!( + !stored_provider + .credentials + .contains_key("MS_GRAPH_ACCESS_TOKEN") + ); + } + + fn provider(name: &str, provider_type: &str) -> Provider { + Provider { + metadata: Some(ObjectMeta { + id: format!("{name}-id"), + name: name.to_string(), + created_at_ms: 1, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: provider_type.to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + } + } + + const TEST_RSA_PRIVATE_KEY: &str = r"-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCvCoZ0mVHpCHsF +zeeqw2caNIe/eb4BQUccFPhZfRnF7sCfyB84zTBmuwG2umRBdjFnVsfIIZRp2HcD +OESrRYYiE1RGfjBXImGVg2Wtza0HYhL1sLyX1eaEefylxoilmApAgWDh9p36h8J2 +s5YHwyXPTttx4DpdWDnxju1iNmwoIB8uVE/5amWgbNvlETMBOcB1RxDHtnVy+xJz +jjjrzK4Qz9WsUTHAvngdi4Yyxvci+yKpjYTg5+UWxmAN6iW522TpLe32MDb5Ug1d +trBvvepWmdQ6CBwPhBHCt/sMoSJAYSO4RKeBnBjeLQBXFTxaOv5iTGIsRTX3K471 +epHp3cT5AgMBAAECggEASQlRv/4nZN5SgsH/K8v7zb3kdHsmUly8AJYpaCGgauvr +uN/mUyueyga2uNl+MqhQBef6VWHZjO6y/gdw86v/Q2GgVQebQQhKAnpAp2w+Ceoc +siKMFqi8VkOWLU+xPbM6d97kH3TpRxt1g1T8wYFmWeF0BEiE4eUJzGaQW14M9BJ+ +G0QxmP/zjX9cNpVeApKTjBWKiH4CXG3DuI3pJ93VOMpUlOsrdLXvKGTze0e01itr +MX/MHHTE+VXB4FB+/zKSA4c36egi676OSXrGC/GDmM8ntJ4CUGeD5uZsMSADiAUn +iccv5iGRWVMIKxUS5Q4k0jy8uWuK+QVP4Y6cQWYArwKBgQDhuSNORBNpIGRfsKGN +iJo/h+qinz6pEIpa3D3oVl7rpkyvgIyaTwfXvC1vfdS9V5VIel2gV2Cx0OrI8yrr +nQu1JuNV/rLmtvqX321fgBLRdoiqF3pAy1gbmdUz1elerAIYL578gXQ6jg1bbdic +kJpn0MsoDUJGwvJnXcgLqG7q3wKBgQDGhRIa4oJsj1vqICc8zt8YsCAcot3vjWLH +588X7JdBGOWJdWxfdmGXQRn5Zw9UhMQnYa3uyTBPeVcXopThlPotYeuFhLSU856T +IJzfpzCJzC4zIQayoyvJFrKe7N70iUQ986dewYy9oxQhHvFKd/qe4ylbzZJXpthX +eWEuuBSjJwKBgGkqXt6qLPj/1IQYwUw15tfOtW0LEKCoSi3HCzjidNsJ4hSqqdeD +Fr5WuDyHvcRxt+XKzTBVRYHTOnBhiw+3XasK8UQxpJyFh/+WY1jpTNs2hLnqslTZ +6LUDWSgLc+1d6qPmHAa9Ma/OWz7L0O4xGR9hUiXY95YMYe/y668yzGq1AoGBAJyU +Gsqfu7U6gYmxoKEine6QBFPx1dD7GF2KJdq93jMXGvyHZFoLOkAdtgnz0rCcI0bY +kWKUxwj4MMxQjNM8OPMQl75xBCmz2XA8Od9htDQLmqjzNKAzePabc3lMZTJFDlE6 +29kuGf79IIRbLn/JECDAFT/2baW60Ep2T0OVJ5njAoGAfaCaQ4aVgjI027q7Y5qP +KfNSI8uuA8PLqmUY30I9KFWzN6VDLu00eKa90F4w3CeWRRQWXW1+007tTz3V1mNw +20A24Fi3HGQmXc7NyuLDODTJsWBICuOemCnRkvcxIlxb+ec7jp+XRmzDwKkzSnVN +pM2zFU8SeVkvHKlEuoHaP0s= +-----END PRIVATE KEY-----"; +} diff --git a/crates/openshell-server/src/service_routing.rs b/crates/openshell-server/src/service_routing.rs new file mode 100644 index 000000000..4b99a8ef7 --- /dev/null +++ b/crates/openshell-server/src/service_routing.rs @@ -0,0 +1,1103 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Browser-facing HTTP routing for sandbox service endpoints. + +use axum::{ + body::Body, + response::{IntoResponse, Response as AxumResponse}, +}; +use http::{HeaderMap, HeaderValue, Method, Request, Response, StatusCode, header}; +use hyper_util::rt::TokioIo; +use openshell_core::config::ServiceRoutingConfig; +use openshell_core::proto::{Sandbox, SandboxPhase, ServiceEndpoint, TcpRelayTarget, relay_open}; +use openshell_core::{ObjectId, VERSION}; +use openshell_ocsf::{ + ActionId, ActivityId, ConfigStateChangeBuilder, DispositionId, Endpoint, HttpActivityBuilder, + HttpRequest, HttpResponse as OcsfHttpResponse, NetworkActivityBuilder, OCSF_TARGET, OcsfEvent, + SandboxContext, SeverityId, StateId, StatusId, Url as OcsfUrl, +}; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::AsyncWriteExt; +use tracing::{info, warn}; + +use crate::ServerState; +use crate::persistence::{ObjectType, Store}; + +const ENDPOINT_OBJECT_TYPE: &str = "service_endpoint"; +const ROUTING_RULE_NAME: &str = "sandbox_service_routing"; +const ROUTING_RULE_TYPE: &str = "gateway"; +const RELAY_RULE_NAME: &str = "sandbox_service_relay"; +const RELAY_TARGET_HOST: &str = "127.0.0.1"; + +impl ObjectType for ServiceEndpoint { + fn object_type() -> &'static str { + ENDPOINT_OBJECT_TYPE + } +} + +pub fn endpoint_key(sandbox: &str, service: &str) -> String { + if service.is_empty() { + sandbox.to_string() + } else { + format!("{sandbox}--{service}") + } +} + +pub fn endpoint_url( + config: &openshell_core::Config, + sandbox: &str, + service: &str, +) -> Option { + let host = endpoint_host(&config.service_routing, sandbox, service)?; + let scheme = endpoint_scheme(config); + let port = config.bind_address.port(); + let include_port = !matches!((scheme, port), ("https", 443) | ("http", 80)); + Some(if include_port { + format!("{scheme}://{host}:{port}/") + } else { + format!("{scheme}://{host}/") + }) +} + +fn endpoint_scheme(config: &openshell_core::Config) -> &'static str { + if config.tls.is_none() + || (config.bind_address.ip().is_loopback() + && config.service_routing.enable_loopback_service_http) + { + "http" + } else { + "https" + } +} + +fn endpoint_host(config: &ServiceRoutingConfig, sandbox: &str, service: &str) -> Option { + let base_domain = config.base_domains.first()?; + Some(if service.is_empty() { + format!("{sandbox}.{base_domain}") + } else { + format!("{sandbox}--{service}.{base_domain}") + }) +} + +pub fn parse_host(host: &str, config: &ServiceRoutingConfig) -> Option<(String, String)> { + let host = host.split_once(':').map_or(host, |(name, _)| name); + for base_domain in &config.base_domains { + let expected_suffix = format!(".{base_domain}"); + let Some(encoded) = host.strip_suffix(&expected_suffix) else { + continue; + }; + let (sandbox, service) = if let Some((sandbox, service)) = encoded.split_once("--") { + if service.is_empty() || service.contains("--") { + return None; + } + (sandbox, service) + } else { + (encoded, "") + }; + if sandbox.is_empty() || sandbox.contains("--") { + return None; + } + return Some((sandbox.to_string(), service.to_string())); + } + None +} + +pub fn is_sandbox_service_request(req: &Request, config: &ServiceRoutingConfig) -> bool { + request_host(req).is_some_and(|host| parse_host(host, config).is_some()) +} + +pub async fn proxy_sandbox_service_request( + state: Arc, + req: Request, +) -> impl IntoResponse { + let Some(host) = request_host(&req) else { + return StatusCode::NOT_FOUND.into_response(); + }; + let Some((sandbox_name, service_name)) = parse_host(host, &state.config.service_routing) else { + return StatusCode::NOT_FOUND.into_response(); + }; + + match proxy_to_endpoint(state, req, sandbox_name, service_name).await { + Ok(response) => response.into_response(), + Err(err) => err.into_response(), + } +} + +#[derive(Debug, Clone)] +struct ServiceRouteError { + status: StatusCode, + message: &'static str, + reason: &'static str, +} + +impl ServiceRouteError { + const fn new(status: StatusCode, message: &'static str, reason: &'static str) -> Self { + Self { + status, + message, + reason, + } + } + + const fn endpoint_not_found() -> Self { + Self::new( + StatusCode::NOT_FOUND, + "Service endpoint not found", + "service endpoint not found", + ) + } + + const fn endpoint_unavailable() -> Self { + Self::new( + StatusCode::NOT_FOUND, + "Service endpoint is not available", + "service endpoint unavailable", + ) + } + + const fn sandbox_not_ready() -> Self { + Self::new( + StatusCode::PRECONDITION_FAILED, + "Sandbox is not ready", + "sandbox not ready", + ) + } + + const fn service_unreachable() -> Self { + Self::new( + StatusCode::BAD_GATEWAY, + "Service endpoint is not reachable", + "service endpoint unreachable", + ) + } + + const fn invalid_request() -> Self { + Self::new( + StatusCode::BAD_REQUEST, + "Invalid service request", + "invalid service request", + ) + } + + const fn internal_error() -> Self { + Self::new( + StatusCode::INTERNAL_SERVER_ERROR, + "Service endpoint is not available", + "service endpoint internal error", + ) + } +} + +impl IntoResponse for ServiceRouteError { + fn into_response(self) -> AxumResponse { + service_error_response(self.status, self.message) + } +} + +pub fn service_error_response(status: StatusCode, message: &'static str) -> AxumResponse { + ( + status, + [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], + message, + ) + .into_response() +} + +async fn proxy_to_endpoint( + state: Arc, + mut req: Request, + sandbox_name: String, + service_name: String, +) -> Result, ServiceRouteError> { + let endpoint = match load_endpoint(&state.store, &sandbox_name, &service_name).await { + Ok(endpoint) => endpoint, + Err(err) => { + emit_service_http_failure(&state, &req, &sandbox_name, &service_name, None, &err); + return Err(err); + } + }; + if !endpoint.domain || endpoint.target_port == 0 || endpoint.target_port > u32::from(u16::MAX) { + let err = ServiceRouteError::endpoint_unavailable(); + emit_service_http_failure( + &state, + &req, + &sandbox_name, + &service_name, + Some(&endpoint), + &err, + ); + return Err(err); + } + + let sandbox = match state + .store + .get_message::(&endpoint.sandbox_id) + .await + { + Ok(Some(sandbox)) => sandbox, + Ok(None) => { + let err = ServiceRouteError::endpoint_unavailable(); + emit_service_http_failure( + &state, + &req, + &sandbox_name, + &service_name, + Some(&endpoint), + &err, + ); + return Err(err); + } + Err(err) => { + warn!(error = %err, sandbox_id = %endpoint.sandbox_id, "sandbox service routing: failed to load sandbox"); + let route_err = ServiceRouteError::internal_error(); + emit_service_http_failure( + &state, + &req, + &sandbox_name, + &service_name, + Some(&endpoint), + &route_err, + ); + return Err(route_err); + } + }; + if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { + let err = ServiceRouteError::sandbox_not_ready(); + emit_service_http_failure( + &state, + &req, + &sandbox_name, + &service_name, + Some(&endpoint), + &err, + ); + return Err(err); + } + let Ok(target_port) = u16::try_from(endpoint.target_port) else { + let err = ServiceRouteError::endpoint_unavailable(); + emit_service_http_failure( + &state, + &req, + &sandbox_name, + &service_name, + Some(&endpoint), + &err, + ); + return Err(err); + }; + if upstream_uri_path(&req).is_err() { + let err = ServiceRouteError::invalid_request(); + emit_service_http_failure( + &state, + &req, + &sandbox_name, + &service_name, + Some(&endpoint), + &err, + ); + return Err(err); + } + + let websocket_upgrade = is_websocket_upgrade(&req); + let downstream_upgrade = websocket_upgrade.then(|| hyper::upgrade::on(&mut req)); + + let (_channel_id, relay_rx) = state + .supervisor_sessions + .open_relay_with_target( + sandbox.object_id(), + relay_open::Target::Tcp(TcpRelayTarget { + host: RELAY_TARGET_HOST.to_string(), + port: u32::from(target_port), + }), + endpoint.object_id().to_string(), + Duration::from_secs(15), + ) + .await + .map_err(|err| { + warn!(error = %err, sandbox_id = %endpoint.sandbox_id, "sandbox service routing: supervisor relay unavailable"); + let route_err = ServiceRouteError::service_unreachable(); + emit_service_relay_failure(&endpoint, target_port, route_err.reason); + route_err + })?; + + let relay = tokio::time::timeout(Duration::from_secs(10), relay_rx) + .await + .map_err(|_| { + let err = ServiceRouteError::service_unreachable(); + emit_service_relay_failure(&endpoint, target_port, "relay claim timed out"); + err + })? + .map_err(|_| { + let err = ServiceRouteError::service_unreachable(); + emit_service_relay_failure(&endpoint, target_port, "relay claim canceled"); + err + })? + .map_err(|err| { + warn!(error = %err, "sandbox service routing: relay target open failed"); + let route_err = ServiceRouteError::service_unreachable(); + emit_service_relay_failure(&endpoint, target_port, route_err.reason); + route_err + })?; + + let (mut sender, conn) = hyper::client::conn::http1::Builder::new() + .handshake(TokioIo::new(relay)) + .await + .map_err(|err| { + warn!(error = %err, "sandbox service routing: failed to start upstream HTTP client"); + let route_err = ServiceRouteError::service_unreachable(); + emit_service_relay_failure(&endpoint, target_port, route_err.reason); + route_err + })?; + + if websocket_upgrade { + tokio::spawn(async move { + if let Err(err) = conn.with_upgrades().await { + warn!(error = %err, "sandbox service routing: upstream WebSocket connection failed"); + } + }); + } else { + tokio::spawn(async move { + if let Err(err) = conn.await { + warn!(error = %err, "sandbox service routing: upstream HTTP connection failed"); + } + }); + } + + let upstream = build_upstream_request(req, target_port, websocket_upgrade)?; + let mut response = sender.send_request(upstream).await.map_err(|err| { + warn!(error = %err, "sandbox service routing: upstream HTTP request failed"); + let route_err = ServiceRouteError::service_unreachable(); + emit_service_relay_failure(&endpoint, target_port, route_err.reason); + route_err + })?; + + if websocket_upgrade && response.status() == StatusCode::SWITCHING_PROTOCOLS { + let upstream_upgrade = hyper::upgrade::on(&mut response); + let downstream_upgrade = downstream_upgrade.ok_or_else(|| { + let err = ServiceRouteError::service_unreachable(); + emit_service_relay_failure(&endpoint, target_port, "websocket upgrade unavailable"); + err + })?; + tokio::spawn(async move { + match (downstream_upgrade.await, upstream_upgrade.await) { + (Ok(downstream), Ok(upstream)) => { + let mut downstream = TokioIo::new(downstream); + let mut upstream = TokioIo::new(upstream); + let _ = tokio::io::copy_bidirectional(&mut downstream, &mut upstream).await; + let _ = downstream.shutdown().await; + let _ = upstream.shutdown().await; + } + (Err(err), _) => { + warn!(error = %err, "sandbox service routing: downstream WebSocket upgrade failed"); + } + (_, Err(err)) => { + warn!(error = %err, "sandbox service routing: upstream WebSocket upgrade failed"); + } + } + }); + + let (parts, _) = response.into_parts(); + return Ok(Response::from_parts(parts, Body::empty())); + } + + let (parts, body) = response.into_parts(); + Ok(Response::from_parts(parts, Body::new(body))) +} + +async fn load_endpoint( + store: &Store, + sandbox_name: &str, + service_name: &str, +) -> Result { + let key = endpoint_key(sandbox_name, service_name); + store + .get_message_by_name::(&key) + .await + .map_err(|err| { + warn!(error = %err, endpoint = %key, "sandbox service routing: failed to load service endpoint"); + ServiceRouteError::internal_error() + })? + .ok_or_else(ServiceRouteError::endpoint_not_found) +} + +fn build_upstream_request( + req: Request, + target_port: u16, + preserve_upgrade_headers: bool, +) -> Result, ServiceRouteError> { + let (parts, body) = req.into_parts(); + let path = parts.uri.path_and_query().map_or("/", |path| path.as_str()); + let uri = path + .parse::() + .map_err(|_| ServiceRouteError::invalid_request())?; + + let mut builder = Request::builder() + .method(parts.method) + .uri(uri) + .version(http::Version::HTTP_11); + + let headers = builder + .headers_mut() + .ok_or_else(ServiceRouteError::internal_error)?; + for (name, value) in &parts.headers { + if (is_hop_by_hop_header(name) + && !(preserve_upgrade_headers && is_websocket_hop_by_hop_header(name))) + || is_gateway_auth_header(name) + { + continue; + } + if name == header::COOKIE { + if let Some(cookie) = sanitize_cookie_header(value) { + headers.append(name, cookie); + } + continue; + } + headers.append(name, value.clone()); + } + headers.insert( + header::HOST, + format!("127.0.0.1:{target_port}").parse().unwrap(), + ); + + builder + .body(body) + .map_err(|_| ServiceRouteError::internal_error()) +} + +fn upstream_uri_path(req: &Request) -> Result<&str, ServiceRouteError> { + let path = req.uri().path_and_query().map_or("/", |path| path.as_str()); + path.parse::() + .map(|_| path) + .map_err(|_| ServiceRouteError::invalid_request()) +} + +fn host_header(headers: &HeaderMap) -> Option<&str> { + headers.get(header::HOST)?.to_str().ok() +} + +pub fn request_host(req: &Request) -> Option<&str> { + host_header(req.headers()).or_else(|| req.uri().authority().map(http::uri::Authority::as_str)) +} + +fn is_websocket_upgrade(req: &Request) -> bool { + req.method() == Method::GET + && header_value_is(req.headers(), header::UPGRADE, "websocket") + && header_contains_token(req.headers(), header::CONNECTION, "upgrade") +} + +fn header_value_is(headers: &HeaderMap, name: header::HeaderName, expected: &str) -> bool { + headers + .get(name) + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.eq_ignore_ascii_case(expected)) +} + +fn header_contains_token(headers: &HeaderMap, name: header::HeaderName, token: &str) -> bool { + headers + .get(name) + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| { + value + .split(',') + .any(|part| part.trim().eq_ignore_ascii_case(token)) + }) +} + +fn is_hop_by_hop_header(name: &header::HeaderName) -> bool { + matches!( + name.as_str(), + "connection" + | "host" + | "keep-alive" + | "proxy-authenticate" + | "proxy-authorization" + | "te" + | "trailer" + | "transfer-encoding" + | "upgrade" + ) +} + +fn is_websocket_hop_by_hop_header(name: &header::HeaderName) -> bool { + matches!(name.as_str(), "connection" | "upgrade") +} + +fn is_gateway_auth_header(name: &header::HeaderName) -> bool { + matches!( + name.as_str(), + "authorization" + | "cf-access-jwt-assertion" + | "x-forwarded-client-cert" + | "x-ssl-client-cert" + | "x-client-cert" + ) +} + +fn sanitize_cookie_header(value: &HeaderValue) -> Option { + let value = value.to_str().ok()?; + let cookies = value + .split(';') + .filter_map(|cookie| { + let cookie = cookie.trim(); + let (name, _) = cookie.split_once('=')?; + (!is_gateway_auth_cookie(name.trim())).then_some(cookie) + }) + .collect::>(); + + if cookies.is_empty() { + return None; + } + + HeaderValue::from_str(&cookies.join("; ")).ok() +} + +fn is_gateway_auth_cookie(name: &str) -> bool { + name.eq_ignore_ascii_case("CF_Authorization") || name.eq_ignore_ascii_case("cf-authorization") +} + +pub fn emit_service_endpoint_config_event(endpoint: &ServiceEndpoint, url: &str, created: bool) { + let event = build_service_endpoint_config_event(endpoint, url, created); + emit_gateway_ocsf_event(&endpoint.sandbox_id, event); +} + +pub fn emit_service_endpoint_delete_event(endpoint: &ServiceEndpoint) { + let event = build_service_endpoint_delete_event(endpoint); + emit_gateway_ocsf_event(&endpoint.sandbox_id, event); +} + +pub fn emit_cross_origin_service_http_rejection(state: &ServerState, req: &Request) { + let Some(host) = request_host(req) else { + return; + }; + let Some((sandbox_name, service_name)) = parse_host(host, &state.config.service_routing) else { + return; + }; + let err = ServiceRouteError::new( + StatusCode::FORBIDDEN, + "Cross-origin service request rejected", + "cross-origin service request rejected", + ); + emit_service_http_failure(state, req, &sandbox_name, &service_name, None, &err); +} + +fn emit_service_http_failure( + state: &ServerState, + req: &Request, + sandbox_name: &str, + service_name: &str, + endpoint: Option<&ServiceEndpoint>, + err: &ServiceRouteError, +) { + let event = build_service_http_failure_event( + state.config.bind_address.port(), + req, + sandbox_name, + service_name, + endpoint, + err, + ); + let sandbox_id = endpoint.map_or("", |endpoint| endpoint.sandbox_id.as_str()); + emit_gateway_ocsf_event(sandbox_id, event); +} + +fn emit_service_relay_failure(endpoint: &ServiceEndpoint, target_port: u16, reason: &str) { + let event = build_service_relay_failure_event(endpoint, target_port, reason); + emit_gateway_ocsf_event(&endpoint.sandbox_id, event); +} + +fn build_service_endpoint_config_event( + endpoint: &ServiceEndpoint, + url: &str, + created: bool, +) -> OcsfEvent { + let service_label = service_display_name(&endpoint.sandbox_name, &endpoint.service_name); + let state_label = if created { + "service_endpoint_created" + } else { + "service_endpoint_updated" + }; + let ctx = gateway_ocsf_ctx(&endpoint.sandbox_id, &endpoint.sandbox_name); + let mut builder = ConfigStateChangeBuilder::new(&ctx) + .state(StateId::Enabled, state_label) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .message(format!( + "Service endpoint exposed {service_label} -> {RELAY_TARGET_HOST}:{}", + endpoint.target_port + )) + .unmapped("endpoint_name", endpoint_name(endpoint)) + .unmapped("service_name", endpoint.service_name.clone()) + .unmapped("target_port", u64::from(endpoint.target_port)); + + if !url.is_empty() { + builder = builder.unmapped("url", url.to_string()); + } + + builder.build() +} + +fn build_service_endpoint_delete_event(endpoint: &ServiceEndpoint) -> OcsfEvent { + let service_label = service_display_name(&endpoint.sandbox_name, &endpoint.service_name); + ConfigStateChangeBuilder::new(&gateway_ocsf_ctx( + &endpoint.sandbox_id, + &endpoint.sandbox_name, + )) + .state(StateId::Disabled, "service_endpoint_deleted") + .severity(SeverityId::Informational) + .status(StatusId::Success) + .message(format!("Service endpoint deleted {service_label}")) + .unmapped("endpoint_name", endpoint_name(endpoint)) + .unmapped("service_name", endpoint.service_name.clone()) + .unmapped("target_port", u64::from(endpoint.target_port)) + .build() +} + +fn build_service_http_failure_event( + bind_port: u16, + req: &Request, + sandbox_name: &str, + service_name: &str, + endpoint: Option<&ServiceEndpoint>, + err: &ServiceRouteError, +) -> OcsfEvent { + let host = request_host(req).unwrap_or("unknown"); + let (hostname, port) = split_authority_for_event(host, bind_port); + let ctx = gateway_ocsf_ctx( + endpoint.map_or("", |endpoint| endpoint.sandbox_id.as_str()), + sandbox_name, + ); + HttpActivityBuilder::new(&ctx) + .activity(http_activity_for_method(req.method())) + .action(ActionId::Denied) + .disposition(if err.status.is_server_error() { + DispositionId::Error + } else { + DispositionId::Blocked + }) + .severity(if err.status.is_server_error() { + SeverityId::Low + } else { + SeverityId::Medium + }) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + req.method().as_str(), + OcsfUrl::new("http", &hostname, req.uri().path(), port), + )) + .http_response(OcsfHttpResponse { + code: err.status.as_u16(), + }) + .dst_endpoint(Endpoint::from_domain(&hostname, port)) + .firewall_rule(ROUTING_RULE_NAME, ROUTING_RULE_TYPE) + .status_detail(err.reason) + .message(format!( + "{}: {}", + err.message, + service_display_name(sandbox_name, service_name) + )) + .build() +} + +fn build_service_relay_failure_event( + endpoint: &ServiceEndpoint, + target_port: u16, + reason: &str, +) -> OcsfEvent { + NetworkActivityBuilder::new(&gateway_ocsf_ctx( + &endpoint.sandbox_id, + &endpoint.sandbox_name, + )) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Error) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_ip_str(RELAY_TARGET_HOST, target_port)) + .firewall_rule(RELAY_RULE_NAME, ROUTING_RULE_TYPE) + .status_detail(reason) + .message(format!( + "Service endpoint is not reachable: {}", + service_display_name(&endpoint.sandbox_name, &endpoint.service_name) + )) + .unmapped("endpoint_name", endpoint_name(endpoint)) + .unmapped("service_name", endpoint.service_name.clone()) + .build() +} + +fn emit_gateway_ocsf_event(sandbox_id: &str, event: OcsfEvent) { + let message = event.format_shorthand(); + info!( + target: OCSF_TARGET, + sandbox_id = %sandbox_id, + message = %message + ); +} + +fn gateway_ocsf_ctx(sandbox_id: &str, sandbox_name: &str) -> SandboxContext { + SandboxContext { + sandbox_id: sandbox_id.to_string(), + sandbox_name: sandbox_name.to_string(), + container_image: "openshell/gateway".to_string(), + hostname: "openshell-gateway".to_string(), + product_version: VERSION.to_string(), + proxy_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + proxy_port: 0, + } +} + +fn endpoint_name(endpoint: &ServiceEndpoint) -> String { + endpoint.metadata.as_ref().map_or_else( + || endpoint_key(&endpoint.sandbox_name, &endpoint.service_name), + |metadata| metadata.name.clone(), + ) +} + +fn service_display_name(sandbox_name: &str, service_name: &str) -> String { + if service_name.is_empty() { + sandbox_name.to_string() + } else { + format!("{sandbox_name}/{service_name}") + } +} + +fn split_authority_for_event(authority: &str, default_port: u16) -> (String, u16) { + let authority = authority.trim(); + match authority.rsplit_once(':') { + Some((host, port)) if !host.is_empty() && port.chars().all(|ch| ch.is_ascii_digit()) => ( + host.trim_end_matches('.').to_ascii_lowercase(), + port.parse().unwrap_or(default_port), + ), + _ => ( + authority.trim_end_matches('.').to_ascii_lowercase(), + default_port, + ), + } +} + +fn http_activity_for_method(method: &Method) -> ActivityId { + match method.as_str() { + "CONNECT" => ActivityId::Open, + "DELETE" => ActivityId::Close, + "GET" => ActivityId::Reset, + "HEAD" => ActivityId::Fail, + "OPTIONS" => ActivityId::Refuse, + "POST" => ActivityId::Traffic, + "PUT" => ActivityId::Listen, + "TRACE" => ActivityId::Trace, + "PATCH" => ActivityId::Patch, + _ => ActivityId::Other, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn endpoint() -> ServiceEndpoint { + ServiceEndpoint { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "endpoint-id".to_string(), + name: "my-sandbox--web".to_string(), + created_at_ms: 1_700_000_000_000, + labels: std::collections::HashMap::default(), + resource_version: 0, + }), + sandbox_id: "sandbox-id".to_string(), + sandbox_name: "my-sandbox".to_string(), + service_name: "web".to_string(), + target_port: 8080, + domain: true, + } + } + + fn config() -> ServiceRoutingConfig { + ServiceRoutingConfig { + base_domains: vec![ + "dev.openshell.localhost".to_string(), + "svc.gateway.localhost".to_string(), + ], + ..ServiceRoutingConfig::default() + } + } + + fn tls_config() -> openshell_core::TlsConfig { + openshell_core::TlsConfig { + cert_path: "server.crt".into(), + key_path: "server.key".into(), + client_ca_path: Some("ca.crt".into()), + require_client_auth: false, + } + } + + #[test] + fn endpoint_url_uses_plain_http_for_loopback_tls_gateway() { + let cfg = openshell_core::Config::new(Some(tls_config())) + .with_bind_address("127.0.0.1:8080".parse().unwrap()) + .with_server_sans(["*.dev.openshell.localhost"]); + + assert_eq!( + endpoint_url(&cfg, "my-sandbox", "web").as_deref(), + Some("http://my-sandbox--web.dev.openshell.localhost:8080/") + ); + } + + #[test] + fn endpoint_url_omits_service_label_for_empty_service_name() { + let cfg = openshell_core::Config::new(Some(tls_config())) + .with_bind_address("127.0.0.1:8080".parse().unwrap()) + .with_server_sans(["*.dev.openshell.localhost"]); + + assert_eq!( + endpoint_url(&cfg, "my-sandbox", "").as_deref(), + Some("http://my-sandbox.dev.openshell.localhost:8080/") + ); + } + + #[test] + fn endpoint_url_keeps_https_for_non_loopback_tls_gateway() { + let cfg = openshell_core::Config::new(Some(tls_config())) + .with_bind_address("0.0.0.0:8080".parse().unwrap()) + .with_server_sans(["*.dev.openshell.localhost"]); + + assert_eq!( + endpoint_url(&cfg, "my-sandbox", "web").as_deref(), + Some("https://my-sandbox--web.dev.openshell.localhost:8080/") + ); + } + + #[test] + fn endpoint_url_keeps_https_when_loopback_plaintext_http_is_disabled() { + let cfg = openshell_core::Config::new(Some(tls_config())) + .with_bind_address("127.0.0.1:8080".parse().unwrap()) + .with_server_sans(["*.dev.openshell.localhost"]) + .with_loopback_service_http(false); + + assert_eq!( + endpoint_url(&cfg, "my-sandbox", "web").as_deref(), + Some("https://my-sandbox--web.dev.openshell.localhost:8080/") + ); + } + + #[test] + fn parses_sandbox_service_host() { + assert_eq!( + parse_host("my-sandbox--web.dev.openshell.localhost", &config()), + Some(("my-sandbox".to_string(), "web".to_string())) + ); + } + + #[test] + fn parses_sandbox_host_without_service_label() { + assert_eq!( + parse_host("my-sandbox.dev.openshell.localhost", &config()), + Some(("my-sandbox".to_string(), String::new())) + ); + } + + #[test] + fn rejects_empty_service_label_separator() { + assert_eq!( + parse_host("my-sandbox--.dev.openshell.localhost", &config()), + None + ); + } + + #[test] + fn parses_sandbox_service_host_with_port() { + assert_eq!( + parse_host("my-sandbox--web.dev.openshell.localhost:8080", &config()), + Some(("my-sandbox".to_string(), "web".to_string())) + ); + } + + #[test] + fn parses_alternate_service_routing_domain() { + assert_eq!( + parse_host("my-sandbox--web.svc.gateway.localhost", &config()), + Some(("my-sandbox".to_string(), "web".to_string())) + ); + } + + #[test] + fn rejects_unknown_base_domain() { + assert_eq!( + parse_host("my-sandbox--web.prod.openshell.localhost", &config()), + None + ); + } + + #[test] + fn identifies_sandbox_service_request_from_host_header() { + let request = Request::builder() + .uri("/") + .header(header::HOST, "my-sandbox--web.dev.openshell.localhost") + .body(Body::empty()) + .unwrap(); + assert!(is_sandbox_service_request(&request, &config())); + } + + #[test] + fn identifies_sandbox_service_request_from_http2_authority() { + let request = Request::builder() + .uri("https://my-sandbox--web.dev.openshell.localhost/") + .body(Body::empty()) + .unwrap(); + assert!(is_sandbox_service_request(&request, &config())); + } + + #[test] + fn ignores_non_sandbox_service_request() { + let request = Request::builder() + .uri("/") + .header(header::HOST, "127.0.0.1:8080") + .body(Body::empty()) + .unwrap(); + assert!(!is_sandbox_service_request(&request, &config())); + } + + #[test] + fn service_route_errors_return_plain_text() { + let response = ServiceRouteError::sandbox_not_ready().into_response(); + + assert_eq!(response.status(), StatusCode::PRECONDITION_FAILED); + assert_eq!( + response.headers()[header::CONTENT_TYPE], + "text/plain; charset=utf-8" + ); + } + + #[test] + fn service_endpoint_config_event_includes_endpoint_metadata() { + let event = + build_service_endpoint_config_event(&endpoint(), "http://my-sandbox--web.local/", true); + let json = event.to_json().unwrap(); + + assert_eq!(json["class_uid"], 5019); + assert_eq!(json["unmapped"]["endpoint_name"], "my-sandbox--web"); + assert_eq!(json["unmapped"]["service_name"], "web"); + assert_eq!(json["unmapped"]["target_port"], 8080); + assert!( + event + .format_shorthand() + .contains("Service endpoint exposed my-sandbox/web") + ); + } + + #[test] + fn service_endpoint_delete_event_includes_endpoint_metadata() { + let event = build_service_endpoint_delete_event(&endpoint()); + let json = event.to_json().unwrap(); + + assert_eq!(json["class_uid"], 5019); + assert_eq!(json["unmapped"]["endpoint_name"], "my-sandbox--web"); + assert_eq!(json["unmapped"]["service_name"], "web"); + assert_eq!(json["unmapped"]["target_port"], 8080); + assert!( + event + .format_shorthand() + .contains("Service endpoint deleted my-sandbox/web") + ); + } + + #[test] + fn service_http_failure_event_omits_query_strings() { + let request = Request::builder() + .method(Method::GET) + .uri("/secret?token=should-not-log") + .header( + header::HOST, + "my-sandbox--web.dev.openshell.localhost:18080", + ) + .body(Body::empty()) + .unwrap(); + + let err = ServiceRouteError::new( + StatusCode::FORBIDDEN, + "Cross-origin service request rejected", + "cross-origin service request rejected", + ); + let event = + build_service_http_failure_event(18080, &request, "my-sandbox", "web", None, &err); + let json = event.to_json().unwrap(); + + assert_eq!(json["class_uid"], 4002); + assert_eq!(json["http_request"]["url"]["path"], "/secret"); + assert_eq!(json["http_response"]["code"], 403); + assert!(!event.format_shorthand().contains("should-not-log")); + } + + #[test] + fn service_relay_failure_event_records_loopback_target() { + let event = build_service_relay_failure_event(&endpoint(), 8080, "relay unavailable"); + let json = event.to_json().unwrap(); + + assert_eq!(json["class_uid"], 4001); + assert_eq!(json["dst_endpoint"]["ip"], RELAY_TARGET_HOST); + assert_eq!(json["dst_endpoint"]["port"], 8080); + assert_eq!(json["unmapped"]["endpoint_name"], "my-sandbox--web"); + } + + #[test] + fn strips_gateway_auth_headers_from_upstream_request() { + let request = Request::builder() + .uri("https://my-sandbox--web.dev.openshell.localhost/path") + .header(header::AUTHORIZATION, "Bearer gateway-token") + .header("cf-access-jwt-assertion", "edge-token") + .header("x-forwarded-client-cert", "cert") + .header( + header::COOKIE, + "theme=dark; CF_Authorization=edge-cookie; app=session", + ) + .header("x-app-header", "kept") + .body(Body::empty()) + .unwrap(); + + let upstream = build_upstream_request(request, 8080, false).unwrap(); + + assert_eq!(upstream.uri(), "/path"); + assert!(!upstream.headers().contains_key(header::AUTHORIZATION)); + assert!(!upstream.headers().contains_key("cf-access-jwt-assertion")); + assert!(!upstream.headers().contains_key("x-forwarded-client-cert")); + assert_eq!( + upstream.headers()[header::COOKIE], + "theme=dark; app=session" + ); + assert_eq!(upstream.headers()["x-app-header"], "kept"); + } + + #[test] + fn detects_websocket_upgrade_request() { + let request = Request::builder() + .method(Method::GET) + .uri("/chat?session=main") + .header(header::CONNECTION, "keep-alive, Upgrade") + .header(header::UPGRADE, "websocket") + .body(Body::empty()) + .unwrap(); + + assert!(is_websocket_upgrade(&request)); + } + + #[test] + fn preserves_websocket_upgrade_headers_for_upstream_request() { + let request = Request::builder() + .method(Method::GET) + .uri("https://my-sandbox--web.dev.openshell.localhost/chat?session=main") + .header(header::CONNECTION, "Upgrade") + .header(header::UPGRADE, "websocket") + .header("sec-websocket-key", "abc") + .body(Body::empty()) + .unwrap(); + + let upstream = build_upstream_request(request, 8080, true).unwrap(); + + assert_eq!(upstream.uri(), "/chat?session=main"); + assert_eq!(upstream.headers()[header::CONNECTION], "Upgrade"); + assert_eq!(upstream.headers()[header::UPGRADE], "websocket"); + assert_eq!(upstream.headers()["sec-websocket-key"], "abc"); + assert_eq!(upstream.headers()[header::HOST], "127.0.0.1:8080"); + } +} diff --git a/crates/openshell-server/src/ssh_sessions.rs b/crates/openshell-server/src/ssh_sessions.rs new file mode 100644 index 000000000..3f1f24a7d --- /dev/null +++ b/crates/openshell-server/src/ssh_sessions.rs @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! SSH session token storage and cleanup. + +use openshell_core::ObjectId; +use openshell_core::proto::SshSession; +use prost::Message; +use std::sync::Arc; +use std::time::Duration; +use tracing::{info, warn}; + +use crate::persistence::{ObjectType, Store}; + +impl ObjectType for SshSession { + fn object_type() -> &'static str { + "ssh_session" + } +} + +/// Spawn a background task that periodically reaps expired and revoked SSH sessions. +pub fn spawn_session_reaper(store: Arc, interval: Duration) { + tokio::spawn(async move { + tokio::time::sleep(interval).await; + + loop { + if let Err(e) = reap_expired_sessions(&store).await { + warn!(error = %e, "SSH session reaper sweep failed"); + } + tokio::time::sleep(interval).await; + } + }); +} + +async fn reap_expired_sessions(store: &Store) -> Result<(), String> { + let now_ms = unix_epoch_millis(); + + let records = store + .list(SshSession::object_type(), 1000, 0) + .await + .map_err(|e| e.to_string())?; + + let mut reaped = 0u32; + for record in records { + let session: SshSession = match Message::decode(record.payload.as_slice()) { + Ok(s) => s, + Err(_) => continue, + }; + + let should_delete = + (session.expires_at_ms > 0 && now_ms > session.expires_at_ms) || session.revoked; + + if should_delete { + if let Err(e) = store + .delete(SshSession::object_type(), session.object_id()) + .await + { + warn!(session_id = %session.object_id(), error = %e, "Failed to reap SSH session"); + } else { + reaped += 1; + } + } + } + + if reaped > 0 { + info!(count = reaped, "SSH session reaper: cleaned up sessions"); + } + Ok(()) +} + +fn unix_epoch_millis() -> i64 { + i64::try_from( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(), + ) + .unwrap_or(i64::MAX) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + async fn test_store() -> Store { + Store::connect("sqlite::memory:?cache=shared") + .await + .expect("in-memory SQLite store should connect") + } + + fn make_session(id: &str, sandbox_id: &str, expires_at_ms: i64, revoked: bool) -> SshSession { + SshSession { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: id.to_string(), + name: format!("session-{id}"), + created_at_ms: 1000, + labels: HashMap::new(), + resource_version: 0, + }), + sandbox_id: sandbox_id.to_string(), + token: id.to_string(), + expires_at_ms, + revoked, + } + } + + fn now_ms() -> i64 { + unix_epoch_millis() + } + + #[tokio::test] + async fn reaper_deletes_expired_sessions() { + let store = test_store().await; + + let expired = make_session("expired1", "sbx1", now_ms() - 60_000, false); + store.put_message(&expired).await.unwrap(); + + let valid = make_session("valid1", "sbx1", now_ms() + 3_600_000, false); + store.put_message(&valid).await.unwrap(); + + reap_expired_sessions(&store).await.unwrap(); + + assert!( + store + .get_message::("expired1") + .await + .unwrap() + .is_none(), + "expired session should be reaped" + ); + assert!( + store + .get_message::("valid1") + .await + .unwrap() + .is_some(), + "valid session should be kept" + ); + } + + #[tokio::test] + async fn reaper_deletes_revoked_sessions() { + let store = test_store().await; + + let revoked = make_session("revoked1", "sbx1", 0, true); + store.put_message(&revoked).await.unwrap(); + + let active = make_session("active1", "sbx1", 0, false); + store.put_message(&active).await.unwrap(); + + reap_expired_sessions(&store).await.unwrap(); + + assert!( + store + .get_message::("revoked1") + .await + .unwrap() + .is_none(), + "revoked session should be reaped" + ); + assert!( + store + .get_message::("active1") + .await + .unwrap() + .is_some(), + "active session should be kept" + ); + } + + #[tokio::test] + async fn reaper_preserves_zero_expiry_sessions() { + let store = test_store().await; + + let no_expiry = make_session("noexpiry1", "sbx1", 0, false); + store.put_message(&no_expiry).await.unwrap(); + + reap_expired_sessions(&store).await.unwrap(); + + assert!( + store + .get_message::("noexpiry1") + .await + .unwrap() + .is_some(), + "session with no expiry should be preserved" + ); + } +} diff --git a/crates/openshell-server/src/ssh_tunnel.rs b/crates/openshell-server/src/ssh_tunnel.rs deleted file mode 100644 index bd317d53f..000000000 --- a/crates/openshell-server/src/ssh_tunnel.rs +++ /dev/null @@ -1,541 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! SSH tunnel handler for the multiplexed gateway. - -use axum::{Router, extract::State, http::Method, response::IntoResponse, routing::any}; -use http::StatusCode; -use hyper::Request; -use hyper_util::rt::TokioIo; -use openshell_core::proto::{Sandbox, SandboxPhase, SshSession}; -use prost::Message; -use std::sync::Arc; -use std::time::Duration; -use tokio::io::AsyncWriteExt; -use tracing::{info, warn}; - -use crate::ServerState; -use crate::persistence::{ObjectType, Store}; - -const HEADER_SANDBOX_ID: &str = "x-sandbox-id"; -const HEADER_TOKEN: &str = "x-sandbox-token"; - -/// Maximum concurrent SSH tunnel connections per session token. -const MAX_CONNECTIONS_PER_TOKEN: u32 = 3; - -/// Redact a bearer token for safe logging — show only the last 4 characters. -fn redact_token(token: &str) -> String { - if token.len() <= 4 { - "****".to_string() - } else { - format!("****{}", &token[token.len() - 4..]) - } -} - -/// Maximum concurrent SSH tunnel connections per sandbox. -const MAX_CONNECTIONS_PER_SANDBOX: u32 = 20; - -fn acquire_connection_slots( - token_counts: &std::sync::Mutex>, - sandbox_counts: &std::sync::Mutex>, - token: &str, - sandbox_id: &str, -) -> Result<(), ConnectionLimit> { - { - let mut counts = token_counts.lock().unwrap(); - let count = counts.entry(token.to_string()).or_insert(0); - if *count >= MAX_CONNECTIONS_PER_TOKEN { - return Err(ConnectionLimit::PerToken); - } - *count += 1; - } - - { - let mut counts = sandbox_counts.lock().unwrap(); - let count = counts.entry(sandbox_id.to_string()).or_insert(0); - if *count >= MAX_CONNECTIONS_PER_SANDBOX { - decrement_connection_count(token_counts, token); - return Err(ConnectionLimit::PerSandbox); - } - *count += 1; - } - - Ok(()) -} - -enum ConnectionLimit { - PerToken, - PerSandbox, -} - -pub fn router(state: Arc) -> Router { - Router::new() - .route("/connect/ssh", any(ssh_connect)) - .with_state(state) -} - -async fn ssh_connect( - State(state): State>, - req: Request, -) -> impl IntoResponse { - if req.method() != Method::CONNECT { - return StatusCode::METHOD_NOT_ALLOWED.into_response(); - } - - let sandbox_id = match header_value(req.headers(), HEADER_SANDBOX_ID) { - Ok(value) => value, - Err(status) => return status.into_response(), - }; - let token = match header_value(req.headers(), HEADER_TOKEN) { - Ok(value) => value, - Err(status) => return status.into_response(), - }; - - let session = match state.store.get_message::(&token).await { - Ok(Some(session)) => session, - Ok(None) => return StatusCode::UNAUTHORIZED.into_response(), - Err(err) => { - warn!(error = %err, "Failed to fetch SSH session"); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - if session.revoked || session.sandbox_id != sandbox_id { - return StatusCode::UNAUTHORIZED.into_response(); - } - - // Check token expiry (0 means no expiry for backward compatibility). - if session.expires_at_ms > 0 { - let now_ms = i64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(), - ) - .unwrap_or(i64::MAX); - if now_ms > session.expires_at_ms { - return StatusCode::UNAUTHORIZED.into_response(); - } - } - - let sandbox = match state.store.get_message::(&sandbox_id).await { - Ok(Some(sandbox)) => sandbox, - Ok(None) => return StatusCode::NOT_FOUND.into_response(), - Err(err) => { - warn!(error = %err, "Failed to fetch sandbox"); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { - return StatusCode::PRECONDITION_FAILED.into_response(); - } - - // Enforce connection caps *before* opening a relay — otherwise denied - // calls churn pending relay slots and wake the supervisor until the relay - // timeout elapses. - if let Err(limit) = acquire_connection_slots( - &state.ssh_connections_by_token, - &state.ssh_connections_by_sandbox, - &token, - &sandbox_id, - ) { - match limit { - ConnectionLimit::PerToken => { - warn!(token = %redact_token(&token), "SSH tunnel: per-token connection limit reached"); - } - ConnectionLimit::PerSandbox => { - warn!(sandbox_id = %sandbox_id, "SSH tunnel: per-sandbox connection limit reached"); - } - } - return StatusCode::TOO_MANY_REQUESTS.into_response(); - } - - // Open a relay channel through the supervisor session. Use a generous - // 30s session-wait timeout because `/connect/ssh` is typically called - // immediately after `sandbox create`, so we need to cover the supervisor's - // initial TLS + gRPC handshake on a cold-started pod. The old - // direct-connect path tolerated ~34s here for similar reasons. - let (channel_id, relay_rx) = match state - .supervisor_sessions - .open_relay(&sandbox_id, Duration::from_secs(30)) - .await - { - Ok(pair) => pair, - Err(status) => { - warn!(sandbox_id = %sandbox_id, error = %status.message(), "SSH tunnel: supervisor session not available"); - decrement_connection_count(&state.ssh_connections_by_token, &token); - decrement_connection_count(&state.ssh_connections_by_sandbox, &sandbox_id); - return StatusCode::BAD_GATEWAY.into_response(); - } - }; - - let sandbox_id_clone = sandbox_id.clone(); - let token_clone = token.clone(); - let state_clone = state.clone(); - - let upgrade = hyper::upgrade::on(req); - tokio::spawn(async move { - // Wait for the supervisor to open its `RelayStream` and deliver the - // bridge half of the relay. - let mut relay = match tokio::time::timeout(Duration::from_secs(10), relay_rx).await { - Ok(Ok(stream)) => stream, - Ok(Err(_)) => { - warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay channel dropped"); - decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); - decrement_connection_count( - &state_clone.ssh_connections_by_sandbox, - &sandbox_id_clone, - ); - return; - } - Err(_) => { - warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay open timed out"); - decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); - decrement_connection_count( - &state_clone.ssh_connections_by_sandbox, - &sandbox_id_clone, - ); - return; - } - }; - - info!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay established, bridging client"); - - match upgrade.await { - Ok(upgraded) => { - let mut upgraded = TokioIo::new(upgraded); - let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut relay).await; - let _ = AsyncWriteExt::shutdown(&mut upgraded).await; - } - Err(err) => { - warn!(error = %err, "SSH upgrade failed"); - } - } - - // Decrement connection counts on tunnel completion. - decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); - decrement_connection_count(&state_clone.ssh_connections_by_sandbox, &sandbox_id_clone); - }); - - StatusCode::OK.into_response() -} - -fn header_value(headers: &http::HeaderMap, name: &str) -> Result { - let value = headers - .get(name) - .ok_or(StatusCode::UNAUTHORIZED)? - .to_str() - .map_err(|_| StatusCode::BAD_REQUEST)? - .trim() - .to_string(); - if value.is_empty() { - return Err(StatusCode::BAD_REQUEST); - } - Ok(value) -} - -impl ObjectType for SshSession { - fn object_type() -> &'static str { - "ssh_session" - } -} - -/// Decrement a connection count entry, removing it if it reaches zero. -fn decrement_connection_count( - counts: &std::sync::Mutex>, - key: &str, -) { - let mut map = counts.lock().unwrap(); - if let Some(count) = map.get_mut(key) { - *count = count.saturating_sub(1); - if *count == 0 { - map.remove(key); - } - } -} - -/// Spawn a background task that periodically reaps expired and revoked SSH sessions. -pub fn spawn_session_reaper(store: Arc, interval: Duration) { - tokio::spawn(async move { - // Initial delay to let startup settle. - tokio::time::sleep(interval).await; - - loop { - if let Err(e) = reap_expired_sessions(&store).await { - warn!(error = %e, "SSH session reaper sweep failed"); - } - tokio::time::sleep(interval).await; - } - }); -} - -async fn reap_expired_sessions(store: &Store) -> Result<(), String> { - let now_ms = i64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(), - ) - .unwrap_or(i64::MAX); - - let records = store - .list(SshSession::object_type(), 1000, 0) - .await - .map_err(|e| e.to_string())?; - - let mut reaped = 0u32; - for record in records { - let session: SshSession = match Message::decode(record.payload.as_slice()) { - Ok(s) => s, - Err(_) => continue, - }; - - let should_delete = - // Expired sessions (expires_at_ms > 0 means expiry is set). - (session.expires_at_ms > 0 && now_ms > session.expires_at_ms) - // Revoked sessions — already invalidated, just cleaning up storage. - || session.revoked; - - if should_delete { - use openshell_core::ObjectId; - if let Err(e) = store - .delete(SshSession::object_type(), session.object_id()) - .await - { - warn!(session_id = %session.object_id(), error = %e, "Failed to reap SSH session"); - } else { - reaped += 1; - } - } - } - - if reaped > 0 { - info!(count = reaped, "SSH session reaper: cleaned up sessions"); - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::persistence::Store; - use std::collections::HashMap; - use std::sync::Mutex; - - fn make_session(id: &str, sandbox_id: &str, expires_at_ms: i64, revoked: bool) -> SshSession { - SshSession { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: id.to_string(), - name: format!("session-{id}"), - created_at_ms: 1000, - labels: HashMap::new(), - }), - sandbox_id: sandbox_id.to_string(), - token: id.to_string(), - expires_at_ms, - revoked, - } - } - - fn now_ms() -> i64 { - i64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(), - ) - .unwrap_or(i64::MAX) - } - - // ---- Connection limit tests ---- - - #[test] - fn decrement_removes_entry_at_zero() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts.lock().unwrap().insert("tok1".to_string(), 1); - decrement_connection_count(&counts, "tok1"); - assert!(counts.lock().unwrap().is_empty()); - } - - #[test] - fn decrement_reduces_count() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts.lock().unwrap().insert("tok1".to_string(), 5); - decrement_connection_count(&counts, "tok1"); - assert_eq!(*counts.lock().unwrap().get("tok1").unwrap(), 4); - } - - #[test] - fn decrement_missing_key_is_noop() { - let counts: Mutex> = Mutex::new(HashMap::new()); - decrement_connection_count(&counts, "nonexistent"); - assert!(counts.lock().unwrap().is_empty()); - } - - #[test] - fn per_token_connection_limit_enforced() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts - .lock() - .unwrap() - .insert("tok1".to_string(), MAX_CONNECTIONS_PER_TOKEN); - let current = *counts.lock().unwrap().get("tok1").unwrap(); - assert!(current >= MAX_CONNECTIONS_PER_TOKEN); - } - - #[test] - fn per_sandbox_connection_limit_enforced() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts - .lock() - .unwrap() - .insert("sbx1".to_string(), MAX_CONNECTIONS_PER_SANDBOX); - let current = *counts.lock().unwrap().get("sbx1").unwrap(); - assert!(current >= MAX_CONNECTIONS_PER_SANDBOX); - } - - #[test] - fn acquire_connection_slots_rejects_per_token_limit_without_touching_sandbox() { - let token_counts: Mutex> = Mutex::new(HashMap::new()); - let sandbox_counts: Mutex> = Mutex::new(HashMap::new()); - token_counts - .lock() - .unwrap() - .insert("tok1".to_string(), MAX_CONNECTIONS_PER_TOKEN); - - let result = acquire_connection_slots(&token_counts, &sandbox_counts, "tok1", "sbx1"); - - assert!(matches!(result, Err(ConnectionLimit::PerToken))); - assert!(sandbox_counts.lock().unwrap().is_empty()); - } - - #[test] - fn acquire_connection_slots_rolls_back_token_increment_on_sandbox_limit() { - let token_counts: Mutex> = Mutex::new(HashMap::new()); - let sandbox_counts: Mutex> = Mutex::new(HashMap::new()); - sandbox_counts - .lock() - .unwrap() - .insert("sbx1".to_string(), MAX_CONNECTIONS_PER_SANDBOX); - - let result = acquire_connection_slots(&token_counts, &sandbox_counts, "tok1", "sbx1"); - - assert!(matches!(result, Err(ConnectionLimit::PerSandbox))); - assert!(token_counts.lock().unwrap().is_empty()); - } - - // ---- Session reaper tests ---- - - #[tokio::test] - async fn reaper_deletes_expired_sessions() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); - - let expired = make_session("expired1", "sbx1", now_ms() - 60_000, false); - store.put_message(&expired).await.unwrap(); - - let valid = make_session("valid1", "sbx1", now_ms() + 3_600_000, false); - store.put_message(&valid).await.unwrap(); - - reap_expired_sessions(&store).await.unwrap(); - - assert!( - store - .get_message::("expired1") - .await - .unwrap() - .is_none(), - "expired session should be reaped" - ); - assert!( - store - .get_message::("valid1") - .await - .unwrap() - .is_some(), - "valid session should be kept" - ); - } - - #[tokio::test] - async fn reaper_deletes_revoked_sessions() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); - - let revoked = make_session("revoked1", "sbx1", 0, true); - store.put_message(&revoked).await.unwrap(); - - let active = make_session("active1", "sbx1", 0, false); - store.put_message(&active).await.unwrap(); - - reap_expired_sessions(&store).await.unwrap(); - - assert!( - store - .get_message::("revoked1") - .await - .unwrap() - .is_none(), - "revoked session should be reaped" - ); - assert!( - store - .get_message::("active1") - .await - .unwrap() - .is_some(), - "active session should be kept" - ); - } - - #[tokio::test] - async fn reaper_preserves_zero_expiry_sessions() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); - - // expires_at_ms = 0 means no expiry (backward compatible). - let no_expiry = make_session("noexpiry1", "sbx1", 0, false); - store.put_message(&no_expiry).await.unwrap(); - - reap_expired_sessions(&store).await.unwrap(); - - assert!( - store - .get_message::("noexpiry1") - .await - .unwrap() - .is_some(), - "session with no expiry should be preserved" - ); - } - - // ---- Expiry validation logic tests ---- - - #[test] - fn expired_session_is_detected() { - let session = make_session("tok1", "sbx1", now_ms() - 1000, false); - let is_expired = session.expires_at_ms > 0 && now_ms() > session.expires_at_ms; - assert!(is_expired, "session in the past should be expired"); - } - - #[test] - fn future_session_is_not_expired() { - let session = make_session("tok1", "sbx1", now_ms() + 3_600_000, false); - let is_expired = session.expires_at_ms > 0 && now_ms() > session.expires_at_ms; - assert!(!is_expired, "session in the future should not be expired"); - } - - #[test] - fn zero_expiry_is_not_expired() { - let session = make_session("tok1", "sbx1", 0, false); - let is_expired = session.expires_at_ms > 0 && now_ms() > session.expires_at_ms; - assert!( - !is_expired, - "session with zero expiry should never be expired" - ); - } -} diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs index 94c352ba5..8f186dcac 100644 --- a/crates/openshell-server/src/supervisor_session.rs +++ b/crates/openshell-server/src/supervisor_session.rs @@ -13,11 +13,12 @@ use tracing::{info, warn}; use uuid::Uuid; use openshell_core::proto::{ - GatewayMessage, RelayFrame, RelayInit, RelayOpen, Sandbox, SessionAccepted, SupervisorMessage, - gateway_message, supervisor_message, + GatewayMessage, RelayFrame, RelayInit, RelayOpen, Sandbox, SessionAccepted, SshRelayTarget, + SupervisorMessage, gateway_message, relay_open, supervisor_message, }; use crate::ServerState; +use crate::auth::principal::Principal; const HEARTBEAT_INTERVAL_SECS: u32 = 15; const RELAY_PENDING_TIMEOUT: Duration = Duration::from_secs(10); @@ -58,8 +59,9 @@ struct LiveSession { connected_at: Instant, } -/// Holds a oneshot sender that will deliver the upgraded relay stream. -type RelayStreamSender = oneshot::Sender; +/// Holds a oneshot sender that will deliver the upgraded relay stream or a +/// target-open failure reported by the supervisor. +type RelayStreamSender = oneshot::Sender>; impl openshell_driver_docker::SupervisorReadiness for SupervisorSessionRegistry { fn is_supervisor_connected(&self, sandbox_id: &str) -> bool { @@ -79,6 +81,7 @@ pub struct SupervisorSessionRegistry { struct PendingRelay { sender: RelayStreamSender, sandbox_id: String, + relay_open: RelayOpen, created_at: Instant, } @@ -234,12 +237,45 @@ impl SupervisorSessionRegistry { &self, sandbox_id: &str, session_wait_timeout: Duration, - ) -> Result<(String, oneshot::Receiver), Status> { + ) -> Result< + ( + String, + oneshot::Receiver>, + ), + Status, + > { + self.open_relay_with_target( + sandbox_id, + relay_open::Target::Ssh(SshRelayTarget {}), + String::new(), + session_wait_timeout, + ) + .await + } + + pub async fn open_relay_with_target( + &self, + sandbox_id: &str, + target: relay_open::Target, + service_id: String, + session_wait_timeout: Duration, + ) -> Result< + ( + String, + oneshot::Receiver>, + ), + Status, + > { let tx = self .wait_for_session(sandbox_id, session_wait_timeout) .await?; let channel_id = Uuid::new_v4().to_string(); + let relay_open = RelayOpen { + channel_id: channel_id.clone(), + target: Some(target), + service_id, + }; // Register the pending relay before sending RelayOpen to avoid a race. // Both caps are checked and the insert happens under a single lock hold @@ -267,15 +303,14 @@ impl SupervisorSessionRegistry { PendingRelay { sender: relay_tx, sandbox_id: sandbox_id.to_string(), + relay_open: relay_open.clone(), created_at: Instant::now(), }, ); } let msg = GatewayMessage { - payload: Some(gateway_message::Payload::RelayOpen(RelayOpen { - channel_id: channel_id.clone(), - })), + payload: Some(gateway_message::Payload::RelayOpen(relay_open)), }; if tx.send(msg).await.is_err() { @@ -287,29 +322,62 @@ impl SupervisorSessionRegistry { Ok((channel_id, relay_rx)) } + pub fn fail_pending_relay(&self, channel_id: &str, error: String) -> bool { + let pending = self.pending_relays.lock().unwrap().remove(channel_id); + if let Some(pending) = pending { + let _ = pending.sender.send(Err(Status::unavailable(error))); + true + } else { + false + } + } + /// Claim a pending relay channel. Called by the `/relay/{channel_id}` HTTP handler /// when the supervisor's reverse CONNECT arrives. /// /// Returns the `DuplexStream` half that the supervisor side should read/write. // `tonic::Status` is large but is the API surface of gRPC handlers. #[allow(clippy::result_large_err)] - pub fn claim_relay(&self, channel_id: &str) -> Result { + pub fn claim_relay( + &self, + channel_id: &str, + principal: Option<&Principal>, + ) -> Result { let pending = { let mut map = self.pending_relays.lock().unwrap(); + let pending = map + .get(channel_id) + .ok_or_else(|| Status::not_found("unknown or expired relay channel"))?; + + if let Some(principal) = principal + && let Err(status) = crate::auth::guard::ensure_sandbox_principal_scope( + principal, + &pending.sandbox_id, + ) + { + info!( + channel_id = %channel_id, + sandbox_id = %pending.sandbox_id, + "relay stream: rejecting cross-sandbox claim" + ); + return Err(status); + } + + if pending.created_at.elapsed() > RELAY_PENDING_TIMEOUT { + map.remove(channel_id); + return Err(Status::deadline_exceeded("relay channel timed out")); + } + map.remove(channel_id) - .ok_or_else(|| Status::not_found("unknown or expired relay channel"))? + .expect("pending relay existed before removal") }; - if pending.created_at.elapsed() > RELAY_PENDING_TIMEOUT { - return Err(Status::deadline_exceeded("relay channel timed out")); - } - // Create a duplex stream pair: one end for the gateway bridge, one for // the supervisor HTTP CONNECT handler. let (gateway_stream, supervisor_stream) = tokio::io::duplex(64 * 1024); - // Send the gateway-side stream to the waiter (ssh_tunnel or exec handler). - if pending.sender.send(gateway_stream).is_err() { + // Send the gateway-side stream to the waiter (exec handler or forward handler). + if pending.sender.send(Ok(gateway_stream)).is_err() { return Err(Status::internal("relay requester dropped")); } @@ -329,10 +397,17 @@ impl SupervisorSessionRegistry { pub async fn replay_pending_relays(&self, sandbox_id: &str, tx: &mpsc::Sender) { for channel_id in self.pending_channel_ids(sandbox_id) { + let relay_open = { + let pending = self.pending_relays.lock().unwrap(); + pending + .get(&channel_id) + .map(|pending| pending.relay_open.clone()) + }; + let Some(relay_open) = relay_open else { + continue; + }; let msg = GatewayMessage { - payload: Some(gateway_message::Payload::RelayOpen(RelayOpen { - channel_id: channel_id.clone(), - })), + payload: Some(gateway_message::Payload::RelayOpen(relay_open)), }; if tx.send(msg).await.is_err() { warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "supervisor session: failed to replay pending relay to superseding session"); @@ -398,6 +473,7 @@ pub async fn handle_relay_stream( >, Status, > { + let principal = request.extensions().get::().cloned(); let mut inbound = request.into_inner(); // First frame must identify the channel. @@ -419,7 +495,7 @@ pub async fn handle_relay_stream( }; // Claim the pending relay. Consumes the entry — it cannot be reused. - let supervisor_side = registry.claim_relay(&channel_id)?; + let supervisor_side = registry.claim_relay(&channel_id, principal.as_ref())?; info!(channel_id = %channel_id, "relay stream: claimed pending relay, bridging"); let (mut read_half, mut write_half) = tokio::io::split(supervisor_side); @@ -503,6 +579,7 @@ pub async fn handle_connect_supervisor( >, Status, > { + let principal = request.extensions().get::().cloned(); let mut inbound = request.into_inner(); // Step 1: Wait for SupervisorHello. @@ -518,6 +595,9 @@ pub async fn handle_connect_supervisor( if sandbox_id.is_empty() { return Err(Status::invalid_argument("sandbox_id is required")); } + if let Some(principal) = principal.as_ref() { + crate::auth::guard::ensure_sandbox_principal_scope(principal, &sandbox_id)?; + } require_persisted_sandbox(&state.store, &sandbox_id).await?; let session_id = Uuid::new_v4().to_string(); @@ -626,7 +706,7 @@ pub async fn handle_connect_supervisor( } async fn run_session_loop( - _state: &Arc, + state: &Arc, sandbox_id: &str, session_id: &str, tx: &mpsc::Sender, @@ -647,7 +727,7 @@ async fn run_session_loop( msg = inbound.message() => { match msg { Ok(Some(msg)) => { - handle_supervisor_message(sandbox_id, session_id, msg); + handle_supervisor_message(state, sandbox_id, session_id, msg); } Ok(None) => { info!(sandbox_id = %sandbox_id, session_id = %session_id, "supervisor session: stream closed by supervisor"); @@ -674,7 +754,12 @@ async fn run_session_loop( } } -fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: SupervisorMessage) { +fn handle_supervisor_message( + state: &Arc, + sandbox_id: &str, + session_id: &str, + msg: SupervisorMessage, +) { match msg.payload { Some(supervisor_message::Payload::Heartbeat(_)) => { // Heartbeat received — nothing to do for now. @@ -688,11 +773,15 @@ fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: Supervisor "supervisor session: relay opened successfully" ); } else { + let failed = state + .supervisor_sessions + .fail_pending_relay(&result.channel_id, result.error.clone()); warn!( sandbox_id = %sandbox_id, session_id = %session_id, channel_id = %result.channel_id, error = %result.error, + pending_relay_failed = failed, "supervisor session: relay open failed" ); } @@ -723,9 +812,19 @@ fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: Supervisor #[cfg(test)] mod tests { use super::*; + use crate::auth::identity::{Identity, IdentityProvider}; + use crate::auth::principal::{SandboxIdentitySource, SandboxPrincipal, UserPrincipal}; use crate::persistence::Store; use tokio::io::{AsyncReadExt, AsyncWriteExt}; + async fn test_store() -> Arc { + Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .expect("in-memory SQLite store should connect"), + ) + } + /// Returns a shutdown sender with its receiver immediately dropped. Tests /// that don't observe the shutdown signal can use this to satisfy the /// `register` signature without the receiver noise. @@ -740,11 +839,51 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), ..Default::default() } } + fn pending_relay( + sandbox_id: &str, + relay_tx: RelayStreamSender, + created_at: Instant, + ) -> PendingRelay { + PendingRelay { + sender: relay_tx, + sandbox_id: sandbox_id.to_string(), + relay_open: RelayOpen { + channel_id: "ch-test".to_string(), + target: Some(relay_open::Target::Ssh(SshRelayTarget {})), + service_id: String::new(), + }, + created_at, + } + } + + fn sandbox_principal(sandbox_id: &str) -> Principal { + Principal::Sandbox(SandboxPrincipal { + sandbox_id: sandbox_id.to_string(), + source: SandboxIdentitySource::BootstrapJwt { + issuer: "openshell-gateway:test".to_string(), + }, + trust_domain: Some("openshell".to_string()), + }) + } + + fn user_principal(subject: &str) -> Principal { + Principal::User(UserPrincipal { + identity: Identity { + subject: subject.to_string(), + display_name: None, + roles: vec![], + scopes: vec![], + provider: IdentityProvider::Oidc, + }, + }) + } + // ---- registry: register / remove ---- #[test] @@ -863,6 +1002,7 @@ mod tests { match msg.payload { Some(gateway_message::Payload::RelayOpen(open)) => { assert_eq!(open.channel_id, channel_id); + assert!(matches!(open.target, Some(relay_open::Target::Ssh(_)))); } other => panic!("expected RelayOpen, got {other:?}"), } @@ -944,11 +1084,7 @@ mod tests { let sandbox_id = if i % 2 == 0 { "sbx-a" } else { "sbx-b" }; pending.insert( format!("channel-{i}"), - PendingRelay { - sender: oneshot_tx, - sandbox_id: sandbox_id.to_string(), - created_at: Instant::now(), - }, + pending_relay(sandbox_id, oneshot_tx, Instant::now()), ); } } @@ -973,11 +1109,7 @@ mod tests { let (oneshot_tx, _) = oneshot::channel(); pending.insert( format!("channel-{i}"), - PendingRelay { - sender: oneshot_tx, - sandbox_id: "sbx".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx", oneshot_tx, Instant::now()), ); } } @@ -1129,11 +1261,7 @@ mod tests { #[tokio::test] async fn require_persisted_sandbox_rejects_missing_sandbox() { - let store = Arc::new( - Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(), - ); + let store = test_store().await; let err = require_persisted_sandbox(&store, "missing") .await @@ -1144,11 +1272,7 @@ mod tests { #[tokio::test] async fn require_persisted_sandbox_accepts_existing_sandbox() { - let store = Arc::new( - Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(), - ); + let store = test_store().await; store .put_message(&sandbox_record("sbx-1", "sandbox-one")) .await @@ -1164,7 +1288,10 @@ mod tests { #[test] fn claim_relay_unknown_channel() { let registry = SupervisorSessionRegistry::new(); - let err = registry.claim_relay("nonexistent").expect_err("should err"); + let principal = sandbox_principal("sbx-test"); + let err = registry + .claim_relay("nonexistent", Some(&principal)) + .expect_err("should err"); assert_eq!(err.code(), tonic::Code::NotFound); } @@ -1174,35 +1301,95 @@ mod tests { let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-1".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); - let result = registry.claim_relay("ch-1"); + let principal = sandbox_principal("sbx-test"); + let result = registry.claim_relay("ch-1", Some(&principal)); assert!(result.is_ok()); assert!(!registry.pending_relays.lock().unwrap().contains_key("ch-1")); } + #[test] + fn claim_relay_rejects_cross_sandbox_principal_without_consuming_channel() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-cross".to_string(), + pending_relay("sbx-owner", relay_tx, Instant::now()), + ); + + let attacker = sandbox_principal("sbx-attacker"); + let err = registry + .claim_relay("ch-cross", Some(&attacker)) + .expect_err("cross-sandbox relay claim must fail"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + assert!( + registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-cross"), + "failed cross-sandbox claim must not consume the channel" + ); + } + + #[test] + fn claim_relay_rejects_user_principal() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-user".to_string(), + pending_relay("sbx-owner", relay_tx, Instant::now()), + ); + + let err = registry + .claim_relay("ch-user", Some(&user_principal("alice"))) + .expect_err("users are not supervisor identities"); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[tokio::test] + async fn relay_open_failure_completes_pending_waiter() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-fail".to_string(), + pending_relay("sbx-test", relay_tx, Instant::now()), + ); + + assert!(registry.fail_pending_relay("ch-fail", "target refused".to_string())); + assert!( + !registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-fail") + ); + + let result = relay_rx.await.expect("failure should wake waiter"); + let status = result.expect_err("waiter should receive status failure"); + assert_eq!(status.code(), tonic::Code::Unavailable); + assert_eq!(status.message(), "target refused"); + } + #[test] fn claim_relay_expired_returns_deadline_exceeded() { let registry = SupervisorSessionRegistry::new(); let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-old".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now() + pending_relay( + "sbx-test", + relay_tx, + Instant::now() .checked_sub(Duration::from_secs(60)) - .expect("test instant subtraction underflow"), - }, + .expect("test duration should be before now"), + ), ); let err = registry - .claim_relay("ch-old") + .claim_relay("ch-old", Some(&sandbox_principal("sbx-test"))) .expect_err("expired entry must fail"); assert_eq!(err.code(), tonic::Code::DeadlineExceeded); // Entry must have been consumed regardless. @@ -1218,19 +1405,15 @@ mod tests { #[test] fn claim_relay_receiver_dropped_returns_internal() { let registry = SupervisorSessionRegistry::new(); - let (relay_tx, relay_rx) = oneshot::channel::(); + let (relay_tx, relay_rx) = oneshot::channel::>(); drop(relay_rx); // Gateway-side waiter has given up already. registry.pending_relays.lock().unwrap().insert( "ch-1".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); let err = registry - .claim_relay("ch-1") + .claim_relay("ch-1", Some(&sandbox_principal("sbx-test"))) .expect_err("should err when receiver is gone"); assert_eq!(err.code(), tonic::Code::Internal); } @@ -1238,18 +1421,19 @@ mod tests { #[tokio::test] async fn claim_relay_connects_both_ends() { let registry = SupervisorSessionRegistry::new(); - let (relay_tx, relay_rx) = oneshot::channel::(); + let (relay_tx, relay_rx) = oneshot::channel::>(); registry.pending_relays.lock().unwrap().insert( "ch-io".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); - let mut supervisor_side = registry.claim_relay("ch-io").expect("claim should succeed"); - let mut gateway_side = relay_rx.await.expect("gateway side should receive stream"); + let mut supervisor_side = registry + .claim_relay("ch-io", Some(&sandbox_principal("sbx-test"))) + .expect("claim should succeed"); + let mut gateway_side = relay_rx + .await + .expect("gateway side should receive result") + .expect("gateway side should receive stream"); // Supervisor side writes → gateway side reads. supervisor_side.write_all(b"hello").await.unwrap(); @@ -1272,13 +1456,13 @@ mod tests { let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-old".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now() + pending_relay( + "sbx-test", + relay_tx, + Instant::now() .checked_sub(Duration::from_secs(60)) - .expect("test instant subtraction underflow"), - }, + .expect("test duration should be before now"), + ), ); registry.reap_expired_relays(); @@ -1297,11 +1481,7 @@ mod tests { let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-fresh".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); registry.reap_expired_relays(); diff --git a/crates/openshell-server/src/tls.rs b/crates/openshell-server/src/tls.rs index 95c18608f..1af1ce0cd 100644 --- a/crates/openshell-server/src/tls.rs +++ b/crates/openshell-server/src/tls.rs @@ -19,17 +19,19 @@ pub struct TlsAcceptor { } impl TlsAcceptor { - /// Create a new TLS acceptor from certificate, key, and client CA files. + /// Create a new TLS acceptor from certificate and key files. /// - /// When `allow_unauthenticated` is `false` (the default), the server - /// enforces mTLS — all clients must present a valid certificate signed - /// by the given CA. + /// When `client_ca_path` is `Some` and `require_client_auth` is `true`, + /// the TLS handshake rejects connections that do not present a valid + /// client certificate signed by the given CA. /// - /// When `allow_unauthenticated` is `true`, the TLS handshake succeeds - /// even without a client certificate. This is required when the server - /// sits behind a reverse proxy (e.g. Cloudflare Tunnel) that terminates - /// TLS and cannot forward client certificates. Application-layer - /// middleware must then enforce authentication (e.g. via a JWT header). + /// When `client_ca_path` is `Some` and `require_client_auth` is `false`, + /// client certificates are validated against the CA but not required. + /// Clients may connect without a certificate; presented certs from an + /// unknown CA are still rejected. + /// + /// When `client_ca_path` is `None`, the server does not request client + /// certificates at all (HTTPS-only). /// /// # Errors /// @@ -37,33 +39,40 @@ impl TlsAcceptor { pub fn from_files( cert_path: &Path, key_path: &Path, - client_ca_path: &Path, - allow_unauthenticated: bool, + client_ca_path: Option<&Path>, + require_client_auth: bool, ) -> Result { let certs = load_certs(cert_path)?; let key = load_key(key_path)?; - let ca_certs = load_certs(client_ca_path)?; - let mut root_store = rustls::RootCertStore::empty(); - for cert in ca_certs { - root_store - .add(cert) - .map_err(|e| Error::tls(format!("failed to add CA certificate: {e}")))?; - } - - let verifier_builder = WebPkiClientVerifier::builder(Arc::new(root_store)); - let verifier = if allow_unauthenticated { - verifier_builder.allow_unauthenticated() + let mut config = if let Some(ca_path) = client_ca_path { + let ca_certs = load_certs(ca_path)?; + let mut root_store = rustls::RootCertStore::empty(); + for cert in ca_certs { + root_store + .add(cert) + .map_err(|e| Error::tls(format!("failed to add CA certificate: {e}")))?; + } + + let verifier_builder = WebPkiClientVerifier::builder(Arc::new(root_store)); + let verifier = if require_client_auth { + verifier_builder + } else { + verifier_builder.allow_unauthenticated() + } + .build() + .map_err(|e| Error::tls(format!("failed to build client verifier: {e}")))?; + + ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert(certs, key) + .map_err(|e| Error::tls(format!("failed to create TLS config: {e}")))? } else { - verifier_builder - } - .build() - .map_err(|e| Error::tls(format!("failed to build client verifier: {e}")))?; - - let mut config = ServerConfig::builder() - .with_client_cert_verifier(verifier) - .with_single_cert(certs, key) - .map_err(|e| Error::tls(format!("failed to create TLS config: {e}")))?; + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| Error::tls(format!("failed to create TLS config: {e}")))? + }; config .alpn_protocols diff --git a/crates/openshell-server/src/tracing_bus.rs b/crates/openshell-server/src/tracing_bus.rs index cf168e306..cc7b64ad3 100644 --- a/crates/openshell-server/src/tracing_bus.rs +++ b/crates/openshell-server/src/tracing_bus.rs @@ -5,7 +5,6 @@ use std::collections::{HashMap, VecDeque}; use std::sync::{Arc, Mutex}; -use std::time::{SystemTime, UNIX_EPOCH}; use openshell_core::proto::{SandboxLogLine, SandboxStreamEvent}; use openshell_ocsf::OCSF_TARGET; @@ -150,7 +149,7 @@ where let msg = visitor.message.unwrap_or_else(|| meta.name().to_string()); let level = display_level(meta.target(), &meta.level().to_string()); - let ts = current_time_ms().unwrap_or(0); + let ts = openshell_core::time::now_ms(); let log = SandboxLogLine { sandbox_id: sandbox_id.clone(), timestamp_ms: ts, @@ -193,11 +192,6 @@ impl tracing::field::Visit for LogVisitor { } } -fn current_time_ms() -> Option { - let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?; - i64::try_from(now.as_millis()).ok() -} - fn display_level(target: &str, level: &str) -> String { if target == OCSF_TARGET { "OCSF".to_string() diff --git a/crates/openshell-server/tests/auth_endpoint_integration.rs b/crates/openshell-server/tests/auth_endpoint_integration.rs index 7b16ee991..c1ea74b9b 100644 --- a/crates/openshell-server/tests/auth_endpoint_integration.rs +++ b/crates/openshell-server/tests/auth_endpoint_integration.rs @@ -11,6 +11,8 @@ //! a full `ServerState`. The test handler mirrors the production logic in //! `auth.rs` but uses a simple `SocketAddr` as state. +mod common; + use axum::{ Router, extract::{Query, State}, @@ -378,353 +380,7 @@ async fn auth_connect_falls_back_to_bind_address() { server.abort(); } -// --------------------------------------------------------------------------- -// Minimal OpenShell for test 7 (plaintext gRPC+HTTP) -// --------------------------------------------------------------------------- - -#[derive(Clone, Default)] -struct TestOpenShell; - -#[tonic::async_trait] -impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { - async fn health( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> { - Ok(tonic::Response::new( - openshell_core::proto::HealthResponse { - status: openshell_core::proto::ServiceStatus::Healthy.into(), - version: "test".to_string(), - }, - )) - } - - async fn create_sandbox( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Ok(tonic::Response::new( - openshell_core::proto::SandboxResponse::default(), - )) - } - - async fn get_sandbox( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Ok(tonic::Response::new( - openshell_core::proto::SandboxResponse::default(), - )) - } - - async fn list_sandboxes( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Ok(tonic::Response::new( - openshell_core::proto::ListSandboxesResponse::default(), - )) - } - - async fn delete_sandbox( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Ok(tonic::Response::new( - openshell_core::proto::DeleteSandboxResponse { deleted: true }, - )) - } - - async fn get_sandbox_config( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Ok(tonic::Response::new( - openshell_core::proto::GetSandboxConfigResponse::default(), - )) - } - - async fn get_gateway_config( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Ok(tonic::Response::new( - openshell_core::proto::GetGatewayConfigResponse::default(), - )) - } - - async fn get_sandbox_provider_environment( - &self, - _: tonic::Request, - ) -> Result< - tonic::Response, - tonic::Status, - > { - Ok(tonic::Response::new( - openshell_core::proto::GetSandboxProviderEnvironmentResponse::default(), - )) - } - - async fn create_ssh_session( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Ok(tonic::Response::new( - openshell_core::proto::CreateSshSessionResponse::default(), - )) - } - - async fn revoke_ssh_session( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Ok(tonic::Response::new( - openshell_core::proto::RevokeSshSessionResponse::default(), - )) - } - - async fn create_provider( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("test")) - } - - async fn get_provider( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("test")) - } - - async fn list_providers( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("test")) - } - - async fn list_provider_profiles( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("test")) - } - - async fn get_provider_profile( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("test")) - } - - async fn import_provider_profiles( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("test")) - } - - async fn lint_provider_profiles( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("test")) - } - - async fn delete_provider_profile( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("test")) - } - - async fn update_provider( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("test")) - } - - async fn delete_provider( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("test")) - } - - type WatchSandboxStream = tokio_stream::wrappers::ReceiverStream< - Result, - >; - type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream< - Result, - >; - type ConnectSupervisorStream = tokio_stream::wrappers::ReceiverStream< - Result, - >; - - async fn watch_sandbox( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - let (_tx, rx) = tokio::sync::mpsc::channel(1); - Ok(tonic::Response::new( - tokio_stream::wrappers::ReceiverStream::new(rx), - )) - } - - async fn exec_sandbox( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - let (_tx, rx) = tokio::sync::mpsc::channel(1); - Ok(tonic::Response::new( - tokio_stream::wrappers::ReceiverStream::new(rx), - )) - } - - async fn update_config( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("test")) - } - - async fn get_sandbox_policy_status( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("test")) - } - - async fn list_sandbox_policies( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("test")) - } - - async fn report_policy_status( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("test")) - } - - async fn get_sandbox_logs( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("test")) - } - - async fn push_sandbox_logs( - &self, - _: tonic::Request>, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("test")) - } - - async fn submit_policy_analysis( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - async fn get_draft_policy( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - async fn approve_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - async fn reject_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - async fn approve_all_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - async fn edit_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - async fn undo_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - async fn clear_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - async fn get_draft_history( - &self, - _request: tonic::Request, - ) -> Result, tonic::Status> - { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - async fn connect_supervisor( - &self, - _request: tonic::Request>, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) - } - - type RelayStreamStream = tokio_stream::wrappers::ReceiverStream< - Result, - >; - - async fn relay_stream( - &self, - _request: tonic::Request>, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) - } -} +use common::TestOpenShell; /// Test 7: Plaintext server (no TLS) accepts both gRPC and HTTP. /// diff --git a/crates/openshell-server/tests/common/mod.rs b/crates/openshell-server/tests/common/mod.rs new file mode 100644 index 000000000..178ea99ab --- /dev/null +++ b/crates/openshell-server/tests/common/mod.rs @@ -0,0 +1,612 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared helpers for openshell-server integration tests. +//! +//! Include with `mod common;` at the top of each integration test file. +//! Items may not be used by every test file; the blanket `#[allow]` prevents +//! spurious dead-code warnings. + +#![allow(dead_code)] + +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto::Builder, +}; +use openshell_core::proto::{ + CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, + DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + ExecSandboxEvent, ExecSandboxInput, ExecSandboxRequest, GatewayMessage, + GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + IssueSandboxTokenRequest, IssueSandboxTokenResponse, ListProvidersRequest, + ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, + RefreshSandboxTokenRequest, RefreshSandboxTokenResponse, RelayFrame, RevokeSshSessionRequest, + RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, + SupervisorMessage, TcpForwardFrame, UpdateProviderRequest, WatchSandboxRequest, + open_shell_server::{OpenShell, OpenShellServer}, +}; +use openshell_server::{MultiplexedService, TlsAcceptor, health_router}; +use rcgen::{CertificateParams, IsCa, KeyPair}; +use std::io::Write; +use std::net::SocketAddr; +use tempfile::tempdir; +use tokio::net::TcpListener; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Response, Status}; + +// --------------------------------------------------------------------------- +// Minimal OpenShell stub: all methods return defaults or Unimplemented. +// --------------------------------------------------------------------------- + +#[derive(Clone, Default)] +pub struct TestOpenShell; + +#[tonic::async_trait] +impl OpenShell for TestOpenShell { + async fn health( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(HealthResponse { + status: ServiceStatus::Healthy.into(), + version: "test".to_string(), + })) + } + + async fn create_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(SandboxResponse::default())) + } + + async fn get_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(SandboxResponse::default())) + } + + async fn list_sandboxes( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(ListSandboxesResponse::default())) + } + + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ListSandboxProvidersResponse::default(), + )) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::DetachSandboxProviderResponse::default(), + )) + } + + async fn delete_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(DeleteSandboxResponse { deleted: true })) + } + + async fn get_sandbox_config( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(GetSandboxConfigResponse::default())) + } + + async fn get_gateway_config( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(GetGatewayConfigResponse::default())) + } + + async fn get_sandbox_provider_environment( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + GetSandboxProviderEnvironmentResponse::default(), + )) + } + + async fn create_ssh_session( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(CreateSshSessionResponse::default())) + } + + async fn expose_service( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ServiceEndpointResponse::default(), + )) + } + + async fn get_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn list_services( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn revoke_ssh_session( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(RevokeSshSessionResponse::default())) + } + + async fn create_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "create_provider not implemented in test", + )) + } + + async fn get_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_provider not implemented in test", + )) + } + + async fn list_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "list_providers not implemented in test", + )) + } + + async fn list_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn get_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn import_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn delete_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn update_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "update_provider not implemented in test", + )) + } + + async fn get_provider_refresh_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn configure_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn rotate_provider_credential( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "delete_provider not implemented in test", + )) + } + + type WatchSandboxStream = ReceiverStream>; + type ExecSandboxStream = ReceiverStream>; + type ConnectSupervisorStream = ReceiverStream>; + + async fn watch_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + let (_tx, rx) = mpsc::channel(1); + Ok(Response::new(ReceiverStream::new(rx))) + } + + async fn exec_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + let (_tx, rx) = mpsc::channel(1); + Ok(Response::new(ReceiverStream::new(rx))) + } + + type ExecSandboxInteractiveStream = ReceiverStream>; + + async fn exec_sandbox_interactive( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + let (_tx, rx) = mpsc::channel(1); + Ok(Response::new(ReceiverStream::new(rx))) + } + + async fn update_config( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn get_sandbox_policy_status( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn list_sandbox_policies( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn report_policy_status( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn get_sandbox_logs( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn push_sandbox_logs( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn submit_policy_analysis( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn get_draft_policy( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn approve_draft_chunk( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn reject_draft_chunk( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn approve_all_draft_chunks( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn edit_draft_chunk( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn undo_draft_chunk( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn clear_draft_chunks( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn get_draft_history( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn issue_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn refresh_sandbox_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + type RelayStreamStream = ReceiverStream>; + + async fn relay_stream( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + type ForwardTcpStream = + std::pin::Pin> + Send>>; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } +} + +// --------------------------------------------------------------------------- +// TLS / PKI helpers (used by TLS integration tests) +// --------------------------------------------------------------------------- + +/// Initialise the rustls crypto provider (idempotent). +pub fn install_rustls_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); +} + +/// PKI bundle: CA cert, server cert+key, client cert+key (all PEM). +#[allow(clippy::struct_field_names)] +pub struct PkiBundle { + pub ca_cert_pem: Vec, + pub server_cert_pem: Vec, + pub server_key_pem: Vec, + pub client_cert_pem: Vec, + pub client_key_pem: Vec, +} + +/// Generate a full PKI: CA → server cert (for `localhost`) + client cert. +/// Returns a `TempDir` that must be kept alive while the paths are in use. +pub fn generate_pki() -> (tempfile::TempDir, PkiBundle) { + // Generate CA + let mut ca_params = + CertificateParams::new(Vec::::new()).expect("failed to create CA params"); + ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + ca_params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-ca"); + let ca_key = KeyPair::generate().expect("failed to generate CA key"); + let ca_cert = ca_params + .self_signed(&ca_key) + .expect("failed to sign CA cert"); + + // Generate server cert signed by CA + let server_params = CertificateParams::new(vec!["localhost".to_string()]) + .expect("failed to create server params"); + let server_key = KeyPair::generate().expect("failed to generate server key"); + let server_cert = server_params + .signed_by(&server_key, &ca_cert, &ca_key) + .expect("failed to sign server cert"); + + // Generate client cert signed by CA + let mut client_params = + CertificateParams::new(Vec::::new()).expect("failed to create client params"); + client_params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-client"); + let client_key = KeyPair::generate().expect("failed to generate client key"); + let client_cert = client_params + .signed_by(&client_key, &ca_cert, &ca_key) + .expect("failed to sign client cert"); + + let dir = tempdir().expect("failed to create tempdir"); + let write_file = |name: &str, data: &[u8]| { + let path = dir.path().join(name); + std::fs::File::create(&path) + .and_then(|mut f| f.write_all(data)) + .expect("failed to write file"); + }; + + write_file("ca.pem", ca_cert.pem().as_bytes()); + write_file("server-cert.pem", server_cert.pem().as_bytes()); + write_file("server-key.pem", server_key.serialize_pem().as_bytes()); + write_file("client-cert.pem", client_cert.pem().as_bytes()); + write_file("client-key.pem", client_key.serialize_pem().as_bytes()); + + let bundle = PkiBundle { + ca_cert_pem: ca_cert.pem().into_bytes(), + server_cert_pem: server_cert.pem().into_bytes(), + server_key_pem: server_key.serialize_pem().into_bytes(), + client_cert_pem: client_cert.pem().into_bytes(), + client_key_pem: client_key.serialize_pem().into_bytes(), + }; + + (dir, bundle) +} + +/// Start a TLS-wrapped test server using the given `TlsAcceptor`. +/// Returns the bound address and a task handle (abort to stop). +pub async fn start_test_server( + tls_acceptor: TlsAcceptor, +) -> (SocketAddr, tokio::task::JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let grpc_service = OpenShellServer::new(TestOpenShell); + let http_service = health_router(); + let service = MultiplexedService::new(grpc_service, http_service); + + let handle = tokio::spawn(async move { + loop { + let Ok((stream, _)) = listener.accept().await else { + continue; + }; + let svc = service.clone(); + let tls = tls_acceptor.clone(); + tokio::spawn(async move { + let Ok(tls_stream) = tls.inner().accept(stream).await else { + return; + }; + let _ = Builder::new(TokioExecutor::new()) + .serve_connection(TokioIo::new(tls_stream), svc) + .await; + }); + } + }); + + (addr, handle) +} + +/// Rogue PKI bundle: client cert + key not signed by the server's CA. +pub struct RoguePkiBundle { + pub client_cert_pem: String, + pub client_key_pem: String, +} + +/// Generate a rogue CA and a client certificate signed by that CA. +/// +/// Used to verify that the server rejects mTLS connections from clients whose +/// certificate chain does not trace back to the trusted CA. +pub fn generate_rogue_pki() -> RoguePkiBundle { + let mut rogue_ca_params = + CertificateParams::new(Vec::::new()).expect("failed to create rogue CA params"); + rogue_ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + rogue_ca_params + .distinguished_name + .push(rcgen::DnType::CommonName, "rogue-ca"); + let rogue_ca_key = KeyPair::generate().expect("failed to generate rogue CA key"); + let rogue_ca_cert = rogue_ca_params + .self_signed(&rogue_ca_key) + .expect("failed to sign rogue CA cert"); + + let mut rogue_client_params = + CertificateParams::new(Vec::::new()).expect("failed to create rogue client params"); + rogue_client_params + .distinguished_name + .push(rcgen::DnType::CommonName, "rogue-client"); + let rogue_client_key = KeyPair::generate().expect("failed to generate rogue client key"); + let rogue_client_cert = rogue_client_params + .signed_by(&rogue_client_key, &rogue_ca_cert, &rogue_ca_key) + .expect("failed to sign rogue client cert"); + + RoguePkiBundle { + client_cert_pem: rogue_client_cert.pem(), + client_key_pem: rogue_client_key.serialize_pem(), + } +} diff --git a/crates/openshell-server/tests/edge_tunnel_auth.rs b/crates/openshell-server/tests/edge_tunnel_auth.rs index 39df0819f..c49953b23 100644 --- a/crates/openshell-server/tests/edge_tunnel_auth.rs +++ b/crates/openshell-server/tests/edge_tunnel_auth.rs @@ -12,458 +12,40 @@ //! //! Test matrix: //! -//! | `allow_unauthenticated` | client cert | bearer auth header | expected | -//! |-----------------------|-------------|--------------------|----------| -//! | false | valid | — | OK | -//! | false | none | — | rejected | -//! | true | valid | — | OK | -//! | true | none | present | OK (*) | -//! | true | none | absent | OK (**) | +//! | `client_ca` | client cert | bearer header | expected | +//! |-------------|-------------|---------------|---------------------------| +//! | Some | valid | — | OK (cert validated) | +//! | Some | none | — | OK (cert optional) | +//! | Some | none | present | OK (bearer auth) | +//! | Some | rogue CA | — | rejected (bad cert) | +//! | None | none | — | OK (HTTPS-only) | //! -//! (*) Simulates the edge tunnel path: no client cert but a JWT header. -//! (**) TLS handshake succeeds, but in production the auth middleware (not yet -//! implemented) would reject. This test proves the TLS layer alone does -//! not block unauthenticated connections when the flag is set. +//! Client certificates are always optional when a CA is configured. They are +//! validated when present (rogue-CA certs are rejected) but never required. +//! Authentication is handled at the application layer (OIDC bearer tokens). + +mod common; use bytes::Bytes; +use common::{ + PkiBundle, generate_pki, generate_rogue_pki, install_rustls_provider, start_test_server, +}; use http_body_util::Empty; use hyper::{Request, StatusCode}; use hyper_rustls::HttpsConnectorBuilder; -use hyper_util::{ - client::legacy::Client, - rt::{TokioExecutor, TokioIo}, - server::conn::auto::Builder, -}; -use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, - GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, - open_shell_client::OpenShellClient, - open_shell_server::{OpenShell, OpenShellServer}, -}; -use openshell_server::{MultiplexedService, TlsAcceptor, health_router}; -use rcgen::{CertificateParams, IsCa, KeyPair}; +use hyper_util::{client::legacy::Client, rt::TokioExecutor}; +use openshell_core::proto::{HealthRequest, ServiceStatus, open_shell_client::OpenShellClient}; +use openshell_server::TlsAcceptor; use rustls::RootCertStore; use rustls::pki_types::CertificateDer; use rustls_pemfile::certs; -use std::io::Write; -use tempfile::tempdir; -use tokio::net::TcpListener; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; +use tonic::Status; use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; -use tonic::{Response, Status}; - -// --------------------------------------------------------------------------- -// Helpers -// --------------------------------------------------------------------------- - -fn install_rustls_provider() { - let _ = rustls::crypto::ring::default_provider().install_default(); -} - -/// Minimal `OpenShell` implementation for testing. -#[derive(Clone, Default)] -struct TestOpenShell; - -#[tonic::async_trait] -impl OpenShell for TestOpenShell { - async fn health( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(HealthResponse { - status: ServiceStatus::Healthy.into(), - version: "test".to_string(), - })) - } - - async fn create_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) - } - - async fn get_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) - } - - async fn list_sandboxes( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(ListSandboxesResponse::default())) - } - - async fn delete_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(DeleteSandboxResponse { deleted: true })) - } - - async fn get_sandbox_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(GetSandboxConfigResponse::default())) - } - - async fn get_gateway_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(GetGatewayConfigResponse::default())) - } - - async fn get_sandbox_provider_environment( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new( - GetSandboxProviderEnvironmentResponse::default(), - )) - } - - async fn create_ssh_session( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(CreateSshSessionResponse::default())) - } - - async fn revoke_ssh_session( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(RevokeSshSessionResponse::default())) - } - - async fn create_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn list_providers( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn list_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_provider_profile( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn import_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn lint_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn delete_provider_profile( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn update_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn delete_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - type WatchSandboxStream = ReceiverStream>; - type ExecSandboxStream = ReceiverStream>; - type ConnectSupervisorStream = ReceiverStream>; - - async fn watch_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - let (_tx, rx) = mpsc::channel(1); - Ok(Response::new(ReceiverStream::new(rx))) - } - async fn exec_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - let (_tx, rx) = mpsc::channel(1); - Ok(Response::new(ReceiverStream::new(rx))) - } - - async fn update_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_sandbox_policy_status( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn list_sandbox_policies( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn report_policy_status( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_sandbox_logs( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn push_sandbox_logs( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn submit_policy_analysis( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_draft_policy( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn approve_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn reject_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn approve_all_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn edit_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn undo_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn clear_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_draft_history( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn connect_supervisor( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - type RelayStreamStream = ReceiverStream>; - - async fn relay_stream( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } -} - -// --------------------------------------------------------------------------- -// PKI generation // --------------------------------------------------------------------------- - -#[allow(dead_code, clippy::struct_field_names)] -struct PkiBundle { - ca_cert_pem: Vec, - server_cert_pem: Vec, - server_key_pem: Vec, - client_cert_pem: Vec, - client_key_pem: Vec, -} - -fn generate_pki() -> (tempfile::TempDir, PkiBundle) { - let mut ca_params = - CertificateParams::new(Vec::::new()).expect("failed to create CA params"); - ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); - ca_params - .distinguished_name - .push(rcgen::DnType::CommonName, "test-ca"); - let ca_key = KeyPair::generate().expect("failed to generate CA key"); - let ca_cert = ca_params - .self_signed(&ca_key) - .expect("failed to sign CA cert"); - - let server_params = CertificateParams::new(vec!["localhost".to_string()]) - .expect("failed to create server params"); - let server_key = KeyPair::generate().expect("failed to generate server key"); - let server_cert = server_params - .signed_by(&server_key, &ca_cert, &ca_key) - .expect("failed to sign server cert"); - - let mut client_params = - CertificateParams::new(Vec::::new()).expect("failed to create client params"); - client_params - .distinguished_name - .push(rcgen::DnType::CommonName, "test-client"); - let client_key = KeyPair::generate().expect("failed to generate client key"); - let client_cert = client_params - .signed_by(&client_key, &ca_cert, &ca_key) - .expect("failed to sign client cert"); - - let dir = tempdir().expect("failed to create tempdir"); - let write_file = |name: &str, data: &[u8]| { - let path = dir.path().join(name); - std::fs::File::create(&path) - .and_then(|mut f| f.write_all(data)) - .expect("failed to write file"); - }; - - write_file("ca.pem", ca_cert.pem().as_bytes()); - write_file("server-cert.pem", server_cert.pem().as_bytes()); - write_file("server-key.pem", server_key.serialize_pem().as_bytes()); - write_file("client-cert.pem", client_cert.pem().as_bytes()); - write_file("client-key.pem", client_key.serialize_pem().as_bytes()); - - let bundle = PkiBundle { - ca_cert_pem: ca_cert.pem().into_bytes(), - server_cert_pem: server_cert.pem().into_bytes(), - server_key_pem: server_key.serialize_pem().into_bytes(), - client_cert_pem: client_cert.pem().into_bytes(), - client_key_pem: client_key.serialize_pem().into_bytes(), - }; - - (dir, bundle) -} - -// --------------------------------------------------------------------------- -// Server + client helpers +// Client helpers // --------------------------------------------------------------------------- -async fn start_test_server( - tls_acceptor: TlsAcceptor, -) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let grpc_service = OpenShellServer::new(TestOpenShell); - let http_service = health_router(); - let service = MultiplexedService::new(grpc_service, http_service); - - let handle = tokio::spawn(async move { - loop { - let Ok((stream, _)) = listener.accept().await else { - continue; - }; - let svc = service.clone(); - let tls = tls_acceptor.clone(); - tokio::spawn(async move { - let Ok(tls_stream) = tls.inner().accept(stream).await else { - return; - }; - let _ = Builder::new(TokioExecutor::new()) - .serve_connection(TokioIo::new(tls_stream), svc) - .await; - }); - } - }); - - (addr, handle) -} - /// Build a gRPC client with mTLS (CA + client cert). async fn grpc_client_mtls( addr: std::net::SocketAddr, @@ -608,17 +190,17 @@ fn https_client_no_cert( // Tests // =========================================================================== -/// Baseline: with `allow_unauthenticated=false` (default), mTLS connections work. +/// Valid client cert is accepted when a CA is configured. #[tokio::test] -async fn baseline_mtls_works_with_mandatory_client_certs() { +async fn mtls_valid_client_cert_accepted() { install_rustls_provider(); let (temp, pki) = generate_pki(); let tls_acceptor = TlsAcceptor::from_files( &temp.path().join("server-cert.pem"), &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), - false, // mandatory mTLS + Some(temp.path().join("ca.pem").as_path()), + false, ) .unwrap(); @@ -648,102 +230,18 @@ async fn baseline_mtls_works_with_mandatory_client_certs() { server.abort(); } -/// Baseline: with `allow_unauthenticated=false`, no-client-cert connections are -/// rejected at the TLS layer. -#[tokio::test] -async fn baseline_no_cert_rejected_with_mandatory_mtls() { - install_rustls_provider(); - let (temp, pki) = generate_pki(); - - let tls_acceptor = TlsAcceptor::from_files( - &temp.path().join("server-cert.pem"), - &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), - false, // mandatory mTLS - ) - .unwrap(); - - let (addr, server) = start_test_server(tls_acceptor).await; - - let ca_cert = tonic::transport::Certificate::from_pem(pki.ca_cert_pem.clone()); - let tls = ClientTlsConfig::new() - .ca_certificate(ca_cert) - .domain_name("localhost"); - let endpoint = Endpoint::from_shared(format!("https://localhost:{}", addr.port())) - .expect("invalid endpoint") - .tls_config(tls) - .expect("failed to set tls"); - - let result = endpoint.connect().await; - if let Ok(channel) = result { - let mut client = OpenShellClient::new(channel); - let rpc_result = client.health(HealthRequest {}).await; - assert!( - rpc_result.is_err(), - "expected RPC to fail without client cert when mTLS is mandatory" - ); - } - // If connect() itself failed, that's also correct — TLS handshake rejected. - - server.abort(); -} - -/// With `allow_unauthenticated=true`, mTLS connections still work (dual-auth). -#[tokio::test] -async fn dual_auth_mtls_still_accepted() { - install_rustls_provider(); - let (temp, pki) = generate_pki(); - - let tls_acceptor = TlsAcceptor::from_files( - &temp.path().join("server-cert.pem"), - &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), - true, // allow unauthenticated (tunnel mode) - ) - .unwrap(); - - let (addr, server) = start_test_server(tls_acceptor).await; - - // gRPC with mTLS should still work - let mut grpc = grpc_client_mtls( - addr, - pki.ca_cert_pem.clone(), - pki.client_cert_pem.clone(), - pki.client_key_pem.clone(), - ) - .await; - let resp = grpc.health(HealthRequest {}).await.unwrap(); - assert_eq!(resp.get_ref().status, ServiceStatus::Healthy as i32); - - // HTTP with mTLS should still work - let client = https_client_mtls(&pki); - let req = Request::builder() - .method("GET") - .uri(format!("https://localhost:{}/healthz", addr.port())) - .body(Empty::::new()) - .unwrap(); - let resp = client.request(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - server.abort(); -} - -/// With `allow_unauthenticated=true`, no-client-cert connections pass the TLS -/// handshake. This simulates Cloudflare Tunnel re-originating a connection. -/// -/// The gRPC health check succeeds because there is no auth middleware yet — -/// this proves the TLS layer is no longer the gate. When auth middleware is -/// added, the test should be updated to expect 401 without a valid JWT. +/// No client cert is accepted when a CA is configured — client certs are +/// always optional. Auth is deferred to the application layer. #[tokio::test] -async fn tunnel_mode_no_cert_passes_tls_handshake() { +async fn no_client_cert_accepted_with_ca_configured() { install_rustls_provider(); let (temp, pki) = generate_pki(); let tls_acceptor = TlsAcceptor::from_files( &temp.path().join("server-cert.pem"), &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), - true, // allow unauthenticated (tunnel mode) + Some(temp.path().join("ca.pem").as_path()), + false, ) .unwrap(); @@ -755,7 +253,7 @@ async fn tunnel_mode_no_cert_passes_tls_handshake() { assert_eq!( resp.get_ref().status, ServiceStatus::Healthy as i32, - "gRPC health check should succeed without client cert in tunnel mode" + "gRPC health check should succeed without client cert" ); // HTTP without client cert @@ -769,28 +267,24 @@ async fn tunnel_mode_no_cert_passes_tls_handshake() { assert_eq!( resp.status(), StatusCode::OK, - "HTTP health check should succeed without client cert in tunnel mode" + "HTTP health check should succeed without client cert" ); server.abort(); } -/// Simulate the steady-state Cloudflare tunnel flow: no client cert, but the -/// `cf-authorization` header carries a token. At the TLS level this must -/// succeed; the header is passed through to the gRPC handler. -/// -/// Note: We use a dummy token value here. When real JWT verification middleware -/// is added, this test should use a properly-signed test JWT. +/// Bearer auth header passes through to the gRPC handler when no client +/// cert is presented. #[tokio::test] -async fn tunnel_mode_cf_authorization_header_reaches_server() { +async fn bearer_header_reaches_server_without_client_cert() { install_rustls_provider(); let (temp, pki) = generate_pki(); let tls_acceptor = TlsAcceptor::from_files( &temp.path().join("server-cert.pem"), &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), - true, + Some(temp.path().join("ca.pem").as_path()), + false, ) .unwrap(); @@ -804,56 +298,35 @@ async fn tunnel_mode_cf_authorization_header_reaches_server() { assert_eq!( resp.get_ref().status, ServiceStatus::Healthy as i32, - "gRPC with cf-authorization header should succeed in tunnel mode" + "gRPC with bearer header should succeed without client cert" ); server.abort(); } -/// With `allow_unauthenticated=true`, a client cert from a rogue CA is still -/// rejected by the TLS layer — the verifier still validates presented certs. +/// A client cert from a rogue CA is rejected at the TLS layer even though +/// client certs are optional — presented certs are still validated. #[tokio::test] -async fn tunnel_mode_rogue_cert_still_rejected() { +async fn rogue_cert_rejected() { install_rustls_provider(); let (temp, pki) = generate_pki(); let tls_acceptor = TlsAcceptor::from_files( &temp.path().join("server-cert.pem"), &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), - true, + Some(temp.path().join("ca.pem").as_path()), + false, ) .unwrap(); let (addr, server) = start_test_server(tls_acceptor).await; // Generate a rogue CA + client cert - let mut rogue_ca_params = - CertificateParams::new(Vec::::new()).expect("failed to create rogue CA params"); - rogue_ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); - rogue_ca_params - .distinguished_name - .push(rcgen::DnType::CommonName, "rogue-ca"); - let rogue_ca_key = KeyPair::generate().expect("failed to generate rogue CA key"); - let rogue_ca_cert = rogue_ca_params - .self_signed(&rogue_ca_key) - .expect("failed to sign rogue CA cert"); - - let mut rogue_client_params = - CertificateParams::new(Vec::::new()).expect("failed to create rogue client params"); - rogue_client_params - .distinguished_name - .push(rcgen::DnType::CommonName, "rogue-client"); - let rogue_client_key = KeyPair::generate().expect("failed to generate rogue client key"); - let rogue_client_cert = rogue_client_params - .signed_by(&rogue_client_key, &rogue_ca_cert, &rogue_ca_key) - .expect("failed to sign rogue client cert"); + let rogue = generate_rogue_pki(); let ca_cert = tonic::transport::Certificate::from_pem(pki.ca_cert_pem.clone()); - let identity = tonic::transport::Identity::from_pem( - rogue_client_cert.pem(), - rogue_client_key.serialize_pem(), - ); + let identity = + tonic::transport::Identity::from_pem(rogue.client_cert_pem, rogue.client_key_pem); let tls = ClientTlsConfig::new() .ca_certificate(ca_cert) .identity(identity) @@ -869,10 +342,53 @@ async fn tunnel_mode_rogue_cert_still_rejected() { let rpc_result = client.health(HealthRequest {}).await; assert!( rpc_result.is_err(), - "expected RPC to fail with rogue client cert even in tunnel mode" + "expected RPC to fail with rogue client cert" ); } // If connect() itself failed, that's also correct. server.abort(); } + +/// HTTPS-only mode: no client CA configured, so the server never requests +/// client certificates. Clients connect with server-only TLS. +#[tokio::test] +async fn https_only_no_client_cert_required() { + install_rustls_provider(); + let (temp, pki) = generate_pki(); + + let tls_acceptor = TlsAcceptor::from_files( + &temp.path().join("server-cert.pem"), + &temp.path().join("server-key.pem"), + None, + false, + ) + .unwrap(); + + let (addr, server) = start_test_server(tls_acceptor).await; + + // gRPC without client cert — should succeed (no client certs requested) + let mut grpc = grpc_client_no_cert(addr, pki.ca_cert_pem.clone()).await; + let resp = grpc.health(HealthRequest {}).await.unwrap(); + assert_eq!( + resp.get_ref().status, + ServiceStatus::Healthy as i32, + "gRPC health check should succeed in HTTPS-only mode" + ); + + // HTTP without client cert + let client = https_client_no_cert(&pki.ca_cert_pem); + let req = Request::builder() + .method("GET") + .uri(format!("https://localhost:{}/healthz", addr.port())) + .body(Empty::::new()) + .unwrap(); + let resp = client.request(req).await.unwrap(); + assert_eq!( + resp.status(), + StatusCode::OK, + "HTTP health check should succeed in HTTPS-only mode" + ); + + server.abort(); +} diff --git a/crates/openshell-server/tests/multiplex_integration.rs b/crates/openshell-server/tests/multiplex_integration.rs index 49c6f9c92..9ca1ee3ee 100644 --- a/crates/openshell-server/tests/multiplex_integration.rs +++ b/crates/openshell-server/tests/multiplex_integration.rs @@ -1,7 +1,10 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +mod common; + use bytes::Bytes; +use common::TestOpenShell; use http_body_util::Empty; use hyper::{Request, StatusCode}; use hyper_util::{ @@ -9,326 +12,11 @@ use hyper_util::{ server::conn::auto::Builder, }; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, - GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, - open_shell_client::OpenShellClient, - open_shell_server::{OpenShell, OpenShellServer}, + HealthRequest, ServiceStatus, open_shell_client::OpenShellClient, + open_shell_server::OpenShellServer, }; use openshell_server::{MultiplexedService, health_router}; use tokio::net::TcpListener; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; -use tonic::{Response, Status}; - -#[derive(Clone, Default)] -struct TestOpenShell; - -#[tonic::async_trait] -impl OpenShell for TestOpenShell { - async fn health( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(HealthResponse { - status: ServiceStatus::Healthy.into(), - version: "test".to_string(), - })) - } - - async fn create_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) - } - - async fn get_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) - } - - async fn list_sandboxes( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(ListSandboxesResponse::default())) - } - - async fn delete_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(DeleteSandboxResponse { deleted: true })) - } - - async fn get_sandbox_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(GetSandboxConfigResponse::default())) - } - - async fn get_gateway_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(GetGatewayConfigResponse::default())) - } - - async fn get_sandbox_provider_environment( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new( - GetSandboxProviderEnvironmentResponse::default(), - )) - } - - async fn create_ssh_session( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(CreateSshSessionResponse::default())) - } - - async fn revoke_ssh_session( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(RevokeSshSessionResponse::default())) - } - - async fn create_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "create_provider not implemented in test", - )) - } - - async fn get_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_provider not implemented in test", - )) - } - - async fn list_providers( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "list_providers not implemented in test", - )) - } - - async fn list_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_provider_profile( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn import_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn lint_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn delete_provider_profile( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn update_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "update_provider not implemented in test", - )) - } - - async fn delete_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "delete_provider not implemented in test", - )) - } - - type WatchSandboxStream = ReceiverStream>; - type ExecSandboxStream = ReceiverStream>; - type ConnectSupervisorStream = ReceiverStream>; - - async fn watch_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - let (_tx, rx) = mpsc::channel(1); - Ok(Response::new(ReceiverStream::new(rx))) - } - - async fn exec_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - let (_tx, rx) = mpsc::channel(1); - Ok(Response::new(ReceiverStream::new(rx))) - } - - async fn update_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_sandbox_policy_status( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn list_sandbox_policies( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn report_policy_status( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_sandbox_logs( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn push_sandbox_logs( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn submit_policy_analysis( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_draft_policy( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn approve_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn reject_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn approve_all_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn edit_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn undo_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn clear_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_draft_history( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn connect_supervisor( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - type RelayStreamStream = ReceiverStream>; - - async fn relay_stream( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } -} #[tokio::test] async fn serves_grpc_and_http_on_same_port() { diff --git a/crates/openshell-server/tests/multiplex_tls_integration.rs b/crates/openshell-server/tests/multiplex_tls_integration.rs index d6a244e49..31ece9ed6 100644 --- a/crates/openshell-server/tests/multiplex_tls_integration.rs +++ b/crates/openshell-server/tests/multiplex_tls_integration.rs @@ -1,447 +1,33 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +mod common; + use bytes::Bytes; +use common::{ + PkiBundle, generate_pki, generate_rogue_pki, install_rustls_provider, start_test_server, +}; use http_body_util::Empty; -use hyper::{Request, StatusCode}; +use hyper::Request; +use hyper::StatusCode; use hyper_rustls::HttpsConnectorBuilder; -use hyper_util::{ - client::legacy::Client, - rt::{TokioExecutor, TokioIo}, - server::conn::auto::Builder, -}; -use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, - GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, - open_shell_client::OpenShellClient, - open_shell_server::{OpenShell, OpenShellServer}, -}; -use openshell_server::{MultiplexedService, TlsAcceptor, health_router}; -use rcgen::{CertificateParams, IsCa, KeyPair}; +use hyper_util::{client::legacy::Client, rt::TokioExecutor}; +use openshell_core::proto::{HealthRequest, ServiceStatus, open_shell_client::OpenShellClient}; use rustls::RootCertStore; use rustls::pki_types::CertificateDer; use rustls_pemfile::certs; -use std::io::Write; -use tempfile::tempdir; -use tokio::net::TcpListener; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; -use tonic::{Response, Status}; - -fn install_rustls_provider() { - let _ = rustls::crypto::ring::default_provider().install_default(); -} - -#[derive(Clone, Default)] -struct TestOpenShell; - -#[tonic::async_trait] -impl OpenShell for TestOpenShell { - async fn health( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(HealthResponse { - status: ServiceStatus::Healthy.into(), - version: "test".to_string(), - })) - } - - async fn create_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) - } - - async fn get_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) - } - - async fn list_sandboxes( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(ListSandboxesResponse::default())) - } - - async fn delete_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(DeleteSandboxResponse { deleted: true })) - } - - async fn get_sandbox_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(GetSandboxConfigResponse::default())) - } - - async fn get_gateway_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(GetGatewayConfigResponse::default())) - } - - async fn get_sandbox_provider_environment( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new( - GetSandboxProviderEnvironmentResponse::default(), - )) - } - - async fn create_ssh_session( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(CreateSshSessionResponse::default())) - } - - async fn revoke_ssh_session( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(RevokeSshSessionResponse::default())) - } - - async fn create_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "create_provider not implemented in test", - )) - } - - async fn get_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_provider not implemented in test", - )) - } - - async fn list_providers( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "list_providers not implemented in test", - )) - } - - async fn list_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_provider_profile( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn import_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn lint_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn delete_provider_profile( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn update_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "update_provider not implemented in test", - )) - } - - async fn delete_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "delete_provider not implemented in test", - )) - } - - type WatchSandboxStream = ReceiverStream>; - type ExecSandboxStream = ReceiverStream>; - type ConnectSupervisorStream = ReceiverStream>; - - async fn watch_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - let (_tx, rx) = mpsc::channel(1); - Ok(Response::new(ReceiverStream::new(rx))) - } - - async fn exec_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - let (_tx, rx) = mpsc::channel(1); - Ok(Response::new(ReceiverStream::new(rx))) - } - - async fn update_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_sandbox_policy_status( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn list_sandbox_policies( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn report_policy_status( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_sandbox_logs( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn push_sandbox_logs( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn submit_policy_analysis( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_draft_policy( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn approve_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn reject_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - async fn approve_all_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn edit_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn undo_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn clear_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_draft_history( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn connect_supervisor( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - type RelayStreamStream = ReceiverStream>; - - async fn relay_stream( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) +fn build_tls_root(cert_pem: &[u8]) -> RootCertStore { + let mut roots = RootCertStore::empty(); + let mut cursor = std::io::Cursor::new(cert_pem); + let parsed = certs(&mut cursor) + .collect::>, _>>() + .expect("failed to parse cert pem"); + for cert in parsed { + roots.add(cert).expect("failed to add cert"); } -} - -/// PKI bundle: CA cert, server cert+key, client cert+key. -#[allow(dead_code, clippy::struct_field_names)] -struct PkiBundle { - ca_cert_pem: Vec, - server_cert_pem: Vec, - server_key_pem: Vec, - client_cert_pem: Vec, - client_key_pem: Vec, -} - -/// Generate a full PKI: CA -> server cert (for localhost) + client cert. -fn generate_pki() -> (tempfile::TempDir, PkiBundle) { - // Generate CA - let mut ca_params = - CertificateParams::new(Vec::::new()).expect("failed to create CA params"); - ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); - ca_params - .distinguished_name - .push(rcgen::DnType::CommonName, "test-ca"); - let ca_key = KeyPair::generate().expect("failed to generate CA key"); - let ca_cert = ca_params - .self_signed(&ca_key) - .expect("failed to sign CA cert"); - - // Generate server cert signed by CA - let server_params = CertificateParams::new(vec!["localhost".to_string()]) - .expect("failed to create server params"); - let server_key = KeyPair::generate().expect("failed to generate server key"); - let server_cert = server_params - .signed_by(&server_key, &ca_cert, &ca_key) - .expect("failed to sign server cert"); - - // Generate client cert signed by CA - let mut client_params = - CertificateParams::new(Vec::::new()).expect("failed to create client params"); - client_params - .distinguished_name - .push(rcgen::DnType::CommonName, "test-client"); - let client_key = KeyPair::generate().expect("failed to generate client key"); - let client_cert = client_params - .signed_by(&client_key, &ca_cert, &ca_key) - .expect("failed to sign client cert"); - - let dir = tempdir().expect("failed to create tempdir"); - let write = |name: &str, data: &[u8]| { - let path = dir.path().join(name); - std::fs::File::create(&path) - .and_then(|mut f| f.write_all(data)) - .expect("failed to write file"); - }; - - write("ca.pem", ca_cert.pem().as_bytes()); - write("server-cert.pem", server_cert.pem().as_bytes()); - write("server-key.pem", server_key.serialize_pem().as_bytes()); - write("client-cert.pem", client_cert.pem().as_bytes()); - write("client-key.pem", client_key.serialize_pem().as_bytes()); - - let bundle = PkiBundle { - ca_cert_pem: ca_cert.pem().into_bytes(), - server_cert_pem: server_cert.pem().into_bytes(), - server_key_pem: server_key.serialize_pem().into_bytes(), - client_cert_pem: client_cert.pem().into_bytes(), - client_key_pem: client_key.serialize_pem().into_bytes(), - }; - - (dir, bundle) -} - -/// Start a test server with the given TLS acceptor, returning its address and -/// a handle that aborts the server on drop. -async fn start_test_server( - tls_acceptor: TlsAcceptor, -) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let grpc_service = OpenShellServer::new(TestOpenShell); - let http_service = health_router(); - let service = MultiplexedService::new(grpc_service, http_service); - - let handle = tokio::spawn(async move { - loop { - let Ok((stream, _)) = listener.accept().await else { - continue; - }; - let svc = service.clone(); - let tls = tls_acceptor.clone(); - tokio::spawn(async move { - let Ok(tls_stream) = tls.inner().accept(stream).await else { - return; - }; - let _ = Builder::new(TokioExecutor::new()) - .serve_connection(TokioIo::new(tls_stream), svc) - .await; - }); - } - }); - - (addr, handle) + roots } /// Build a gRPC client with mTLS (CA + client cert). @@ -465,18 +51,6 @@ async fn grpc_client_mtls( OpenShellClient::new(channel) } -fn build_tls_root(cert_pem: &[u8]) -> RootCertStore { - let mut roots = RootCertStore::empty(); - let mut cursor = std::io::Cursor::new(cert_pem); - let parsed = certs(&mut cursor) - .collect::>, _>>() - .expect("failed to parse cert pem"); - for cert in parsed { - roots.add(cert).expect("failed to add cert"); - } - roots -} - /// Build an HTTPS client with mTLS (CA trust + client cert/key). fn https_client_mtls( pki: &PkiBundle, @@ -516,10 +90,10 @@ async fn serves_grpc_and_http_over_tls_on_same_port() { install_rustls_provider(); let (temp, pki) = generate_pki(); - let tls_acceptor = TlsAcceptor::from_files( + let tls_acceptor = openshell_server::TlsAcceptor::from_files( &temp.path().join("server-cert.pem"), &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), + Some(temp.path().join("ca.pem").as_path()), false, ) .unwrap(); @@ -555,10 +129,10 @@ async fn mtls_valid_client_cert_accepted() { install_rustls_provider(); let (temp, pki) = generate_pki(); - let tls_acceptor = TlsAcceptor::from_files( + let tls_acceptor = openshell_server::TlsAcceptor::from_files( &temp.path().join("server-cert.pem"), &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), + Some(temp.path().join("ca.pem").as_path()), false, ) .unwrap(); @@ -579,21 +153,56 @@ async fn mtls_valid_client_cert_accepted() { } #[tokio::test] -async fn mtls_no_client_cert_rejected() { +async fn no_client_cert_accepted_with_ca() { install_rustls_provider(); let (temp, pki) = generate_pki(); - let tls_acceptor = TlsAcceptor::from_files( + let tls_acceptor = openshell_server::TlsAcceptor::from_files( &temp.path().join("server-cert.pem"), &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), + Some(temp.path().join("ca.pem").as_path()), false, ) .unwrap(); let (addr, server) = start_test_server(tls_acceptor).await; - // Connect with CA trust but no client cert -- should be rejected. + // Connect with CA trust but no client cert — should succeed (certs are optional). + let ca_cert = tonic::transport::Certificate::from_pem(pki.ca_cert_pem.clone()); + let tls = ClientTlsConfig::new() + .ca_certificate(ca_cert) + .domain_name("localhost"); + let endpoint = Endpoint::from_shared(format!("https://localhost:{}", addr.port())) + .expect("invalid endpoint") + .tls_config(tls) + .expect("failed to set tls"); + + let channel = endpoint + .connect() + .await + .expect("should connect without client cert"); + let mut client = OpenShellClient::new(channel); + let response = client.health(HealthRequest {}).await.unwrap(); + assert_eq!(response.get_ref().status, ServiceStatus::Healthy as i32); + + server.abort(); +} + +#[tokio::test] +async fn no_client_cert_rejected_when_required() { + install_rustls_provider(); + let (temp, pki) = generate_pki(); + + let tls_acceptor = openshell_server::TlsAcceptor::from_files( + &temp.path().join("server-cert.pem"), + &temp.path().join("server-key.pem"), + Some(temp.path().join("ca.pem").as_path()), + true, + ) + .unwrap(); + + let (addr, server) = start_test_server(tls_acceptor).await; + let ca_cert = tonic::transport::Certificate::from_pem(pki.ca_cert_pem.clone()); let tls = ClientTlsConfig::new() .ca_certificate(ca_cert) @@ -604,14 +213,12 @@ async fn mtls_no_client_cert_rejected() { .expect("failed to set tls"); let result = endpoint.connect().await; - // Connection should fail at the TLS handshake level or shortly after. - // The exact error depends on timing -- it may fail on connect or on first RPC. if let Ok(channel) = result { let mut client = OpenShellClient::new(channel); let rpc_result = client.health(HealthRequest {}).await; assert!( rpc_result.is_err(), - "expected RPC to fail without client cert" + "expected RPC to fail without client cert when mTLS is required" ); } @@ -623,10 +230,10 @@ async fn mtls_wrong_ca_client_cert_rejected() { install_rustls_provider(); let (temp, pki) = generate_pki(); - let tls_acceptor = TlsAcceptor::from_files( + let tls_acceptor = openshell_server::TlsAcceptor::from_files( &temp.path().join("server-cert.pem"), &temp.path().join("server-key.pem"), - &temp.path().join("ca.pem"), + Some(temp.path().join("ca.pem").as_path()), false, ) .unwrap(); @@ -634,33 +241,12 @@ async fn mtls_wrong_ca_client_cert_rejected() { let (addr, server) = start_test_server(tls_acceptor).await; // Generate a rogue CA + client cert not signed by the server's CA. - let mut rogue_ca_params = - CertificateParams::new(Vec::::new()).expect("failed to create rogue CA params"); - rogue_ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); - rogue_ca_params - .distinguished_name - .push(rcgen::DnType::CommonName, "rogue-ca"); - let rogue_ca_key = KeyPair::generate().expect("failed to generate rogue CA key"); - let rogue_ca_cert = rogue_ca_params - .self_signed(&rogue_ca_key) - .expect("failed to sign rogue CA cert"); - - let mut rogue_client_params = - CertificateParams::new(Vec::::new()).expect("failed to create rogue client params"); - rogue_client_params - .distinguished_name - .push(rcgen::DnType::CommonName, "rogue-client"); - let rogue_client_key = KeyPair::generate().expect("failed to generate rogue client key"); - let rogue_client_cert = rogue_client_params - .signed_by(&rogue_client_key, &rogue_ca_cert, &rogue_ca_key) - .expect("failed to sign rogue client cert"); + let rogue = generate_rogue_pki(); // Connect with rogue client cert -- server should reject it. let ca_cert = tonic::transport::Certificate::from_pem(pki.ca_cert_pem.clone()); - let identity = tonic::transport::Identity::from_pem( - rogue_client_cert.pem(), - rogue_client_key.serialize_pem(), - ); + let identity = + tonic::transport::Identity::from_pem(rogue.client_cert_pem, rogue.client_key_pem); let tls = ClientTlsConfig::new() .ca_certificate(ca_cert) .identity(identity) diff --git a/crates/openshell-server/tests/supervisor_relay_integration.rs b/crates/openshell-server/tests/supervisor_relay_integration.rs index f8519cdc7..7f6bab7e9 100644 --- a/crates/openshell-server/tests/supervisor_relay_integration.rs +++ b/crates/openshell-server/tests/supervisor_relay_integration.rs @@ -23,7 +23,7 @@ use hyper_util::{ server::conn::auto::Builder, }; use openshell_core::proto::{ - GatewayMessage, RelayFrame, RelayInit, SupervisorMessage, + GatewayMessage, RelayFrame, RelayInit, SupervisorMessage, TcpForwardFrame, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -87,6 +87,24 @@ impl OpenShell for RelayGateway { Err(Status::unimplemented("unused")) } + type ForwardTcpStream = + std::pin::Pin> + Send>>; + async fn forward_tcp( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + type ExecSandboxInteractiveStream = + ReceiverStream>; + async fn exec_sandbox_interactive( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn health( &self, _: tonic::Request, @@ -111,6 +129,24 @@ impl OpenShell for RelayGateway { ) -> Result, Status> { Err(Status::unimplemented("unused")) } + async fn list_sandbox_providers( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn attach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn detach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } async fn delete_sandbox( &self, _: tonic::Request, @@ -142,6 +178,33 @@ impl OpenShell for RelayGateway { ) -> Result, Status> { Err(Status::unimplemented("unused")) } + async fn expose_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn get_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn list_services( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn revoke_ssh_session( &self, _: tonic::Request, @@ -207,6 +270,33 @@ impl OpenShell for RelayGateway { ) -> Result, Status> { Err(Status::unimplemented("unused")) } + async fn get_provider_refresh_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn configure_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn rotate_provider_credential( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } async fn delete_provider( &self, @@ -304,6 +394,18 @@ impl OpenShell for RelayGateway { ) -> Result, Status> { Err(Status::unimplemented("unused")) } + async fn issue_sandbox_token( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn refresh_sandbox_token( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } } // --------------------------------------------------------------------------- @@ -421,7 +523,7 @@ async fn relay_round_trips_bytes() { tokio::spawn(run_echo_supervisor(channel, channel_id)); - let relay = relay_rx.await.expect("relay duplex"); + let relay = relay_rx.await.expect("relay result").expect("relay duplex"); let (mut read_half, mut write_half) = tokio::io::split(relay); write_half.write_all(b"hello relay").await.expect("write"); @@ -446,7 +548,7 @@ async fn relay_closes_cleanly_when_gateway_drops() { let supervisor = tokio::spawn(run_echo_supervisor(channel, channel_id)); - let relay = relay_rx.await.expect("relay duplex"); + let relay = relay_rx.await.expect("relay result").expect("relay duplex"); drop(relay); // The supervisor's inbound stream should terminate shortly after the @@ -491,7 +593,7 @@ async fn relay_sees_eof_when_supervisor_closes() { }) }; - let relay = relay_rx.await.expect("relay duplex"); + let relay = relay_rx.await.expect("relay result").expect("relay duplex"); let (mut read_half, _write_half) = tokio::io::split(relay); let mut buf = [0u8; 16]; let n = tokio::time::timeout(Duration::from_secs(5), read_half.read(&mut buf)) @@ -537,8 +639,8 @@ async fn concurrent_relays_multiplex_independently() { tokio::spawn(run_echo_supervisor(channel.clone(), id_a)); tokio::spawn(run_echo_supervisor(channel, id_b)); - let relay_a = rx_a.await.expect("relay a"); - let relay_b = rx_b.await.expect("relay b"); + let relay_a = rx_a.await.expect("relay a result").expect("relay a"); + let relay_b = rx_b.await.expect("relay b result").expect("relay b"); let (mut ra, mut wa) = tokio::io::split(relay_a); let (mut rb, mut wb) = tokio::io::split(relay_b); diff --git a/crates/openshell-server/tests/ws_tunnel_integration.rs b/crates/openshell-server/tests/ws_tunnel_integration.rs index 173a7225d..ee253e9dd 100644 --- a/crates/openshell-server/tests/ws_tunnel_integration.rs +++ b/crates/openshell-server/tests/ws_tunnel_integration.rs @@ -23,6 +23,8 @@ //! The WS tunnel handler is kept standalone so it stays isolated from the full //! `ServerState` dependency while still matching the production bridge logic. +mod common; + use axum::{ Router, extract::{State, WebSocketUpgrade, ws::Message}, @@ -30,6 +32,7 @@ use axum::{ routing::get, }; use bytes::Bytes; +use common::TestOpenShell; use futures_util::{SinkExt, StreamExt}; use http_body_util::Empty; use hyper::{Request, StatusCode}; @@ -38,323 +41,14 @@ use hyper_util::{ server::conn::auto::Builder, }; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, - GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, - open_shell_client::OpenShellClient, - open_shell_server::{OpenShell, OpenShellServer}, + HealthRequest, ServiceStatus, open_shell_client::OpenShellClient, + open_shell_server::OpenShellServer, }; use openshell_server::{MultiplexedService, health_router}; use std::net::SocketAddr; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; use tokio_tungstenite::tungstenite; -use tonic::{Response, Status}; - -// --------------------------------------------------------------------------- -// Minimal OpenShell implementation (same as other integration tests) -// --------------------------------------------------------------------------- - -#[derive(Clone, Default)] -struct TestOpenShell; - -#[tonic::async_trait] -impl OpenShell for TestOpenShell { - async fn health( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(HealthResponse { - status: ServiceStatus::Healthy.into(), - version: "test".to_string(), - })) - } - - async fn create_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) - } - - async fn get_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) - } - - async fn list_sandboxes( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(ListSandboxesResponse::default())) - } - - async fn delete_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(DeleteSandboxResponse { deleted: true })) - } - - async fn get_sandbox_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(GetSandboxConfigResponse::default())) - } - - async fn get_gateway_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(GetGatewayConfigResponse::default())) - } - - async fn get_sandbox_provider_environment( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new( - GetSandboxProviderEnvironmentResponse::default(), - )) - } - - async fn create_ssh_session( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(CreateSshSessionResponse::default())) - } - - async fn revoke_ssh_session( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Ok(Response::new(RevokeSshSessionResponse::default())) - } - - async fn create_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn list_providers( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn list_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_provider_profile( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn import_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn lint_provider_profiles( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn delete_provider_profile( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn update_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn delete_provider( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - type WatchSandboxStream = ReceiverStream>; - type ExecSandboxStream = ReceiverStream>; - type ConnectSupervisorStream = ReceiverStream>; - - async fn watch_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - let (_tx, rx) = mpsc::channel(1); - Ok(Response::new(ReceiverStream::new(rx))) - } - - async fn exec_sandbox( - &self, - _request: tonic::Request, - ) -> Result, Status> { - let (_tx, rx) = mpsc::channel(1); - Ok(Response::new(ReceiverStream::new(rx))) - } - - async fn update_config( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_sandbox_policy_status( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn list_sandbox_policies( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn report_policy_status( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_sandbox_logs( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn push_sandbox_logs( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn submit_policy_analysis( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_draft_policy( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn approve_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn reject_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn approve_all_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn edit_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn undo_draft_chunk( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn clear_draft_chunks( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn get_draft_history( - &self, - _request: tonic::Request, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - async fn connect_supervisor( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } - - type RelayStreamStream = ReceiverStream>; - - async fn relay_stream( - &self, - _request: tonic::Request>, - ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) - } -} // --------------------------------------------------------------------------- // Test WS tunnel handler (standalone, no ServerState dependency) diff --git a/crates/openshell-tui/src/app.rs b/crates/openshell-tui/src/app.rs index 1cab7127c..ba817bcf8 100644 --- a/crates/openshell-tui/src/app.rs +++ b/crates/openshell-tui/src/app.rs @@ -5,9 +5,11 @@ use std::collections::HashMap; use std::time::{Duration, Instant}; use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; +use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::proto::open_shell_client::OpenShellClient; use openshell_core::proto::setting_value; use openshell_core::settings::{self, SettingValueKind}; +use tonic::service::interceptor::InterceptedService; use tonic::transport::Channel; // --------------------------------------------------------------------------- @@ -413,7 +415,7 @@ pub struct App { // Active gateway connection pub gateway_name: String, pub endpoint: String, - pub client: OpenShellClient, + pub client: OpenShellClient>, pub status_text: String, // Gateway list @@ -580,7 +582,7 @@ pub fn format_labels(labels: &HashMap) -> String { impl App { #[allow(clippy::large_types_passed_by_value)] // Theme is Copy; one-shot ctor pub fn new( - client: OpenShellClient, + client: OpenShellClient>, gateway_name: String, endpoint: String, theme: crate::theme::Theme, diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index 8571ebbe1..1969715ce 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -18,6 +18,7 @@ use crossterm::terminal::{ EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode, }; use miette::{IntoDiagnostic, Result}; +use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::metadata::{ObjectId, ObjectLabels, ObjectName}; use openshell_core::proto::open_shell_client::OpenShellClient; use ratatui::Terminal; @@ -41,6 +42,7 @@ pub use theme::ThemeMode; /// background, `Dark`/`Light` forces a specific palette. pub async fn run( channel: Channel, + interceptor: EdgeAuthInterceptor, gateway_name: &str, endpoint: &str, theme_mode: ThemeMode, @@ -50,7 +52,7 @@ pub async fn run( // after our own enable_raw_mode() would conflict. let detected_theme = theme::detect(theme_mode); - let client = OpenShellClient::new(channel); + let client = OpenShellClient::with_interceptor(channel, interceptor); let mut app = App::new( client, gateway_name.to_string(), @@ -491,8 +493,8 @@ async fn handle_gateway_switch(app: &mut App) { }; match connect_to_gateway(&name, &endpoint).await { - Ok(channel) => { - app.client = OpenShellClient::new(channel); + Ok((channel, interceptor)) => { + app.client = OpenShellClient::with_interceptor(channel, interceptor); app.gateway_name = name; app.endpoint = endpoint; app.reset_sandbox_state(); @@ -505,8 +507,76 @@ async fn handle_gateway_switch(app: &mut App) { } } -/// Build a gRPC channel to a gateway using its mTLS certs on disk. -async fn connect_to_gateway(name: &str, endpoint: &str) -> Result { +/// Build a gRPC channel and auth interceptor for a gateway. +/// +/// Checks gateway metadata for the auth mode and loads the appropriate +/// credentials (mTLS certs or OIDC bearer token). +async fn connect_to_gateway(name: &str, endpoint: &str) -> Result<(Channel, EdgeAuthInterceptor)> { + let meta = openshell_bootstrap::get_gateway_metadata(name); + + if meta.as_ref().and_then(|m| m.auth_mode.as_deref()) == Some("oidc") { + let bundle = openshell_bootstrap::oidc_token::load_oidc_token(name).ok_or_else(|| { + miette::miette!( + "No OIDC token for gateway '{name}'.\n\ + Authenticate with: openshell gateway login" + ) + })?; + if openshell_bootstrap::oidc_token::is_token_expired(&bundle) { + miette::bail!( + "OIDC token for gateway '{name}' has expired.\n\ + Re-authenticate with: openshell gateway login" + ); + } + let interceptor = EdgeAuthInterceptor::new(Some(&bundle.access_token), None)?; + let channel = build_oidc_channel(name, endpoint).await?; + Ok((channel, interceptor)) + } else { + let channel = build_mtls_channel(name, endpoint).await?; + Ok((channel, EdgeAuthInterceptor::noop())) + } +} + +/// Build an HTTPS channel for OIDC-authenticated gateways. +/// +/// Tries mTLS client certs for the transport layer when available (the server +/// may still require them alongside the bearer token), falls back to CA-only +/// or system roots. +async fn build_oidc_channel(name: &str, endpoint: &str) -> Result { + let mtls_dir = gateway_mtls_dir(name); + + let tls_config = mtls_dir.as_ref().map_or_else( + || ClientTlsConfig::new().with_enabled_roots(), + |dir| { + let ca = std::fs::read(dir.join("ca.crt")).ok(); + let cert = std::fs::read(dir.join("tls.crt")).ok(); + let key = std::fs::read(dir.join("tls.key")).ok(); + + match (ca, cert, key) { + (Some(ca), Some(cert), Some(key)) => ClientTlsConfig::new() + .ca_certificate(Certificate::from_pem(ca)) + .identity(Identity::from_pem(cert, key)), + (Some(ca), _, _) => { + ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca)) + } + _ => ClientTlsConfig::new().with_enabled_roots(), + } + }, + ); + + Endpoint::from_shared(endpoint.to_string()) + .into_diagnostic()? + .connect_timeout(Duration::from_secs(10)) + .http2_keep_alive_interval(Duration::from_secs(10)) + .keep_alive_while_idle(true) + .tls_config(tls_config) + .into_diagnostic()? + .connect() + .await + .into_diagnostic() +} + +/// Build a gRPC channel using mTLS client certificates. +async fn build_mtls_channel(name: &str, endpoint: &str) -> Result { let mtls_dir = gateway_mtls_dir(name) .ok_or_else(|| miette::miette!("cannot determine config directory for gateway {name}"))?; @@ -524,7 +594,7 @@ async fn connect_to_gateway(name: &str, endpoint: &str) -> Result { .ca_certificate(Certificate::from_pem(ca)) .identity(Identity::from_pem(cert, key)); - let channel = Endpoint::from_shared(endpoint.to_string()) + Endpoint::from_shared(endpoint.to_string()) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) .http2_keep_alive_interval(Duration::from_secs(10)) @@ -533,9 +603,7 @@ async fn connect_to_gateway(name: &str, endpoint: &str) -> Result { .into_diagnostic()? .connect() .await - .into_diagnostic()?; - - Ok(channel) + .into_diagnostic() } /// Resolve the mTLS cert directory for a gateway. @@ -839,10 +907,7 @@ async fn handle_shell_connect( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, &app.endpoint); - let gateway_url = format!( - "{}://{}:{gateway_port}{}", - session.gateway_scheme, gateway_host, session.connect_path - ); + let gateway_url = format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port); // Step 4: Build the ProxyCommand using our own binary. let exe = match std::env::current_exe() { @@ -988,10 +1053,7 @@ async fn handle_exec_command( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, &app.endpoint); - let gateway_url = format!( - "{}://{}:{gateway_port}{}", - session.gateway_scheme, gateway_host, session.connect_path - ); + let gateway_url = format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port); let exe = match std::env::current_exe() { Ok(p) => p, @@ -1080,7 +1142,8 @@ async fn handle_exec_command( // SSH utility functions are shared via openshell_core::forward. use openshell_core::forward::{ - build_proxy_command, resolve_ssh_gateway, shell_escape, validate_ssh_session_response, + build_proxy_command, format_gateway_url, resolve_ssh_gateway, shell_escape, + validate_ssh_session_response, }; /// Convert a `SandboxPolicy` proto into styled ratatui lines for the policy viewer. @@ -1390,7 +1453,9 @@ fn spawn_create_sandbox(app: &mut App, tx: mpsc::UnboundedSender) { /// This is called from within the create-sandbox task so the pacman animation /// keeps running while forwards are being established. async fn start_port_forwards( - client: &mut OpenShellClient, + client: &mut OpenShellClient< + tonic::service::interceptor::InterceptedService, + >, endpoint: &str, gateway_name: &str, sandbox_name: &str, @@ -1424,10 +1489,7 @@ async fn start_port_forwards( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, endpoint); - let gateway_url = format!( - "{}://{}:{gateway_port}{}", - session.gateway_scheme, gateway_host, session.connect_path - ); + let gateway_url = format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port); // Build ProxyCommand. let exe = match std::env::current_exe() { @@ -1565,10 +1627,12 @@ fn spawn_create_provider(app: &App, tx: mpsc::UnboundedSender) { name: provider_name.clone(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: ptype.clone(), credentials: credentials.clone(), config: HashMap::default(), + credential_expires_at_ms: HashMap::default(), }), }; @@ -1655,11 +1719,14 @@ fn spawn_update_provider(app: &App, tx: mpsc::UnboundedSender) { name: name.clone(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: ptype, credentials, config: HashMap::default(), + credential_expires_at_ms: HashMap::default(), }), + credential_expires_at_ms: HashMap::default(), }; match tokio::time::timeout(Duration::from_secs(5), client.update_provider(req)).await { @@ -1998,6 +2065,7 @@ fn spawn_set_global_setting(app: &App, tx: mpsc::UnboundedSender) { delete_setting: false, global: true, merge_operations: vec![], + expected_resource_version: 0, }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; @@ -2033,6 +2101,7 @@ fn spawn_delete_global_setting(app: &App, tx: mpsc::UnboundedSender) { delete_setting: true, global: true, merge_operations: vec![], + expected_resource_version: 0, }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; @@ -2102,6 +2171,7 @@ fn spawn_set_sandbox_setting(app: &App, tx: mpsc::UnboundedSender) { delete_setting: false, global: false, merge_operations: vec![], + expected_resource_version: 0, }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; @@ -2141,6 +2211,7 @@ fn spawn_delete_sandbox_setting(app: &App, tx: mpsc::UnboundedSender) { delete_setting: true, global: false, merge_operations: vec![], + expected_resource_version: 0, }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; diff --git a/crates/openshell-tui/src/ui/create_provider.rs b/crates/openshell-tui/src/ui/create_provider.rs index 9f1cc6d83..3df8b818f 100644 --- a/crates/openshell-tui/src/ui/create_provider.rs +++ b/crates/openshell-tui/src/ui/create_provider.rs @@ -8,6 +8,8 @@ use ratatui::widgets::{Block, Borders, Clear, Padding, Paragraph}; use crate::app::{App, CreateProviderPhase, ProviderKeyField}; +use super::centered_rect; + /// Draw the create provider modal overlay. pub fn draw(frame: &mut Frame<'_>, app: &App, area: Rect) { let t = &app.theme; @@ -743,23 +745,3 @@ fn draw_secret_field( }; frame.render_widget(Paragraph::new(display), chunks[1]); } - -fn centered_rect(width: u16, height: u16, area: Rect) -> Rect { - let vert = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Length((area.height.saturating_sub(height)) / 2), - Constraint::Length(height), - Constraint::Min(0), - ]) - .split(area); - let horiz = Layout::default() - .direction(Direction::Horizontal) - .constraints([ - Constraint::Length((area.width.saturating_sub(width)) / 2), - Constraint::Length(width), - Constraint::Min(0), - ]) - .split(vert[1]); - horiz[1] -} diff --git a/crates/openshell-tui/src/ui/create_sandbox.rs b/crates/openshell-tui/src/ui/create_sandbox.rs index cb90e244d..a120036cf 100644 --- a/crates/openshell-tui/src/ui/create_sandbox.rs +++ b/crates/openshell-tui/src/ui/create_sandbox.rs @@ -8,6 +8,8 @@ use ratatui::widgets::{Block, Borders, Clear, Padding, Paragraph}; use crate::app::{App, CreateFormField, CreatePhase}; +use super::centered_rect; + /// Draw the create sandbox modal overlay. pub fn draw(frame: &mut Frame<'_>, app: &App, area: Rect) { let Some(form) = &app.create_form else { @@ -418,23 +420,3 @@ fn draw_text_field( }; frame.render_widget(Paragraph::new(display), chunks[1]); } - -fn centered_rect(width: u16, height: u16, area: Rect) -> Rect { - let vert = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Length((area.height.saturating_sub(height)) / 2), - Constraint::Length(height), - Constraint::Min(0), - ]) - .split(area); - let horiz = Layout::default() - .direction(Direction::Horizontal) - .constraints([ - Constraint::Length((area.width.saturating_sub(width)) / 2), - Constraint::Length(width), - Constraint::Min(0), - ]) - .split(vert[1]); - horiz[1] -} diff --git a/crates/openshell-tui/src/ui/mod.rs b/crates/openshell-tui/src/ui/mod.rs index 13ac94c10..98c8badb5 100644 --- a/crates/openshell-tui/src/ui/mod.rs +++ b/crates/openshell-tui/src/ui/mod.rs @@ -453,6 +453,28 @@ fn draw_command_bar(frame: &mut Frame<'_>, app: &App, area: Rect) { frame.render_widget(bar, area); } +/// Center a popup rectangle within `area` using absolute width and height (in +/// terminal columns/rows). +pub fn centered_rect(width: u16, height: u16, area: Rect) -> Rect { + let vert = Layout::default() + .direction(Direction::Vertical) + .constraints([ + Constraint::Length((area.height.saturating_sub(height)) / 2), + Constraint::Length(height), + Constraint::Min(0), + ]) + .split(area); + let horiz = Layout::default() + .direction(Direction::Horizontal) + .constraints([ + Constraint::Length((area.width.saturating_sub(width)) / 2), + Constraint::Length(width), + Constraint::Min(0), + ]) + .split(vert[1]); + horiz[1] +} + /// Center a popup rectangle within `area` using percentage-based width and /// an absolute height (in rows). pub fn centered_popup(percent_x: u16, height: u16, area: Rect) -> Rect { diff --git a/crates/openshell-tui/src/ui/sandbox_draft.rs b/crates/openshell-tui/src/ui/sandbox_draft.rs index a9cd6b0c9..38214d6da 100644 --- a/crates/openshell-tui/src/ui/sandbox_draft.rs +++ b/crates/openshell-tui/src/ui/sandbox_draft.rs @@ -6,11 +6,13 @@ use crate::app::App; use openshell_core::proto::PolicyChunk; use ratatui::Frame; -use ratatui::layout::{Constraint, Direction, Layout, Rect}; +use ratatui::layout::Rect; use ratatui::style::Modifier; use ratatui::text::{Line, Span}; use ratatui::widgets::{Block, Borders, Clear, Padding, Paragraph, Wrap}; +use super::centered_rect; + /// Draw the network rules panel (list view with highlight bar). pub fn draw(frame: &mut Frame<'_>, app: &mut App, area: Rect) { let t = &app.theme; @@ -441,23 +443,3 @@ fn format_short_time(epoch_ms: i64) -> String { let seconds = time_of_day % 60; format!("{hours:02}:{minutes:02}:{seconds:02}") } - -fn centered_rect(width: u16, height: u16, area: Rect) -> Rect { - let vert = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Length((area.height.saturating_sub(height)) / 2), - Constraint::Length(height), - Constraint::Min(0), - ]) - .split(area); - let horiz = Layout::default() - .direction(Direction::Horizontal) - .constraints([ - Constraint::Length((area.width.saturating_sub(width)) / 2), - Constraint::Length(width), - Constraint::Min(0), - ]) - .split(vert[1]); - horiz[1] -} diff --git a/crates/openshell-tui/src/ui/sandbox_logs.rs b/crates/openshell-tui/src/ui/sandbox_logs.rs index da20b40c8..45f548e8b 100644 --- a/crates/openshell-tui/src/ui/sandbox_logs.rs +++ b/crates/openshell-tui/src/ui/sandbox_logs.rs @@ -2,12 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 use ratatui::Frame; -use ratatui::layout::{Constraint, Direction, Layout, Rect}; +use ratatui::layout::Rect; use ratatui::text::{Line, Span}; use ratatui::widgets::{Block, Borders, Clear, Padding, Paragraph, Wrap}; use crate::app::{App, LogLine}; +use super::centered_rect; + pub fn draw(frame: &mut Frame<'_>, app: &mut App, area: Rect) { let t = &app.theme; let name = app @@ -202,26 +204,6 @@ pub fn draw_detail_popup( ); } -fn centered_rect(width: u16, height: u16, area: Rect) -> Rect { - let vert = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Length((area.height.saturating_sub(height)) / 2), - Constraint::Length(height), - Constraint::Min(0), - ]) - .split(area); - let horiz = Layout::default() - .direction(Direction::Horizontal) - .constraints([ - Constraint::Length((area.width.saturating_sub(width)) / 2), - Constraint::Length(width), - Constraint::Min(0), - ]) - .split(vert[1]); - horiz[1] -} - // --------------------------------------------------------------------------- // Log line rendering (compact, truncated) // --------------------------------------------------------------------------- diff --git a/crates/openshell-tui/src/ui/splash.rs b/crates/openshell-tui/src/ui/splash.rs index 1a5595d1f..7f889f995 100644 --- a/crates/openshell-tui/src/ui/splash.rs +++ b/crates/openshell-tui/src/ui/splash.rs @@ -11,6 +11,8 @@ use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect}; use ratatui::text::{Line, Span}; use ratatui::widgets::{Block, BorderType, Borders, Clear, Padding, Paragraph}; +use super::centered_rect; + // --------------------------------------------------------------------------- // ANSI Shadow figlet art — OPEN (6 lines, 35 display cols) // --------------------------------------------------------------------------- @@ -117,23 +119,3 @@ pub fn draw(frame: &mut Frame<'_>, area: Rect, theme: &crate::theme::Theme) { frame.render_widget(footer, chunks[2]); } - -fn centered_rect(width: u16, height: u16, area: Rect) -> Rect { - let vert = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Length((area.height.saturating_sub(height)) / 2), - Constraint::Length(height), - Constraint::Min(0), - ]) - .split(area); - let horiz = Layout::default() - .direction(Direction::Horizontal) - .constraints([ - Constraint::Length((area.width.saturating_sub(width)) / 2), - Constraint::Length(width), - Constraint::Min(0), - ]) - .split(vert[1]); - horiz[1] -} diff --git a/deploy/deb/openshell-gateway.service b/deploy/deb/openshell-gateway.service index 26e3c07be..1ed112b05 100644 --- a/deploy/deb/openshell-gateway.service +++ b/deploy/deb/openshell-gateway.service @@ -6,17 +6,8 @@ After=default.target [Service] Type=simple StateDirectory=openshell/gateway -# %S resolves to $XDG_STATE_HOME for user services. -Environment=OPENSHELL_BIND_ADDRESS=127.0.0.1 -Environment=OPENSHELL_SERVER_PORT=17670 -Environment=OPENSHELL_DISABLE_TLS=true -Environment=OPENSHELL_DISABLE_GATEWAY_AUTH=true -Environment=OPENSHELL_DB_URL=sqlite:%S/openshell/gateway/openshell.db -Environment=OPENSHELL_GRPC_ENDPOINT=http://127.0.0.1:17670 -Environment=OPENSHELL_SSH_GATEWAY_HOST=127.0.0.1 -Environment=OPENSHELL_SSH_GATEWAY_PORT=17670 -Environment=OPENSHELL_VM_DRIVER_STATE_DIR=%S/openshell/vm-driver -EnvironmentFile=-%h/.config/openshell/gateway.env +EnvironmentFile=-%E/openshell/gateway.env +ExecStartPre=/usr/bin/openshell-gateway generate-certs --output-dir %S/openshell/tls --server-san host.openshell.internal ExecStart=/usr/bin/openshell-gateway Restart=on-failure RestartSec=5s diff --git a/deploy/docker/Dockerfile.ci b/deploy/docker/Dockerfile.ci index 3c669a96f..77a8c94e2 100644 --- a/deploy/docker/Dockerfile.ci +++ b/deploy/docker/Dockerfile.ci @@ -4,12 +4,12 @@ # SPDX-License-Identifier: Apache-2.0 # CI runner image with all development tools pre-installed -# Rebuild triggered automatically when mise.toml or this file changes +# Rebuild triggered automatically when mise.toml, mise.lock, tasks, or this file changes FROM nvcr.io/nvidia/base/ubuntu:noble-20251013 -ARG DOCKER_VERSION=29.4.1 -ARG BUILDX_VERSION=v0.33.0 +ARG DOCKER_VERSION=29.5.1 +ARG BUILDX_VERSION=v0.34.0 ARG NPM_VERSION=11.13.0 ARG TARGETARCH @@ -29,6 +29,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ libz3-dev \ pkg-config \ libssl-dev \ + musl-tools \ openssh-client \ python3 \ python3-venv \ @@ -56,7 +57,7 @@ RUN case "$TARGETARCH" in \ && chmod +x /usr/local/lib/docker/cli-plugins/docker-buildx # Install GitHub CLI used by install.sh and CI jobs -ARG GH_VERSION=2.91.0 +ARG GH_VERSION=2.92.0 RUN case "$TARGETARCH" in \ amd64) gh_arch=amd64 ;; \ arm64) gh_arch=arm64 ;; \ @@ -82,7 +83,8 @@ RUN --mount=type=secret,id=MISE_GITHUB_TOKEN \ npm install -g "npm@${NPM_VERSION}" && \ mise reshim && \ (/root/.cargo/bin/rustup component remove rust-docs || true) && \ - rm -rf /root/.rustup/toolchains/*/share/doc /root/.rustup/toolchains/*/share/man + rm -rf /root/.rustup/toolchains/*/share/doc /root/.rustup/toolchains/*/share/man && \ + helm plugin install https://github.com/helm-unittest/helm-unittest --verify=false # Set working directory for CI jobs WORKDIR /builds diff --git a/deploy/docker/Dockerfile.gateway b/deploy/docker/Dockerfile.gateway new file mode 100644 index 000000000..9dd7ed8b9 --- /dev/null +++ b/deploy/docker/Dockerfile.gateway @@ -0,0 +1,36 @@ +# syntax=docker/dockerfile:1.4 + +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Gateway image build. +# +# The Rust binary is built natively before this image build runs and staged at: +# deploy/docker/.build/prebuilt-binaries//openshell-gateway +# +# Use tasks/scripts/docker-build-image.sh gateway (or `mise run build:docker:gateway`) +# to stage the binary and build the image in one step. CI builds the binary +# per-architecture via the `rust-native-build.yml` workflow and uploads it as +# an artifact, which is downloaded into the same staging directory before the +# image build job runs. +# +# The runtime is distroless Debian 13, which provides glibc and the dynamic +# loader needed by the GNU-linked gateway binary while keeping the attack +# surface small. The default digest currently carries Debian glibc +# 2.41-12+deb13u3. + +ARG GATEWAY_BASE_IMAGE=gcr.io/distroless/cc-debian13:nonroot@sha256:e1fd250ce83d94603e9887ec991156a6c26905a6b0001039b7a43699018c0733 + +FROM ${GATEWAY_BASE_IMAGE} AS gateway + +ARG TARGETARCH + +WORKDIR /app + +COPY deploy/docker/.build/prebuilt-binaries/${TARGETARCH}/openshell-gateway /usr/local/bin/openshell-gateway + +USER 1000:1000 +EXPOSE 8080 + +ENTRYPOINT ["/usr/local/bin/openshell-gateway"] +CMD ["--bind-address", "0.0.0.0", "--port", "8080"] diff --git a/deploy/docker/Dockerfile.gateway-macos b/deploy/docker/Dockerfile.gateway-macos index 4cae2f0e7..27d2ffbba 100644 --- a/deploy/docker/Dockerfile.gateway-macos +++ b/deploy/docker/Dockerfile.gateway-macos @@ -94,6 +94,7 @@ RUN touch crates/openshell-core/src/lib.rs \ proto/*.proto ARG OPENSHELL_CARGO_VERSION +ARG OPENSHELL_IMAGE_TAG RUN --mount=type=cache,id=cargo-registry-gateway-macos,sharing=locked,target=/root/.cargo/registry \ --mount=type=cache,id=cargo-git-gateway-macos,sharing=locked,target=/root/.cargo/git \ --mount=type=cache,id=cargo-target-gateway-macos-${CARGO_TARGET_CACHE_SCOPE},sharing=locked,target=/build/target \ diff --git a/deploy/docker/Dockerfile.images b/deploy/docker/Dockerfile.images deleted file mode 100644 index 62662fa93..000000000 --- a/deploy/docker/Dockerfile.images +++ /dev/null @@ -1,122 +0,0 @@ -# syntax=docker/dockerfile:1.4 - -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -# Shared OpenShell image build graph. -# -# Targets: -# gateway Final gateway image -# supervisor Final supervisor image (Ubuntu base, supervisor binary) -# -# Rust binaries are built natively before the image build and staged at: -# deploy/docker/.build/prebuilt-binaries//openshell-{gateway,sandbox} -# -# For local dev (Skaffold), pass --build-arg BUILD_FROM_SOURCE=1 to compile -# binaries inside Docker instead. BuildKit only executes the selected binary -# staging stage, so missing prebuilt files do not cause a build failure. - -# Controls binary source: 0 = prebuilt (release), 1 = compile in Docker (local dev). -# Must be declared here (global scope) so it can be used in FROM instructions below. -ARG BUILD_FROM_SOURCE=0 - -# --------------------------------------------------------------------------- -# Optional in-Docker Rust build (BUILD_FROM_SOURCE=1, local dev only) -# --------------------------------------------------------------------------- -FROM rust:1.95.0-slim-bookworm AS rust-builder - -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - cmake \ - pkg-config \ - libssl-dev \ - ca-certificates \ - && rm -rf /var/lib/apt/lists/* - -WORKDIR /build - -COPY Cargo.toml Cargo.lock ./ -COPY crates/ crates/ -COPY proto/ proto/ -COPY providers/ providers/ - -RUN --mount=type=cache,target=/usr/local/cargo/registry \ - --mount=type=cache,target=/build/target \ - cargo build --release \ - --features "openshell-core/dev-settings" \ - --bin openshell-gateway \ - --bin openshell-sandbox \ - && mkdir -p /build/out \ - && install -m 0755 target/release/openshell-gateway /build/out/openshell-gateway \ - && install -m 0755 target/release/openshell-sandbox /build/out/openshell-sandbox - -# --------------------------------------------------------------------------- -# Per-arch binary stages -# --------------------------------------------------------------------------- - -# Prebuilt path (release default, BUILD_FROM_SOURCE=0) -FROM scratch AS gateway-binary-0 -ARG TARGETARCH -# --chmod=755 preserves the executable bit through actions/upload-artifact + -# download-artifact, which strip exec perms during the roundtrip. -COPY --chmod=755 deploy/docker/.build/prebuilt-binaries/${TARGETARCH}/openshell-gateway /build/out/openshell-gateway - -# Source-built path (local dev, BUILD_FROM_SOURCE=1) -FROM rust-builder AS gateway-binary-1 - -FROM gateway-binary-${BUILD_FROM_SOURCE} AS gateway-binary - -# Prebuilt path (release default, BUILD_FROM_SOURCE=0) -FROM scratch AS supervisor-binary-0 -ARG TARGETARCH -# --chmod=755 preserves the executable bit through actions/upload-artifact + -# download-artifact, which strip exec perms during the roundtrip. -COPY --chmod=755 deploy/docker/.build/prebuilt-binaries/${TARGETARCH}/openshell-sandbox /build/out/openshell-sandbox - -# Source-built path (local dev, BUILD_FROM_SOURCE=1) -FROM rust-builder AS supervisor-binary-1 - -FROM supervisor-binary-${BUILD_FROM_SOURCE} AS supervisor-binary - -# --------------------------------------------------------------------------- -# Final gateway image -# --------------------------------------------------------------------------- -FROM nvcr.io/nvidia/base/ubuntu:noble-20251013 AS gateway - -RUN apt-get update && apt-get install -y --no-install-recommends \ - ca-certificates && \ - apt-get install -y --only-upgrade gpgv && \ - rm -rf /var/lib/apt/lists/* - -RUN useradd --create-home --user-group openshell - -WORKDIR /app - -COPY --from=gateway-binary /build/out/openshell-gateway /usr/local/bin/ - -RUN mkdir -p /build/crates/openshell-server -COPY --chmod=755 crates/openshell-server/migrations /build/crates/openshell-server/migrations - -USER openshell -EXPOSE 8080 - -ENTRYPOINT ["openshell-gateway"] -CMD ["--bind-address", "0.0.0.0", "--port", "8080"] - -# --------------------------------------------------------------------------- -# Final supervisor image -# --------------------------------------------------------------------------- -# Supervisor image based on the same NVIDIA Ubuntu base used by the gateway. -# -# Used by: -# - Docker driver: binary is extracted from the image and run inside the -# agent container. -# - Podman driver: image is mounted as an OCI volume at /opt/openshell/bin. -# - Kubernetes driver: image runs as an init container that invokes the -# binary's `copy-self` subcommand to seed an emptyDir volume. -# -# An Ubuntu base provides glibc and the dynamic loader needed to exec the -# dynamically linked binary. `FROM scratch` would be smaller but cannot run -# the binary, breaking the Kubernetes init-container path. -FROM nvcr.io/nvidia/base/ubuntu:noble-20251013 AS supervisor -COPY --from=supervisor-binary /build/out/openshell-sandbox /openshell-sandbox diff --git a/deploy/docker/Dockerfile.supervisor b/deploy/docker/Dockerfile.supervisor new file mode 100644 index 000000000..c84cc70e9 --- /dev/null +++ b/deploy/docker/Dockerfile.supervisor @@ -0,0 +1,35 @@ +# syntax=docker/dockerfile:1.4 + +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Supervisor image build. +# +# The final image is `scratch`: it only carries the static `openshell-sandbox` +# binary used by Docker extraction, Podman image volumes, and the Kubernetes +# init container copy-self path. A static musl binary lets the image stay +# `scratch` while still being executable as an init container. +# +# The Rust binary is built natively before this image build runs and staged at: +# deploy/docker/.build/prebuilt-binaries//openshell-sandbox +# +# Use tasks/scripts/docker-build-image.sh supervisor (or `mise run build:docker:supervisor`) +# to stage the binary and build the image in one step. CI builds the binary +# per-architecture via the `rust-native-build.yml` workflow (with the musl +# target) and uploads it as an artifact, which is downloaded into the same +# staging directory before the image build job runs. + +FROM scratch AS supervisor + +ARG TARGETARCH + +# --chmod=0550 drops world-execute and survives the actions/upload-artifact +# + download-artifact roundtrip (which strips exec perms). Ownership is left +# at root (0:0) deliberately: the Podman driver mounts this image as a +# read-only image volume into the sandbox container and drops DAC_OVERRIDE, +# so the container's UID 0 must own the binary to read+exec it. Mode 0550 +# (r-xr-x---) is the security win; the chown to a non-root UID was breaking +# Podman without buying anything since the container is always UID 0. +COPY --chmod=0550 deploy/docker/.build/prebuilt-binaries/${TARGETARCH}/openshell-sandbox /openshell-sandbox + +ENTRYPOINT ["/openshell-sandbox"] diff --git a/deploy/helm/openshell/.helmignore b/deploy/helm/openshell/.helmignore index a12325802..0aecc346a 100644 --- a/deploy/helm/openshell/.helmignore +++ b/deploy/helm/openshell/.helmignore @@ -18,5 +18,6 @@ .vscode/ # Ignore development files +README.md.gotmpl skaffold.yaml ci/ diff --git a/deploy/helm/openshell/README.md b/deploy/helm/openshell/README.md index ee7565f29..9df0b91a0 100644 --- a/deploy/helm/openshell/README.md +++ b/deploy/helm/openshell/README.md @@ -1,6 +1,11 @@ # OpenShell Helm Chart -> **Experimental** — the Kubernetes deployment path is under active development. Expect rough edges and breaking changes. + + +> **Experimental** - the Kubernetes deployment path is under active development. Expect rough edges and breaking changes. This chart deploys the OpenShell gateway into a Kubernetes cluster. It is published as an OCI artifact to GHCR at `oci://ghcr.io/nvidia/openshell/helm-chart`. @@ -8,24 +13,24 @@ This chart deploys the OpenShell gateway into a Kubernetes cluster. It is publis The Kubernetes Agent Sandbox CRDs and controller must be installed on the cluster before deploying OpenShell. Install them with: -```bash +```shell kubectl apply -f https://github.com/kubernetes-sigs/agent-sandbox/releases/latest/download/manifest.yaml ``` ## Install on Kubernetes -```bash +```shell helm install openshell oci://ghcr.io/nvidia/openshell/helm-chart --version ``` ## Install on OpenShift -```bash +```shell # Precreate the openshell namespace so we can create the SCC cluster role oc create ns openshell -# Sandboxes are deployed into the openshell namespace and use the default service account for now -oc adm policy add-scc-to-user privileged -z default -n openshell +# Sandboxes are deployed into the openshell namespace and use the openshell-sandbox service account +oc adm policy add-scc-to-user privileged -z openshell-sandbox -n openshell # Deploy openshell with overrides to allow SCC assignment of fsGroup and runAsUser for the gateway helm install openshell oci://ghcr.io/nvidia/openshell/helm-chart --version -n openshell \ @@ -47,8 +52,124 @@ The `dev` tags are intended for testing changes ahead of a release. Production d ## Configuration -See [`values.yaml`](values.yaml) for configurable values. Selected overlays: - -- [`ci/values-gateway.yaml`](ci/values-gateway.yaml) — gateway-only configuration -- [`ci/values-cert-manager.yaml`](ci/values-cert-manager.yaml) — cert-manager integration -- [`ci/values-keycloak.yaml`](ci/values-keycloak.yaml) — Keycloak OIDC integration +See [`values.yaml`](values.yaml) for source defaults. Selected overlays: + +- [`ci/values-gateway.yaml`](ci/values-gateway.yaml) - gateway-only configuration +- [`ci/values-cert-manager.yaml`](ci/values-cert-manager.yaml) - cert-manager integration +- [`ci/values-keycloak.yaml`](ci/values-keycloak.yaml) - Keycloak OIDC integration + +## PKI bootstrap + +By default, a pre-install/pre-upgrade hook Job runs `openshell-gateway generate-certs` +to create the gateway's server and client mTLS Secrets. The Job uses the gateway image +itself, so air-gapped environments only need to mirror that one image (no separate +openssl/alpine sidecar). + +The Job is idempotent: + +- Both target Secrets exist: log and exit 0. +- Exactly one exists: fail with `kubectl delete secret -n ` recovery hint. +- Neither exists: generate a CA, server cert, and client cert; POST both `kubernetes.io/tls` Secrets (`tls.crt`, `tls.key`, `ca.crt`). + +Disable with `--set pkiInitJob.enabled=false` when bringing your own PKI (cert-manager, +external CA, or pre-created Secrets). See `certManager.*` in `values.yaml` for the +cert-manager alternative. + +## Values + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| affinity | object | `{}` | Affinity rules for the gateway pod. | +| certManager.caSecretName | string | `"openshell-ca-tls"` | Secret created for the intermediate CA (Certificate with isCA: true). | +| certManager.certificateDuration | string | `"8760h"` | Duration for cert-manager-issued certificates. | +| certManager.certificateRenewBefore | string | `"720h"` | Renewal window for cert-manager-issued certificates. | +| certManager.clientCaFromServerTlsSecret | bool | `true` | Mount gateway client CA from the server TLS secret's ca.crt (populated by cert-manager for certs issued by a CA Issuer). Avoids a separate openshell-server-client-ca Secret. | +| certManager.enabled | bool | `false` | Create cert-manager Issuer and Certificate resources instead of using the PKI bootstrap Job. | +| certManager.serverDnsNames | list | `["openshell","openshell.openshell.svc","openshell.openshell.svc.cluster.local","localhost","openshell.localhost","*.openshell.localhost","host.docker.internal"]` | DNS SANs on the cert-manager-issued server certificate. | +| certManager.serverIpAddresses | list | `["127.0.0.1"]` | IP SANs on the cert-manager-issued server certificate. | +| fullnameOverride | string | `""` | Override the full generated resource name. | +| grpcRoute.enabled | bool | `false` | Create a Gateway API GRPCRoute for the gateway service. | +| grpcRoute.gateway.className | string | `"eg"` | GatewayClass to reference. Envoy Gateway installs one named "eg". | +| grpcRoute.gateway.create | bool | `false` | When true, a Gateway resource is created in the release namespace. Set to false and provide name/namespace to attach to a pre-existing Gateway. | +| grpcRoute.gateway.listener.allowedRoutes | string | `"Same"` | "Same" restricts attached routes to the release namespace; "All" allows any namespace. | +| grpcRoute.gateway.listener.port | int | `80` | Listener port for the generated Gateway resource. | +| grpcRoute.gateway.listener.protocol | string | `"HTTP"` | Listener protocol for the generated Gateway resource. | +| grpcRoute.gateway.name | string | `""` | Name of the Gateway resource. Defaults to the chart fullname. | +| grpcRoute.gateway.namespace | string | `""` | Namespace of the Gateway referenced by the GRPCRoute parentRef. Defaults to the release namespace. | +| grpcRoute.hostnames | list | `[]` | Hostnames the GRPCRoute matches on. Leave empty to match all hosts. | +| image.pullPolicy | string | `"IfNotPresent"` | Gateway image pull policy. | +| image.repository | string | `"ghcr.io/nvidia/openshell/gateway"` | Gateway image repository. | +| image.tag | string | `""` | Gateway image tag. Defaults to the chart appVersion when empty. | +| imagePullSecrets | list | `[]` | Image pull secrets attached to gateway and helper pods. | +| nameOverride | string | `"openshell"` | Override the chart name used in generated resource names. | +| networkPolicy.enabled | bool | `true` | Create a NetworkPolicy restricting SSH ingress on sandbox pods to the gateway. | +| nodeSelector | object | `{}` | Node selector for the gateway pod. | +| pkiInitJob.enabled | bool | `true` | Run a pre-install/pre-upgrade Job that creates gateway and client mTLS Secrets. | +| pkiInitJob.serverDnsNames | list | `[]` | Extra DNS SANs to append to the server certificate. | +| pkiInitJob.serverIpAddresses | list | `[]` | Extra IP SANs to append to the server certificate. | +| podAnnotations | object | `{}` | Extra annotations to add to the gateway pod. | +| podLabels | object | `{}` | Extra labels to add to the gateway pod. | +| podLifecycle.terminationGracePeriodSeconds | int | `5` | Grace period, in seconds, before Kubernetes terminates the gateway pod. | +| podSecurityContext.fsGroup | int | `1000` | fsGroup assigned to the gateway pod. | +| probes.liveness.failureThreshold | int | `3` | Liveness probe failure threshold before the container is restarted. | +| probes.liveness.initialDelaySeconds | int | `2` | Liveness probe initial delay, in seconds. | +| probes.liveness.periodSeconds | int | `5` | Liveness probe period, in seconds. | +| probes.liveness.timeoutSeconds | int | `1` | Liveness probe timeout, in seconds. | +| probes.readiness.failureThreshold | int | `3` | Readiness probe failure threshold before the pod is marked not ready. | +| probes.readiness.initialDelaySeconds | int | `1` | Readiness probe initial delay, in seconds. | +| probes.readiness.periodSeconds | int | `2` | Readiness probe period, in seconds. | +| probes.readiness.timeoutSeconds | int | `1` | Readiness probe timeout, in seconds. | +| probes.startup.failureThreshold | int | `30` | Startup probe failure threshold before the container is killed. | +| probes.startup.periodSeconds | int | `2` | Startup probe period, in seconds. | +| probes.startup.timeoutSeconds | int | `1` | Startup probe timeout, in seconds. | +| replicaCount | int | `1` | Number of OpenShell gateway replicas. | +| resources | object | `{}` | Gateway pod resource requests and limits. | +| sandboxServiceAccount.annotations | object | `{}` | Annotations to add to the generated sandbox service account. | +| sandboxServiceAccount.create | bool | `true` | Create a service account for sandbox pods. | +| sandboxServiceAccount.name | string | `""` | Existing service account name for sandbox pods when sandboxServiceAccount.create is false. | +| securityContext.allowPrivilegeEscalation | bool | `false` | Whether the gateway container can gain additional privileges. | +| securityContext.capabilities.drop | list | `["ALL"]` | Linux capabilities dropped from the gateway container. | +| securityContext.runAsNonRoot | bool | `true` | Require the gateway container to run as a non-root user. | +| securityContext.runAsUser | int | `1000` | UID assigned to the gateway container. | +| server.auth.allowUnauthenticatedUsers | bool | `false` | UNSAFE: accept unauthenticated CLI/user requests as a local developer principal. Intended only for trusted local Skaffold/k3d development or a fully trusted fronting proxy. Leave false for shared or production clusters. | +| server.dbUrl | string | `"sqlite:/var/openshell/openshell.db"` | Gateway database URL. | +| server.disableTls | bool | `false` | Disable TLS entirely - the server listens on plaintext HTTP. Set to true when a reverse proxy / tunnel terminates TLS at the edge. | +| server.enableLoopbackServiceHttp | bool | `true` | Enable plaintext HTTP routing for loopback sandbox service URLs on TLS-enabled gateways. | +| server.enableUserNamespaces | bool | `false` | Enable Kubernetes user namespace isolation (hostUsers: false) for sandbox pods. Requires Kubernetes 1.33+ with user namespace support available (beta through 1.35, GA in 1.36+), plus a supporting container runtime and Linux 5.12+. When enabled, container UID 0 maps to an unprivileged host UID and capabilities become namespaced. | +| server.grpcEndpoint | string | `""` | gRPC endpoint sandboxes call back into the gateway. Leave empty to derive it from the chart fullname, release namespace, service port, and disableTls flag, for example https://openshell.openshell.svc.cluster.local:8080. Override only when sandboxes must reach the gateway via a different hostname (e.g. an external ingress or a host alias). | +| server.hostGatewayIP | string | `""` | Host gateway IP for sandbox pod hostAliases. When set, sandbox pods get hostAliases entries mapping host.docker.internal and host.openshell.internal to this IP, allowing them to reach services running on the Docker host. Auto-detected by the cluster entrypoint script. | +| server.logLevel | string | `"info"` | Gateway log level. | +| server.oidc.adminRole | string | `""` | Role name for admin access. Leave empty (with userRole also empty) for authentication-only mode. Both must be set or both empty. | +| server.oidc.audience | string | `"openshell-cli"` | Expected audience claim for the API resource server. This should match the server's --oidc-audience, NOT the CLI client ID. | +| server.oidc.caConfigMapName | string | `""` | Name of a ConfigMap containing a CA certificate bundle (key: ca.crt) for verifying the OIDC issuer's TLS certificate. Required when the issuer uses a non-public CA (e.g. OpenShift ingress, private PKI). | +| server.oidc.issuer | string | `""` | OIDC issuer URL (e.g. https://keycloak.example.com/realms/openshell). | +| server.oidc.jwksTtl | int | `3600` | JWKS key cache TTL in seconds. | +| server.oidc.rolesClaim | string | `""` | Dot-separated path to the roles array in the JWT claims. Keycloak: "realm_access.roles", Entra ID: "roles", Okta: "groups". | +| server.oidc.scopesClaim | string | `""` | Dot-separated path to the scopes array in the JWT claims. | +| server.oidc.userRole | string | `""` | Role name for standard user access. | +| server.sandboxImage | string | `"ghcr.io/nvidia/openshell-community/sandboxes/base:latest"` | Default sandbox image used when requests do not specify one. | +| server.sandboxImagePullPolicy | string | `""` | Kubernetes imagePullPolicy for sandbox pods. Empty = Kubernetes default (Always for :latest, IfNotPresent otherwise). Set to "Always" for dev clusters so new images are picked up without manual eviction. | +| server.sandboxJwt.gatewayId | string | `""` | Stable gateway identity embedded in iss/aud of every minted token. Defaults to the release name so HA replicas share identity. | +| server.sandboxJwt.k8sSaTokenTtlSecs | int | `3600` | Lifetime (seconds) of the projected ServiceAccount token kubelet writes into each sandbox pod for the IssueSandboxToken bootstrap exchange. Kubelet enforces a minimum of 600s; the driver clamps values outside [600, 86400]. Default 3600 — generous, since the supervisor consumes the token within seconds of pod start. | +| server.sandboxJwt.signingSecretName | string | `""` | Name of the Opaque Secret holding the signing key material. Empty falls back to the chart fullname with "-jwt-keys" appended. | +| server.sandboxJwt.ttlSecs | int | `3600` | Token TTL in seconds. Defaults to 3600 (1h). | +| server.sandboxNamespace | string | `""` | Namespace where sandbox pods are created. Defaults to the Helm release namespace (.Release.Namespace) when left empty. | +| server.tls.certSecretName | string | `"openshell-server-tls"` | K8s secret (type kubernetes.io/tls) with tls.crt and tls.key for the server. | +| server.tls.clientCaSecretName | string | `"openshell-server-client-ca"` | K8s secret with ca.crt for client certificate verification (mTLS). Set to "" to disable mTLS and run HTTPS-only (use OIDC for auth instead). | +| server.tls.clientTlsSecretName | string | `"openshell-client-tls"` | K8s secret mounted into sandbox pods for mTLS to the server. | +| server.workspaceDefaultStorageSize | string | `""` | Default storage size for the workspace PVC in sandbox pods. Uses Kubernetes quantity syntax (e.g. "2Gi", "10Gi", "500Mi"). Empty = built-in default (2Gi). | +| service.healthPort | int | `8081` | Gateway health service port. | +| service.metricsPort | int | `9090` | Gateway metrics service port. | +| service.port | int | `8080` | Gateway gRPC/HTTP service port. | +| service.type | string | `"ClusterIP"` | Kubernetes Service type for the gateway. | +| serviceAccount.annotations | object | `{}` | Annotations to add to the generated service account. | +| serviceAccount.create | bool | `true` | Create a service account for the gateway. | +| serviceAccount.name | string | `""` | Existing service account name to use when serviceAccount.create is false. | +| supervisor.image.pullPolicy | string | `""` | Supervisor image pull policy. Defaults to the gateway image pull policy when empty. | +| supervisor.image.repository | string | `"ghcr.io/nvidia/openshell/supervisor"` | Supervisor image repository. | +| supervisor.image.tag | string | `""` | Supervisor image tag. Defaults to the chart appVersion when empty. | +| supervisor.sideloadMethod | string | `""` | How the supervisor binary is delivered into sandbox pods. Empty (default) = auto-detect from cluster version: K8s >= v1.35 -> "image-volume" (ImageVolume enabled by default; GA in v1.36) K8s < v1.35 -> "init-container" (copies via init container + emptyDir) On K8s v1.33-v1.34 with the ImageVolume feature gate manually enabled, set this to "image-volume" explicitly. | +| tolerations | list | `[]` | Tolerations for the gateway pod. | + +---------------------------------------------- +Autogenerated from chart metadata using [helm-docs v1.14.2](https://github.com/norwoodj/helm-docs/releases/v1.14.2) diff --git a/deploy/helm/openshell/README.md.gotmpl b/deploy/helm/openshell/README.md.gotmpl new file mode 100644 index 000000000..9e6a0ec65 --- /dev/null +++ b/deploy/helm/openshell/README.md.gotmpl @@ -0,0 +1,79 @@ +# OpenShell Helm Chart + + + +> **Experimental** - the Kubernetes deployment path is under active development. Expect rough edges and breaking changes. + +This chart deploys the OpenShell gateway into a Kubernetes cluster. It is published as an OCI artifact to GHCR at `oci://ghcr.io/nvidia/openshell/helm-chart`. + +## Prerequisites + +The Kubernetes Agent Sandbox CRDs and controller must be installed on the cluster before deploying OpenShell. Install them with: + +```shell +kubectl apply -f https://github.com/kubernetes-sigs/agent-sandbox/releases/latest/download/manifest.yaml +``` + +## Install on Kubernetes + +```shell +helm install openshell oci://ghcr.io/nvidia/openshell/helm-chart --version +``` + +## Install on OpenShift + +```shell +# Precreate the openshell namespace so we can create the SCC cluster role +oc create ns openshell + +# Sandboxes are deployed into the openshell namespace and use the openshell-sandbox service account +oc adm policy add-scc-to-user privileged -z openshell-sandbox -n openshell + +# Deploy openshell with overrides to allow SCC assignment of fsGroup and runAsUser for the gateway +helm install openshell oci://ghcr.io/nvidia/openshell/helm-chart --version -n openshell \ + --set pkiInitJob.enabled=false \ + --set server.disableTls=true \ + --set podSecurityContext.fsGroup=null \ + --set securityContext.runAsUser=null +``` + +## Available versions + +| Tag | Source | Notes | +| --- | --- | --- | +| `` (e.g. `0.6.0`) | Tagged GitHub release | Tracks the matching gateway and supervisor image versions. Recommended for production. | +| `0.0.0-dev` | Latest commit on `main` | Floating tag, overwritten on every push. `appVersion` is `dev`, so images resolve to the `:dev` tag. | +| `0.0.0-dev.` | A specific commit on `main` | Per-commit pin. Chart version and `appVersion` both use the full 40-character commit SHA, which matches the image tag pushed by CI. | + +The `dev` tags are intended for testing changes ahead of a release. Production deployments should pin to a tagged release. + +## Configuration + +See [`values.yaml`](values.yaml) for source defaults. Selected overlays: + +- [`ci/values-gateway.yaml`](ci/values-gateway.yaml) - gateway-only configuration +- [`ci/values-cert-manager.yaml`](ci/values-cert-manager.yaml) - cert-manager integration +- [`ci/values-keycloak.yaml`](ci/values-keycloak.yaml) - Keycloak OIDC integration + +## PKI bootstrap + +By default, a pre-install/pre-upgrade hook Job runs `openshell-gateway generate-certs` +to create the gateway's server and client mTLS Secrets. The Job uses the gateway image +itself, so air-gapped environments only need to mirror that one image (no separate +openssl/alpine sidecar). + +The Job is idempotent: + +- Both target Secrets exist: log and exit 0. +- Exactly one exists: fail with `kubectl delete secret -n ` recovery hint. +- Neither exists: generate a CA, server cert, and client cert; POST both `kubernetes.io/tls` Secrets (`tls.crt`, `tls.key`, `ca.crt`). + +Disable with `--set pkiInitJob.enabled=false` when bringing your own PKI (cert-manager, +external CA, or pre-created Secrets). See `certManager.*` in `values.yaml` for the +cert-manager alternative. + +{{ template "chart.valuesSection" . }} +{{ template "helm-docs.versionFooter" . }} diff --git a/deploy/helm/openshell/ci/values-skaffold.yaml b/deploy/helm/openshell/ci/values-skaffold.yaml index 24b60e1c6..795c056f6 100644 --- a/deploy/helm/openshell/ci/values-skaffold.yaml +++ b/deploy/helm/openshell/ci/values-skaffold.yaml @@ -6,6 +6,8 @@ server: sandboxImagePullPolicy: IfNotPresent # Comment out to enforce mTLS (uses PKI secrets generated by pkiInitJob). disableTls: true + auth: + allowUnauthenticatedUsers: true supervisor: image: diff --git a/deploy/helm/openshell/ci/values-tls-disabled.yaml b/deploy/helm/openshell/ci/values-tls-disabled.yaml index ea7c7900c..7a771a178 100644 --- a/deploy/helm/openshell/ci/values-tls-disabled.yaml +++ b/deploy/helm/openshell/ci/values-tls-disabled.yaml @@ -5,6 +5,5 @@ # Typical when a reverse proxy or tunnel terminates TLS at the edge. server: disableTls: true - disableGatewayAuth: true pkiInitJob: enabled: false diff --git a/deploy/helm/openshell/skaffold.yaml b/deploy/helm/openshell/skaffold.yaml index 2de9ee4e6..779211877 100644 --- a/deploy/helm/openshell/skaffold.yaml +++ b/deploy/helm/openshell/skaffold.yaml @@ -1,12 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# Local dev: builds gateway + supervisor images using Dockerfile.images with -# BUILD_FROM_SOURCE=1, which compiles Rust binaries inside Docker without -# requiring pre-staged artifacts. +# Local dev: builds gateway + supervisor images via tasks/scripts/docker-build-image.sh, +# which first stages Rust binaries natively on the host (using cargo / cargo-zigbuild +# when cross-compiling) and then builds the image from the prebuilt binary. This +# mirrors CI and is faster than compiling inside Docker on every rebuild because +# the host's cargo target cache and sccache are reused across iterations. # # Run from repo root: -# skaffold dev -f deploy/helm/openshell/skaffold.yaml +# mise run helm:skaffold:dev +# mise run helm:skaffold:run # # See https://skaffold.dev/docs/deployers/helm/ (setValueTemplates, IMAGE_* fields). apiVersion: skaffold/v4beta14 @@ -23,39 +26,34 @@ build: context: ../../.. custom: buildCommand: | - docker buildx build \ - --build-arg BUILD_FROM_SOURCE=1 \ - --target gateway \ - --tag "$IMAGE" \ - --load \ - --file deploy/docker/Dockerfile.images \ - . + IMAGE_NAME="${IMAGE%:*}" \ + IMAGE_TAG="${IMAGE##*:}" \ + tasks/scripts/docker-build-image.sh gateway dependencies: paths: - Cargo.toml - Cargo.lock - crates/** - proto/** - - deploy/docker/Dockerfile.images - - crates/openshell-server/migrations/** + - deploy/docker/Dockerfile.gateway + - tasks/scripts/docker-build-image.sh + - tasks/scripts/stage-prebuilt-binaries.sh - image: openshell/supervisor context: ../../.. custom: buildCommand: | - docker buildx build \ - --build-arg BUILD_FROM_SOURCE=1 \ - --target supervisor \ - --tag "$IMAGE" \ - --load \ - --file deploy/docker/Dockerfile.images \ - . + IMAGE_NAME="${IMAGE%:*}" \ + IMAGE_TAG="${IMAGE##*:}" \ + tasks/scripts/docker-build-image.sh supervisor dependencies: paths: - Cargo.toml - Cargo.lock - crates/** - proto/** - - deploy/docker/Dockerfile.images + - deploy/docker/Dockerfile.supervisor + - tasks/scripts/docker-build-image.sh + - tasks/scripts/stage-prebuilt-binaries.sh deploy: helm: releases: diff --git a/deploy/helm/openshell/templates/_helpers.tpl b/deploy/helm/openshell/templates/_helpers.tpl index 09159340d..3e375a54a 100644 --- a/deploy/helm/openshell/templates/_helpers.tpl +++ b/deploy/helm/openshell/templates/_helpers.tpl @@ -59,6 +59,17 @@ Create the name of the service account to use {{- end }} {{- end }} +{{/* +Create the name of the service account assigned to sandbox pods +*/}} +{{- define "openshell.sandboxServiceAccountName" -}} +{{- if .Values.sandboxServiceAccount.create }} +{{- default (printf "%s-sandbox" (include "openshell.fullname" .) | trunc 63 | trimSuffix "-") .Values.sandboxServiceAccount.name }} +{{- else }} +{{- default "default" .Values.sandboxServiceAccount.name }} +{{- end }} +{{- end }} + {{/* Gateway image reference. Uses image.tag when set; falls back to .Chart.AppVersion so a released chart automatically pulls the matching image without extra overrides. @@ -81,3 +92,47 @@ Namespaced Issuer (selfSigned) for cert-manager CA bootstrap. {{- define "openshell.issuerSelfSigned" -}} {{- printf "%s-selfsigned" (include "openshell.fullname" .) | trunc 63 | trimSuffix "-" }} {{- end }} + +{{/* +Namespace where sandbox pods are created. An explicit +.Values.server.sandboxNamespace is used verbatim. Otherwise it defaults to +.Release.Namespace so `helm install -n my-ns` works without extra overrides. +*/}} +{{- define "openshell.sandboxNamespace" -}} +{{- .Values.server.sandboxNamespace | default .Release.Namespace -}} +{{- end }} + +{{/* +gRPC endpoint sandbox pods use to call back into the gateway. An explicit +.Values.server.grpcEndpoint is used verbatim. Otherwise it is derived from +the in-cluster Service DNS, release namespace, service port, and disableTls +flag — so the default value works for any release name or namespace without +override. +*/}} +{{/* +Supervisor sideload method. When supervisor.sideloadMethod is set, use it +verbatim. Otherwise auto-detect from the cluster version: the ImageVolume +feature gate is enabled by default starting in K8s v1.35 (GA in v1.36). +Clusters on v1.33-v1.34 can opt in by setting sideloadMethod explicitly +after enabling the feature gate. +*/}} +{{- define "openshell.supervisorSideloadMethod" -}} +{{- if .Values.supervisor.sideloadMethod -}} +{{- .Values.supervisor.sideloadMethod -}} +{{- else -}} +{{- if semverCompare ">=1.35-0" .Capabilities.KubeVersion.Version -}} +image-volume +{{- else -}} +init-container +{{- end -}} +{{- end -}} +{{- end }} + +{{- define "openshell.grpcEndpoint" -}} +{{- if .Values.server.grpcEndpoint -}} +{{- .Values.server.grpcEndpoint -}} +{{- else -}} +{{- $scheme := ternary "http" "https" (default false .Values.server.disableTls) -}} +{{- printf "%s://%s.%s.svc.cluster.local:%d" $scheme (include "openshell.fullname" .) .Release.Namespace (int .Values.service.port) -}} +{{- end -}} +{{- end }} diff --git a/deploy/helm/openshell/templates/certgen.yaml b/deploy/helm/openshell/templates/certgen.yaml new file mode 100644 index 000000000..61203760b --- /dev/null +++ b/deploy/helm/openshell/templates/certgen.yaml @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +{{- if and .Values.pkiInitJob.enabled .Values.certManager.enabled }} +{{- fail "pkiInitJob.enabled and certManager.enabled cannot both be true; disable one to avoid conflicting PKI sources." }} +{{- end }} +{{- if .Values.pkiInitJob.enabled }} +{{- $hookName := printf "%s-certgen" (include "openshell.fullname" .) }} +{{- $ns := .Release.Namespace }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ $hookName }} + namespace: {{ $ns }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + annotations: + helm.sh/hook: pre-install,pre-upgrade + helm.sh/hook-weight: "-30" + helm.sh/hook-delete-policy: before-hook-creation +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: {{ $hookName }} + namespace: {{ $ns }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + annotations: + helm.sh/hook: pre-install,pre-upgrade + helm.sh/hook-weight: "-30" + helm.sh/hook-delete-policy: before-hook-creation +rules: + - apiGroups: [""] + resources: ["secrets"] + verbs: ["get", "create"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: {{ $hookName }} + namespace: {{ $ns }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + annotations: + helm.sh/hook: pre-install,pre-upgrade + helm.sh/hook-weight: "-30" + helm.sh/hook-delete-policy: before-hook-creation +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: {{ $hookName }} +subjects: + - kind: ServiceAccount + name: {{ $hookName }} + namespace: {{ $ns }} +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ $hookName }} + namespace: {{ $ns }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + annotations: + helm.sh/hook: pre-install,pre-upgrade + helm.sh/hook-weight: "-20" + helm.sh/hook-delete-policy: before-hook-creation,hook-succeeded +spec: + backoffLimit: 3 + activeDeadlineSeconds: 120 + ttlSecondsAfterFinished: 300 + template: + metadata: + labels: + {{- include "openshell.selectorLabels" . | nindent 8 }} + spec: + restartPolicy: OnFailure + serviceAccountName: {{ $hookName }} + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + containers: + - name: certgen + image: {{ include "openshell.image" . | quote }} + imagePullPolicy: {{ .Values.image.pullPolicy }} + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL + env: + - name: POD_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + command: ["/usr/local/bin/openshell-gateway"] + args: + - generate-certs + - --server-secret-name={{ .Values.server.tls.certSecretName }} + - --client-secret-name={{ .Values.server.tls.clientTlsSecretName }} + - --jwt-secret-name={{ .Values.server.sandboxJwt.signingSecretName | default (printf "%s-jwt-keys" (include "openshell.fullname" .)) }} + {{- range .Values.pkiInitJob.serverDnsNames }} + - --server-san={{ . }} + {{- end }} + {{- range .Values.pkiInitJob.serverIpAddresses }} + - --server-san={{ . }} + {{- end }} +{{- end }} diff --git a/deploy/helm/openshell/templates/clusterrole.yaml b/deploy/helm/openshell/templates/clusterrole.yaml index a660aee75..30a192fc3 100644 --- a/deploy/helm/openshell/templates/clusterrole.yaml +++ b/deploy/helm/openshell/templates/clusterrole.yaml @@ -8,6 +8,14 @@ metadata: labels: {{- include "openshell.labels" . | nindent 4 }} rules: + # Validate projected sandbox ServiceAccount tokens during the + # IssueSandboxToken bootstrap exchange. + - apiGroups: + - authentication.k8s.io + resources: + - tokenreviews + verbs: + - create - apiGroups: - "" resources: diff --git a/deploy/helm/openshell/templates/gateway-config.yaml b/deploy/helm/openshell/templates/gateway-config.yaml new file mode 100644 index 000000000..bd74664c5 --- /dev/null +++ b/deploy/helm/openshell/templates/gateway-config.yaml @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +{{/* +ConfigMap holding the gateway TOML config file (RFC 0003). + +The gateway reads `/etc/openshell/gateway.toml` (mounted from this ConfigMap) +at startup. CLI flags and OPENSHELL_* env vars on the StatefulSet container +still override anything in this file. + +One value is intentionally NOT rendered here: + - server.dbUrl → passed via --db-url in the StatefulSet args +*/}} +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "openshell.fullname" . }}-config + labels: + {{- include "openshell.labels" . | nindent 4 }} +data: + gateway.toml: | + [openshell] + version = 1 + + [openshell.gateway] + bind_address = "0.0.0.0:{{ .Values.service.port }}" + {{- if .Values.service.healthPort }} + health_bind_address = "0.0.0.0:{{ .Values.service.healthPort }}" + {{- end }} + {{- if .Values.service.metricsPort }} + metrics_bind_address = "0.0.0.0:{{ .Values.service.metricsPort }}" + {{- end }} + log_level = {{ .Values.server.logLevel | quote }} + sandbox_namespace = {{ include "openshell.sandboxNamespace" . | quote }} + default_image = {{ .Values.server.sandboxImage | quote }} + supervisor_image = {{ include "openshell.supervisorImage" . | quote }} + {{- if .Values.server.hostGatewayIP }} + host_gateway_ip = {{ .Values.server.hostGatewayIP | quote }} + {{- end }} + {{- if .Values.server.enableUserNamespaces }} + enable_user_namespaces = true + {{- end }} + {{- if .Values.server.disableTls }} + disable_tls = true + {{- else }} + client_tls_secret_name = {{ .Values.server.tls.clientTlsSecretName | quote }} + {{- end }} + enable_loopback_service_http = {{ .Values.server.enableLoopbackServiceHttp }} + {{- $sans := list -}} + {{- if and .Values.certManager.enabled .Values.certManager.serverDnsNames }} + {{- $sans = .Values.certManager.serverDnsNames }} + {{- else if and .Values.pkiInitJob.enabled .Values.pkiInitJob.serverDnsNames }} + {{- $sans = .Values.pkiInitJob.serverDnsNames }} + {{- end }} + {{- if $sans }} + server_sans = [{{- range $i, $san := $sans }}{{ if $i }}, {{ end }}{{ $san | quote }}{{- end }}] + {{- end }} + + {{- if not .Values.server.disableTls }} + + [openshell.gateway.tls] + cert_path = "/etc/openshell-tls/server/tls.crt" + key_path = "/etc/openshell-tls/server/tls.key" + client_ca_path = "/etc/openshell-tls/client-ca/ca.crt" + {{- end }} + + {{- if .Values.server.auth.allowUnauthenticatedUsers }} + + [openshell.gateway.auth] + allow_unauthenticated_users = true + {{- end }} + + [openshell.gateway.gateway_jwt] + signing_key_path = "/etc/openshell-jwt/signing.pem" + public_key_path = "/etc/openshell-jwt/public.pem" + kid_path = "/etc/openshell-jwt/kid" + gateway_id = {{ .Values.server.sandboxJwt.gatewayId | default (include "openshell.fullname" .) | quote }} + ttl_secs = {{ .Values.server.sandboxJwt.ttlSecs | default 3600 }} + + {{- if .Values.server.oidc.issuer }} + + [openshell.gateway.oidc] + issuer = {{ .Values.server.oidc.issuer | quote }} + audience = {{ .Values.server.oidc.audience | quote }} + jwks_ttl_secs = {{ .Values.server.oidc.jwksTtl }} + {{- if .Values.server.oidc.rolesClaim }} + roles_claim = {{ .Values.server.oidc.rolesClaim | quote }} + {{- end }} + {{- if .Values.server.oidc.adminRole }} + admin_role = {{ .Values.server.oidc.adminRole | quote }} + {{- end }} + {{- if .Values.server.oidc.userRole }} + user_role = {{ .Values.server.oidc.userRole | quote }} + {{- end }} + {{- if .Values.server.oidc.scopesClaim }} + scopes_claim = {{ .Values.server.oidc.scopesClaim | quote }} + {{- end }} + {{- end }} + + [openshell.drivers.kubernetes] + grpc_endpoint = {{ include "openshell.grpcEndpoint" . | quote }} + service_account_name = {{ include "openshell.sandboxServiceAccountName" . | quote }} + supervisor_sideload_method = {{ include "openshell.supervisorSideloadMethod" . | quote }} + sa_token_ttl_secs = {{ .Values.server.sandboxJwt.k8sSaTokenTtlSecs | default 3600 }} + {{- if .Values.server.sandboxImagePullPolicy }} + image_pull_policy = {{ .Values.server.sandboxImagePullPolicy | quote }} + {{- end }} + {{- if .Values.server.workspaceDefaultStorageSize }} + workspace_default_storage_size = {{ .Values.server.workspaceDefaultStorageSize | quote }} + {{- end }} + {{- if .Values.supervisor.image.pullPolicy }} + supervisor_image_pull_policy = {{ .Values.supervisor.image.pullPolicy | quote }} + {{- end }} diff --git a/deploy/helm/openshell/templates/networkpolicy.yaml b/deploy/helm/openshell/templates/networkpolicy.yaml index 3e0b6f504..e85571e5f 100644 --- a/deploy/helm/openshell/templates/networkpolicy.yaml +++ b/deploy/helm/openshell/templates/networkpolicy.yaml @@ -11,7 +11,7 @@ apiVersion: networking.k8s.io/v1 kind: NetworkPolicy metadata: name: {{ include "openshell.fullname" . }}-sandbox-ssh - namespace: {{ .Values.server.sandboxNamespace }} + namespace: {{ include "openshell.sandboxNamespace" . }} labels: {{- include "openshell.labels" . | nindent 4 }} spec: diff --git a/deploy/helm/openshell/templates/pki-hook.yaml b/deploy/helm/openshell/templates/pki-hook.yaml deleted file mode 100644 index c5e83c734..000000000 --- a/deploy/helm/openshell/templates/pki-hook.yaml +++ /dev/null @@ -1,191 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -{{- if and .Values.pkiInitJob.enabled .Values.certManager.enabled }} -{{- fail "pkiInitJob.enabled and certManager.enabled cannot both be true; disable one to avoid conflicting PKI sources." }} -{{- end }} -{{- if .Values.pkiInitJob.enabled }} -{{- $hookName := printf "%s-pki-hook" (include "openshell.fullname" .) }} -{{- $ns := .Release.Namespace }} -{{- $serverSecret := .Values.server.tls.certSecretName }} -{{- $clientSecret := .Values.server.tls.clientTlsSecretName }} -{{- $sanParts := list }} -{{- range .Values.pkiInitJob.serverDnsNames }}{{- $sanParts = append $sanParts (printf "DNS:%s" .) }}{{- end }} -{{- range .Values.pkiInitJob.serverIpAddresses }}{{- $sanParts = append $sanParts (printf "IP:%s" .) }}{{- end }} -{{- $serverSans := join "," $sanParts }} -apiVersion: v1 -kind: ServiceAccount -metadata: - name: {{ $hookName }} - namespace: {{ $ns }} - labels: - {{- include "openshell.labels" . | nindent 4 }} - annotations: - helm.sh/hook: pre-install,pre-upgrade - helm.sh/hook-weight: "-30" - helm.sh/hook-delete-policy: before-hook-creation ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: Role -metadata: - name: {{ $hookName }} - namespace: {{ $ns }} - labels: - {{- include "openshell.labels" . | nindent 4 }} - annotations: - helm.sh/hook: pre-install,pre-upgrade - helm.sh/hook-weight: "-30" - helm.sh/hook-delete-policy: before-hook-creation -rules: - - apiGroups: [""] - resources: ["secrets"] - verbs: ["get", "create"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: RoleBinding -metadata: - name: {{ $hookName }} - namespace: {{ $ns }} - labels: - {{- include "openshell.labels" . | nindent 4 }} - annotations: - helm.sh/hook: pre-install,pre-upgrade - helm.sh/hook-weight: "-30" - helm.sh/hook-delete-policy: before-hook-creation -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: Role - name: {{ $hookName }} -subjects: - - kind: ServiceAccount - name: {{ $hookName }} - namespace: {{ $ns }} ---- -apiVersion: batch/v1 -kind: Job -metadata: - name: {{ $hookName }} - namespace: {{ $ns }} - labels: - {{- include "openshell.labels" . | nindent 4 }} - annotations: - helm.sh/hook: pre-install,pre-upgrade - helm.sh/hook-weight: "-20" - helm.sh/hook-delete-policy: before-hook-creation,hook-succeeded -spec: - backoffLimit: 3 - activeDeadlineSeconds: 120 - ttlSecondsAfterFinished: 300 - template: - metadata: - labels: - {{- include "openshell.selectorLabels" . | nindent 8 }} - spec: - restartPolicy: OnFailure - serviceAccountName: {{ $hookName }} - containers: - - name: pki-gen - image: {{ .Values.pkiInitJob.image.repository }}:{{ .Values.pkiInitJob.image.tag }} - imagePullPolicy: {{ .Values.pkiInitJob.image.pullPolicy }} - securityContext: - allowPrivilegeEscalation: false - capabilities: - drop: - - ALL - env: - - name: NAMESPACE - valueFrom: - fieldRef: - fieldPath: metadata.namespace - - name: SERVER_SECRET - value: {{ $serverSecret | quote }} - - name: CLIENT_SECRET - value: {{ $clientSecret | quote }} - - name: CA_DAYS - value: {{ .Values.pkiInitJob.caValidityDays | quote }} - - name: CERT_DAYS - value: {{ .Values.pkiInitJob.certValidityDays | quote }} - - name: SERVER_SANS - value: {{ $serverSans | quote }} - command: - - /bin/sh - - -c - - | - set -eu - apk add --no-cache openssl curl >/dev/null 2>&1 - - TOKEN=$(cat /var/run/secrets/kubernetes.io/serviceaccount/token) - K8S_CA=/var/run/secrets/kubernetes.io/serviceaccount/ca.crt - API=https://kubernetes.default.svc - - # Idempotency: skip only when both TLS secrets already exist. - # Checking one is insufficient — a partial cleanup can leave one half - # of the pair behind, which would cause mTLS to fail at runtime. - HTTP_SERVER=$(curl -s -o /dev/null -w "%{http_code}" \ - -H "Authorization: Bearer $TOKEN" --cacert "$K8S_CA" \ - "$API/api/v1/namespaces/$NAMESPACE/secrets/$SERVER_SECRET") - HTTP_CLIENT=$(curl -s -o /dev/null -w "%{http_code}" \ - -H "Authorization: Bearer $TOKEN" --cacert "$K8S_CA" \ - "$API/api/v1/namespaces/$NAMESPACE/secrets/$CLIENT_SECRET") - if [ "$HTTP_SERVER" = "200" ] && [ "$HTTP_CLIENT" = "200" ]; then - echo "PKI secrets already exist, skipping." - exit 0 - fi - if [ "$HTTP_SERVER" = "200" ] || [ "$HTTP_CLIENT" = "200" ]; then - echo "ERROR: partial PKI state — one secret exists but not both." >&2 - echo "To recover: kubectl delete secret -n $NAMESPACE $SERVER_SECRET $CLIENT_SECRET" >&2 - exit 1 - fi - - cd /tmp - - # CA (ECDSA P-256) - openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 -out ca.key 2>/dev/null - openssl req -new -x509 -sha256 -key ca.key -out ca.crt \ - -days "$CA_DAYS" -subj "/O=openshell/CN=openshell-ca" \ - -addext "basicConstraints=critical,CA:TRUE,pathlen:0" \ - -addext "keyUsage=critical,keyCertSign,cRLSign" - - # Server cert (ECDSA P-256) - printf "[ext]\nsubjectAltName=%s\nextendedKeyUsage=serverAuth\nkeyUsage=digitalSignature,keyEncipherment\n" \ - "$SERVER_SANS" > server.ext - openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 -out server.key 2>/dev/null - openssl req -new -sha256 -key server.key -out server.csr -subj "/CN=openshell-server" - openssl x509 -req -sha256 -in server.csr -CA ca.crt -CAkey ca.key \ - -CAcreateserial -days "$CERT_DAYS" -extensions ext -extfile server.ext -out server.crt - - # Client cert (ECDSA P-256) - printf "[ext]\nextendedKeyUsage=clientAuth\nkeyUsage=digitalSignature,keyEncipherment\n" \ - > client.ext - openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 -out client.key 2>/dev/null - openssl req -new -sha256 -key client.key -out client.csr -subj "/CN=openshell-client" - openssl x509 -req -sha256 -in client.csr -CA ca.crt -CAkey ca.key \ - -CAcreateserial -days "$CERT_DAYS" -extensions ext -extfile client.ext -out client.crt - - CA_B64=$(base64 -w0 ca.crt) - SERVER_CRT_B64=$(base64 -w0 server.crt) - SERVER_KEY_B64=$(base64 -w0 server.key) - CLIENT_CRT_B64=$(base64 -w0 client.crt) - CLIENT_KEY_B64=$(base64 -w0 client.key) - - # Create server TLS secret - printf '{"apiVersion":"v1","kind":"Secret","metadata":{"name":"%s","namespace":"%s"},"type":"kubernetes.io/tls","data":{"tls.crt":"%s","tls.key":"%s","ca.crt":"%s"}}\n' \ - "$SERVER_SECRET" "$NAMESPACE" \ - "$SERVER_CRT_B64" "$SERVER_KEY_B64" "$CA_B64" > server-secret.json - curl -sf -X POST \ - -H "Authorization: Bearer $TOKEN" -H "Content-Type: application/json" \ - --cacert "$K8S_CA" "$API/api/v1/namespaces/$NAMESPACE/secrets" \ - -d @server-secret.json - - # Create client TLS secret - printf '{"apiVersion":"v1","kind":"Secret","metadata":{"name":"%s","namespace":"%s"},"type":"kubernetes.io/tls","data":{"tls.crt":"%s","tls.key":"%s","ca.crt":"%s"}}\n' \ - "$CLIENT_SECRET" "$NAMESPACE" \ - "$CLIENT_CRT_B64" "$CLIENT_KEY_B64" "$CA_B64" > client-secret.json - curl -sf -X POST \ - -H "Authorization: Bearer $TOKEN" -H "Content-Type: application/json" \ - --cacert "$K8S_CA" "$API/api/v1/namespaces/$NAMESPACE/secrets" \ - -d @client-secret.json - - rm -f *.key *.csr *.crt *.ext *.srl *.json - echo "PKI secrets created." -{{- end }} diff --git a/deploy/helm/openshell/templates/role.yaml b/deploy/helm/openshell/templates/role.yaml index 1d756117c..5ecc4428a 100644 --- a/deploy/helm/openshell/templates/role.yaml +++ b/deploy/helm/openshell/templates/role.yaml @@ -5,6 +5,7 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: {{ include "openshell.fullname" . }}-sandbox + namespace: {{ include "openshell.sandboxNamespace" . }} labels: {{- include "openshell.labels" . | nindent 4 }} rules: @@ -29,3 +30,15 @@ rules: - get - list - watch + # Per-sandbox identity: TokenReview authenticates the projected token from + # the configured sandbox service account, then the gateway resolves the + # returned pod name and UID to the pod's `openshell.io/sandbox-id` + # annotation. patch is intentionally NOT granted — the annotation is set + # once at pod create and must remain immutable for the lifetime of the + # sandbox. + - apiGroups: + - "" + resources: + - pods + verbs: + - get diff --git a/deploy/helm/openshell/templates/rolebinding.yaml b/deploy/helm/openshell/templates/rolebinding.yaml index 2bb3c7d08..e5233f753 100644 --- a/deploy/helm/openshell/templates/rolebinding.yaml +++ b/deploy/helm/openshell/templates/rolebinding.yaml @@ -5,6 +5,7 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: {{ include "openshell.fullname" . }}-sandbox + namespace: {{ include "openshell.sandboxNamespace" . }} labels: {{- include "openshell.labels" . | nindent 4 }} roleRef: diff --git a/deploy/helm/openshell/templates/serviceaccount.yaml b/deploy/helm/openshell/templates/serviceaccount.yaml index 1f03f8e94..a98ad5363 100644 --- a/deploy/helm/openshell/templates/serviceaccount.yaml +++ b/deploy/helm/openshell/templates/serviceaccount.yaml @@ -13,3 +13,19 @@ metadata: {{- toYaml . | nindent 4 }} {{- end }} {{- end }} +{{- if and .Values.serviceAccount.create .Values.sandboxServiceAccount.create }} +--- +{{- end }} +{{- if .Values.sandboxServiceAccount.create }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ include "openshell.sandboxServiceAccountName" . }} + namespace: {{ include "openshell.sandboxNamespace" . }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + {{- with .Values.sandboxServiceAccount.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +{{- end }} diff --git a/deploy/helm/openshell/templates/ssh-handshake-secret-hook.yaml b/deploy/helm/openshell/templates/ssh-handshake-secret-hook.yaml deleted file mode 100644 index ad444847b..000000000 --- a/deploy/helm/openshell/templates/ssh-handshake-secret-hook.yaml +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -{{- if .Values.sshHandshake.hook.enabled }} -{{- $name := .Values.server.sshHandshakeSecretName }} -{{- $ns := .Release.Namespace }} -{{- $existing := lookup "v1" "Secret" $ns $name }} -{{- if not $existing }} -{{- $hex := .Values.sshHandshake.value }} -{{- if not $hex }} -{{- $hex = printf "%s%s" (uuidv4 | replace "-" "") (uuidv4 | replace "-" "") }} -{{- end }} -apiVersion: v1 -kind: Secret -metadata: - name: {{ $name }} - namespace: {{ $ns }} - labels: - {{- include "openshell.labels" . | nindent 4 }} - annotations: - helm.sh/hook: pre-install,pre-upgrade - helm.sh/hook-weight: "-20" -type: Opaque -stringData: - secret: {{ $hex | quote }} -{{- end }} -{{- end }} diff --git a/deploy/helm/openshell/templates/statefulset.yaml b/deploy/helm/openshell/templates/statefulset.yaml index 2db3a0c5f..5dd4f1caf 100644 --- a/deploy/helm/openshell/templates/statefulset.yaml +++ b/deploy/helm/openshell/templates/statefulset.yaml @@ -15,10 +15,15 @@ spec: {{- include "openshell.selectorLabels" . | nindent 6 }} template: metadata: - {{- with .Values.podAnnotations }} annotations: + # Roll the StatefulSet when the rendered gateway TOML changes — the + # gateway only reads /etc/openshell/gateway.toml at startup, so + # without this annotation a `helm upgrade` that only mutates the + # ConfigMap would leave pods running with stale config. + checksum/gateway-config: {{ include (print $.Template.BasePath "/gateway-config.yaml") . | sha256sum }} + {{- with .Values.podAnnotations }} {{- toYaml . | nindent 8 }} - {{- end }} + {{- end }} labels: {{- include "openshell.labels" . | nindent 8 }} {{- with .Values.podLabels }} @@ -47,106 +52,47 @@ spec: image: {{ include "openshell.image" . | quote }} imagePullPolicy: {{ .Values.image.pullPolicy }} args: - - --bind-address - - "0.0.0.0" - - --port - - {{ .Values.service.port | quote }} - - --health-port - - {{ .Values.service.healthPort | quote }} - {{- if .Values.service.metricsPort }} - - --metrics-port - - {{ .Values.service.metricsPort | quote }} - {{- end }} - - --log-level - - {{ .Values.server.logLevel }} + - --config + - /etc/openshell/gateway.toml - --db-url - {{ .Values.server.dbUrl | quote }} env: - - name: OPENSHELL_SANDBOX_NAMESPACE - value: {{ .Values.server.sandboxNamespace | quote }} - - name: OPENSHELL_SANDBOX_IMAGE - value: {{ .Values.server.sandboxImage | quote }} - {{- if .Values.server.sandboxImagePullPolicy }} - - name: OPENSHELL_SANDBOX_IMAGE_PULL_POLICY - value: {{ .Values.server.sandboxImagePullPolicy | quote }} - {{- end }} - - name: OPENSHELL_SUPERVISOR_IMAGE - value: {{ include "openshell.supervisorImage" . | quote }} - {{- if .Values.supervisor.image.pullPolicy }} - - name: OPENSHELL_SUPERVISOR_IMAGE_PULL_POLICY - value: {{ .Values.supervisor.image.pullPolicy | quote }} - {{- end }} - - name: OPENSHELL_GRPC_ENDPOINT - value: {{ if .Values.server.disableTls }}{{ .Values.server.grpcEndpoint | replace "https://" "http://" | quote }}{{ else }}{{ .Values.server.grpcEndpoint | quote }}{{ end }} - {{- if .Values.server.sshGatewayHost }} - - name: OPENSHELL_SSH_GATEWAY_HOST - value: {{ .Values.server.sshGatewayHost | quote }} - {{- end }} - {{- if .Values.server.sshGatewayPort }} - - name: OPENSHELL_SSH_GATEWAY_PORT - value: {{ .Values.server.sshGatewayPort | quote }} - {{- end }} - {{- if .Values.server.hostGatewayIP }} - - name: OPENSHELL_HOST_GATEWAY_IP - value: {{ .Values.server.hostGatewayIP | quote }} - {{- end }} - - name: OPENSHELL_SSH_HANDSHAKE_SECRET - valueFrom: - secretKeyRef: - name: {{ .Values.server.sshHandshakeSecretName | quote }} - key: secret - {{- if .Values.server.disableTls }} - - name: OPENSHELL_DISABLE_TLS - value: "true" - {{- else }} - - name: OPENSHELL_TLS_CERT - value: /etc/openshell-tls/server/tls.crt - - name: OPENSHELL_TLS_KEY - value: /etc/openshell-tls/server/tls.key - - name: OPENSHELL_TLS_CLIENT_CA - value: /etc/openshell-tls/client-ca/ca.crt - - name: OPENSHELL_CLIENT_TLS_SECRET_NAME - value: {{ .Values.server.tls.clientTlsSecretName | quote }} - {{- if .Values.server.disableGatewayAuth }} - - name: OPENSHELL_DISABLE_GATEWAY_AUTH - value: "true" - {{- end }} - {{- end }} - {{- if .Values.server.oidc.issuer }} - - name: OPENSHELL_OIDC_ISSUER - value: {{ .Values.server.oidc.issuer | quote }} - - name: OPENSHELL_OIDC_AUDIENCE - value: {{ .Values.server.oidc.audience | quote }} - - name: OPENSHELL_OIDC_JWKS_TTL - value: {{ .Values.server.oidc.jwksTtl | quote }} - {{- if .Values.server.oidc.rolesClaim }} - - name: OPENSHELL_OIDC_ROLES_CLAIM - value: {{ .Values.server.oidc.rolesClaim | quote }} - {{- end }} - {{- if .Values.server.oidc.adminRole }} - - name: OPENSHELL_OIDC_ADMIN_ROLE - value: {{ .Values.server.oidc.adminRole | quote }} - {{- end }} - {{- if .Values.server.oidc.userRole }} - - name: OPENSHELL_OIDC_USER_ROLE - value: {{ .Values.server.oidc.userRole | quote }} - {{- end }} - {{- if .Values.server.oidc.scopesClaim }} - - name: OPENSHELL_OIDC_SCOPES_CLAIM - value: {{ .Values.server.oidc.scopesClaim | quote }} - {{- end }} + # All gateway settings live in the ConfigMap-backed TOML file + # mounted at /etc/openshell/gateway.toml. The only env var below + # is a process-level setting consumed by libraries outside + # gateway code (currently just SSL_CERT_FILE for OIDC issuer TLS). + {{- if and .Values.server.oidc.issuer .Values.server.oidc.caConfigMapName }} + # OIDC issuer custom-CA: rustls/reqwest read SSL_CERT_FILE for + # outbound TLS verification. This is a process-level env var + # consumed by the TLS stack itself, not by gateway code, so it + # cannot be represented in the gateway TOML schema. + - name: SSL_CERT_FILE + value: /etc/openshell-tls/oidc-ca/ca.crt {{- end }} volumeMounts: - name: openshell-data mountPath: /var/openshell + - name: gateway-config + mountPath: /etc/openshell + readOnly: true + - name: sandbox-jwt + mountPath: /etc/openshell-jwt + readOnly: true {{- if not .Values.server.disableTls }} - name: tls-cert mountPath: /etc/openshell-tls/server readOnly: true + {{- if or .Values.server.tls.clientCaSecretName .Values.pkiInitJob.enabled (and .Values.certManager.enabled .Values.certManager.clientCaFromServerTlsSecret) }} - name: tls-client-ca mountPath: /etc/openshell-tls/client-ca readOnly: true {{- end }} + {{- end }} + {{- if and .Values.server.oidc.issuer .Values.server.oidc.caConfigMapName }} + - name: oidc-ca + mountPath: /etc/openshell-tls/oidc-ca + readOnly: true + {{- end }} ports: - name: grpc containerPort: {{ .Values.service.port }} @@ -185,10 +131,18 @@ spec: resources: {{- toYaml .Values.resources | nindent 12 }} volumes: + - name: gateway-config + configMap: + name: {{ include "openshell.fullname" . }}-config + - name: sandbox-jwt + secret: + secretName: {{ .Values.server.sandboxJwt.signingSecretName | default (printf "%s-jwt-keys" (include "openshell.fullname" .)) }} + defaultMode: 0400 {{- if not .Values.server.disableTls }} - name: tls-cert secret: secretName: {{ .Values.server.tls.certSecretName }} + {{- if or .Values.server.tls.clientCaSecretName .Values.pkiInitJob.enabled (and .Values.certManager.enabled .Values.certManager.clientCaFromServerTlsSecret) }} - name: tls-client-ca secret: {{- if or .Values.pkiInitJob.enabled (and .Values.certManager.enabled .Values.certManager.clientCaFromServerTlsSecret) }} @@ -200,6 +154,12 @@ spec: secretName: {{ .Values.server.tls.clientCaSecretName }} {{- end }} {{- end }} + {{- end }} + {{- if and .Values.server.oidc.issuer .Values.server.oidc.caConfigMapName }} + - name: oidc-ca + configMap: + name: {{ .Values.server.oidc.caConfigMapName }} + {{- end }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/deploy/helm/openshell/tests/gateway_config_test.yaml b/deploy/helm/openshell/tests/gateway_config_test.yaml new file mode 100644 index 000000000..f17203c6f --- /dev/null +++ b/deploy/helm/openshell/tests/gateway_config_test.yaml @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +suite: gateway TOML config shape +templates: + - templates/gateway-config.yaml + - templates/statefulset.yaml +release: + name: openshell + namespace: my-namespace + +tests: + # Regression for Drew's P2: a ConfigMap-only mutation in `helm upgrade` + # must roll the StatefulSet, otherwise pods keep running with stale config. + - it: annotates the StatefulSet pod template with a ConfigMap checksum + template: templates/statefulset.yaml + asserts: + - exists: + path: spec.template.metadata.annotations["checksum/gateway-config"] + + - it: mounts the OIDC CA bundle when TLS is disabled + template: templates/statefulset.yaml + set: + server.disableTls: true + server.oidc.issuer: https://issuer.example.com + server.oidc.caConfigMapName: openshell-oidc-ca + asserts: + - equal: + path: spec.template.spec.containers[0].volumeMounts[3].name + value: oidc-ca + - equal: + path: spec.template.spec.containers[0].volumeMounts[3].mountPath + value: /etc/openshell-tls/oidc-ca + - equal: + path: spec.template.spec.volumes[2].name + value: oidc-ca + - equal: + path: spec.template.spec.volumes[2].configMap.name + value: openshell-oidc-ca + + # Regression for the P1 bug Drew flagged: grpc_endpoint MUST live in the + # Kubernetes driver table, not in [openshell.gateway]. The gateway-side + # schema has `deny_unknown_fields` and no `grpc_endpoint` field, so writing + # it at gateway scope makes `config_file::load` reject the default install. + - it: renders grpc_endpoint under [openshell.drivers.kubernetes], not [openshell.gateway] + template: templates/gateway-config.yaml + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: '(?ms)\[openshell\.drivers\.kubernetes\].*?grpc_endpoint' + - notMatchRegex: + path: data["gateway.toml"] + pattern: '(?ms)\[openshell\.gateway\][^\[]*?grpc_endpoint' + + - it: renders the sandbox service account name under [openshell.drivers.kubernetes] + template: templates/gateway-config.yaml + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: '(?ms)\[openshell\.drivers\.kubernetes\].*?service_account_name\s*=\s*"openshell-sandbox"' + + - it: does not render local mTLS user auth for Kubernetes deployments + template: templates/gateway-config.yaml + asserts: + - notMatchRegex: + path: data["gateway.toml"] + pattern: '\[openshell\.gateway\.mtls_auth\]' + + - it: does not allow unauthenticated users by default + template: templates/gateway-config.yaml + asserts: + - notMatchRegex: + path: data["gateway.toml"] + pattern: '\[openshell\.gateway\.auth\]' + + - it: renders explicit unauthenticated user dev mode when enabled + template: templates/gateway-config.yaml + set: + server.auth.allowUnauthenticatedUsers: true + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: '(?ms)\[openshell\.gateway\.auth\].*?allow_unauthenticated_users\s*=\s*true' + + - it: uses the configured existing sandbox service account name + template: templates/gateway-config.yaml + set: + sandboxServiceAccount.create: false + sandboxServiceAccount.name: precreated-sandbox + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: '(?ms)\[openshell\.drivers\.kubernetes\].*?service_account_name\s*=\s*"precreated-sandbox"' + + - it: omits server_sans when no DNS SANs are configured + template: templates/gateway-config.yaml + asserts: + - notMatchRegex: + path: data["gateway.toml"] + pattern: 'server_sans\s*=' + + - it: emits disable_tls=true and omits the [openshell.gateway.tls] section when disableTls is set + set: + server.disableTls: true + certManager.enabled: false + pkiInitJob.enabled: false + template: templates/gateway-config.yaml + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: 'disable_tls\s*=\s*true' + - notMatchRegex: + path: data["gateway.toml"] + pattern: '\[openshell\.gateway\.tls\]' + + - it: renders server_sans from certManager.serverDnsNames + set: + certManager.enabled: true + certManager.serverDnsNames: + - openshell + - "*.dev.openshell.localhost" + pkiInitJob.enabled: false + template: templates/gateway-config.yaml + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: 'server_sans\s*=\s*\["openshell", "\*\.dev\.openshell\.localhost"\]' diff --git a/deploy/helm/openshell/tests/sandbox_namespace_test.yaml b/deploy/helm/openshell/tests/sandbox_namespace_test.yaml new file mode 100644 index 000000000..ee89fce53 --- /dev/null +++ b/deploy/helm/openshell/tests/sandbox_namespace_test.yaml @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +suite: sandboxNamespace defaulting +templates: + - templates/gateway-config.yaml + - templates/networkpolicy.yaml + - templates/role.yaml + - templates/rolebinding.yaml + - templates/serviceaccount.yaml +release: + name: openshell + namespace: my-namespace + +tests: + - it: defaults sandbox_namespace to release namespace in the TOML config + template: templates/gateway-config.yaml + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: 'sandbox_namespace\s*=\s*"my-namespace"' + + - it: uses explicit sandboxNamespace when set + template: templates/gateway-config.yaml + set: + server.sandboxNamespace: other-ns + asserts: + - matchRegex: + path: data["gateway.toml"] + pattern: 'sandbox_namespace\s*=\s*"other-ns"' + + - it: defaults NetworkPolicy namespace to release namespace + template: templates/networkpolicy.yaml + set: + networkPolicy.enabled: true + asserts: + - equal: + path: metadata.namespace + value: my-namespace + + - it: uses explicit sandboxNamespace for NetworkPolicy + template: templates/networkpolicy.yaml + set: + networkPolicy.enabled: true + server.sandboxNamespace: other-ns + asserts: + - equal: + path: metadata.namespace + value: other-ns + + - it: uses explicit sandboxNamespace for sandbox RBAC + template: templates/role.yaml + set: + server.sandboxNamespace: other-ns + asserts: + - equal: + path: metadata.namespace + value: other-ns + + - it: uses explicit sandboxNamespace for sandbox RoleBinding + template: templates/rolebinding.yaml + set: + server.sandboxNamespace: other-ns + asserts: + - equal: + path: metadata.namespace + value: other-ns + + - it: uses explicit sandboxNamespace for sandbox ServiceAccount + template: templates/serviceaccount.yaml + set: + server.sandboxNamespace: other-ns + asserts: + - equal: + path: metadata.namespace + value: other-ns + documentIndex: 1 diff --git a/deploy/helm/openshell/tests/sandbox_service_account_test.yaml b/deploy/helm/openshell/tests/sandbox_service_account_test.yaml new file mode 100644 index 000000000..c42641582 --- /dev/null +++ b/deploy/helm/openshell/tests/sandbox_service_account_test.yaml @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +suite: sandbox service account +templates: + - templates/serviceaccount.yaml +release: + name: openshell + namespace: my-namespace + +tests: + - it: creates gateway and sandbox service accounts by default + asserts: + - hasDocuments: + count: 2 + - equal: + path: metadata.name + value: openshell + documentIndex: 0 + - equal: + path: metadata.name + value: openshell-sandbox + documentIndex: 1 + + - it: uses the configured existing sandbox service account name + set: + sandboxServiceAccount.create: false + sandboxServiceAccount.name: precreated-sandbox + asserts: + - hasDocuments: + count: 1 diff --git a/deploy/helm/openshell/values.yaml b/deploy/helm/openshell/values.yaml index f8e090721..2d707168c 100644 --- a/deploy/helm/openshell/values.yaml +++ b/deploy/helm/openshell/values.yaml @@ -3,228 +3,315 @@ # Default values for OpenShell +# -- Number of OpenShell gateway replicas. replicaCount: 1 image: + # -- Gateway image repository. repository: ghcr.io/nvidia/openshell/gateway + # -- Gateway image pull policy. pullPolicy: IfNotPresent + # -- Gateway image tag. Defaults to the chart appVersion when empty. tag: "" -# Supervisor image — provides the openshell-sandbox binary that is copied into -# sandbox pods via an init container. tag defaults to appVersion (same as the -# gateway image) so both stay in sync when the chart is released. +# Supervisor image - provides the openshell-sandbox binary injected into sandbox +# pods. tag defaults to appVersion (same as the gateway image) so both stay in +# sync when the chart is released. supervisor: image: + # -- Supervisor image repository. repository: ghcr.io/nvidia/openshell/supervisor + # -- Supervisor image pull policy. Defaults to the gateway image pull policy when empty. pullPolicy: "" + # -- Supervisor image tag. Defaults to the chart appVersion when empty. tag: "" + # -- How the supervisor binary is delivered into sandbox pods. + # Empty (default) = auto-detect from cluster version: + # K8s >= v1.35 -> "image-volume" (ImageVolume enabled by default; GA in v1.36) + # K8s < v1.35 -> "init-container" (copies via init container + emptyDir) + # On K8s v1.33-v1.34 with the ImageVolume feature gate manually enabled, + # set this to "image-volume" explicitly. + sideloadMethod: "" +# -- Image pull secrets attached to gateway and helper pods. imagePullSecrets: [] +# -- Override the chart name used in generated resource names. nameOverride: "openshell" +# -- Override the full generated resource name. fullnameOverride: "" serviceAccount: + # -- Create a service account for the gateway. create: true + # -- Annotations to add to the generated service account. annotations: {} + # -- Existing service account name to use when serviceAccount.create is false. name: "" +sandboxServiceAccount: + # -- Create a service account for sandbox pods. + create: true + # -- Annotations to add to the generated sandbox service account. + annotations: {} + # -- Existing service account name for sandbox pods when sandboxServiceAccount.create is false. + name: "" + +# -- Extra annotations to add to the gateway pod. podAnnotations: {} +# -- Extra labels to add to the gateway pod. podLabels: {} podSecurityContext: + # -- fsGroup assigned to the gateway pod. fsGroup: 1000 securityContext: + # -- Require the gateway container to run as a non-root user. runAsNonRoot: true + # -- UID assigned to the gateway container. runAsUser: 1000 + # -- Whether the gateway container can gain additional privileges. allowPrivilegeEscalation: false capabilities: + # -- Linux capabilities dropped from the gateway container. drop: - ALL service: + # -- Kubernetes Service type for the gateway. type: ClusterIP + # -- Gateway gRPC/HTTP service port. port: 8080 + # -- Gateway health service port. healthPort: 8081 + # -- Gateway metrics service port. metricsPort: 9090 # Pod restart behavior and health probe tuning. podLifecycle: + # -- Grace period, in seconds, before Kubernetes terminates the gateway pod. terminationGracePeriodSeconds: 5 probes: startup: + # -- Startup probe period, in seconds. periodSeconds: 2 + # -- Startup probe timeout, in seconds. timeoutSeconds: 1 + # -- Startup probe failure threshold before the container is killed. failureThreshold: 30 liveness: + # -- Liveness probe initial delay, in seconds. initialDelaySeconds: 2 + # -- Liveness probe period, in seconds. periodSeconds: 5 + # -- Liveness probe timeout, in seconds. timeoutSeconds: 1 + # -- Liveness probe failure threshold before the container is restarted. failureThreshold: 3 readiness: + # -- Readiness probe initial delay, in seconds. initialDelaySeconds: 1 + # -- Readiness probe period, in seconds. periodSeconds: 2 + # -- Readiness probe timeout, in seconds. timeoutSeconds: 1 + # -- Readiness probe failure threshold before the pod is marked not ready. failureThreshold: 3 +# -- Gateway pod resource requests and limits. resources: {} +# -- Node selector for the gateway pod. nodeSelector: {} +# -- Tolerations for the gateway pod. tolerations: [] +# -- Affinity rules for the gateway pod. affinity: {} # Server configuration server: + # -- Gateway log level. logLevel: info - sandboxNamespace: openshell + # -- Namespace where sandbox pods are created. Defaults to the Helm release + # namespace (.Release.Namespace) when left empty. + sandboxNamespace: "" + # -- Gateway database URL. dbUrl: "sqlite:/var/openshell/openshell.db" + # -- Default sandbox image used when requests do not specify one. sandboxImage: "ghcr.io/nvidia/openshell-community/sandboxes/base:latest" - # Kubernetes imagePullPolicy for sandbox pods. Empty = Kubernetes default - # (Always for :latest, IfNotPresent otherwise). Set to "Always" for dev + # -- Kubernetes imagePullPolicy for sandbox pods. Empty = Kubernetes default + # (Always for :latest, IfNotPresent otherwise). Set to "Always" for dev # clusters so new images are picked up without manual eviction. sandboxImagePullPolicy: "" - # gRPC endpoint for sandboxes to callback to OpenShell (must be reachable from pods) - grpcEndpoint: "https://openshell.openshell.svc.cluster.local:8080" - # Public host/port returned to CLI clients for SSH proxy CONNECT requests. - # For local clusters the default 127.0.0.1:8080 is correct; for remote - # clusters these should be set to the externally reachable host and port. - sshGatewayHost: "" - sshGatewayPort: 0 + # -- Default storage size for the workspace PVC in sandbox pods. + # Uses Kubernetes quantity syntax (e.g. "2Gi", "10Gi", "500Mi"). + # Empty = built-in default (2Gi). + workspaceDefaultStorageSize: "" + # -- gRPC endpoint sandboxes call back into the gateway. Leave empty to derive + # it from the chart fullname, release namespace, service port, and + # disableTls flag, for example https://openshell.openshell.svc.cluster.local:8080. + # Override only when sandboxes must reach the gateway via a different + # hostname (e.g. an external ingress or a host alias). + grpcEndpoint: "" # TLS configuration for the server. The server always terminates mTLS # directly and requires client certificates. - # Name of the Kubernetes Secret holding the NSSH1 HMAC handshake key. - # The secret must contain a `secret` key with the hex-encoded HMAC key. - # By default a pre-install/pre-upgrade hook creates it when missing (see sshHandshake). - sshHandshakeSecretName: "openshell-ssh-handshake" - # Host gateway IP for sandbox pod hostAliases. When set, sandbox pods get + # -- Host gateway IP for sandbox pod hostAliases. When set, sandbox pods get # hostAliases entries mapping host.docker.internal and host.openshell.internal # to this IP, allowing them to reach services running on the Docker host. # Auto-detected by the cluster entrypoint script. hostGatewayIP: "" - # Disable gateway authentication (mTLS client certificate requirement). - # Set to true when the gateway sits behind a reverse proxy (e.g. - # Cloudflare Tunnel) that terminates TLS. - disableGatewayAuth: false - # Disable TLS entirely — the server listens on plaintext HTTP. + # -- Enable Kubernetes user namespace isolation (hostUsers: false) for sandbox + # pods. Requires Kubernetes 1.33+ with user namespace support available + # (beta through 1.35, GA in 1.36+), plus a supporting container runtime and + # Linux 5.12+. When enabled, container UID 0 maps to an unprivileged host + # UID and capabilities become namespaced. + enableUserNamespaces: false + # -- Disable TLS entirely - the server listens on plaintext HTTP. # Set to true when a reverse proxy / tunnel terminates TLS at the edge. disableTls: false + # -- Enable plaintext HTTP routing for loopback sandbox service URLs on + # TLS-enabled gateways. + enableLoopbackServiceHttp: true + auth: + # -- UNSAFE: accept unauthenticated CLI/user requests as a local developer + # principal. Intended only for trusted local Skaffold/k3d development or a + # fully trusted fronting proxy. Leave false for shared or production clusters. + allowUnauthenticatedUsers: false tls: - # K8s secret (type kubernetes.io/tls) with tls.crt and tls.key for the server + # -- K8s secret (type kubernetes.io/tls) with tls.crt and tls.key for the server. certSecretName: openshell-server-tls - # K8s secret with ca.crt for client certificate verification + # -- K8s secret with ca.crt for client certificate verification (mTLS). + # Set to "" to disable mTLS and run HTTPS-only (use OIDC for auth instead). clientCaSecretName: openshell-server-client-ca - # K8s secret mounted into sandbox pods for mTLS to the server + # -- K8s secret mounted into sandbox pods for mTLS to the server. clientTlsSecretName: openshell-client-tls + # Gateway-minted sandbox JWT signing keys. The pre-install certgen hook + # generates an Ed25519 keypair and writes it to a secret containing + # signing.pem (PKCS#8), public.pem (SPKI), and kid (plain text). + sandboxJwt: + # -- Name of the Opaque Secret holding the signing key material. Empty + # falls back to the chart fullname with "-jwt-keys" appended. + signingSecretName: "" + # -- Stable gateway identity embedded in iss/aud of every minted token. + # Defaults to the release name so HA replicas share identity. + gatewayId: "" + # -- Token TTL in seconds. Defaults to 3600 (1h). + ttlSecs: 3600 + # -- Lifetime (seconds) of the projected ServiceAccount token kubelet + # writes into each sandbox pod for the IssueSandboxToken bootstrap + # exchange. Kubelet enforces a minimum of 600s; the driver clamps + # values outside [600, 86400]. Default 3600 — generous, since the + # supervisor consumes the token within seconds of pod start. + k8sSaTokenTtlSecs: 3600 # OIDC (OpenID Connect) configuration for JWT-based authentication. # When issuer is set, the server validates Bearer tokens on gRPC requests. oidc: - # OIDC issuer URL (e.g. https://keycloak.example.com/realms/openshell). + # -- OIDC issuer URL (e.g. https://keycloak.example.com/realms/openshell). issuer: "" - # Expected audience claim for the API resource server. + # -- Expected audience claim for the API resource server. # This should match the server's --oidc-audience, NOT the CLI client ID. audience: "openshell-cli" - # JWKS key cache TTL in seconds. + # -- JWKS key cache TTL in seconds. jwksTtl: 3600 - # Dot-separated path to the roles array in the JWT claims. + # -- Dot-separated path to the roles array in the JWT claims. # Keycloak: "realm_access.roles", Entra ID: "roles", Okta: "groups". rolesClaim: "" - # Role name for admin access. Leave empty (with userRole also empty) for + # -- Role name for admin access. Leave empty (with userRole also empty) for # authentication-only mode. Both must be set or both empty. adminRole: "" - # Role name for standard user access. + # -- Role name for standard user access. userRole: "" - # Dot-separated path to the scopes array in the JWT claims. + # -- Dot-separated path to the scopes array in the JWT claims. scopesClaim: "" + # -- Name of a ConfigMap containing a CA certificate bundle (key: ca.crt) + # for verifying the OIDC issuer's TLS certificate. Required when the + # issuer uses a non-public CA (e.g. OpenShift ingress, private PKI). + caConfigMapName: "" # NetworkPolicy restricting SSH ingress on sandbox pods to the gateway only. networkPolicy: + # -- Create a NetworkPolicy restricting SSH ingress on sandbox pods to the gateway. enabled: true -# NSSH1 SSH gateway handshake Secret (`server.sshHandshakeSecretName`). -# Helm hook creates it only when the Secret does not already exist (safe upgrades). -# Set sshHandshake.value from a gitignored values file for a stable dev secret. -sshHandshake: - hook: - enabled: true - # 64 hex chars (32 bytes), matching openshell-bootstrap. If empty, Helm generates - # a random value at install template time (two UUIDs, dashes stripped). - value: "" - # PKI bootstrap via a pre-install/pre-upgrade hook Job. -# Generates a self-signed CA, server TLS secret, and client TLS secret using -# openssl (ECDSA P-256) inside the cluster. Key material is written directly to -# K8s Secrets and never appears in Helm release history. Idempotent: existing -# secrets are left untouched on upgrade. -# Air-gapped environments should override pkiInitJob.image with an image that has -# openssl and curl pre-installed (the default alpine image fetches them at runtime). +# Runs `openshell-gateway generate-certs` to create the server and client TLS +# Secrets in-cluster. Key material is written directly to K8s Secrets and +# never appears in Helm release history. Idempotent: existing secrets are +# left untouched on upgrade. Reuses the gateway image - no extra image to +# mirror in air-gapped environments. +# +# The server certificate already includes the built-in cluster SANs +# (`openshell`, `openshell.openshell.svc`, the cluster.local FQDN, `localhost`, +# `openshell.localhost`, `*.openshell.localhost`, `host.docker.internal`, and +# `127.0.0.1`) baked into the gateway binary. The lists below are additional +# SANs appended on top. Wildcard DNS SANs also enable sandbox service URLs under +# that domain, for example `*.apps.example.com` enables +# `--.apps.example.com`. pkiInitJob: + # -- Run a pre-install/pre-upgrade Job that creates gateway and client mTLS Secrets. enabled: true - image: - repository: alpine - tag: "3" - pullPolicy: IfNotPresent - # Days until the CA certificate expires. - caValidityDays: 3650 - # Days until server and client certificates expire. - certValidityDays: 3650 - # DNS SANs for the server certificate. - serverDnsNames: - - openshell - - openshell.openshell.svc - - openshell.openshell.svc.cluster.local - - localhost - - host.docker.internal - # IP SANs for the server certificate. - serverIpAddresses: - - 127.0.0.1 + # -- Extra DNS SANs to append to the server certificate. + serverDnsNames: [] + # -- Extra IP SANs to append to the server certificate. + serverIpAddresses: [] # cert-manager Certificate/Issuer resources (requires cert-manager CRDs in-cluster). # Uses namespaced Issuers only (no ClusterIssuer). Does not install cert-manager itself. certManager: + # -- Create cert-manager Issuer and Certificate resources instead of using the PKI bootstrap Job. enabled: false - # Secret created for the intermediate CA (Certificate with isCA: true). + # -- Secret created for the intermediate CA (Certificate with isCA: true). caSecretName: openshell-ca-tls - # Mount gateway client CA from the server TLS secret's ca.crt (populated by + # -- Mount gateway client CA from the server TLS secret's ca.crt (populated by # cert-manager for certs issued by a CA Issuer). Avoids a separate # openshell-server-client-ca Secret. clientCaFromServerTlsSecret: true + # -- Duration for cert-manager-issued certificates. certificateDuration: 8760h + # -- Renewal window for cert-manager-issued certificates. certificateRenewBefore: 720h + # -- DNS SANs on the cert-manager-issued server certificate. serverDnsNames: - openshell - openshell.openshell.svc - openshell.openshell.svc.cluster.local - localhost + - openshell.localhost + - "*.openshell.localhost" - host.docker.internal + # -- IP SANs on the cert-manager-issued server certificate. serverIpAddresses: - 127.0.0.1 -# Kubernetes Gateway API — HTTPRoute and Gateway resources. +# Kubernetes Gateway API - HTTPRoute and Gateway resources. # Requires a Gateway API controller in the cluster. Install Envoy Gateway via # the skaffold.yaml releases or independently: # helm install eg oci://docker.io/envoyproxy/gateway-helm \ # --version v1.4.1 -n envoy-gateway-system --create-namespace grpcRoute: + # -- Create a Gateway API GRPCRoute for the gateway service. enabled: false - # Hostnames the GRPCRoute matches on. Leave empty to match all hosts. + # -- Hostnames the GRPCRoute matches on. Leave empty to match all hosts. hostnames: [] gateway: - # When true, a Gateway resource is created in the release namespace. + # -- When true, a Gateway resource is created in the release namespace. # Set to false and provide name/namespace to attach to a pre-existing Gateway. create: false - # GatewayClass to reference. Envoy Gateway installs one named "eg". + # -- GatewayClass to reference. Envoy Gateway installs one named "eg". className: "eg" - # Name of the Gateway resource. Defaults to the chart fullname. + # -- Name of the Gateway resource. Defaults to the chart fullname. name: "" - # Namespace of the Gateway referenced by the GRPCRoute parentRef. + # -- Namespace of the Gateway referenced by the GRPCRoute parentRef. # Defaults to the release namespace. namespace: "" # Listener settings (only used when gateway.create is true). listener: + # -- Listener port for the generated Gateway resource. port: 80 + # -- Listener protocol for the generated Gateway resource. protocol: HTTP - # "Same" restricts attached routes to the release namespace; "All" allows any namespace. + # -- "Same" restricts attached routes to the release namespace; "All" allows any namespace. allowedRoutes: Same diff --git a/deploy/kube/manifests/openshell-helmchart.yaml b/deploy/kube/manifests/openshell-helmchart.yaml index eba79364c..3ca6e3b90 100644 --- a/deploy/kube/manifests/openshell-helmchart.yaml +++ b/deploy/kube/manifests/openshell-helmchart.yaml @@ -33,11 +33,7 @@ spec: sandboxImagePullPolicy: __SANDBOX_IMAGE_PULL_POLICY__ supervisorImage: ghcr.io/nvidia/openshell/supervisor:latest dbUrl: __DB_URL__ - sshGatewayHost: __SSH_GATEWAY_HOST__ - sshGatewayPort: __SSH_GATEWAY_PORT__ - grpcEndpoint: "https://openshell.openshell.svc.cluster.local:8080" hostGatewayIP: __HOST_GATEWAY_IP__ - disableGatewayAuth: __DISABLE_GATEWAY_AUTH__ disableTls: __DISABLE_TLS__ oidc: issuer: "__OIDC_ISSUER__" diff --git a/deploy/man/openshell-gateway.8.md b/deploy/man/openshell-gateway.8.md index 5e3c4eef2..82025bc18 100644 --- a/deploy/man/openshell-gateway.8.md +++ b/deploy/man/openshell-gateway.8.md @@ -22,12 +22,13 @@ network and filesystem policies to sandboxes, routes inference requests, and provides the SSH tunnel endpoint for CLI-to-sandbox connections. -When installed via RPM, the gateway runs as a systemd user service -with the Podman compute driver. Sandboxes are rootless Podman -containers on the host. +When installed via a Linux package, the gateway runs as a systemd user +service. The packaged service starts from built-in defaults and reads +the default gateway TOML path only when that file exists. -The gateway exposes a single port (default 8080) with multiplexed -gRPC and HTTP, secured by mutual TLS (mTLS) by default. +The gateway exposes a single port with multiplexed gRPC and HTTP, +secured by mutual TLS (mTLS) by default unless the TOML config disables +TLS. # OPTIONS @@ -36,7 +37,7 @@ gRPC and HTTP, secured by mutual TLS (mTLS) by default. Environment: **OPENSHELL_BIND_ADDRESS**. **--port** *PORT* -: Port for the gRPC/HTTP API. Default: **8080**. +: Port for the gRPC/HTTP API. Default: **17670**. Environment: **OPENSHELL_SERVER_PORT**. **--health-port** *PORT* @@ -53,28 +54,44 @@ gRPC and HTTP, secured by mutual TLS (mTLS) by default. Environment: **OPENSHELL_LOG_LEVEL**. **--db-url** *URL* -: SQLite database URL for state persistence. Required. +: SQLite database URL for state persistence. When unset, the gateway + stores SQLite state under *~/.local/state/openshell/gateway/*. Environment: **OPENSHELL_DB_URL**. **--drivers** *DRIVER*\[,*DRIVER*\] : Compute driver. Accepts a comma-delimited list. The gateway currently requires exactly one driver. Options: **podman**, - **docker**, **kubernetes**. Default: **kubernetes**. + **docker**, **kubernetes**, **vm**. When unset, the gateway + auto-detects Kubernetes, then Podman, then Docker. VM is opt-in. Environment: **OPENSHELL_DRIVERS**. **--tls-cert** *PATH* -: Path to server TLS certificate file. Required unless - **--disable-tls** is set. Environment: **OPENSHELL_TLS_CERT**. +: Path to server TLS certificate file. Defaults to the local generated + TLS bundle when present. Required unless **--disable-tls** is set. + Environment: **OPENSHELL_TLS_CERT**. **--tls-key** *PATH* -: Path to server TLS private key file. Required unless - **--disable-tls** is set. Environment: **OPENSHELL_TLS_KEY**. +: Path to server TLS private key file. Defaults to the local generated + TLS bundle when present. Required unless **--disable-tls** is set. + Environment: **OPENSHELL_TLS_KEY**. **--tls-client-ca** *PATH* : Path to CA certificate for client certificate verification (mTLS). - Required unless **--disable-tls** is set. + When set without **--oidc-issuer**, client certificates are required + and the TLS handshake rejects unauthenticated connections. When set + together with **--oidc-issuer**, client certificates are accepted + but not required. Client certificates can authenticate local + single-user CLI callers when mTLS auth is enabled; sandbox + supervisors still authenticate with gateway-minted bearer tokens. Environment: **OPENSHELL_TLS_CLIENT_CA**. +**--enable-mtls-auth** *BOOL* +: Enable mTLS client certificate authentication for local single-user + Docker, Podman, and VM gateways. Defaults on for local gateways with + client certificate verification and no OIDC issuer. Not supported with + the Kubernetes compute driver. + Environment: **OPENSHELL_ENABLE_MTLS_AUTH**. + **--disable-tls** : Disable TLS entirely and listen on plaintext HTTP. When the bind address is **0.0.0.0** (the RPM default), disabling TLS exposes the @@ -83,44 +100,20 @@ gRPC and HTTP, secured by mutual TLS (mTLS) by default. **--bind-address** to **127.0.0.1**. Environment: **OPENSHELL_DISABLE_TLS**. -**--disable-gateway-auth** -: Disable mTLS client certificate requirement. The TLS handshake - accepts connections without a client certificate. Ignored when - **--disable-tls** is set. - Environment: **OPENSHELL_DISABLE_GATEWAY_AUTH**. - -**--sandbox-image** *IMAGE* -: Default container image for sandboxes. - Environment: **OPENSHELL_SANDBOX_IMAGE**. - -**--sandbox-image-pull-policy** *POLICY* -: Image pull policy: Always, IfNotPresent, Never. - Environment: **OPENSHELL_SANDBOX_IMAGE_PULL_POLICY**. - -**--ssh-handshake-secret** *SECRET* -: Shared secret for gateway-to-sandbox SSH handshake. - Environment: **OPENSHELL_SSH_HANDSHAKE_SECRET**. - -**--ssh-handshake-skew-secs** *SECONDS* -: Allowed clock skew in seconds for SSH handshake. Default: **30**. - Environment: **OPENSHELL_SSH_HANDSHAKE_SKEW_SECS**. - -**--ssh-gateway-host** *HOST* -: Public host for the SSH gateway endpoint. Default: **127.0.0.1**. - Environment: **OPENSHELL_SSH_GATEWAY_HOST**. - -**--ssh-gateway-port** *PORT* -: Public port for the SSH gateway endpoint. Default: **8080**. - Environment: **OPENSHELL_SSH_GATEWAY_PORT**. +**--server-san** *SAN* +: Subject Alternative Name configured on the gateway server + certificate. Repeat or pass a comma-separated value through + **OPENSHELL_SERVER_SAN**. Wildcard DNS SANs also enable sandbox + service URLs under that domain. + Environment: **OPENSHELL_SERVER_SAN**. -**--grpc-endpoint** *URL* -: gRPC endpoint for sandbox callbacks. Should be reachable from - within sandbox containers. - Environment: **OPENSHELL_GRPC_ENDPOINT**. +Compute driver settings such as sandbox image, callback endpoint, image +pull policy, network name, VM state directory, and guest TLS material are +configured in the TOML file passed with **--config**. # SYSTEMD INTEGRATION -The RPM installs a systemd user unit at +The package installs a systemd user unit at */usr/lib/systemd/user/openshell-gateway.service*. Manage the gateway with standard systemd commands: @@ -134,14 +127,13 @@ View logs: journalctl --user -u openshell-gateway journalctl --user -u openshell-gateway -f -The unit runs two **ExecStartPre** scripts on first start: +The unit runs **openshell-gateway generate-certs** as an **ExecStartPre** +step on first start. This generates a self-signed PKI bundle for mTLS +and sandbox JWT signing material, adding missing JWT files to older +TLS-only installs when needed. -1. **init-pki.sh** generates a self-signed PKI bundle for mTLS. -2. **init-gateway-env.sh** generates the environment configuration - file with an auto-generated SSH handshake secret. - -Both scripts are idempotent and skip generation if their output files -already exist. +The gateway then starts from built-in defaults and reads +*~/.config/openshell/gateway.toml* when that file exists. To persist the service across logouts: @@ -149,11 +141,16 @@ To persist the service across logouts: # CONFIGURATION -The systemd user unit reads configuration from -*~/.config/openshell/gateway.env*. See **openshell-gateway.env**(5) -for the full variable reference. +The systemd user unit launches the gateway with: + + openshell-gateway -To override individual settings without modifying gateway.env: +Gateway listener, TLS, database, and compute driver settings have local +defaults. Create *~/.config/openshell/gateway.toml* when you need to +override them. The gateway rejects `database_url` in TOML; set +**OPENSHELL_DB_URL** when you need a different database. + +To override individual settings without creating TOML: systemctl --user edit openshell-gateway @@ -167,19 +164,13 @@ This creates a drop-in override that persists across package upgrades. */usr/lib/systemd/user/openshell-gateway.service* : Systemd user unit file. -*/usr/libexec/openshell/init-pki.sh* -: PKI bootstrap script. - -*/usr/libexec/openshell/init-gateway-env.sh* -: Gateway environment file generator. - -*~/.config/openshell/gateway.env* -: Gateway environment configuration (generated on first start). +*~/.config/openshell/gateway.toml* +: Optional gateway TOML configuration. *~/.local/state/openshell/tls/* -: Auto-generated TLS certificates. +: Auto-generated TLS certificates and sandbox JWT signing keys. -*~/.local/state/openshell/gateway.db* +*~/.local/state/openshell/gateway/openshell.db* : SQLite database for gateway state. *~/.config/openshell/gateways/openshell/mtls/* @@ -193,18 +184,17 @@ Start the gateway as a systemd user service: Check gateway health from the CLI: - openshell gateway add --local https://127.0.0.1:8080 + openshell gateway add --local https://127.0.0.1:17670 openshell status -Override the API port via a systemd drop-in: +Override the API port in TOML: - systemctl --user edit openshell-gateway - # Add: [Service] - # Add: Environment=OPENSHELL_SERVER_PORT=9090 + $EDITOR ~/.config/openshell/gateway.toml + systemctl --user restart openshell-gateway # SEE ALSO -**openshell**(1), **openshell-gateway.env**(5), **systemctl**(1), -**journalctl**(1), **loginctl**(1), **podman**(1) +**openshell**(1), **systemctl**(1), **journalctl**(1), **loginctl**(1), +**podman**(1) Full documentation: *https://docs.nvidia.com/openshell/* diff --git a/deploy/man/openshell-gateway.env.5.md b/deploy/man/openshell-gateway.env.5.md deleted file mode 100644 index a4e715edd..000000000 --- a/deploy/man/openshell-gateway.env.5.md +++ /dev/null @@ -1,168 +0,0 @@ ---- -title: OPENSHELL-GATEWAY.ENV -section: 5 -header: OpenShell Manual -footer: openshell-gateway -date: 2025 ---- - -# NAME - -openshell-gateway.env - OpenShell gateway environment configuration - -# DESCRIPTION - -The **openshell-gateway.env** file contains environment variables that -configure the OpenShell gateway server when running as a systemd user -service. It is generated automatically on first start by -**init-gateway-env.sh** and is not overwritten on subsequent starts or -package upgrades. - -The file uses the standard systemd **EnvironmentFile** format: one -**KEY=VALUE** pair per line. Lines beginning with **#** are comments. -Shell variable expansion is not performed. - -# LOCATION - -The file is located at: - - ~/.config/openshell/gateway.env - -The systemd user unit reads it via: - - EnvironmentFile=-~/.config/openshell/gateway.env - -The **-** prefix means the service starts normally if the file does not -exist (the unit has built-in defaults for all required settings except -the SSH handshake secret). - -# VARIABLES - -## Required - -**OPENSHELL_SSH_HANDSHAKE_SECRET** -: Shared HMAC secret for gateway-to-sandbox SSH handshake - authentication. Auto-generated as a 32-byte hex string on first - start. To regenerate: **openssl rand -hex 32**. - -## Gateway - -**OPENSHELL_BIND_ADDRESS** (default: 0.0.0.0) -: IP address to bind all listeners to. The RPM default of **0.0.0.0** - exposes the gateway on all network interfaces; mTLS must remain - enabled to prevent unauthenticated access. Set to **127.0.0.1** for - local-only access. - -**OPENSHELL_SERVER_PORT** (default: 8080) -: Port for the multiplexed gRPC/HTTP API. - -**OPENSHELL_HEALTH_PORT** (default: 0) -: Port for unauthenticated health endpoints (/healthz, /readyz). - Set to a non-zero value to enable a dedicated health listener. - -**OPENSHELL_METRICS_PORT** (default: 0) -: Port for Prometheus metrics endpoint (/metrics). Set to a - non-zero value to enable a dedicated metrics listener. - -**OPENSHELL_LOG_LEVEL** (default: info) -: Log verbosity: **trace**, **debug**, **info**, **warn**, **error**. - -**OPENSHELL_DRIVERS** (default: podman) -: Compute driver for sandbox management. Options: **podman**, - **docker**, **kubernetes**. The RPM unit defaults to **podman**. - -**OPENSHELL_DB_URL** (default: sqlite://$XDG_STATE_HOME/openshell/gateway.db) -: SQLite database URL for gateway state persistence. - -**OPENSHELL_DISABLE_GATEWAY_AUTH** (default: unset) -: Set to **true** to disable mTLS client certificate verification. - -## TLS - -**OPENSHELL_TLS_CERT** (default: auto-generated path) -: Path to server TLS certificate. - -**OPENSHELL_TLS_KEY** (default: auto-generated path) -: Path to server TLS private key. - -**OPENSHELL_TLS_CLIENT_CA** (default: auto-generated path) -: Path to CA certificate for client certificate verification. - -**OPENSHELL_DISABLE_TLS** (default: unset) -: Set to **true** to disable TLS entirely and listen on plaintext - HTTP. Not recommended for production. When the bind address is - **0.0.0.0** (the RPM default), disabling TLS exposes the API to the - entire network without authentication. Restrict - **OPENSHELL_BIND_ADDRESS** to **127.0.0.1** or place the gateway - behind a TLS-terminating reverse proxy. - -**OPENSHELL_PODMAN_TLS_CA** (default: auto-generated path) -: CA certificate bind-mounted into sandbox containers. - -**OPENSHELL_PODMAN_TLS_CERT** (default: auto-generated path) -: Client certificate bind-mounted into sandbox containers. - -**OPENSHELL_PODMAN_TLS_KEY** (default: auto-generated path) -: Client private key bind-mounted into sandbox containers. - -## Images - -**OPENSHELL_SUPERVISOR_IMAGE** (default: ghcr.io/nvidia/openshell/supervisor:latest) -: OCI image containing the supervisor binary, mounted read-only - into sandbox containers. - -**OPENSHELL_SANDBOX_IMAGE** (default: ghcr.io/nvidia/openshell-community/sandboxes/base:latest) -: Default OCI image for sandbox containers. - -**OPENSHELL_SANDBOX_IMAGE_PULL_POLICY** (default: missing) -: When to pull sandbox images: **always** (every sandbox creation), - **missing** (only if not cached locally), **never** (use cached - only), **newer** (pull if a newer version exists). - -## Podman Driver - -**OPENSHELL_PODMAN_SOCKET** (default: $XDG_RUNTIME_DIR/podman/podman.sock) -: Path to the Podman API Unix socket. - -**OPENSHELL_NETWORK_NAME** (default: openshell) -: Name of the Podman bridge network for sandbox containers. Created - automatically if it does not exist. - -**OPENSHELL_STOP_TIMEOUT** (default: 10) -: Seconds to wait after SIGTERM before sending SIGKILL when stopping - a sandbox container. - -# EXAMPLES - -Change the API port to 9090: - - OPENSHELL_SERVER_PORT=9090 - -Pin sandbox images to a specific version: - - OPENSHELL_SUPERVISOR_IMAGE=ghcr.io/nvidia/openshell/supervisor:v0.0.37 - OPENSHELL_SANDBOX_IMAGE=ghcr.io/nvidia/openshell-community/sandboxes/base:v0.0.37 - -Air-gapped deployment (pre-loaded images, no registry access): - - OPENSHELL_SANDBOX_IMAGE_PULL_POLICY=never - -Enable debug logging: - - OPENSHELL_LOG_LEVEL=debug - -Use externally-managed TLS certificates: - - OPENSHELL_TLS_CERT=/etc/pki/tls/certs/openshell.crt - OPENSHELL_TLS_KEY=/etc/pki/tls/private/openshell.key - OPENSHELL_TLS_CLIENT_CA=/etc/pki/tls/certs/openshell-ca.crt - -Disable TLS (behind a reverse proxy): - - OPENSHELL_DISABLE_TLS=true - -# SEE ALSO - -**openshell-gateway**(8), **openshell**(1), **systemd.exec**(5) - -Full documentation: *https://docs.nvidia.com/openshell/* diff --git a/deploy/man/openshell.1.md b/deploy/man/openshell.1.md index 98a683ec5..6ba6f4afb 100644 --- a/deploy/man/openshell.1.md +++ b/deploy/man/openshell.1.md @@ -190,7 +190,7 @@ development task, or behind a cloud reverse proxy. Register the local RPM gateway and create a sandbox: - openshell gateway add --local https://127.0.0.1:8080 + openshell gateway add --local https://127.0.0.1:17670 openshell sandbox create -- claude List sandboxes and connect to one: @@ -208,7 +208,7 @@ Check gateway health: # SEE ALSO -**openshell-gateway**(8), **openshell-gateway.env**(5) +**openshell-gateway**(8) Full documentation: *https://docs.nvidia.com/openshell/* diff --git a/deploy/rpm/CONFIGURATION.md b/deploy/rpm/CONFIGURATION.md index 04caaaae5..f48b7d158 100644 --- a/deploy/rpm/CONFIGURATION.md +++ b/deploy/rpm/CONFIGURATION.md @@ -6,16 +6,72 @@ the RPM package on Fedora and RHEL systems. For first-time setup, see QUICKSTART.md. For troubleshooting, see TROUBLESHOOTING.md. +## Default configuration + +The RPM ships a default TOML configuration template at +`/usr/share/openshell-gateway/gateway.toml.default`. On first start of +`openshell-gateway.service`, the systemd unit copies this template to +`~/.config/openshell/gateway.toml` if no config file exists there yet. + +The defaults are tuned for rootless Podman use: + +```toml +[openshell] +version = 1 + +[openshell.gateway] +bind_address = "0.0.0.0:17670" +compute_drivers = ["podman"] +``` + +`bind_address = "0.0.0.0:17670"` is required because Podman sandbox +containers reach the gateway over the host network bridge and cannot +connect to `127.0.0.1` inside the gateway's network namespace. mTLS is +enabled by default and protects all connections. + +`compute_drivers = ["podman"]` pins the compute driver to Podman. Without +this, the gateway auto-detects in order: Kubernetes, Podman, Docker. Pinning +prevents unexpected driver selection if Docker is also installed on the host. + +### Customizing the configuration + +Edit `~/.config/openshell/gateway.toml` directly. The template at +`/usr/share/openshell-gateway/gateway.toml.default` is not read at runtime +and is not overwritten by RPM upgrades. + +To apply environment variable overrides that persist across upgrades without +editing the TOML file, add them to `~/.config/openshell/gateway.env`: + +```shell +# Example: restrict to loopback only +OPENSHELL_BIND_ADDRESS=127.0.0.1 +``` + +To override the path to the TOML config file entirely: + +```shell +# In ~/.config/openshell/gateway.env +OPENSHELL_GATEWAY_CONFIG=/path/to/custom/gateway.toml +``` + +For one-off service overrides that persist across package upgrades: + +```shell +systemctl --user edit openshell-gateway +``` + ## TLS (mTLS) The RPM enables mutual TLS by default. The gateway requires a valid -client certificate for all API connections, protecting the API even -though it listens on all interfaces (`0.0.0.0`). +client certificate for all API connections and listens on +`0.0.0.0:17670` by default (see "Default configuration" above). ### Auto-generated certificates -On first start, the `init-pki.sh` script generates certificates using -OpenSSL: +On first start, the systemd user service runs +`openshell-gateway generate-certs --output-dir ~/.local/state/openshell/tls --server-san host.openshell.internal` +to generate certificates with `rcgen` (the same routine the CLI uses for +local mTLS bundles): | File | Purpose | Location | |------|---------|----------| @@ -49,6 +105,7 @@ Names: - `openshell.openshell.svc.cluster.local` - `host.containers.internal` - `host.docker.internal` +- `host.openshell.internal` - `127.0.0.1` To connect from a remote machine, you need externally-managed @@ -61,13 +118,13 @@ To use certificates from an external CA or cert-manager: 1. Place the server cert, key, and CA cert on the filesystem. -1. Edit `~/.config/openshell/gateway.env` or use - `systemctl --user edit openshell-gateway` to override: +1. Edit `~/.config/openshell/gateway.toml`: - ```shell - OPENSHELL_TLS_CERT=/path/to/server/tls.crt - OPENSHELL_TLS_KEY=/path/to/server/tls.key - OPENSHELL_TLS_CLIENT_CA=/path/to/ca.crt + ```toml + [openshell.gateway.tls] + cert_path = "/path/to/server/tls.crt" + key_path = "/path/to/server/tls.key" + client_ca_path = "/path/to/ca.crt" ``` 1. Place the client cert where the CLI expects it: @@ -92,25 +149,21 @@ The gateway regenerates the PKI on next start. ### Disabling TLS -> **WARNING:** The RPM gateway binds to all interfaces (`0.0.0.0`) by -> default. With TLS disabled, the gateway API is exposed to the entire -> network with **no authentication**. Any host that can reach the -> gateway port has full access, including the ability to create -> sandboxes, execute arbitrary code, and access configured credentials. -> Only disable TLS when the gateway is behind a TLS-terminating reverse -> proxy that enforces its own authentication. When disabling TLS without -> a reverse proxy, restrict `OPENSHELL_BIND_ADDRESS` to `127.0.0.1`. +> **WARNING:** With TLS disabled, the gateway API has no authentication. +> Keep the bind address on `127.0.0.1`, or place the gateway behind a +> TLS-terminating reverse proxy that enforces its own authentication. To disable TLS (not recommended for production): -1. Edit `~/.config/openshell/gateway.env`: +1. Edit `~/.config/openshell/gateway.toml`: - ```shell - OPENSHELL_DISABLE_TLS=true + ```toml + [openshell.gateway] + disable_tls = true ``` -1. Comment out the `OPENSHELL_TLS_*` and `OPENSHELL_PODMAN_TLS_*` - variables if they are set. +1. Remove or comment out the `guest_tls_*` entries in + `~/.config/openshell/gateway.toml` if they are set. 1. Restart the gateway. @@ -120,14 +173,15 @@ When mTLS is enabled, the Podman driver bind-mounts the client certificates into each sandbox container so the supervisor process can establish an mTLS connection back to the gateway. -The following environment variables control the host-side paths of the -client certificates that are mounted into sandbox containers: +The following TOML fields control the host-side paths of the client +certificates that are mounted into sandbox containers: -| Variable | Description | -|----------|-------------| -| `OPENSHELL_PODMAN_TLS_CA` | CA certificate (host path) | -| `OPENSHELL_PODMAN_TLS_CERT` | Client certificate (host path) | -| `OPENSHELL_PODMAN_TLS_KEY` | Client private key (host path) | +```toml +[openshell.gateway] +guest_tls_ca = "/home/user/.local/state/openshell/tls/ca.crt" +guest_tls_cert = "/home/user/.local/state/openshell/tls/client/tls.crt" +guest_tls_key = "/home/user/.local/state/openshell/tls/client/tls.key" +``` Inside the container, the supervisor reads them from: @@ -141,55 +195,52 @@ configuration is required. ## Configuration reference -All settings are controlled via environment variables. The user unit -reads from `~/.config/openshell/gateway.env` (generated on first start) -and from `Environment=` directives in the systemd unit. +> **Upgrading from a previous release?** See the +> ["Migrating from gateway.env"](TROUBLESHOOTING.md#migrating-from-gatewayenv) +> section in TROUBLESHOOTING.md for the env-to-TOML mapping and notes on +> the default port, bind address, and database path changes. -Values in `gateway.env` override the unit defaults. Use -`systemctl --user edit openshell-gateway` to add overrides that persist -across package upgrades. +Gateway and driver settings have local runtime defaults. The gateway reads +`~/.config/openshell/gateway.toml` when that file exists. Set +`OPENSHELL_GATEWAY_CONFIG` in the launch environment to use a different file. + +Use `systemctl --user edit openshell-gateway` for service environment +overrides that persist across package upgrades. ### Gateway settings -| Variable | Default | Description | -|----------|---------|-------------| -| `OPENSHELL_BIND_ADDRESS` | `0.0.0.0` | IP address to bind all listeners to. The default exposes the gateway on all interfaces; mTLS must remain enabled to prevent unauthenticated access. Set to `127.0.0.1` for local-only access. | -| `OPENSHELL_SERVER_PORT` | `8080` | Port for the gRPC/HTTP API | -| `OPENSHELL_HEALTH_PORT` | `0` (disabled) | Port for unauthenticated health endpoints (`/healthz`, `/readyz`). Set to a non-zero value to enable. | -| `OPENSHELL_METRICS_PORT` | `0` (disabled) | Port for Prometheus metrics (`/metrics`). Set to a non-zero value to enable. | -| `OPENSHELL_LOG_LEVEL` | `info` | Log level: `trace`, `debug`, `info`, `warn`, `error` | -| `OPENSHELL_DRIVERS` | `podman` | Compute driver (`podman`, `docker`, `kubernetes`) | -| `OPENSHELL_DB_URL` | `sqlite://$XDG_STATE_HOME/openshell/gateway.db` | SQLite database URL for state persistence | -| `OPENSHELL_SSH_HANDSHAKE_SECRET` | (auto-generated) | Shared secret for sandbox SSH authentication | -| `OPENSHELL_DISABLE_GATEWAY_AUTH` | (unset) | Set to `true` to skip mTLS client certificate checks | - -### TLS settings - -| Variable | Default | Description | -|----------|---------|-------------| -| `OPENSHELL_TLS_CERT` | (auto-generated path) | Server TLS certificate | -| `OPENSHELL_TLS_KEY` | (auto-generated path) | Server TLS private key | -| `OPENSHELL_TLS_CLIENT_CA` | (auto-generated path) | CA for client certificate verification | -| `OPENSHELL_DISABLE_TLS` | (unset) | Set to `true` to disable TLS | -| `OPENSHELL_PODMAN_TLS_CA` | (auto-generated path) | CA cert mounted into sandbox containers | -| `OPENSHELL_PODMAN_TLS_CERT` | (auto-generated path) | Client cert mounted into sandbox containers | -| `OPENSHELL_PODMAN_TLS_KEY` | (auto-generated path) | Client key mounted into sandbox containers | - -### Sandbox settings - -| Variable | Default | Description | -|----------|---------|-------------| -| `OPENSHELL_SUPERVISOR_IMAGE` | `ghcr.io/nvidia/openshell/supervisor:latest` | Supervisor binary OCI image | -| `OPENSHELL_SANDBOX_IMAGE` | `ghcr.io/nvidia/openshell-community/sandboxes/base:latest` | Default sandbox base image | -| `OPENSHELL_SANDBOX_IMAGE_PULL_POLICY` | `missing` | Image pull policy: `always`, `missing`, `never`, `newer` | - -### Podman driver settings - -| Variable | Default | Description | -|----------|---------|-------------| -| `OPENSHELL_PODMAN_SOCKET` | `$XDG_RUNTIME_DIR/podman/podman.sock` | Podman API Unix socket path | -| `OPENSHELL_NETWORK_NAME` | `openshell` | Podman bridge network name for sandbox containers | -| `OPENSHELL_STOP_TIMEOUT` | `10` | Container stop timeout in seconds (SIGTERM then SIGKILL) | +| TOML option | Default | Description | +|-------------|---------|-------------| +| `bind_address` | `0.0.0.0:17670` (RPM default) | Address for the gRPC/HTTP API. | +| `compute_drivers` | `["podman"]` (RPM default) | When unset, the gateway auto-detects Kubernetes, then Podman, then Docker. The RPM default pins to Podman. | +| `default_image` | `ghcr.io/nvidia/openshell-community/sandboxes/base:latest` | Default sandbox image. | +| `supervisor_image` | `ghcr.io/nvidia/openshell/supervisor:latest` | Supervisor image mounted into Podman sandboxes. | +| `guest_tls_ca`, `guest_tls_cert`, `guest_tls_key` | auto-generated paths | Client TLS material bind-mounted into sandbox containers. | +| `[openshell.gateway.tls]` paths | auto-generated paths | Server TLS certificate, key, and client CA. | +| `disable_tls` | unset | Set to `true` to disable TLS. | + +The database URL is not accepted in TOML. When `OPENSHELL_DB_URL` is unset, +the gateway uses `sqlite:$XDG_STATE_HOME/openshell/gateway/openshell.db`. + +### Driver TOML settings + +Create `~/.config/openshell/gateway.toml` when you need to customize driver +settings: + +```toml +[openshell] +version = 1 + +[openshell.gateway] +bind_address = "0.0.0.0:17670" +compute_drivers = ["podman"] +default_image = "ghcr.io/nvidia/openshell-community/sandboxes/base:latest" + +[openshell.drivers.podman] +image_pull_policy = "missing" +network_name = "openshell" +stop_timeout_secs = 10 +``` ### Image management @@ -204,14 +255,14 @@ podman pull ghcr.io/nvidia/openshell/supervisor:latest podman pull ghcr.io/nvidia/openshell-community/sandboxes/base:latest ``` -Or set `OPENSHELL_SANDBOX_IMAGE_PULL_POLICY=always` to pull on every -sandbox creation. +Or set `image_pull_policy = "always"` in +`[openshell.drivers.podman]` to pull on every sandbox creation. To pin specific image versions instead of `:latest`: ```shell -OPENSHELL_SUPERVISOR_IMAGE=ghcr.io/nvidia/openshell/supervisor:v0.0.37 -OPENSHELL_SANDBOX_IMAGE=ghcr.io/nvidia/openshell-community/sandboxes/base:v0.0.37 +supervisor_image = "ghcr.io/nvidia/openshell/supervisor:v0.0.37" +default_image = "ghcr.io/nvidia/openshell-community/sandboxes/base:v0.0.37" ``` For air-gapped environments: @@ -234,8 +285,9 @@ For air-gapped environments: 1. Set pull policy to `never`: - ```shell - OPENSHELL_SANDBOX_IMAGE_PULL_POLICY=never + ```toml + [openshell.drivers.podman] + image_pull_policy = "never" ``` ## File locations @@ -245,9 +297,9 @@ For air-gapped environments: | Gateway binary | `/usr/bin/openshell-gateway` | | CLI binary | `/usr/bin/openshell` | | Systemd user unit | `/usr/lib/systemd/user/openshell-gateway.service` | -| PKI bootstrap script | `/usr/libexec/openshell/init-pki.sh` | -| Env generator script | `/usr/libexec/openshell/init-gateway-env.sh` | +| Default TOML config template (read-only) | `/usr/share/openshell-gateway/gateway.toml.default` | +| Active gateway TOML configuration | `~/.config/openshell/gateway.toml` | +| Optional environment variable overrides | `~/.config/openshell/gateway.env` | | TLS certificates | `~/.local/state/openshell/tls/` | | CLI client certs | `~/.config/openshell/gateways/openshell/mtls/` | -| Gateway database | `~/.local/state/openshell/gateway.db` | -| Gateway configuration | `~/.config/openshell/gateway.env` | +| Gateway database | `~/.local/state/openshell/gateway/openshell.db` | diff --git a/deploy/rpm/QUICKSTART.md b/deploy/rpm/QUICKSTART.md index b25be7d77..c6634ced9 100644 --- a/deploy/rpm/QUICKSTART.md +++ b/deploy/rpm/QUICKSTART.md @@ -51,8 +51,9 @@ The gateway pulls container images from ghcr.io on first sandbox creation. Ensure the host can reach ghcr.io over HTTPS (port 443). For air-gapped environments, pre-load images with `podman pull` and -set `OPENSHELL_SANDBOX_IMAGE_PULL_POLICY=never` in -`~/.config/openshell/gateway.env`. See CONFIGURATION.md for details. +set `image_pull_policy = "never"` in +`~/.config/openshell/gateway.toml`. See CONFIGURATION.md for +details. ## Start the gateway @@ -63,13 +64,11 @@ systemctl --user enable --now openshell-gateway On first start, the gateway automatically generates: - A self-signed PKI bundle (CA, server cert, client cert) for mTLS -- An SSH handshake secret for sandbox authentication -- A commented configuration file at `~/.config/openshell/gateway.env` -> **Note:** The gateway binds to all interfaces (`0.0.0.0`) by default. -> Mutual TLS (mTLS) is enabled automatically on first start, requiring a -> valid client certificate for every connection. Do not disable TLS -> without restricting the bind address to `127.0.0.1`. See +> **Note:** The RPM default configuration binds to `0.0.0.0:17670` so +> Podman sandbox containers can reach the gateway over the host network +> bridge. Mutual TLS (mTLS) is enabled automatically on first start, +> requiring a valid client certificate for every connection. See > CONFIGURATION.md for details. Verify the service is running: @@ -83,7 +82,7 @@ systemctl --user status openshell-gateway The CLI needs to know where the gateway is. Register it: ```shell -openshell gateway add --local https://127.0.0.1:8080 +openshell gateway add --local https://127.0.0.1:17670 ``` This discovers the pre-provisioned mTLS certificates at diff --git a/deploy/rpm/TROUBLESHOOTING.md b/deploy/rpm/TROUBLESHOOTING.md index 2c33e1a57..68a1f4946 100644 --- a/deploy/rpm/TROUBLESHOOTING.md +++ b/deploy/rpm/TROUBLESHOOTING.md @@ -5,10 +5,11 @@ and upgrade procedures for the RPM deployment. ## CLI compatibility -The RPM installs the gateway as a systemd user service with the Podman -compute driver. The published online docs and some CLI commands assume -a Docker/K3s deployment model. This section clarifies which commands -work, which do not, and what to use instead. +The RPM installs the gateway as a systemd user service. On a standard RPM +install the gateway auto-detects Podman because the package depends on it. +The published online docs and some CLI commands assume a Docker/K3s +deployment model. This section clarifies which commands work, which do not, +and what to use instead. ### Commands that work normally @@ -67,14 +68,14 @@ Forward the gateway port over SSH and connect via localhost: ```shell # On the remote CLI machine: -ssh -L 8080:127.0.0.1:8080 user@gateway-host +ssh -L 17670:127.0.0.1:17670 user@gateway-host # In another terminal on the same machine: # Copy the client certs from the gateway host first: scp -r user@gateway-host:~/.config/openshell/gateways/openshell/mtls/ \ ~/.config/openshell/gateways/openshell/mtls/ -openshell gateway add --local https://127.0.0.1:8080 +openshell gateway add --local https://127.0.0.1:17670 openshell status ``` @@ -82,6 +83,9 @@ openshell status Generate certificates that include the server's hostname or IP in the SANs. See "Using externally-managed certificates" in CONFIGURATION.md. +Then change `bind_address` in +`~/.config/openshell/gateway.toml` to the interface the remote CLI +can reach, for example `0.0.0.0:17670`, and restart the gateway. After placing the server and client certs, register from the remote CLI: @@ -91,7 +95,7 @@ CLI: mkdir -p ~/.config/openshell/gateways/openshell/mtls/ cp ca.crt tls.crt tls.key ~/.config/openshell/gateways/openshell/mtls/ -openshell gateway add --local https://:8080 +openshell gateway add --local https://:17670 ``` ### Firewall @@ -99,7 +103,7 @@ openshell gateway add --local https://:8080 For remote access, open the gateway port in firewalld: ```shell -sudo firewall-cmd --add-port=8080/tcp --permanent +sudo firewall-cmd --add-port=17670/tcp --permanent sudo firewall-cmd --reload ``` @@ -117,7 +121,7 @@ The CLI cannot find a registered gateway. This happens when the gateway is running but has not been registered with the CLI. ```shell -openshell gateway add --local https://127.0.0.1:8080 +openshell gateway add --local https://127.0.0.1:17670 ``` ### Gateway fails to start @@ -186,8 +190,8 @@ podman pull ghcr.io/nvidia/openshell-community/sandboxes/base:latest podman pull ghcr.io/nvidia/openshell/supervisor:latest ``` -Or set `OPENSHELL_SANDBOX_IMAGE_PULL_POLICY=always` in -`~/.config/openshell/gateway.env` and restart the gateway. +Or set `image_pull_policy = "always"` in +`~/.config/openshell/gateway.toml` and restart the gateway. ### Gateway stops on logout @@ -210,17 +214,22 @@ After upgrading the RPM packages: ```shell sudo dnf update openshell openshell-gateway +systemctl --user restart podman.socket systemctl --user restart openshell-gateway ``` The SQLite database schema is auto-migrated on startup. Running sandboxes are stopped during the restart. -The `gateway.env` file is not overwritten during upgrades. The -`init-gateway-env.sh` script is idempotent and only generates the file -on first start. New configuration options from newer versions can be -added manually by referencing CONFIGURATION.md or running -`openshell-gateway --help`. +Restarting `podman.socket` after a package upgrade is recommended: if the +unit file changed on disk during the upgrade, the running socket may become +non-functional until restarted, causing the gateway to fail with a +connection error on `/run/user//podman/podman.sock`. The gateway +retries briefly on startup, but a stale socket will not recover on its own. + +Package upgrades do not overwrite `~/.config/openshell/gateway.toml` when you +create one. New gateway process options can be added manually by referencing +CONFIGURATION.md or running `openshell-gateway --help`. To pick up new container images after an upgrade: @@ -228,3 +237,65 @@ To pick up new container images after an upgrade: podman pull ghcr.io/nvidia/openshell/supervisor:latest podman pull ghcr.io/nvidia/openshell-community/sandboxes/base:latest ``` + +### Migrating from gateway.env + +Previous releases generated `~/.config/openshell/gateway.env` on first +start and used it to configure the gateway at launch. The gateway now +starts from built-in runtime defaults and reads +`~/.config/openshell/gateway.toml` when that file exists. + +If you have a `gateway.env` file it is still honored: the systemd unit +reads it via `EnvironmentFile` on every start. You can leave it in place +or delete it. New installs no longer generate one. + +To migrate settings to TOML, create `~/.config/openshell/gateway.toml` +and map the relevant variables: + +| Environment variable | TOML equivalent | +|---|---| +| `OPENSHELL_BIND_ADDRESS=A` + `OPENSHELL_SERVER_PORT=P` | `bind_address = "A:P"` under `[openshell.gateway]` | +| `OPENSHELL_DRIVERS=podman` | `compute_drivers = ["podman"]` under `[openshell.gateway]` | +| `OPENSHELL_DISABLE_TLS=true` | `disable_tls = true` under `[openshell.gateway]` | +| `OPENSHELL_TLS_CERT=PATH` | `cert_path = "PATH"` under `[openshell.gateway.tls]` | +| `OPENSHELL_TLS_KEY=PATH` | `key_path = "PATH"` under `[openshell.gateway.tls]` | +| `OPENSHELL_TLS_CLIENT_CA=PATH` | `client_ca_path = "PATH"` under `[openshell.gateway.tls]` | +| `OPENSHELL_DB_URL=URL` | env-only — not accepted in TOML; keep in env or drop-in override | +| `OPENSHELL_LOG_LEVEL=debug` | env-only — keep as `Environment=OPENSHELL_LOG_LEVEL=debug` in a drop-in | + +Other breaking changes in this release: + +- **Default port changed from 8080 to 17670.** If you registered the + gateway at `https://127.0.0.1:8080`, re-register it: + + ```shell + openshell gateway add --local https://127.0.0.1:17670 + ``` + +- **Default bind address changed from `0.0.0.0` to `127.0.0.1`.** If + you relied on network-accessible access without an explicit bind + address, add the following to `~/.config/openshell/gateway.toml`: + + ```toml + [openshell.gateway] + bind_address = "0.0.0.0:17670" + ``` + + Also update your firewall rule if applicable: + + ```shell + sudo firewall-cmd --remove-port=8080/tcp --permanent + sudo firewall-cmd --add-port=17670/tcp --permanent + sudo firewall-cmd --reload + ``` + +- **Database path changed** from `~/.local/state/openshell/gateway.db` + to `~/.local/state/openshell/gateway/openshell.db`. Existing gateway + state (registered sandboxes, etc.) is not migrated automatically. To + preserve state across the upgrade, move the file before restarting: + + ```shell + mkdir -p ~/.local/state/openshell/gateway + mv ~/.local/state/openshell/gateway.db \ + ~/.local/state/openshell/gateway/openshell.db + ``` diff --git a/deploy/rpm/gateway.toml.default b/deploy/rpm/gateway.toml.default new file mode 100644 index 000000000..d85379964 --- /dev/null +++ b/deploy/rpm/gateway.toml.default @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Default gateway configuration for RPM installs. +# +# This file is seeded to ~/.config/openshell/gateway.toml on first start +# of the openshell-gateway.service systemd user unit. Edit that copy to +# customize. This file is not read directly at runtime. +# +# Configuration precedence (highest to lowest): +# CLI flag > OPENSHELL_* env var > TOML file > built-in default +# +# To override settings without editing this file, set OPENSHELL_* variables +# in ~/.config/openshell/gateway.env or run: +# systemctl --user edit openshell-gateway + +[openshell] +version = 1 + +[openshell.gateway] +# Podman sandbox containers reach the gateway over the host network bridge, +# which requires binding to all interfaces. Override to 127.0.0.1:17670 if +# you don't use Podman or want loopback-only access (e.g. behind a reverse +# proxy). mTLS is enabled by default and protects all connections. +bind_address = "0.0.0.0:17670" + +# Pin to the Podman compute driver. Without this, the gateway auto-detects +# in order: Kubernetes, Podman, Docker. Pinning prevents unexpected driver +# selection if Docker is also installed on the host. +compute_drivers = ["podman"] diff --git a/deploy/rpm/init-gateway-env.sh b/deploy/rpm/init-gateway-env.sh deleted file mode 100644 index 299a19041..000000000 --- a/deploy/rpm/init-gateway-env.sh +++ /dev/null @@ -1,115 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -# Generate the gateway environment configuration file on first start. -# -# Called from the systemd ExecStartPre directive to bootstrap the -# gateway configuration. Idempotent: exits immediately if the file -# already exists. -# -# Usage: -# init-gateway-env.sh -# -# The generated file contains an auto-generated SSH handshake secret -# and commented defaults for all gateway environment variables. - -set -euo pipefail - -ENV_FILE="${1:?Usage: init-gateway-env.sh }" - -# ── Idempotent: skip if env file already exists ───────────────────── -if [ -f "${ENV_FILE}" ]; then - exit 0 -fi - -# ── Create parent directory ───────────────────────────────────────── -mkdir -p "$(dirname "${ENV_FILE}")" - -# ── Generate SSH handshake secret ─────────────────────────────────── -SECRET=$(od -An -tx1 -N32 /dev/urandom | tr -dc 0-9a-f) - -# ── Write environment file ────────────────────────────────────────── -cat > "${ENV_FILE}" << EOF -# OpenShell Gateway Environment Configuration -# Generated on first start. Edit freely; this file is not overwritten. -# -# Run 'openshell-gateway --help' for the full list of options. -# See /usr/share/doc/openshell-gateway/ for guides. - -# ---- Required ---- - -# Shared secret for gateway-to-sandbox SSH handshake authentication. -# Auto-generated on first start. To regenerate: -# openssl rand -hex 32 -OPENSHELL_SSH_HANDSHAKE_SECRET=${SECRET} - -# ---- Optional (uncomment to override defaults) ---- - -# Database URL for gateway state persistence. -# Default for the user unit: sqlite://\$XDG_STATE_HOME/openshell/gateway.db -#OPENSHELL_DB_URL=sqlite:///path/to/gateway.db - -# Compute driver: podman (default for RPM), docker, kubernetes. -#OPENSHELL_DRIVERS=podman - -# Bind address. 0.0.0.0 listens on all interfaces; mTLS prevents -# unauthenticated access. -#OPENSHELL_BIND_ADDRESS=0.0.0.0 - -# API port (default: 8080). -#OPENSHELL_SERVER_PORT=8080 - -# Log level: trace, debug, info, warn, error. -#OPENSHELL_LOG_LEVEL=info - -# ---- Images ---- - -# Supervisor binary OCI image (mounted read-only into sandboxes). -#OPENSHELL_SUPERVISOR_IMAGE=ghcr.io/nvidia/openshell/supervisor:latest - -# Default sandbox base image. -#OPENSHELL_SANDBOX_IMAGE=ghcr.io/nvidia/openshell-community/sandboxes/base:latest - -# Image pull policy: always, missing (default), never, newer. -# Use 'always' to pick up new tags automatically. -# Use 'never' for air-gapped environments with pre-loaded images. -#OPENSHELL_SANDBOX_IMAGE_PULL_POLICY=missing - -# ---- TLS (mTLS enabled by default) ---- -# PKI is auto-generated by init-pki.sh on first start. Client certs are -# placed in ~/.config/openshell/gateways/openshell/mtls/ so the CLI -# discovers them automatically. -# -# To use externally-managed certs, uncomment and edit the paths below. -# To rotate certs, delete ~/.local/state/openshell/tls/ and restart. -# WARNING: Disabling TLS with the default bind address (0.0.0.0) exposes -# the gateway API to the entire network with NO authentication. Only -# disable TLS when behind a TLS-terminating reverse proxy, or restrict -# OPENSHELL_BIND_ADDRESS to 127.0.0.1. -#OPENSHELL_DISABLE_TLS=true - -# Server TLS (gateway listens with these certs). -#OPENSHELL_TLS_CERT=\$XDG_STATE_HOME/openshell/tls/server/tls.crt -#OPENSHELL_TLS_KEY=\$XDG_STATE_HOME/openshell/tls/server/tls.key -#OPENSHELL_TLS_CLIENT_CA=\$XDG_STATE_HOME/openshell/tls/ca.crt - -# Podman driver: client certs bind-mounted into sandbox containers. -#OPENSHELL_PODMAN_TLS_CA=\$XDG_STATE_HOME/openshell/tls/ca.crt -#OPENSHELL_PODMAN_TLS_CERT=\$XDG_STATE_HOME/openshell/tls/client/tls.crt -#OPENSHELL_PODMAN_TLS_KEY=\$XDG_STATE_HOME/openshell/tls/client/tls.key - -# ---- Podman driver ---- - -# Podman API Unix socket path. -#OPENSHELL_PODMAN_SOCKET=\$XDG_RUNTIME_DIR/podman/podman.sock - -# Podman bridge network name for sandbox containers. -#OPENSHELL_NETWORK_NAME=openshell - -# Container stop timeout in seconds (SIGTERM then SIGKILL). -#OPENSHELL_STOP_TIMEOUT=10 -EOF - -chmod 600 "${ENV_FILE}" -echo "Gateway environment generated: ${ENV_FILE}" diff --git a/deploy/rpm/init-pki.sh b/deploy/rpm/init-pki.sh deleted file mode 100755 index 900ec20a6..000000000 --- a/deploy/rpm/init-pki.sh +++ /dev/null @@ -1,197 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -# Generate a self-signed PKI bundle for the OpenShell gateway. -# -# Called from the systemd ExecStartPre directive to bootstrap mTLS on -# first start. Idempotent: exits immediately if all cert files exist. -# Detects and recovers from partial PKI state (e.g. interrupted runs). -# -# All files are generated in a temporary staging directory first and -# moved into place only after the full PKI is complete, preventing -# partial state from persisting across failures. -# -# Usage: -# init-pki.sh -# -# Output layout: -# /ca.crt CA certificate -# /ca.key CA private key (mode 0600) -# /server/tls.crt Server certificate -# /server/tls.key Server private key (mode 0600) -# /client/tls.crt Client certificate -# /client/tls.key Client private key (mode 0600) -# -# Client certs are also copied to the CLI's auto-discovery directory: -# $XDG_CONFIG_HOME/openshell/gateways/openshell/mtls/{ca.crt,tls.crt,tls.key} - -set -euo pipefail - -PKI_DIR="${1:?Usage: init-pki.sh }" - -# ── Resolve CLI cert directory ─────────────────────────────────────── -CLI_MTLS_DIR="${XDG_CONFIG_HOME:-${HOME}/.config}/openshell/gateways/openshell/mtls" - -# ── Required PKI files ─────────────────────────────────────────────── -PKI_FILES=( - "${PKI_DIR}/ca.crt" - "${PKI_DIR}/ca.key" - "${PKI_DIR}/server/tls.crt" - "${PKI_DIR}/server/tls.key" - "${PKI_DIR}/client/tls.crt" - "${PKI_DIR}/client/tls.key" -) - -CLI_FILES=( - "${CLI_MTLS_DIR}/ca.crt" - "${CLI_MTLS_DIR}/tls.crt" - "${CLI_MTLS_DIR}/tls.key" -) - -# ── Idempotent: skip if all PKI files exist ────────────────────────── -all_pki_exist=true -for f in "${PKI_FILES[@]}"; do - if [ ! -f "$f" ]; then - all_pki_exist=false - break - fi -done - -if [ "$all_pki_exist" = true ]; then - # PKI is complete. Ensure CLI copies also exist (they may have been - # deleted independently, e.g. user cleared their config directory). - cli_ok=true - for f in "${CLI_FILES[@]}"; do - if [ ! -f "$f" ]; then - cli_ok=false - break - fi - done - if [ "$cli_ok" = false ]; then - echo "PKI exists but CLI auto-discovery certs missing; re-copying..." - mkdir -p "${CLI_MTLS_DIR}" - cp "${PKI_DIR}/ca.crt" "${CLI_MTLS_DIR}/ca.crt" - cp "${PKI_DIR}/client/tls.crt" "${CLI_MTLS_DIR}/tls.crt" - cp "${PKI_DIR}/client/tls.key" "${CLI_MTLS_DIR}/tls.key" - chmod 600 "${CLI_MTLS_DIR}/tls.key" - fi - exit 0 -fi - -# ── Partial state recovery ─────────────────────────────────────────── -# If some PKI files exist but not all, a previous run was interrupted. -# Remove the partial state so we can regenerate cleanly. -partial=false -for f in "${PKI_FILES[@]}"; do - if [ -f "$f" ]; then - partial=true - break - fi -done -if [ "$partial" = true ]; then - echo "WARNING: Partial PKI detected in ${PKI_DIR}, regenerating..." - rm -f "${PKI_DIR}/ca.crt" "${PKI_DIR}/ca.key" "${PKI_DIR}/ca.srl" - rm -rf "${PKI_DIR}/server" "${PKI_DIR}/client" -fi - -# ── Temporary workspace (cleaned up on exit) ───────────────────────── -WORK=$(mktemp -d) -trap 'rm -rf "${WORK}"' EXIT - -# Stage directory mirrors the final PKI layout. -STAGE="${WORK}/pki" -mkdir -p "${STAGE}/server" "${STAGE}/client" - -# ── Server certificate SANs ───────────────────────────────────────── -# These must match what the supervisor connects to. The CLI also -# connects using localhost/127.0.0.1 by default. -cat > "${WORK}/server-san.cnf" <<'EOF' -[req] -distinguished_name = req_dn -req_extensions = v3_req -prompt = no - -[req_dn] -O = openshell -CN = openshell-server - -[v3_req] -subjectAltName = @alt_names - -[alt_names] -DNS.1 = localhost -DNS.2 = openshell -DNS.3 = openshell.openshell.svc -DNS.4 = openshell.openshell.svc.cluster.local -DNS.5 = host.containers.internal -DNS.6 = host.docker.internal -IP.1 = 127.0.0.1 -EOF - -# ── Generate CA (into staging) ─────────────────────────────────────── -openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \ - -keyout "${STAGE}/ca.key" \ - -out "${STAGE}/ca.crt" \ - -days 3650 -nodes \ - -subj "/O=openshell/CN=openshell-ca" \ - 2>/dev/null -chmod 600 "${STAGE}/ca.key" - -# ── Generate server certificate (into staging) ─────────────────────── -openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \ - -keyout "${STAGE}/server/tls.key" \ - -out "${WORK}/server.csr" \ - -nodes \ - -config "${WORK}/server-san.cnf" \ - 2>/dev/null - -openssl x509 -req \ - -in "${WORK}/server.csr" \ - -CA "${STAGE}/ca.crt" -CAkey "${STAGE}/ca.key" -CAcreateserial \ - -out "${STAGE}/server/tls.crt" \ - -days 3650 \ - -extensions v3_req \ - -extfile "${WORK}/server-san.cnf" \ - 2>/dev/null -chmod 600 "${STAGE}/server/tls.key" - -# ── Generate client certificate (into staging) ─────────────────────── -openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \ - -keyout "${STAGE}/client/tls.key" \ - -out "${WORK}/client.csr" \ - -nodes \ - -subj "/O=openshell/CN=openshell-client" \ - 2>/dev/null - -openssl x509 -req \ - -in "${WORK}/client.csr" \ - -CA "${STAGE}/ca.crt" -CAkey "${STAGE}/ca.key" -CAcreateserial \ - -out "${STAGE}/client/tls.crt" \ - -days 3650 \ - 2>/dev/null -chmod 600 "${STAGE}/client/tls.key" - -# ── Move staged PKI into final location ────────────────────────────── -# Create parent directories and move files individually. Using mv on -# individual files rather than whole directories so we do not clobber -# the target directory if it already exists. -mkdir -p "${PKI_DIR}/server" "${PKI_DIR}/client" -mv "${STAGE}/ca.crt" "${PKI_DIR}/ca.crt" -mv "${STAGE}/ca.key" "${PKI_DIR}/ca.key" -mv "${STAGE}/server/tls.crt" "${PKI_DIR}/server/tls.crt" -mv "${STAGE}/server/tls.key" "${PKI_DIR}/server/tls.key" -mv "${STAGE}/client/tls.crt" "${PKI_DIR}/client/tls.crt" -mv "${STAGE}/client/tls.key" "${PKI_DIR}/client/tls.key" - -# ── Copy client certs to CLI auto-discovery directory ──────────────── -# The CLI automatically looks for certs at: -# $XDG_CONFIG_HOME/openshell/gateways//mtls/{ca.crt,tls.crt,tls.key} -# For localhost gateways, defaults to "openshell". -mkdir -p "${CLI_MTLS_DIR}" -cp "${PKI_DIR}/ca.crt" "${CLI_MTLS_DIR}/ca.crt" -cp "${PKI_DIR}/client/tls.crt" "${CLI_MTLS_DIR}/tls.crt" -cp "${PKI_DIR}/client/tls.key" "${CLI_MTLS_DIR}/tls.key" -chmod 600 "${CLI_MTLS_DIR}/tls.key" - -echo "PKI bootstrap complete: ${PKI_DIR}" diff --git a/deploy/snap/README.md b/deploy/snap/README.md new file mode 100644 index 000000000..419aacaa4 --- /dev/null +++ b/deploy/snap/README.md @@ -0,0 +1,177 @@ +# Building a snap package + +OpenShell snap packages are defined by the root `snapcraft.yaml` and built with +Snapcraft from source. + +The helper task under `tasks/` still stages the same payload from pre-built +binaries when you want to inspect the snap root or produce local artifacts. + +## Prerequisites + +- Linux on `amd64` or `arm64` +- `snap` from `snapd` +- `snapcraft` +- Docker from the Docker snap (`sudo snap install docker`) + +## Build with Snapcraft + +Build the snap from source with the root manifest: + +```shell +snapcraft pack +``` + +The manifest builds the Rust binaries inside Snapcraft, installs the CLI, +gateway, and sandbox supervisor into the snap, and keeps the same runtime +environment as the current deployment logic. + +## Staged helper flow + +The helper task under `tasks/` still stages the same payload from pre-built +binaries when you want to inspect the snap root or produce local artifacts. + +For that flow, install `mise` and build: + +- `openshell` +- `openshell-gateway` +- `openshell-sandbox` + +## Build helper binaries + +Build the release binaries used by the staged helper flow: + +```shell +mise run build:rust:snap +``` + +This convenience target builds the CLI with `bundled-z3`, the gateway, and +`openshell-sandbox` for the Docker driver to bind-mount into sandbox containers. + +## Pack the snap + +Run the packaging hook through mise: + +```shell +VERSION="$(uv run python tasks/scripts/release.py get-version --snap)" + +OPENSHELL_CLI_BINARY="$PWD/target/release/openshell" \ +OPENSHELL_GATEWAY_BINARY="$PWD/target/release/openshell-gateway" \ +OPENSHELL_DOCKER_SUPERVISOR_BINARY="$PWD/target/release/openshell-sandbox" \ +OPENSHELL_SNAP_VERSION="$VERSION" \ +OPENSHELL_OUTPUT_DIR=artifacts \ + mise run package:snap +``` + +The artifact is written to `artifacts/openshell_${VERSION}_${ARCH}.snap`. The +packaging hook fails before `snap pack` if `openshell-sandbox` is missing or not +executable. + +## Stage without packing + +To inspect the snap root without running `snap pack`: + +```shell +VERSION="$(uv run python tasks/scripts/release.py get-version --snap)" + +OPENSHELL_CLI_BINARY="$PWD/target/release/openshell" \ +OPENSHELL_GATEWAY_BINARY="$PWD/target/release/openshell-gateway" \ +OPENSHELL_DOCKER_SUPERVISOR_BINARY="$PWD/target/release/openshell-sandbox" \ +OPENSHELL_SNAP_VERSION="$VERSION" \ + mise run package:snap:stage +``` + +The staged root is written to `artifacts/snap-root`. + +## Commands in the snap + +The snap exposes the CLI: + +- `openshell` + +It also defines a system service with packaged Docker driver settings. + +- `openshell.gateway` + +The gateway service uses `refresh-mode: endure` so snap refreshes do not restart +it while sandboxes are active. Restart the service manually when you are ready +to move the gateway to the refreshed snap revision. + +`openshell-sandbox` is staged next to `openshell-gateway` as the Docker +supervisor binary. The gateway app starts through a small wrapper that sets +Snap-specific defaults and reads `$SNAP_COMMON/gateway.toml` when that file +exists. The service stores its gateway database under `$SNAP_COMMON`. + +## Interfaces + +The `openshell` CLI app plugs: + +- `home` +- `network` +- `ssh-keys` +- `system-observe` + +The `openshell.gateway` service plugs: + +- `docker` +- `log-observe` +- `network` +- `network-bind` +- `ssh-keys` +- `system-observe` + +## Start a Docker gateway from the snap + +The snapped gateway talks to Docker through the Docker snap's +`docker:docker-daemon` slot. The snap declares `default-provider: docker` on +its Docker plug so snapd can install the Docker snap when OpenShell is +installed. Connect the interface before using the Docker driver: + +```shell +sudo snap connect openshell:docker docker:docker-daemon +sudo snap connect openshell:log-observe +sudo snap connect openshell:system-observe +sudo snap connect openshell:ssh-keys +``` + +The gateway uses Docker's default Unix socket location. The Docker snap exposes +that socket through the connected `docker` interface, so no `DOCKER_HOST` +override is required. The OpenShell snap still requires the Docker snap because +it relies on the `docker:docker-daemon` slot; it does not work with Docker +installed from a Debian package or Docker's upstream packages. + +The service runs the gateway with Snap-specific environment defaults: + +```shell +OPENSHELL_DISABLE_TLS=true \ +OPENSHELL_DB_URL="sqlite:$SNAP_COMMON/gateway.db?mode=rwc" \ +openshell.gateway +``` + +This stores the gateway SQLite database at +`/var/snap/openshell/common/gateway.db`. Create +`/var/snap/openshell/common/gateway.toml` when you need to override gateway or +Docker driver settings. + +## Connect with the OpenShell CLI + +Register the snap-run gateway as a local plaintext gateway: + +```shell +openshell gateway add http://127.0.0.1:17670 --local --name snap-docker +openshell gateway select snap-docker +openshell status +``` + +Then use normal sandbox commands: + +```shell +openshell sandbox create --name demo +openshell sandbox connect demo +``` + +To avoid changing the default gateway, pass the gateway name per command: + +```shell +openshell --gateway snap-docker status +openshell --gateway snap-docker sandbox create --name demo +``` diff --git a/deploy/snap/bin/openshell-gateway-wrapper b/deploy/snap/bin/openshell-gateway-wrapper new file mode 100755 index 000000000..cfba8db36 --- /dev/null +++ b/deploy/snap/bin/openshell-gateway-wrapper @@ -0,0 +1,15 @@ +#!/bin/sh +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +set -eu + +CANONICAL_CONFIG_FILE="${SNAP_COMMON}/gateway.toml" +export OPENSHELL_DB_URL="${OPENSHELL_DB_URL:-sqlite:${SNAP_COMMON}/gateway.db?mode=rwc}" +export OPENSHELL_DISABLE_TLS="${OPENSHELL_DISABLE_TLS:-true}" + +if [ -z "${OPENSHELL_GATEWAY_CONFIG:-}" ] && [ -f "$CANONICAL_CONFIG_FILE" ]; then + exec "${SNAP}/bin/openshell-gateway" --config "$CANONICAL_CONFIG_FILE" "$@" +fi + +exec "${SNAP}/bin/openshell-gateway" "$@" diff --git a/deploy/snap/meta/snap.yaml.in b/deploy/snap/meta/snap.yaml.in new file mode 100644 index 000000000..920dd9141 --- /dev/null +++ b/deploy/snap/meta/snap.yaml.in @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +name: openshell +title: OpenShell +version: "@VERSION@" +summary: Safe, sandboxed runtimes for autonomous AI agents +description: | + OpenShell provides safe, sandboxed runtimes for autonomous AI agents. + It offers a CLI for managing gateways, sandboxes, and providers with + policy-enforced egress routing, credential proxying, and privacy-aware + LLM inference routing. + +base: "@BASE@" +grade: "@GRADE@" +confinement: strict +license: Apache-2.0 +website: https://docs.nvidia.com/openshell/latest/index.html +source-code: https://github.com/NVIDIA/OpenShell +issues: https://github.com/NVIDIA/OpenShell/issues +contact: https://github.com/NVIDIA/OpenShell/security/policy +architectures: + - "@ARCH@" + +apps: + openshell: + command: bin/openshell + plugs: + - home + - network + - ssh-keys + - system-observe + gateway: + command: bin/openshell-gateway-wrapper + daemon: simple + refresh-mode: endure + environment: + XDG_DATA_HOME: "$SNAP_COMMON" + # Used for creating and locating certain sockets. + XDG_RUNTIME_DIR: "$SNAP_COMMON" + + plugs: + - docker + - log-observe + - network + - network-bind + - ssh-keys + - system-observe + +plugs: + docker: + interface: docker + default-provider: docker diff --git a/docs/about/container-gateway.mdx b/docs/about/container-gateway.mdx new file mode 100644 index 000000000..370d7f146 --- /dev/null +++ b/docs/about/container-gateway.mdx @@ -0,0 +1,150 @@ +--- +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +title: "Running the Gateway as a Container" +sidebar-title: "Container Gateway" +description: "Run the OpenShell gateway using docker run or docker-compose without the installer." +keywords: "Generative AI, Cybersecurity, AI Agents, Sandboxing, Docker, Podman, docker-compose, container, immutable OS, bootc, rpm-ostree" +position: 4 +--- + +Use this approach when you want to run the OpenShell gateway as a container instead of installing it with the system package manager. This is useful on immutable OS distributions (Fedora CoreOS, bootc-based images, Silverblue) where the standard installer is not appropriate, or anywhere you prefer a container-first workflow. + +The gateway image is published at `ghcr.io/nvidia/openshell/gateway`. + +## Quick Start + +This example runs the gateway locally with TLS disabled. It is suitable for development on a single machine. Binding to `127.0.0.1` prevents remote access without authentication. + +```shell +docker run -d \ + --name openshell-gateway \ + --restart unless-stopped \ + -p 127.0.0.1:8080:8080 \ + -v openshell-state:/var/openshell \ + -v /var/run/docker.sock:/var/run/docker.sock \ + -e OPENSHELL_DRIVERS=docker \ + -e OPENSHELL_DB_URL=sqlite:/var/openshell/openshell.db \ + -e OPENSHELL_DISABLE_TLS=true \ + ghcr.io/nvidia/openshell/gateway:latest +``` + +Register the gateway with the CLI: + +```shell +openshell gateway add http://127.0.0.1:8080 --local --name local +``` + +Confirm the CLI can reach the gateway: + +```shell +openshell status +``` + + +Disabling TLS removes authentication. Binding to `127.0.0.1` limits access to the local machine. If you expose the port on `0.0.0.0`, enable TLS and local mTLS user authentication, or put the gateway behind a trusted proxy with its own authentication. + + +## Full mTLS Setup + +To run the gateway with mutual TLS, generate the PKI bundle first, then start the gateway with the cert paths configured. + +Bootstrap the PKI into a local state directory: + +```shell +mkdir -p ~/.local/state/openshell/tls + +docker run --rm \ + -v "$HOME/.local/state/openshell:/home/openshell/.local/state/openshell" \ + -v "$HOME/.config/openshell:/home/openshell/.config/openshell" \ + ghcr.io/nvidia/openshell/gateway:latest \ + generate-certs --output-dir /home/openshell/.local/state/openshell/tls +``` + +This writes the server and client certificates under `~/.local/state/openshell/tls/`, writes sandbox JWT signing keys under `~/.local/state/openshell/tls/jwt/`, and copies the client bundle to `~/.config/openshell/gateways/openshell/mtls/` so the CLI picks it up automatically. + +Start the gateway with mTLS enabled: + +```shell +docker run -d \ + --name openshell-gateway \ + --restart unless-stopped \ + -p 127.0.0.1:8080:8080 \ + -v "$HOME/.local/state/openshell:/home/openshell/.local/state/openshell" \ + -v /var/run/docker.sock:/var/run/docker.sock \ + -e OPENSHELL_DRIVERS=docker \ + -e OPENSHELL_DB_URL=sqlite:/home/openshell/.local/state/openshell/openshell.db \ + -e OPENSHELL_LOCAL_TLS_DIR=/home/openshell/.local/state/openshell/tls \ + -e OPENSHELL_TLS_CERT=/home/openshell/.local/state/openshell/tls/server/tls.crt \ + -e OPENSHELL_TLS_KEY=/home/openshell/.local/state/openshell/tls/server/tls.key \ + -e OPENSHELL_TLS_CLIENT_CA=/home/openshell/.local/state/openshell/tls/ca.crt \ + -e OPENSHELL_ENABLE_MTLS_AUTH=true \ + -e OPENSHELL_DOCKER_TLS_CA=/home/openshell/.local/state/openshell/tls/ca.crt \ + -e OPENSHELL_DOCKER_TLS_CERT=/home/openshell/.local/state/openshell/tls/client/tls.crt \ + -e OPENSHELL_DOCKER_TLS_KEY=/home/openshell/.local/state/openshell/tls/client/tls.key \ + ghcr.io/nvidia/openshell/gateway:latest +``` + +Register the gateway with mTLS: + +```shell +openshell gateway add https://127.0.0.1:8080 --local --name local +``` + +## Docker Compose + +Save the following as `compose.yml`. This uses the TLS-disabled configuration bound to localhost, suitable for local development. + +```yaml +services: + gateway: + image: ghcr.io/nvidia/openshell/gateway:latest + restart: unless-stopped + ports: + - "127.0.0.1:8080:8080" + volumes: + - openshell-state:/var/openshell + - /var/run/docker.sock:/var/run/docker.sock + environment: + OPENSHELL_DRIVERS: docker + OPENSHELL_DB_URL: "sqlite:/var/openshell/openshell.db" + OPENSHELL_DISABLE_TLS: "true" + +volumes: + openshell-state: +``` + +Start the gateway: + +```shell +docker compose up -d +``` + +Register the gateway with the CLI: + +```shell +openshell gateway add http://127.0.0.1:8080 --local --name local +``` + +## Using Podman + +Replace `docker` with `podman` in the commands above. Mount the Podman socket instead of the Docker socket and set the driver to `podman`: + +```shell +podman run -d \ + --name openshell-gateway \ + -p 127.0.0.1:8080:8080 \ + -v openshell-state:/var/openshell \ + -v "$XDG_RUNTIME_DIR/podman/podman.sock:/var/run/podman.sock" \ + -e OPENSHELL_DRIVERS=podman \ + -e OPENSHELL_PODMAN_SOCKET=/var/run/podman.sock \ + -e OPENSHELL_DB_URL=sqlite:/var/openshell/openshell.db \ + -e OPENSHELL_DISABLE_TLS=true \ + ghcr.io/nvidia/openshell/gateway:latest +``` + +## Next Steps + +- To create your first sandbox, refer to the [Quickstart](/get-started/quickstart). +- To control what the agent can access, refer to [Policies](/sandboxes/policies). +- For environment variable reference, refer to [Sandbox Compute Drivers](/reference/sandbox-compute-drivers). diff --git a/docs/about/installation.mdx b/docs/about/installation.mdx index 378439758..cd9973f13 100644 --- a/docs/about/installation.mdx +++ b/docs/about/installation.mdx @@ -24,7 +24,7 @@ Use `openshell status` to confirm the CLI can reach the gateway. ## Supported Compute Drivers -OpenShell supports several local compute drivers. The installer chooses a default driver for your platform, and the gateway reads the driver choice from its startup configuration. Sandbox commands use the same CLI workflow after the gateway is running. +OpenShell supports several local compute drivers. Package-managed gateways leave the driver unset by default so the gateway can auto-detect an available driver. Set `compute_drivers` in the gateway TOML when you need to pin a specific driver. | Compute Driver | How It Is Configured | System Requirements | |---|---|---| @@ -38,6 +38,10 @@ For detailed driver behavior, refer to [Sandbox Compute Drivers](/reference/sand On macOS, the install script uses Homebrew. The Homebrew package installs the `openshell` CLI, the gateway binary, and a Homebrew-managed gateway service. +The Homebrew service listens on `https://127.0.0.1:17670` and generates a local mTLS bundle on install. The gateway starts from built-in defaults and reads `~/.config/openshell/gateway.toml` when that file exists. If that file is absent, the Homebrew service also falls back to a Homebrew prefix config when present, such as `/opt/homebrew/var/openshell/gateway.toml`. + +The CLI reads the client bundle from `~/.config/openshell/gateways/openshell/mtls/`. + The installer starts the service for you. Use Homebrew service commands when you need to inspect, restart, or stop the gateway service: ```shell @@ -51,6 +55,12 @@ On Fedora and RHEL, the install script uses RPM packages. The RPM installs the ` On Debian and Ubuntu, the install script uses a Debian package. The Debian package installs the `openshell` CLI, the `openshell-gateway` daemon, VM sandbox support, and a systemd user service. +Linux packages require glibc 2.31 or newer. The installer checks libc before downloading packages and exits with an error on older glibc versions, Alpine, musl-based distributions, or unknown libc environments. + +The Linux user service listens on `https://127.0.0.1:17670`, starts from built-in defaults, and generates a local mTLS bundle before the gateway starts. Create `~/.config/openshell/gateway.toml` only when you need to override those defaults. + +The CLI reads the client bundle from `~/.config/openshell/gateways/openshell/mtls/`. + The installer starts the service for you. Use systemd user commands when you need to inspect, restart, or stop the gateway service: ```shell @@ -72,6 +82,7 @@ Kubernetes deployments use the OpenShell Helm chart. For step-by-step installati ## Next Steps - To create your first sandbox, refer to the [Quickstart](/get-started/quickstart). +- To run the gateway as a container without the installer, refer to [Running the Gateway as a Container](/about/container-gateway). - To register, select, and inspect gateways, refer to [Gateways](/sandboxes/manage-gateways). - To supply API keys or tokens, refer to [Manage Providers](/sandboxes/manage-providers). - To control what the agent can access, refer to [Policies](/sandboxes/policies). diff --git a/docs/about/overview.mdx b/docs/about/overview.mdx index 3d23c8794..6ef41b34a 100644 --- a/docs/about/overview.mdx +++ b/docs/about/overview.mdx @@ -44,7 +44,7 @@ OpenShell supports a range of agent deployment patterns. | Use Case | Description | |-----------------------------|----------------------------------------------------------------------------------------------------------| -| Secure coding agents | Run Claude Code, OpenCode, or OpenClaw with constrained file and network access. | +| Secure coding agents | Run Claude Code, OpenCode, Codex, or GitHub Copilot CLI with constrained file and network access. | | Private enterprise development | Route inference to self-hosted or private backends while keeping sensitive context under your control. | | Compliance and audit | Treat policy YAML as version-controlled security controls that can be reviewed and audited. | | Reusable environments | Use community sandbox images or bring your own containerized runtime. | diff --git a/docs/about/release-notes.mdx b/docs/about/release-notes.mdx index 2efa2b82c..f763c8995 100644 --- a/docs/about/release-notes.mdx +++ b/docs/about/release-notes.mdx @@ -5,7 +5,7 @@ title: "NVIDIA OpenShell Release Notes" sidebar-title: "Release Notes" description: "Track the latest changes and improvements to NVIDIA OpenShell." keywords: "Generative AI, Cybersecurity, Release Notes, Changelog, AI Agents" -position: 5 +position: 6 --- NVIDIA OpenShell follows a frequent release cadence. Use the following GitHub resources directly. diff --git a/docs/about/supported-agents.mdx b/docs/about/supported-agents.mdx index ec03cf755..7eeb3db5c 100644 --- a/docs/about/supported-agents.mdx +++ b/docs/about/supported-agents.mdx @@ -4,9 +4,9 @@ title: "Supported Agents" description: "AI agent frameworks and runtimes compatible with OpenShell sandboxes." keywords: "Generative AI, Cybersecurity, AI Agents, Sandboxing, Claude, Codex, Cursor" -position: 4 +position: 5 --- -The following table summarizes the agents that run in OpenShell sandboxes. All agent sandbox images are maintained in the [OpenShell Community](https://github.com/NVIDIA/OpenShell-Community) repository. Agents in the base image are auto-configured when passed as the trailing command to `openshell sandbox create`. +The following table summarizes the agents that run in OpenShell sandboxes. Most agent sandbox images are maintained in the [OpenShell Community](https://github.com/NVIDIA/OpenShell-Community) repository. Agents in the base image are auto-configured when passed as the trailing command to `openshell sandbox create`. | Agent | Source | Default Policy | Notes | |---|---|---|---| @@ -14,8 +14,9 @@ The following table summarizes the agents that run in OpenShell sandboxes. All a | [OpenCode](https://opencode.ai/) | [`base`](https://github.com/NVIDIA/OpenShell-Community/tree/main/sandboxes/base) | Partial coverage | Pre-installed. Add `opencode.ai` endpoint and OpenCode binary paths to the policy for full functionality. | | [Codex](https://developers.openai.com/codex) | [`base`](https://github.com/NVIDIA/OpenShell-Community/tree/main/sandboxes/base) | No coverage | Pre-installed. Requires a custom policy with OpenAI endpoints and Codex binary paths. Requires `OPENAI_API_KEY`. | | [GitHub Copilot CLI](https://docs.github.com/en/copilot/github-copilot-in-the-cli) | [`base`](https://github.com/NVIDIA/OpenShell-Community/tree/main/sandboxes/base) | Full coverage | Pre-installed. Works out of the box. Requires `GITHUB_TOKEN` or `COPILOT_GITHUB_TOKEN`. | -| [OpenClaw](https://openclaw.ai/) | [`openclaw`](https://github.com/NVIDIA/OpenShell-Community/tree/main/sandboxes/openclaw) | Bundled | Agent orchestration layer. Launch with `openshell sandbox create --from openclaw`. | +| [OpenClaw](https://openclaw.ai/) | [NemoClaw](https://github.com/NVIDIA/NemoClaw) | Blueprint-managed | Run OpenClaw more securely inside NVIDIA OpenShell with managed inference using NemoClaw. | | [Ollama](https://ollama.com/) | [`ollama`](https://github.com/NVIDIA/OpenShell-Community/tree/main/sandboxes/ollama) | Bundled | Run cloud and local models. Includes Claude Code, Codex, and OpenCode. Launch with `openshell sandbox create --from ollama`. | +| [Pi](https://pi.dev/) | [`pi`](https://github.com/NVIDIA/OpenShell-Community/tree/main/sandboxes/pi) | Bundled | Comes with Pi pre-installed. Launch with `openshell sandbox create --from pi`. | For base image details and `--from` usage, refer to [Sandboxes](/sandboxes/manage-sandboxes#base-sandbox-container). diff --git a/docs/get-started/quickstart.mdx b/docs/get-started/quickstart.mdx index fcfbc945b..0955c7673 100644 --- a/docs/get-started/quickstart.mdx +++ b/docs/get-started/quickstart.mdx @@ -36,7 +36,7 @@ If you prefer [uv](https://docs.astral.sh/uv/): uv tool install -U openshell ``` -After installing the CLI, run `openshell --help` in your terminal to see the full CLI reference. +After installing the CLI, run `openshell --help` in your terminal to view the full CLI reference. You can also clone the [NVIDIA OpenShell GitHub repository](https://github.com/NVIDIA/OpenShell) and use the `/openshell-cli` skill to load the CLI reference into your agent. @@ -90,17 +90,6 @@ If `OPENAI_API_KEY` is set in your environment, the CLI picks it up automaticall If not, you can configure it from inside the sandbox after it launches. - - -Run the following command to create a sandbox with OpenClaw: - -```shell -openshell sandbox create --from openclaw -``` - -The `--from` flag pulls a pre-built sandbox container with its bundled policy and optional skills. - - Use the `--from` flag to create a sandbox from the base container: diff --git a/docs/get-started/tutorials/first-network-policy.mdx b/docs/get-started/tutorials/first-network-policy.mdx index effe7c445..3e4593308 100644 --- a/docs/get-started/tutorials/first-network-policy.mdx +++ b/docs/get-started/tutorials/first-network-policy.mdx @@ -4,7 +4,7 @@ title: "Write Your First Sandbox Network Policy" sidebar-title: "First Network Policy" slug: "get-started/tutorials/first-network-policy" -description: "See how OpenShell network policies work by creating a sandbox, observing default-deny in action, and applying a fine-grained L7 read-only rule." +description: "Learn how OpenShell network policies work by creating a sandbox, observing default-deny in action, and applying a fine-grained L7 read-only rule." keywords: "Generative AI, Cybersecurity, Tutorial, Policy, Network Policy, Sandbox, Security" --- @@ -117,7 +117,7 @@ network_policies: - { path: /usr/bin/curl } ``` -The `filesystem_policy`, `landlock`, and `process` sections preserve the default sandbox settings. This is required because `policy set` replaces the entire policy. The `network_policies` section is the key part: `curl` may make GET, HEAD, and OPTIONS requests to `api.github.com` over HTTPS. Everything else is denied. The proxy auto-detects TLS on HTTPS endpoints and terminates it to inspect each HTTP request and enforce the `read-only` access preset at the method level. +The `filesystem_policy`, `landlock`, and `process` sections preserve the default sandbox settings. This is required because `policy set` replaces the entire policy. The `network_policies` section is the key part: `curl` can make GET, HEAD, and OPTIONS requests to `api.github.com` over HTTPS. Everything else is denied. The proxy auto-detects TLS on HTTPS endpoints and terminates it to inspect each HTTP request and enforce the `read-only` access preset at the method level. Apply it: diff --git a/docs/get-started/tutorials/github-sandbox.mdx b/docs/get-started/tutorials/github-sandbox.mdx index 1f225fb60..7c76d4e41 100644 --- a/docs/get-started/tutorials/github-sandbox.mdx +++ b/docs/get-started/tutorials/github-sandbox.mdx @@ -11,7 +11,7 @@ keywords: "Generative AI, Cybersecurity, Tutorial, GitHub, Sandbox, Policy, Clau This tutorial walks through an iterative sandbox policy workflow. You launch a sandbox, ask Claude Code to push code to GitHub, and observe the default network policy denying the request. You then diagnose the denial from your machine and from inside the sandbox, apply a policy update, and verify that the policy update to the sandbox takes effect. -After completing this tutorial, you will have: +After completing this tutorial, you have: - A running sandbox with Claude Code that can push to a GitHub repository. - A custom network policy that grants GitHub access for a specific repository. @@ -131,7 +131,7 @@ The sandbox runs a proxy that enforces policies on outbound traffic. The `github_rest_api` policy allows GET requests (used to read the file) but blocks PUT/write requests to GitHub. This is a sandbox-level restriction, not a token issue. No matter what token you provide, pushes through the API -will be blocked until the policy is updated. +are blocked until you update the policy. Both perspectives confirm the same thing: the proxy is doing its job. The default policy is designed to be restrictive. To allow GitHub pushes, you need to update the network policy. @@ -162,7 +162,7 @@ Refer to the following policy example to compare with the generated policy befor The following YAML shows a complete policy that extends the [default policy](/reference/default-policy) with GitHub access for a single repository. Replace `` with your GitHub organization or username and `` with your repository name. -The `filesystem_policy`, `landlock`, and `process` sections are static. They are read once at sandbox creation and cannot be changed by a hot-reload. They are included here for completeness so the file is self-contained, but only the `network_policies` section takes effect when you apply this to a running sandbox. +The `filesystem_policy`, `landlock`, and `process` sections are static. OpenShell reads them at sandbox creation, and a hot reload cannot change them. They are included here for completeness so the file is self-contained, but only the `network_policies` section takes effect when you apply this to a running sandbox. ```yaml version: 1 diff --git a/docs/get-started/tutorials/index.mdx b/docs/get-started/tutorials/index.mdx index f5a79b543..0d82509ad 100644 --- a/docs/get-started/tutorials/index.mdx +++ b/docs/get-started/tutorials/index.mdx @@ -22,6 +22,11 @@ Create a sandbox, observe default-deny networking, apply a read-only L7 policy, Launch Claude Code in a sandbox, diagnose a policy denial, and iterate on a custom GitHub policy from outside the sandbox. + + +Configure a Providers v2 Microsoft Graph provider with gateway-managed OAuth2 refresh-token rotation. + + Route inference through Ollama using cloud-hosted or local models, and verify it from a sandbox. @@ -29,6 +34,6 @@ Route inference through Ollama using cloud-hosted or local models, and verify it -Route inference to a local LM Studio server via the OpenAI or Anthropic compatible APIs. +Route inference to a local LM Studio server using the OpenAI-compatible or Anthropic-compatible APIs. diff --git a/docs/get-started/tutorials/inference-ollama.mdx b/docs/get-started/tutorials/inference-ollama.mdx index 4a16eee01..4f46b847e 100644 --- a/docs/get-started/tutorials/inference-ollama.mdx +++ b/docs/get-started/tutorials/inference-ollama.mdx @@ -8,12 +8,12 @@ description: "Run local and cloud models inside an OpenShell sandbox using the O keywords: "Generative AI, Cybersecurity, Tutorial, Inference Routing, Ollama, Local Inference, Sandbox" --- -This tutorial covers two ways to use Ollama with OpenShell: +This tutorial covers two ways of running Ollama with OpenShell: -1. **Ollama sandbox (recommended)** — a self-contained sandbox with Ollama, Claude Code, and Codex pre-installed. One command to start. -2. **Host-level Ollama** — run Ollama on the gateway host and route sandbox inference to it. Useful when you want a single Ollama instance shared across multiple sandboxes. +1. Ollama sandbox. This is the recommended way to run Ollama. A self-contained sandbox with Ollama, Claude Code, and Codex pre-installed. One command starts it. +2. Host-level Ollama. This is an alternative way to run Ollama. Run Ollama on the gateway host and route sandbox inference to it. Use this option when you want a single Ollama instance shared across multiple sandboxes. -After completing this tutorial, you will know how to: +After completing this tutorial, you know how to: - Launch the Ollama community sandbox for a batteries-included experience. - Use `ollama launch` to start coding agents inside a sandbox. @@ -190,11 +190,11 @@ The response should be JSON from the model. Common issues and fixes: -- **Ollama not reachable from sandbox** — Ollama must be bound to `0.0.0.0`, not `127.0.0.1`. This applies to host-level Ollama only; the community sandbox handles this automatically. -- **`OPENAI_BASE_URL` wrong** — Use `http://host.openshell.internal:11434/v1`, not `localhost` or `127.0.0.1`. -- **Model not found** — Run `ollama ps` to confirm the model is loaded. Run `ollama pull ` if needed. -- **HTTPS vs HTTP** — Code inside sandboxes must call `https://inference.local`, not `http://`. -- **AMD GPU driver issues** — Ollama v0.18+ requires ROCm 7 drivers for AMD GPUs. Update your drivers if you see GPU detection failures. +- **Ollama not reachable from sandbox:** Ollama must be bound to `0.0.0.0`, not `127.0.0.1`. This applies to host-level Ollama only; the community sandbox handles this automatically. +- **`OPENAI_BASE_URL` wrong:** Use `http://host.openshell.internal:11434/v1`, not `localhost` or `127.0.0.1`. +- **Model not found:** Run `ollama ps` to confirm the model is loaded. Run `ollama pull ` if needed. +- **HTTPS instead of HTTP:** Code inside sandboxes must call `https://inference.local`, not `http://`. +- **AMD GPU driver issues:** Ollama v0.18+ requires ROCm 7 drivers for AMD GPUs. Update your drivers if you see GPU detection failures. Useful commands: diff --git a/docs/get-started/tutorials/local-inference-lmstudio.mdx b/docs/get-started/tutorials/local-inference-lmstudio.mdx index 2d1246459..7ce604009 100644 --- a/docs/get-started/tutorials/local-inference-lmstudio.mdx +++ b/docs/get-started/tutorials/local-inference-lmstudio.mdx @@ -15,7 +15,7 @@ The LM Studio server provides easy setup with both OpenAI and Anthropic compatib -This tutorial will cover: +This tutorial covers: - Expose a local inference server to OpenShell sandboxes. - Verify end-to-end inference from inside a sandbox. @@ -54,11 +54,11 @@ lms daemon up Start the LM Studio local server from the Developer tab, and verify the OpenAI-compatible endpoint is enabled. -LM Studio will listen to `127.0.0.1:1234` by default. For use with OpenShell, you'll need to configure LM Studio to listen on all interfaces (`0.0.0.0`). +LM Studio listens to `127.0.0.1:1234` by default. For use with OpenShell, configure LM Studio to listen on all interfaces (`0.0.0.0`). -If you're using the GUI, go to the Developer Tab, select Server Settings, then enable Serve on Local Network. +If you use the GUI, go to the Developer Tab, select Server Settings, then enable Serve on Local Network. -If you're using llmster in headless mode, run `lms server start --bind 0.0.0.0`. +If you use llmster in headless mode, run `lms server start --bind 0.0.0.0`. ## Test with a small model diff --git a/docs/get-started/tutorials/microsoft-graph-provider-refresh.mdx b/docs/get-started/tutorials/microsoft-graph-provider-refresh.mdx new file mode 100644 index 000000000..eb68b147c --- /dev/null +++ b/docs/get-started/tutorials/microsoft-graph-provider-refresh.mdx @@ -0,0 +1,189 @@ +--- +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +title: "Refresh Microsoft Graph Credentials with Providers v2" +sidebar-title: "Microsoft Graph Provider Refresh" +slug: "get-started/tutorials/microsoft-graph-provider-refresh" +description: "Configure a Providers v2 Microsoft Graph profile with gateway-managed OAuth2 refresh-token rotation." +keywords: "Generative AI, Cybersecurity, Tutorial, Providers, Microsoft Graph, OAuth2, Credential Refresh, Sandbox" +--- + +Use Providers v2 to keep Microsoft Graph access tokens short lived while sandboxes receive a stable `MS_GRAPH_ACCESS_TOKEN` placeholder. OpenShell stores the non-injectable refresh material at the gateway, refreshes the Microsoft Graph access token before it expires, updates the provider record, and injects the current credential into newly launched sandbox processes. + +After completing this tutorial, you have: + +- A custom Microsoft Graph mail provider profile. +- A provider instance configured with `oauth2-refresh-token`. +- A sandbox that can use `curl` to read Microsoft Graph mail through provider-owned policy. + + +This tutorial starts after your OAuth client has already completed the initial Microsoft sign-in flow. It does not publish a token bootstrap script. Use the Microsoft identity platform documentation for the [device authorization grant flow](https://learn.microsoft.com/en-ie/entra/identity-platform/v2-oauth2-device-code) or [authorization code flow](https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow), and use any standards-compliant client that returns an access token, refresh token, and expiry. + + +## Prerequisites + +- A working OpenShell installation with an active gateway. Complete the [Quickstart](/get-started/quickstart) before proceeding. +- A Microsoft Entra app registration that can acquire delegated Microsoft Graph mail access. +- Delegated Microsoft Graph mail permission for the signed-in user. `Mail.Read` allows reading the signed-in user's mailbox; see the [Microsoft Graph permissions reference](https://learn.microsoft.com/en-us/graph/permissions-reference). +OAuth material from your initial Microsoft sign-in flow: + +| Variable | Value | +|---|---| +| `MS_TENANT_ID` | Microsoft Entra tenant ID, domain, or `common`. | +| `MS_CLIENT_ID` | Microsoft Entra application client ID. | +| `MS_GRAPH_ACCESS_TOKEN` | Current delegated Microsoft Graph access token. | +| `MS_GRAPH_REFRESH_TOKEN` | Delegated OAuth refresh token. | +| `MS_GRAPH_ACCESS_TOKEN_EXPIRES_AT` | Absolute expiry for the current access token. | + +`MS_GRAPH_ACCESS_TOKEN_EXPIRES_AT` can be an RFC3339 timestamp such as `2026-01-01T00:00:00Z` or a Unix epoch millisecond timestamp. + + +Do not commit access tokens, refresh tokens, or local `.env` files. The commands below pass token material to the gateway; they are not examples of values to store in source control. + + + + +## Enable Providers v2 + +Enable provider profile policy composition on the active gateway: + +```shell +openshell settings set --global --key providers_v2_enabled --value true --yes +``` + +## Create a Microsoft Graph Provider Profile + +Create `microsoft-graph-mail.yaml` with this profile: + +```yaml showLineNumbers={false} +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +id: microsoft-graph-mail +display_name: Microsoft Graph Mail +description: Delegated Microsoft Graph mail read access +category: messaging +credentials: + - name: graph_access_token + description: Microsoft Graph delegated access token + env_vars: [MS_GRAPH_ACCESS_TOKEN] + required: true + auth_style: bearer + header_name: authorization + refresh: + strategy: oauth2_refresh_token + token_url: https://login.microsoftonline.com/common/oauth2/v2.0/token + scopes: [https://graph.microsoft.com/.default] + refresh_before_seconds: 600 + max_lifetime_seconds: 3600 + material: + - name: tenant_id + description: Microsoft Entra tenant ID + required: true + - name: client_id + description: Microsoft Entra application client ID + required: true + - name: refresh_token + description: Delegated OAuth refresh token + required: true + secret: true +endpoints: + - host: graph.microsoft.com + port: 443 + protocol: rest + access: read-only + enforcement: enforce +binaries: + - /usr/bin/curl + - /usr/local/bin/curl +``` + +Lint and import the profile: + +```shell +openshell provider profile lint -f microsoft-graph-mail.yaml +openshell provider profile import -f microsoft-graph-mail.yaml +``` + +The profile defines the refresh strategy and Graph network policy. The `tenant_id` refresh material selects the Microsoft token endpoint during gateway-managed refresh. + +## Create the Provider + +Create the provider with the current Microsoft Graph access token: + +```shell +openshell provider create \ + --name microsoft-mail \ + --type microsoft-graph-mail \ + --credential MS_GRAPH_ACCESS_TOKEN="$MS_GRAPH_ACCESS_TOKEN" +``` + +The current CLI requires an initial credential at provider creation time. Refresh material is configured separately and is not injected into the sandbox. + +## Configure Refresh + +Configure gateway-managed OAuth2 refresh-token rotation: + +```shell +openshell provider refresh configure microsoft-mail \ + --credential-key MS_GRAPH_ACCESS_TOKEN \ + --strategy oauth2-refresh-token \ + --material tenant_id="$MS_TENANT_ID" \ + --material client_id="$MS_CLIENT_ID" \ + --material refresh_token="$MS_GRAPH_REFRESH_TOKEN" \ + --secret-material-key refresh_token \ + --credential-expires-at "$MS_GRAPH_ACCESS_TOKEN_EXPIRES_AT" +``` + +`--secret-material-key refresh_token` names the material key to mark as sensitive. It is not the refresh-token value. If Microsoft returns a rotated refresh token during refresh, OpenShell stores the new `refresh_token` material and marks it secret automatically. + +Force the first refresh immediately: + +```shell +openshell provider refresh rotate microsoft-mail \ + --credential-key MS_GRAPH_ACCESS_TOKEN +``` + +Check refresh status: + +```shell +openshell provider refresh status microsoft-mail \ + --credential-key MS_GRAPH_ACCESS_TOKEN +``` + +The status output shows refresh state, expiry, next refresh, and last refresh timing. It does not print access-token values or refresh material. + +## Launch a Sandbox + +Launch a sandbox with the Microsoft Graph provider attached: + +```shell +openshell sandbox create \ + --name microsoft-graph-mail \ + --keep \ + --provider microsoft-mail \ + --no-auto-providers \ + -- /bin/sh +``` + +Provider policy allows `curl` to reach `graph.microsoft.com:443`. The sandbox process receives `MS_GRAPH_ACCESS_TOKEN` as an OpenShell placeholder, and the proxy resolves that placeholder to the current gateway-managed access token when `curl` sends it in the authorization header. + +## Verify Microsoft Graph Access + +Inside the sandbox, list a small page of mailbox messages: + +```shell +curl -sS \ + -H "Authorization: Bearer $MS_GRAPH_ACCESS_TOKEN" \ + 'https://graph.microsoft.com/v1.0/me/messages?$select=sender,subject&$top=5' +``` + +The request uses the [Microsoft Graph list messages API](https://learn.microsoft.com/en-us/graph/api/user-list-messages?view=graph-rest-1.0). If the token has delegated mail read permission, Microsoft Graph returns message metadata for the signed-in user's mailbox. + +## Update Running Sandboxes + +Provider refresh updates the provider record at the gateway. Running sandboxes poll for provider environment revisions, but already-running processes keep the environment they started with. + +If you attach this provider to an existing sandbox or update provider credentials after a process has already started, launch a new process inside the sandbox before expecting `MS_GRAPH_ACCESS_TOKEN` to appear in that process environment. + + diff --git a/docs/index.mdx b/docs/index.mdx index 4d358a827..91aba803d 100644 --- a/docs/index.mdx +++ b/docs/index.mdx @@ -38,7 +38,7 @@ uncontrolled network activity. Install OpenShell and create your first sandbox in two commands. -{/*Terminal demo styles live in fern/main.css — inline