From 084c93b6adb81f91ec3f598ae446d913a1f7d397 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Thu, 7 May 2026 16:05:40 -0700 Subject: [PATCH 001/157] fix(installer): repair dev install package and service setup (#1252) --- crates/openshell-core/src/config.rs | 90 ++++++++++++++++++++-- crates/openshell-server/src/cli.rs | 4 +- docs/reference/sandbox-compute-drivers.mdx | 2 +- install-dev.sh | 24 +++++- 4 files changed, 108 insertions(+), 12 deletions(-) diff --git a/crates/openshell-core/src/config.rs b/crates/openshell-core/src/config.rs index 1ec06677b..3f0d34b0f 100644 --- a/crates/openshell-core/src/config.rs +++ b/crates/openshell-core/src/config.rs @@ -6,7 +6,9 @@ use serde::{Deserialize, Serialize}; use std::fmt; use std::net::SocketAddr; -use std::path::PathBuf; +#[cfg(unix)] +use std::os::unix::fs::FileTypeExt; +use std::path::{Path, PathBuf}; use std::process::Command; use std::str::FromStr; @@ -108,8 +110,8 @@ pub fn detect_driver() -> Option { return Some(ComputeDriverKind::Podman); } - // Docker: check if docker binary is available - if is_binary_available("docker") { + // Docker: check if the CLI is available or a local Docker socket exists. + if is_docker_available() { return Some(ComputeDriverKind::Docker); } @@ -124,6 +126,55 @@ fn is_binary_available(name: &str) -> bool { .is_ok_and(|output| output.status.success()) } +fn is_docker_available() -> bool { + is_binary_available("docker") || docker_socket_available() +} + +fn docker_socket_available() -> bool { + docker_socket_candidates() + .iter() + .any(|path| is_unix_socket(path)) +} + +fn docker_socket_candidates() -> Vec { + let mut candidates = Vec::new(); + + if let Ok(host) = std::env::var("DOCKER_HOST") + && let Some(path) = docker_host_unix_socket_path(&host) + { + candidates.push(path); + } + + candidates.push(PathBuf::from("/var/run/docker.sock")); + + if let Some(home) = std::env::var_os("HOME") { + candidates.push(PathBuf::from(home).join(".docker/run/docker.sock")); + } + + if let Some(runtime_dir) = std::env::var_os("XDG_RUNTIME_DIR") { + candidates.push(PathBuf::from(runtime_dir).join("docker.sock")); + } + + candidates +} + +fn docker_host_unix_socket_path(host: &str) -> Option { + let path = host.trim().strip_prefix("unix://")?; + (!path.is_empty()).then(|| PathBuf::from(path)) +} + +#[cfg(unix)] +fn is_unix_socket(path: &Path) -> bool { + path.metadata() + .is_ok_and(|metadata| metadata.file_type().is_socket()) +} + +#[cfg(not(unix))] +fn is_unix_socket(path: &Path) -> bool { + let _ = path; + false +} + /// Server configuration. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { @@ -563,8 +614,13 @@ const fn default_ssh_session_ttl_secs() -> u64 { #[cfg(test)] mod tests { - use super::{ComputeDriverKind, Config, detect_driver}; + use super::{ + ComputeDriverKind, Config, detect_driver, docker_host_unix_socket_path, is_unix_socket, + }; use std::net::SocketAddr; + #[cfg(unix)] + use std::os::unix::net::UnixListener; + use std::path::PathBuf; #[test] fn compute_driver_kind_parses_supported_values() { @@ -614,12 +670,36 @@ mod tests { #[test] fn detect_driver_returns_none_without_k8s_env_or_binaries() { // When KUBERNETES_SERVICE_HOST is not set and no docker/podman binaries - // are available, detect_driver should return None. + // or Docker socket are available, detect_driver should return None. // This test may pass or fail depending on the test environment, // but it documents the expected behavior. let _ = detect_driver(); // Returns Some or None based on environment } + #[test] + fn docker_host_unix_socket_path_parses_unix_hosts() { + assert_eq!( + docker_host_unix_socket_path("unix:///var/run/docker.sock"), + Some(PathBuf::from("/var/run/docker.sock")) + ); + assert_eq!(docker_host_unix_socket_path("tcp://127.0.0.1:2375"), None); + assert_eq!(docker_host_unix_socket_path("unix://"), None); + } + + #[cfg(unix)] + #[test] + fn is_unix_socket_detects_socket_files() { + let temp_dir = tempfile::tempdir().expect("create temp dir"); + let socket_path = temp_dir.path().join("docker.sock"); + let _listener = UnixListener::bind(&socket_path).expect("bind unix socket"); + + assert!(is_unix_socket(&socket_path)); + + let regular_file = temp_dir.path().join("not-a-socket"); + std::fs::write(®ular_file, b"not a socket").expect("write regular file"); + assert!(!is_unix_socket(®ular_file)); + } + #[test] #[allow(unsafe_code)] // std::env::set_var/remove_var require unsafe in Rust 2024 fn detect_driver_prefers_kubernetes_when_k8s_env_is_set() { diff --git a/crates/openshell-server/src/cli.rs b/crates/openshell-server/src/cli.rs index 6eb1ab2db..ccc08cf2b 100644 --- a/crates/openshell-server/src/cli.rs +++ b/crates/openshell-server/src/cli.rs @@ -67,8 +67,8 @@ struct Args { /// `kubernetes,podman`. The configuration format is future-proofed for /// multiple drivers, but the gateway currently requires exactly one. /// When unset, the gateway auto-detects the driver based on the runtime - /// environment (Kubernetes → Podman → Docker). VM is never auto-detected - /// and requires explicit configuration. + /// environment (Kubernetes → Podman → Docker CLI or socket). VM is never + /// auto-detected and requires explicit configuration. #[arg( long, alias = "driver", diff --git a/docs/reference/sandbox-compute-drivers.mdx b/docs/reference/sandbox-compute-drivers.mdx index cc78b3b80..4db867135 100644 --- a/docs/reference/sandbox-compute-drivers.mdx +++ b/docs/reference/sandbox-compute-drivers.mdx @@ -22,7 +22,7 @@ openshell-gateway --drivers docker You can also set the driver with `OPENSHELL_DRIVERS`. Supported values are `docker`, `podman`, `kubernetes`, and `vm`. -When `--drivers` and `OPENSHELL_DRIVERS` are unset, the gateway auto-detects Kubernetes, then Podman, then Docker. The VM driver is never auto-detected; configure it explicitly with `--drivers vm`. +When `--drivers` and `OPENSHELL_DRIVERS` are unset, the gateway auto-detects Kubernetes, then Podman, then Docker by CLI availability or a local Unix socket. The VM driver is never auto-detected; configure it explicitly with `--drivers vm`. Common gateway options: diff --git a/install-dev.sh b/install-dev.sh index b0cf88f07..edacd2b8d 100755 --- a/install-dev.sh +++ b/install-dev.sh @@ -227,10 +227,26 @@ find_deb_asset() { _arch="$2" awk -v arch="$_arch" ' - $2 ~ "^\\*?openshell_.*_" arch "\\.deb$" { - sub("^\\*", "", $2) - print $2 - exit + { + name = $2 + sub("^\\*", "", name) + + if (name == "openshell-dev-" arch ".deb") { + selected = name + found = 1 + exit + } + + if (fallback == "" && name ~ "^openshell_.*_" arch "\\.deb$") { + fallback = name + } + } + END { + if (found) { + print selected + } else if (fallback != "") { + print fallback + } } ' "$_checksums" } From 62619eefc71a9ecdb6b1cd7b86ef138049451e62 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Thu, 7 May 2026 17:39:44 -0700 Subject: [PATCH 002/157] fix(docker): use supervisor image entrypoint path (#1259) Signed-off-by: Drew Newberry --- .../skills/debug-openshell-cluster/SKILL.md | 2 ++ .github/workflows/docker-build.yml | 4 ++-- .github/workflows/release-dev.yml | 2 ++ .github/workflows/release-tag.yml | 2 ++ .../workflows/shadow-rust-native-build.yml | 11 +++++++++ architecture/build.md | 4 +++- architecture/compute-runtimes.md | 2 +- crates/openshell-driver-docker/README.md | 15 ++++++++++++ crates/openshell-driver-docker/src/lib.rs | 2 +- .../openshell-driver-kubernetes/src/driver.rs | 24 +++++++++---------- crates/openshell-sandbox/src/main.rs | 11 ++++----- tasks/docker.toml | 2 +- 12 files changed, 57 insertions(+), 24 deletions(-) diff --git a/.agents/skills/debug-openshell-cluster/SKILL.md b/.agents/skills/debug-openshell-cluster/SKILL.md index 408ef85c7..16158c0dc 100644 --- a/.agents/skills/debug-openshell-cluster/SKILL.md +++ b/.agents/skills/debug-openshell-cluster/SKILL.md @@ -63,6 +63,7 @@ Use gateway metadata, deployment values, or the user's setup notes to identify t docker info docker ps --filter name=openshell docker logs --tail=200 +docker run --rm --entrypoint /openshell-sandbox "${OPENSHELL_DOCKER_SUPERVISOR_IMAGE:-ghcr.io/nvidia/openshell/supervisor:latest}" --version openshell status ``` @@ -71,6 +72,7 @@ Common findings: - Docker daemon unavailable: start Docker Desktop or Docker Engine. - Gateway process stopped: inspect exit status and logs. - Sandbox image missing or pull denied: verify image reference and registry credentials. +- Docker driver cannot initialize because it cannot find `openshell-sandbox`: verify `OPENSHELL_DOCKER_SUPERVISOR_BIN`, the sibling binary next to `openshell-gateway`, or the configured supervisor image contains `/openshell-sandbox`. - Sandbox never registers: check gateway logs and supervisor callback endpoint. For source checkout development, restart the local gateway with: diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 6c3807858..42d991b60 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -155,6 +155,7 @@ jobs: component: ${{ needs.resolve.outputs.binary_component }} arch: ${{ matrix.arch }} cargo-version: ${{ inputs['cargo-version'] }} + image-tag: ${{ needs.resolve.outputs.image_tag_base }} checkout-ref: ${{ inputs['checkout-ref'] }} features: openshell-core/dev-settings artifact-name: ${{ needs.resolve.outputs.artifact_prefix }}-linux-${{ matrix.arch }} @@ -238,7 +239,6 @@ jobs: --cache-to "type=gha,mode=max,scope=${{ inputs.component }}-${{ matrix.arch }}" - name: Smoke check ${{ inputs.component }} image - if: ${{ !inputs.push }} run: | set -euo pipefail image="${IMAGE_REGISTRY}/${{ inputs.component }}:${IMAGE_TAG}" @@ -249,7 +249,7 @@ jobs: grep -q '^openshell-gateway ' <<<"$output" ;; supervisor) - output="$(docker run --rm --platform "${{ matrix.platform }}" "$image" --version)" + output="$(docker run --rm --platform "${{ matrix.platform }}" --entrypoint /openshell-sandbox "$image" --version)" echo "$output" grep -q '^openshell-sandbox ' <<<"$output" ;; diff --git a/.github/workflows/release-dev.yml b/.github/workflows/release-dev.yml index 0385930bd..7b9bb2f92 100644 --- a/.github/workflows/release-dev.yml +++ b/.github/workflows/release-dev.yml @@ -432,6 +432,8 @@ jobs: sed -i -E '/^\[workspace\.package\]/,/^\[/{s/^version[[:space:]]*=[[:space:]]*".*"/version = "'"${{ needs.compute-versions.outputs.cargo_version }}"'"/}' Cargo.toml - name: Build ${{ matrix.target }} + env: + OPENSHELL_IMAGE_TAG: ${{ github.sha }} run: | set -euo pipefail mise x -- cargo build --release --target ${{ matrix.target }} -p openshell-server diff --git a/.github/workflows/release-tag.yml b/.github/workflows/release-tag.yml index 97c8422a2..60966b3b6 100644 --- a/.github/workflows/release-tag.yml +++ b/.github/workflows/release-tag.yml @@ -466,6 +466,8 @@ jobs: sed -i -E '/^\[workspace\.package\]/,/^\[/{s/^version[[:space:]]*=[[:space:]]*".*"/version = "'"${{ needs.compute-versions.outputs.cargo_version }}"'"/}' Cargo.toml - name: Build ${{ matrix.target }} + env: + OPENSHELL_IMAGE_TAG: ${{ needs.compute-versions.outputs.source_sha }} run: | set -euo pipefail mise x -- cargo build --release --target ${{ matrix.target }} -p openshell-server diff --git a/.github/workflows/shadow-rust-native-build.yml b/.github/workflows/shadow-rust-native-build.yml index b943a1ddb..7113d260f 100644 --- a/.github/workflows/shadow-rust-native-build.yml +++ b/.github/workflows/shadow-rust-native-build.yml @@ -42,6 +42,11 @@ on: required: false type: string default: "" + image-tag: + description: "Supervisor image tag to bake into gateway binaries" + required: false + type: string + default: "" workflow_dispatch: inputs: component: @@ -85,6 +90,11 @@ on: required: false type: string default: "" + image-tag: + description: "Supervisor image tag to bake into gateway binaries" + required: false + type: string + default: "" permissions: contents: read @@ -207,6 +217,7 @@ jobs: # Preserve the release-codegen setting used by the old Dockerfile # Rust build path so image artifacts keep the same release profile. CARGO_PROFILE_RELEASE_CODEGEN_UNITS: "1" + OPENSHELL_IMAGE_TAG: ${{ inputs['image-tag'] }} run: | set -euo pipefail args=( diff --git a/architecture/build.md b/architecture/build.md index baf44eba9..266575efb 100644 --- a/architecture/build.md +++ b/architecture/build.md @@ -12,7 +12,7 @@ OpenShell builds these main artifacts: |---|---| | Gateway binary | `crates/openshell-server` | | CLI package and Python SDK | `python/openshell` plus Rust binaries where packaged | -| Gateway container image | `deploy/docker/Dockerfile.images` | +| Gateway and supervisor container images | `deploy/docker/Dockerfile.images` | | Helm chart | `deploy/helm/openshell` | | VM driver/runtime assets | `crates/openshell-driver-vm` | | Published docs site | `docs/` rendered by Fern config in `fern/` | @@ -25,6 +25,8 @@ The Docker image pipeline stages prebuilt Rust binaries, then builds container images from `deploy/docker/Dockerfile.images`. CI builds native artifacts on the target architecture, stages them under `deploy/docker/.build/`, and then uses Buildx to publish per-architecture images and multi-architecture tags. +Gateway image builds bake the corresponding supervisor image tag into the +gateway binary so Docker sandboxes do not depend on `:latest` by default. Local image work should use `mise` tasks rather than direct Docker commands so the same staging and tagging assumptions are used locally and in CI. diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index 095b7d020..33917a28f 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -38,7 +38,7 @@ The supervisor must be available inside each sandbox workload: | Runtime | Delivery model | |---|---| -| Docker | Bind-mounted or extracted supervisor binary configured by the gateway. | +| Docker | Bind-mounted local supervisor binary, or a binary extracted from the configured supervisor image. | | Podman | Read-only OCI image volume containing the supervisor binary. | | Kubernetes | Sandbox pod image or pod template configuration. | | VM | Embedded in the guest rootfs bundle. | diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index 7bc8048b2..99b6e1385 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -34,6 +34,21 @@ contract: The agent child process does not retain these supervisor privileges. +## Supervisor Binary Resolution + +The Docker driver bind-mounts a host-side Linux `openshell-sandbox` binary into +each sandbox container. Resolution order is: + +1. `--docker-supervisor-bin` / `OPENSHELL_DOCKER_SUPERVISOR_BIN`. +2. A sibling `openshell-sandbox` next to the running `openshell-gateway` binary. +3. A local Linux cargo target build for the Docker daemon architecture. +4. `--docker-supervisor-image` / `OPENSHELL_DOCKER_SUPERVISOR_IMAGE`, or the + release-matched default supervisor image, extracting `/openshell-sandbox`. + +Release and Docker-image gateway builds bake the matching supervisor image tag +into the binary at compile time. The default Docker supervisor image is not +`:latest` unless a custom build explicitly sets that tag. + ## Callback and TLS `OPENSHELL_ENDPOINT` is injected from the gateway's configured gRPC endpoint diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 0eaef3bce..a864a3eb6 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -1759,7 +1759,7 @@ async fn extract_supervisor_binary_bytes(docker: &Docker, image: &str) -> CoreRe ), ContainerCreateBody { image: Some(image.to_string()), - entrypoint: Some(vec!["/openshell-sandbox".to_string()]), + entrypoint: Some(vec![SUPERVISOR_IMAGE_BINARY_PATH.to_string()]), cmd: Some(Vec::new()), ..Default::default() }, diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index e2d06044d..6c855f63e 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -688,19 +688,19 @@ fn supervisor_volume_mount() -> serde_json::Value { /// Path of the supervisor binary inside the supervisor image. /// -/// The supervisor image places the binary at the filesystem root and ships -/// nothing else. We invoke it directly — there is no shell, `cp`, or PATH -/// resolution available inside the image. +/// The supervisor image places the binary at the filesystem root. We invoke +/// it directly so the init path does not depend on shell utilities or PATH +/// resolution inside the image. const SUPERVISOR_IMAGE_BINARY_PATH: &str = "/openshell-sandbox"; /// Build the init container that copies the supervisor binary into the emptyDir. /// -/// The supervisor image contains only the supervisor binary at -/// `/openshell-sandbox`. We invoke that binary with the `copy-self` -/// subcommand so it copies itself into the shared emptyDir volume, where the -/// agent container then executes it from a fixed, writable path. This pattern -/// (binary self-copy) avoids requiring `sh`/`cp` in the supervisor image and -/// mirrors the approach used by argoexec's emissary executor. +/// The supervisor image contains the supervisor binary at `/openshell-sandbox`. +/// We invoke that binary with the `copy-self` subcommand so it copies itself +/// into the shared emptyDir volume, where the agent container then executes it +/// from a fixed, writable path. This pattern (binary self-copy) avoids requiring +/// `sh`/`cp` in the supervisor image and mirrors the approach used by argoexec's +/// emissary executor. fn supervisor_init_container( supervisor_image: &str, supervisor_image_pull_policy: &str, @@ -1559,8 +1559,8 @@ mod tests { assert_eq!(init_containers[0]["image"], "supervisor-image:latest"); assert_eq!(init_containers[0]["imagePullPolicy"], "IfNotPresent"); - // The supervisor image ships only the binary (no shell). The init - // container must invoke the binary directly with `copy-self `. + // The init container must invoke the binary directly with + // `copy-self ` rather than depending on shell utilities. let init_command = init_containers[0]["command"] .as_array() .expect("init container command should be set"); @@ -1573,7 +1573,7 @@ mod tests { ); assert!( !init_command.iter().any(|v| v == "sh"), - "init container must not depend on a shell (supervisor image ships only the binary)" + "init container must not depend on a shell" ); // Agent container command should be overridden to the emptyDir path diff --git a/crates/openshell-sandbox/src/main.rs b/crates/openshell-sandbox/src/main.rs index 20d455663..6ae1bd5fe 100644 --- a/crates/openshell-sandbox/src/main.rs +++ b/crates/openshell-sandbox/src/main.rs @@ -19,9 +19,9 @@ use openshell_sandbox::run_sandbox; /// Subcommand name used to self-copy the supervisor binary into a shared volume. /// -/// The supervisor image only ships the binary itself, so init containers -/// cannot rely on `sh`/`cp` to copy the binary out. Invoking the binary itself -/// with this argument performs the copy in pure Rust. +/// Init containers invoke the binary directly instead of relying on `sh`/`cp` +/// to copy the binary out. Invoking the binary itself with this argument +/// performs the copy in pure Rust. const COPY_SELF_SUBCOMMAND: &str = "copy-self"; /// `OpenShell` Sandbox - process isolation and monitoring. @@ -148,9 +148,8 @@ fn copy_self(dest: &str) -> Result<()> { fn main() -> Result<()> { // Handle `copy-self ` before clap so it works without any of the - // sandbox flags. The supervisor image only ships the binary itself, and - // Kubernetes init containers invoke this path to seed an emptyDir volume - // that the agent container then executes from. + // sandbox flags. Kubernetes init containers invoke this path to seed an + // emptyDir volume that the agent container then executes from. let raw_args: Vec = std::env::args().collect(); if raw_args.get(1).map(String::as_str) == Some(COPY_SELF_SUBCOMMAND) { let dest = raw_args.get(2).ok_or_else(|| { diff --git a/tasks/docker.toml b/tasks/docker.toml index a58fdcf86..502b2363c 100644 --- a/tasks/docker.toml +++ b/tasks/docker.toml @@ -27,7 +27,7 @@ run = "tasks/scripts/docker-build-image.sh gateway" hide = true ["build:docker:supervisor"] -description = "Build the supervisor image (FROM scratch, binary only)" +description = "Build the supervisor image" run = "tasks/scripts/docker-build-image.sh supervisor" hide = true From 8ab5ee875b09d2bca469baec3fa6a6fe5f630fd3 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Thu, 7 May 2026 20:03:16 -0700 Subject: [PATCH 003/157] fix(vm): harden compute driver socket (#1248) --- Cargo.lock | 1 + architecture/compute-runtimes.md | 8 +- crates/openshell-driver-vm/README.md | 6 +- crates/openshell-driver-vm/src/driver.rs | 217 +++++++++- crates/openshell-driver-vm/src/main.rs | 467 +++++++++++++++++++-- crates/openshell-server/Cargo.toml | 1 + crates/openshell-server/src/compute/vm.rs | 351 +++++++++++++++- docs/reference/sandbox-compute-drivers.mdx | 4 +- 8 files changed, 973 insertions(+), 82 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a1c390646..28d86ea93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3633,6 +3633,7 @@ dependencies = [ "metrics", "metrics-exporter-prometheus", "miette", + "nix", "openshell-core", "openshell-driver-docker", "openshell-driver-kubernetes", diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index 33917a28f..9ab512f98 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -23,7 +23,13 @@ Each runtime receives a sandbox spec from the gateway and is responsible for: | Docker | Local development with Docker available. | Container plus nested sandbox namespace. | Uses host networking so loopback gateway endpoints work from the supervisor. | | Podman | Rootless or single-machine deployments. | Container plus nested sandbox namespace. | Uses the Podman REST API, OCI image volumes, and CDI GPU devices when available. | | Kubernetes | Cluster deployment through Helm. | Pod plus nested sandbox namespace. | Uses Kubernetes API objects, service accounts, secrets, PVC-backed workspace storage, and GPU resources. | -| VM | Experimental microVM isolation. | Per-sandbox libkrun VM. | Gateway spawns `openshell-driver-vm` as a subprocess over a Unix socket. | +| VM | Experimental microVM isolation. | Per-sandbox libkrun VM. | Gateway spawns `openshell-driver-vm` as a subprocess over a private, state-local Unix socket. | + +VM runtime state paths are derived only from driver-validated sandbox IDs +matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a +private `run/` directory plus Unix peer UID/PID checks. Standalone +unauthenticated TCP mode is disabled unless explicitly enabled for local +development. Runtime-specific implementation notes belong in the driver crate README: diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index a3bdf9822..0a11ceb0a 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -42,7 +42,7 @@ By default `mise run gateway:vm`: - Listens on plaintext HTTP at `127.0.0.1:18081`. - Registers the CLI gateway `vm-dev` by writing `~/.config/openshell/gateways/vm-dev/metadata.json`. It does not modify the workspace `.env`. - Persists the gateway SQLite DB under `.cache/gateway-vm/gateway.db`. -- Places the VM driver state (per-sandbox rootfs + `compute-driver.sock`) under `/tmp/openshell-vm-driver-$USER-vm-dev/` so the AF_UNIX socket path stays under macOS `SUN_LEN`. +- Places the VM driver state (per-sandbox rootfs plus `run/compute-driver.sock`) under `/tmp/openshell-vm-driver-$USER-vm-dev/` so the AF_UNIX socket path stays under macOS `SUN_LEN`. - Passes `--driver-dir $PWD/target/debug` so the freshly built `openshell-driver-vm` is used instead of an older installed copy from `~/.local/libexec/openshell`, `/usr/libexec/openshell`, or `/usr/local/libexec`. For GPU passthrough (VFIO), pass `-- --gpu` and run with root privileges: @@ -124,7 +124,7 @@ The gateway resolves `openshell-driver-vm` in this order: `--driver-dir`, conven |---|---|---|---| | `--drivers vm` | `OPENSHELL_DRIVERS` | `kubernetes` | Select the VM compute driver. | | `--grpc-endpoint URL` | `OPENSHELL_GRPC_ENDPOINT` | — | Required. URL the sandbox guest dials to reach the gateway. Use `http://host.containers.internal:` (or `host.docker.internal` / `host.openshell.internal`) so traffic flows through gvproxy's host-loopback NAT (HostIP `192.168.127.254` → host `127.0.0.1`). Loopback URLs like `http://127.0.0.1:` are rewritten automatically by the driver. The bare gateway IP (`192.168.127.1`) only carries gvproxy's own services and will not reach host-bound ports. | -| `--vm-driver-state-dir DIR` | `OPENSHELL_VM_DRIVER_STATE_DIR` | `target/openshell-vm-driver` | Per-sandbox rootfs, console logs, and the `compute-driver.sock` UDS. | +| `--vm-driver-state-dir DIR` | `OPENSHELL_VM_DRIVER_STATE_DIR` | `target/openshell-vm-driver` | Per-sandbox rootfs, console logs, image cache, and private `run/compute-driver.sock` UDS. | | `--driver-dir DIR` | `OPENSHELL_DRIVER_DIR` | unset | Override the directory searched for `openshell-driver-vm`. | | `--vm-driver-vcpus N` | `OPENSHELL_VM_DRIVER_VCPUS` | `2` | vCPUs per sandbox. | | `--vm-driver-mem-mib N` | `OPENSHELL_VM_DRIVER_MEM_MIB` | `2048` | Memory per sandbox, in MiB. | @@ -156,7 +156,7 @@ RUST_LOG=openshell_server=debug,openshell_driver_vm=debug \ mise run gateway:vm ``` -The VM guest's serial console is appended to `//console.log`. The `compute-driver.sock` lives at `/compute-driver.sock`; the gateway removes it on clean shutdown via `ManagedDriverProcess::drop`. +The VM guest's serial console is appended to `//console.log`. Sandbox IDs must match `[A-Za-z0-9._-]{1,128}` before the driver uses them in host paths. The gateway-owned compute-driver socket lives at `/run/compute-driver.sock`; OpenShell creates `run/` with owner-only permissions, removes same-owner stale sockets, and the gateway removes the socket on clean shutdown via `ManagedDriverProcess::drop`. UDS clients must match the driver UID and provide the expected gateway process PID by default. Standalone same-UID UDS mode requires the explicit `--allow-same-uid-peer` development flag. TCP mode is disabled by default because it is unauthenticated; use `--allow-unauthenticated-tcp --bind-address 127.0.0.1:50061` only for local development. ## Prerequisites diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 92cab23af..b797f4835 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -38,7 +38,7 @@ use std::fs; use std::io::Read; use std::net::Ipv4Addr; use std::os::unix::fs::PermissionsExt; -use std::path::{Path, PathBuf}; +use std::path::{Component, Path, PathBuf}; use std::pin::Pin; use std::process::Stdio; use std::sync::Arc; @@ -362,7 +362,7 @@ impl VmDriver { let is_gpu = spec.is_some_and(|s| s.gpu); let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); - let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id); + let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id)?; let rootfs = state_dir.join("rootfs"); let image_ref = self.resolved_sandbox_image(sandbox).ok_or_else(|| { Status::failed_precondition( @@ -620,6 +620,10 @@ impl VmDriver { sandbox_id: &str, sandbox_name: &str, ) -> Result { + if !sandbox_id.is_empty() { + validate_sandbox_id(sandbox_id)?; + } + let record = { let registry = self.registry.lock().await; if let Some((id, record)) = registry.get_key_value(sandbox_id) { @@ -670,13 +674,7 @@ impl VmDriver { self.release_gpu_and_subnet(&record_id); } - if let Err(err) = tokio::fs::remove_dir_all(&state_dir).await - && err.kind() != std::io::ErrorKind::NotFound - { - return Err(Status::internal(format!( - "failed to remove state dir: {err}" - ))); - } + remove_sandbox_state_dir(&self.config.state_dir, &state_dir).await?; { let mut registry = self.registry.lock().await; @@ -692,6 +690,10 @@ impl VmDriver { sandbox_id: &str, sandbox_name: &str, ) -> Result, Status> { + if !sandbox_id.is_empty() { + validate_sandbox_id(sandbox_id)?; + } + let registry = self.registry.lock().await; let sandbox = if sandbox_id.is_empty() { registry @@ -1452,6 +1454,8 @@ fn check_gpu_privileges() -> Result<(), String> { // gRPC API surface, so boxing here would diverge from every other handler. #[allow(clippy::result_large_err)] fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Status> { + validate_sandbox_id(&sandbox.id)?; + let spec = sandbox .spec .as_ref() @@ -1487,6 +1491,32 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu Ok(()) } +#[allow(clippy::result_large_err)] +fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { + if sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox id is required")); + } + if sandbox_id.len() > 128 { + return Err(Status::invalid_argument( + "sandbox id exceeds maximum length (128 bytes)", + )); + } + if matches!(sandbox_id, "." | "..") { + return Err(Status::invalid_argument( + "sandbox id must match [A-Za-z0-9._-]{1,128}", + )); + } + if !sandbox_id + .bytes() + .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'_' | b'-')) + { + return Err(Status::invalid_argument( + "sandbox id must match [A-Za-z0-9._-]{1,128}", + )); + } + Ok(()) +} + #[allow(clippy::result_large_err)] fn parse_registry_reference(image_ref: &str) -> Result { Reference::try_from(image_ref).map_err(|err| { @@ -2236,8 +2266,71 @@ fn sandboxes_root_dir(root: &Path) -> PathBuf { root.join("sandboxes") } -fn sandbox_state_dir(root: &Path, sandbox_id: &str) -> PathBuf { - sandboxes_root_dir(root).join(sandbox_id) +#[allow(clippy::result_large_err)] +fn sandbox_state_dir(root: &Path, sandbox_id: &str) -> Result { + validate_sandbox_id(sandbox_id)?; + Ok(sandboxes_root_dir(root).join(sandbox_id)) +} + +#[allow(clippy::result_large_err)] +fn validate_sandbox_state_dir(root: &Path, state_dir: &Path) -> Result<(), Status> { + let sandboxes_root = sandboxes_root_dir(root); + let relative = state_dir.strip_prefix(&sandboxes_root).map_err(|_| { + Status::internal(format!( + "refusing to use sandbox state path outside vm state root: {}", + state_dir.display() + )) + })?; + + let mut components = relative.components(); + match components.next() { + Some(Component::Normal(_)) => {} + _ => { + return Err(Status::internal(format!( + "refusing to use malformed sandbox state path: {}", + state_dir.display() + ))); + } + } + if components.next().is_some() { + return Err(Status::internal(format!( + "refusing to use nested sandbox state path: {}", + state_dir.display() + ))); + } + + Ok(()) +} + +async fn remove_sandbox_state_dir(root: &Path, state_dir: &Path) -> Result<(), Status> { + validate_sandbox_state_dir(root, state_dir)?; + + let metadata = match tokio::fs::symlink_metadata(state_dir).await { + Ok(metadata) => metadata, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(()), + Err(err) => { + return Err(Status::internal(format!( + "failed to stat sandbox state dir: {err}" + ))); + } + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Status::internal(format!( + "refusing to remove symlinked sandbox state dir: {}", + state_dir.display() + ))); + } + if !file_type.is_dir() { + return Err(Status::internal(format!( + "sandbox state path is not a directory: {}", + state_dir.display() + ))); + } + + tokio::fs::remove_dir_all(state_dir) + .await + .map_err(|err| Status::internal(format!("failed to remove state dir: {err}"))) } fn image_cache_root_dir(root: &Path) -> PathBuf { @@ -2430,6 +2523,7 @@ mod tests { }; use prost_types::{Struct, Value, value::Kind}; use std::fs; + use std::path::Path; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{SystemTime, UNIX_EPOCH}; use tonic::Code; @@ -2437,6 +2531,7 @@ mod tests { #[test] fn validate_vm_sandbox_rejects_gpu_when_not_enabled() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: true, ..Default::default() @@ -2452,6 +2547,7 @@ mod tests { #[test] fn validate_vm_sandbox_accepts_gpu_when_enabled() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: true, ..Default::default() @@ -2464,6 +2560,7 @@ mod tests { #[test] fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: false, gpu_device: "0000:2d:00.0".to_string(), @@ -2480,6 +2577,7 @@ mod tests { #[test] fn validate_vm_sandbox_rejects_platform_config() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { template: Some(SandboxTemplate { platform_config: Some(Struct { @@ -2506,6 +2604,7 @@ mod tests { #[test] fn validate_vm_sandbox_accepts_template_image() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { template: Some(SandboxTemplate { image: "ghcr.io/example/sandbox:latest".to_string(), @@ -2518,6 +2617,51 @@ mod tests { validate_vm_sandbox(&sandbox, false).expect("template.image should be accepted"); } + #[test] + fn validate_vm_sandbox_rejects_path_unsafe_ids() { + let mut unsafe_ids = [ + "", + ".", + "..", + "../escape", + "/tmp/escape", + "nested/path", + "nested\\path", + "bad\nid", + "bad id", + "unicodé", + ] + .into_iter() + .map(str::to_string) + .collect::>(); + unsafe_ids.push("a".repeat(129)); + + for sandbox_id in unsafe_ids { + let sandbox = Sandbox { + id: sandbox_id.clone(), + spec: Some(SandboxSpec { + template: Some(SandboxTemplate { + image: "ghcr.io/example/sandbox:latest".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, false) + .expect_err("path-unsafe sandbox id should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument, "id={sandbox_id:?}"); + assert!(err.message().contains("sandbox id"), "id={sandbox_id:?}"); + } + } + + #[test] + fn sandbox_state_dir_rejects_path_unsafe_ids() { + let err = sandbox_state_dir(Path::new("/tmp/openshell-vm"), "../escape") + .expect_err("path traversal should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + } + #[test] fn capabilities_report_configured_default_image() { let driver = VmDriver { @@ -2919,9 +3063,14 @@ mod tests { #[tokio::test] async fn delete_sandbox_keeps_registry_entry_when_cleanup_fails() { + let base = unique_temp_dir(); + let driver_state = base.join("driver-state"); let (events, _) = broadcast::channel(WATCH_BUFFER); let driver = VmDriver { - config: VmDriverConfig::default(), + config: VmDriverConfig { + state_dir: driver_state.clone(), + ..Default::default() + }, launcher_bin: PathBuf::from("openshell-driver-vm"), registry: Arc::new(Mutex::new(HashMap::new())), image_cache_lock: Arc::new(Mutex::new(())), @@ -2933,9 +3082,8 @@ mod tests { ))), }; - let base = unique_temp_dir(); - std::fs::create_dir_all(&base).unwrap(); - let state_file = base.join("state-file"); + let state_file = sandbox_state_dir(&driver_state, "sandbox-123").unwrap(); + std::fs::create_dir_all(state_file.parent().unwrap()).unwrap(); std::fs::write(&state_file, "not a directory").unwrap(); insert_test_record( @@ -2950,10 +3098,11 @@ mod tests { .delete_sandbox("sandbox-123", "sandbox-123") .await .expect_err("state dir cleanup should fail for a file path"); - assert!(err.message().contains("failed to remove state dir")); + assert!(err.message().contains("not a directory")); assert!(driver.registry.lock().await.contains_key("sandbox-123")); - let retry_state_dir = base.join("state-dir"); + std::fs::remove_file(&state_file).unwrap(); + let retry_state_dir = sandbox_state_dir(&driver_state, "sandbox-123").unwrap(); std::fs::create_dir_all(&retry_state_dir).unwrap(); { let mut registry = driver.registry.lock().await; @@ -2975,6 +3124,40 @@ mod tests { let _ = std::fs::remove_dir_all(base); } + #[tokio::test] + async fn remove_sandbox_state_dir_rejects_paths_outside_state_root() { + let base = unique_temp_dir(); + let state_root = base.join("driver-state"); + let outside = base.join("outside"); + std::fs::create_dir_all(&outside).unwrap(); + + let err = remove_sandbox_state_dir(&state_root, &outside) + .await + .expect_err("outside state paths should be rejected"); + assert!(err.message().contains("outside vm state root")); + + let _ = std::fs::remove_dir_all(base); + } + + #[cfg(unix)] + #[tokio::test] + async fn remove_sandbox_state_dir_rejects_symlinked_state_dir() { + let base = unique_temp_dir(); + let state_root = base.join("driver-state"); + let target = base.join("target"); + let state_dir = sandbox_state_dir(&state_root, "sandbox-123").unwrap(); + std::fs::create_dir_all(&target).unwrap(); + std::fs::create_dir_all(state_dir.parent().unwrap()).unwrap(); + std::os::unix::fs::symlink(&target, &state_dir).unwrap(); + + let err = remove_sandbox_state_dir(&state_root, &state_dir) + .await + .expect_err("symlinked state dir should be rejected"); + assert!(err.message().contains("symlinked sandbox state dir")); + + let _ = std::fs::remove_dir_all(base); + } + #[test] fn validate_openshell_endpoint_accepts_loopback_hosts() { validate_openshell_endpoint("http://127.0.0.1:8080") diff --git a/crates/openshell-driver-vm/src/main.rs b/crates/openshell-driver-vm/src/main.rs index 596e6c88d..ed9967f4a 100644 --- a/crates/openshell-driver-vm/src/main.rs +++ b/crates/openshell-driver-vm/src/main.rs @@ -2,22 +2,27 @@ // SPDX-License-Identifier: Apache-2.0 use clap::Parser; +use futures::Stream; use miette::{IntoDiagnostic, Result}; use openshell_core::VERSION; use openshell_core::proto::compute::v1::compute_driver_server::ComputeDriverServer; #[cfg(target_os = "macos")] use openshell_driver_vm::{VM_RUNTIME_DIR_ENV, configured_runtime_dir}; use openshell_driver_vm::{VmBackend, VmDriver, VmDriverConfig, VmLaunchConfig, procguard, run_vm}; +use std::io; use std::net::SocketAddr; -use std::path::PathBuf; -use tokio::net::UnixListener; -use tokio_stream::wrappers::UnixListenerStream; +use std::os::unix::fs::{FileTypeExt, MetadataExt, PermissionsExt}; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::net::{UnixListener, UnixStream}; use tracing::info; use tracing_subscriber::EnvFilter; #[derive(Parser, Debug)] #[command(name = "openshell-driver-vm")] #[command(version = VERSION)] +#[allow(clippy::struct_excessive_bools)] struct Args { #[arg(long, hide = true, default_value_t = false)] internal_run_vm: bool, @@ -46,15 +51,28 @@ struct Args { #[arg(long, hide = true, default_value_t = 1)] vm_krun_log_level: u32, + #[arg(long, env = "OPENSHELL_COMPUTE_DRIVER_BIND")] + bind_address: Option, + + #[arg(long, env = "OPENSHELL_COMPUTE_DRIVER_SOCKET")] + bind_socket: Option, + + #[arg(long, hide = true)] + expected_peer_pid: Option, + #[arg( long, - env = "OPENSHELL_COMPUTE_DRIVER_BIND", - default_value = "127.0.0.1:50061" + env = "OPENSHELL_COMPUTE_DRIVER_ALLOW_UNAUTHENTICATED_TCP", + default_value_t = false )] - bind_address: SocketAddr, + allow_unauthenticated_tcp: bool, - #[arg(long, env = "OPENSHELL_COMPUTE_DRIVER_SOCKET")] - bind_socket: Option, + #[arg( + long, + env = "OPENSHELL_COMPUTE_DRIVER_ALLOW_SAME_UID_PEER", + default_value_t = false + )] + allow_same_uid_peer: bool, #[arg(long, env = "OPENSHELL_LOG_LEVEL", default_value = "info")] log_level: String, @@ -154,6 +172,8 @@ async fn main() -> Result<()> { ) .init(); + let listen_mode = compute_driver_listen_mode(&args).map_err(|err| miette::miette!("{err}"))?; + // Arm procguard so that if the gateway is killed (SIGKILL or crash) // we also die. Without this the driver is reparented to init and // keeps its per-sandbox VM launchers alive forever. Launchers have @@ -170,18 +190,18 @@ async fn main() -> Result<()> { openshell_endpoint: args .openshell_endpoint .ok_or_else(|| miette::miette!("OPENSHELL_GRPC_ENDPOINT is required"))?, - state_dir: args.state_dir, + state_dir: args.state_dir.clone(), launcher_bin: None, - default_image: args.default_image, - ssh_handshake_secret: args.ssh_handshake_secret.unwrap_or_default(), + default_image: args.default_image.clone(), + ssh_handshake_secret: args.ssh_handshake_secret.clone().unwrap_or_default(), ssh_handshake_skew_secs: args.ssh_handshake_skew_secs, - log_level: args.log_level, + log_level: args.log_level.clone(), krun_log_level: args.krun_log_level, vcpus: args.vcpus, mem_mib: args.mem_mib, - guest_tls_ca: args.guest_tls_ca, - guest_tls_cert: args.guest_tls_cert, - guest_tls_key: args.guest_tls_key, + guest_tls_ca: args.guest_tls_ca.clone(), + guest_tls_cert: args.guest_tls_cert.clone(), + guest_tls_key: args.guest_tls_key.clone(), gpu_enabled: args.gpu, gpu_mem_mib: args.gpu_mem_mib, gpu_vcpus: args.gpu_vcpus, @@ -189,32 +209,241 @@ async fn main() -> Result<()> { .await .map_err(|err| miette::miette!("{err}"))?; - if let Some(socket_path) = args.bind_socket { - if let Some(parent) = socket_path.parent() { - std::fs::create_dir_all(parent).into_diagnostic()?; + match listen_mode { + ComputeDriverListenMode::Unix { + socket_path, + expected_peer_pid, + } => { + prepare_compute_driver_socket(&socket_path).map_err(|err| miette::miette!("{err}"))?; + + info!(socket = %socket_path.display(), "Starting vm compute driver"); + let listener = UnixListener::bind(&socket_path).into_diagnostic()?; + restrict_socket_permissions(&socket_path).map_err(|err| miette::miette!("{err}"))?; + let result = tonic::transport::Server::builder() + .add_service(ComputeDriverServer::new(driver)) + .serve_with_incoming(AuthenticatedUnixIncoming::new(listener, expected_peer_pid)) + .await + .into_diagnostic(); + let _ = std::fs::remove_file(&socket_path); + result + } + ComputeDriverListenMode::Tcp(bind_address) => { + info!(address = %bind_address, "Starting unauthenticated dev vm compute driver"); + tonic::transport::Server::builder() + .add_service(ComputeDriverServer::new(driver)) + .serve(bind_address) + .await + .into_diagnostic() } - match std::fs::remove_file(&socket_path) { - Ok(()) => {} - Err(err) if err.kind() == std::io::ErrorKind::NotFound => {} - Err(err) => return Err(err).into_diagnostic(), + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ComputeDriverListenMode { + Unix { + socket_path: PathBuf, + expected_peer_pid: Option, + }, + Tcp(SocketAddr), +} + +fn compute_driver_listen_mode(args: &Args) -> std::result::Result { + if let Some(socket_path) = args.bind_socket.clone() { + if args.expected_peer_pid.is_none() && !args.allow_same_uid_peer { + return Err( + "--expected-peer-pid is required with --bind-socket; use --allow-same-uid-peer only for local development" + .to_string(), + ); + } + return Ok(ComputeDriverListenMode::Unix { + socket_path, + expected_peer_pid: args.expected_peer_pid, + }); + } + + if !args.allow_unauthenticated_tcp { + return Err( + "--bind-socket is required; unauthenticated TCP mode is disabled unless --allow-unauthenticated-tcp is set for local development" + .to_string(), + ); + } + + let Some(bind_address) = args.bind_address else { + return Err("--bind-address is required with --allow-unauthenticated-tcp".to_string()); + }; + + Ok(ComputeDriverListenMode::Tcp(bind_address)) +} + +fn prepare_compute_driver_socket(socket_path: &Path) -> std::result::Result<(), String> { + let Some(parent) = socket_path.parent() else { + return Err(format!( + "vm compute driver socket path '{}' has no parent directory", + socket_path.display() + )); + }; + let expected_uid = current_euid(); + prepare_private_socket_dir(parent, expected_uid)?; + remove_stale_socket(socket_path, expected_uid) +} + +fn current_euid() -> u32 { + nix::unistd::Uid::effective().as_raw() +} + +fn prepare_private_socket_dir( + socket_dir: &Path, + expected_uid: u32, +) -> std::result::Result<(), String> { + std::fs::create_dir_all(socket_dir) + .map_err(|err| format!("create socket dir {}: {err}", socket_dir.display()))?; + let metadata = std::fs::symlink_metadata(socket_dir) + .map_err(|err| format!("stat socket dir {}: {err}", socket_dir.display()))?; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(format!( + "socket dir {} is a symlink; refusing to use it", + socket_dir.display() + )); + } + if !file_type.is_dir() { + return Err(format!( + "socket dir {} is not a directory", + socket_dir.display() + )); + } + if metadata.uid() != expected_uid { + return Err(format!( + "socket dir {} is owned by uid {} but current euid is {}", + socket_dir.display(), + metadata.uid(), + expected_uid + )); + } + std::fs::set_permissions(socket_dir, std::fs::Permissions::from_mode(0o700)) + .map_err(|err| format!("chmod socket dir {}: {err}", socket_dir.display())) +} + +fn remove_stale_socket(socket_path: &Path, expected_uid: u32) -> std::result::Result<(), String> { + let metadata = match std::fs::symlink_metadata(socket_path) { + Ok(metadata) => metadata, + Err(err) if err.kind() == io::ErrorKind::NotFound => return Ok(()), + Err(err) => return Err(format!("stat socket {}: {err}", socket_path.display())), + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(format!( + "socket {} is a symlink; refusing to remove it", + socket_path.display() + )); + } + if metadata.uid() != expected_uid { + return Err(format!( + "socket {} is owned by uid {} but current euid is {}", + socket_path.display(), + metadata.uid(), + expected_uid + )); + } + if !file_type.is_socket() { + return Err(format!( + "socket path {} exists but is not a Unix socket", + socket_path.display() + )); + } + std::fs::remove_file(socket_path) + .map_err(|err| format!("remove stale socket {}: {err}", socket_path.display())) +} + +fn restrict_socket_permissions(socket_path: &Path) -> std::result::Result<(), String> { + std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(0o600)) + .map_err(|err| format!("chmod socket {}: {err}", socket_path.display())) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct PeerCredentials { + uid: u32, + pid: Option, +} + +fn peer_credentials(stream: &UnixStream) -> std::result::Result { + let credentials = stream + .peer_cred() + .map_err(|err| format!("read peer credentials: {err}"))?; + Ok(PeerCredentials { + uid: credentials.uid(), + pid: credentials.pid(), + }) +} + +fn authorize_peer_credentials( + peer: PeerCredentials, + driver_uid: u32, + gateway_pid: Option, +) -> std::result::Result<(), String> { + if peer.uid != driver_uid { + return Err(format!( + "peer uid {} does not match current euid {}", + peer.uid, driver_uid + )); + } + let Some(gateway_pid) = gateway_pid else { + return Ok(()); + }; + let Some(peer_process_id) = peer.pid.and_then(|pid| u32::try_from(pid).ok()) else { + return Err(format!( + "peer pid is unavailable; expected gateway pid {gateway_pid}" + )); + }; + if peer_process_id != gateway_pid { + return Err(format!( + "peer pid {peer_process_id} does not match expected gateway pid {gateway_pid}" + )); + } + Ok(()) +} + +struct AuthenticatedUnixIncoming { + listener: UnixListener, + expected_uid: u32, + expected_peer_pid: Option, +} + +impl AuthenticatedUnixIncoming { + fn new(listener: UnixListener, expected_peer_pid: Option) -> Self { + Self { + listener, + expected_uid: current_euid(), + expected_peer_pid, } + } +} - info!(socket = %socket_path.display(), "Starting vm compute driver"); - let listener = UnixListener::bind(&socket_path).into_diagnostic()?; - let result = tonic::transport::Server::builder() - .add_service(ComputeDriverServer::new(driver)) - .serve_with_incoming(UnixListenerStream::new(listener)) - .await - .into_diagnostic(); - let _ = std::fs::remove_file(&socket_path); - result - } else { - info!(address = %args.bind_address, "Starting vm compute driver"); - tonic::transport::Server::builder() - .add_service(ComputeDriverServer::new(driver)) - .serve(args.bind_address) - .await - .into_diagnostic() +impl Stream for AuthenticatedUnixIncoming { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + loop { + match this.listener.poll_accept(cx) { + Poll::Ready(Ok((stream, _addr))) => { + let authorized = peer_credentials(&stream).and_then(|peer| { + authorize_peer_credentials(peer, this.expected_uid, this.expected_peer_pid) + }); + match authorized { + Ok(()) => return Poll::Ready(Some(Ok(stream))), + Err(err) => { + tracing::warn!( + error = %err, + "rejected vm compute driver UDS client" + ); + } + } + } + Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), + Poll::Pending => return Poll::Pending, + } + } } } @@ -310,3 +539,165 @@ fn maybe_reexec_internal_vm_with_runtime_env() -> Result<()> { fn maybe_reexec_internal_vm_with_runtime_env() -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::{ + Args, ComputeDriverListenMode, PeerCredentials, authorize_peer_credentials, + compute_driver_listen_mode, + }; + use clap::Parser; + use std::path::PathBuf; + + #[test] + fn peer_authorization_accepts_matching_uid_and_pid() { + authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: Some(42), + }, + 1000, + Some(42), + ) + .unwrap(); + } + + #[test] + fn peer_authorization_rejects_wrong_pid() { + let err = authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: Some(7), + }, + 1000, + Some(42), + ) + .expect_err("wrong pid should be rejected"); + assert!(err.contains("does not match expected gateway pid")); + } + + #[test] + fn peer_authorization_rejects_wrong_uid() { + let err = authorize_peer_credentials( + PeerCredentials { + uid: 1001, + pid: Some(42), + }, + 1000, + Some(42), + ) + .expect_err("wrong uid should be rejected"); + assert!(err.contains("does not match current euid")); + } + + #[test] + fn peer_authorization_rejects_missing_pid_when_expected() { + let err = authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: None, + }, + 1000, + Some(42), + ) + .expect_err("missing pid should be rejected"); + assert!(err.contains("peer pid is unavailable")); + } + + #[test] + fn peer_authorization_accepts_matching_uid_without_expected_pid() { + authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: None, + }, + 1000, + None, + ) + .unwrap(); + } + + #[test] + fn listen_mode_rejects_default_tcp() { + let args = Args::parse_from(["openshell-driver-vm"]); + let err = compute_driver_listen_mode(&args).expect_err("default TCP should be disabled"); + assert!(err.contains("--bind-socket is required")); + } + + #[test] + fn listen_mode_rejects_bind_address_without_tcp_opt_in() { + let args = Args::parse_from(["openshell-driver-vm", "--bind-address", "127.0.0.1:50061"]); + let err = + compute_driver_listen_mode(&args).expect_err("TCP bind should require explicit opt-in"); + assert!(err.contains("--allow-unauthenticated-tcp")); + } + + #[test] + fn listen_mode_requires_bind_address_with_tcp_opt_in() { + let args = Args::parse_from(["openshell-driver-vm", "--allow-unauthenticated-tcp"]); + let err = + compute_driver_listen_mode(&args).expect_err("TCP opt-in should require an address"); + assert!(err.contains("--bind-address is required")); + } + + #[test] + fn listen_mode_accepts_explicit_unauthenticated_tcp() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--allow-unauthenticated-tcp", + "--bind-address", + "127.0.0.1:50061", + ]); + assert_eq!( + compute_driver_listen_mode(&args).unwrap(), + ComputeDriverListenMode::Tcp("127.0.0.1:50061".parse().unwrap()) + ); + } + + #[test] + fn listen_mode_requires_expected_peer_pid_for_uds() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--bind-socket", + "/tmp/compute-driver.sock", + ]); + let err = compute_driver_listen_mode(&args) + .expect_err("UDS should require gateway peer pid by default"); + assert!(err.contains("--expected-peer-pid is required")); + } + + #[test] + fn listen_mode_accepts_uds_with_expected_peer_pid() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--bind-socket", + "/tmp/compute-driver.sock", + "--expected-peer-pid", + "42", + ]); + assert_eq!( + compute_driver_listen_mode(&args).unwrap(), + ComputeDriverListenMode::Unix { + socket_path: PathBuf::from("/tmp/compute-driver.sock"), + expected_peer_pid: Some(42), + } + ); + } + + #[test] + fn listen_mode_accepts_explicit_same_uid_uds_dev_mode() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--bind-socket", + "/tmp/compute-driver.sock", + "--allow-same-uid-peer", + ]); + assert_eq!( + compute_driver_listen_mode(&args).unwrap(), + ComputeDriverListenMode::Unix { + socket_path: PathBuf::from("/tmp/compute-driver.sock"), + expected_peer_pid: None, + } + ); + } +} diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index cb6561f3e..fab20186c 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -82,6 +82,7 @@ rand = { workspace = true } petname = "2" ipnet = "2" tempfile = "3" +nix = { workspace = true } [features] dev-settings = ["openshell-core/dev-settings"] diff --git a/crates/openshell-server/src/compute/vm.rs b/crates/openshell-server/src/compute/vm.rs index e5b974f74..1e62d4942 100644 --- a/crates/openshell-server/src/compute/vm.rs +++ b/crates/openshell-server/src/compute/vm.rs @@ -38,6 +38,10 @@ use openshell_core::proto::compute::v1::{ GetCapabilitiesRequest, compute_driver_client::ComputeDriverClient, }; use openshell_core::{Config, Error, Result}; +#[cfg(unix)] +use std::os::unix::fs::{FileTypeExt, MetadataExt, PermissionsExt}; +#[cfg(unix)] +use std::path::Path; use std::path::PathBuf; #[cfg(unix)] use std::{io::ErrorKind, process::Stdio, sync::Arc, time::Duration}; @@ -52,6 +56,8 @@ use tonic::transport::Endpoint; use tower::service_fn; const DRIVER_BIN_NAME: &str = "openshell-driver-vm"; +const COMPUTE_DRIVER_SOCKET_RUN_DIR: &str = "run"; +const COMPUTE_DRIVER_SOCKET_NAME: &str = "compute-driver.sock"; /// Configuration for launching and talking to the VM compute driver. #[derive(Debug, Clone)] @@ -210,7 +216,145 @@ fn push_unique_path(paths: &mut Vec, path: PathBuf) { /// Path of the Unix domain socket the driver will listen on. pub fn compute_driver_socket_path(vm_config: &VmComputeConfig) -> PathBuf { - vm_config.state_dir.join("compute-driver.sock") + vm_config + .state_dir + .join(COMPUTE_DRIVER_SOCKET_RUN_DIR) + .join(COMPUTE_DRIVER_SOCKET_NAME) +} + +#[cfg(unix)] +fn prepare_compute_driver_socket_path( + vm_config: &VmComputeConfig, + socket_path: &Path, +) -> Result<()> { + let expected_uid = current_euid(); + prepare_vm_state_dir(&vm_config.state_dir, expected_uid)?; + let parent = socket_path.parent().ok_or_else(|| { + Error::execution(format!( + "vm compute driver socket path '{}' has no parent directory", + socket_path.display() + )) + })?; + prepare_private_socket_dir(parent, expected_uid)?; + remove_stale_socket(socket_path, expected_uid) +} + +#[cfg(unix)] +fn current_euid() -> u32 { + nix::unistd::Uid::effective().as_raw() +} + +#[cfg(unix)] +fn prepare_vm_state_dir(state_dir: &Path, expected_uid: u32) -> Result<()> { + std::fs::create_dir_all(state_dir).map_err(|err| { + Error::execution(format!( + "failed to create vm driver state dir '{}': {err}", + state_dir.display() + )) + })?; + let metadata = checked_directory_metadata(state_dir, expected_uid, "vm driver state dir")?; + let mode = metadata.permissions().mode() & 0o777; + if mode & 0o022 != 0 { + return Err(Error::execution(format!( + "vm driver state dir '{}' must not be group/world-writable (mode {mode:03o})", + state_dir.display() + ))); + } + Ok(()) +} + +#[cfg(unix)] +fn prepare_private_socket_dir(socket_dir: &Path, expected_uid: u32) -> Result<()> { + std::fs::create_dir_all(socket_dir).map_err(|err| { + Error::execution(format!( + "failed to create vm compute driver socket dir '{}': {err}", + socket_dir.display() + )) + })?; + let _ = checked_directory_metadata(socket_dir, expected_uid, "vm compute driver socket dir")?; + std::fs::set_permissions(socket_dir, std::fs::Permissions::from_mode(0o700)).map_err(|err| { + Error::execution(format!( + "failed to restrict vm compute driver socket dir '{}': {err}", + socket_dir.display() + )) + }) +} + +#[cfg(unix)] +fn checked_directory_metadata( + path: &Path, + expected_uid: u32, + label: &str, +) -> Result { + let metadata = std::fs::symlink_metadata(path).map_err(|err| { + Error::execution(format!( + "failed to stat {label} '{}': {err}", + path.display() + )) + })?; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Error::execution(format!( + "{label} '{}' is a symlink; refusing to use it", + path.display() + ))); + } + if !file_type.is_dir() { + return Err(Error::execution(format!( + "{label} '{}' is not a directory", + path.display() + ))); + } + if metadata.uid() != expected_uid { + return Err(Error::execution(format!( + "{label} '{}' is owned by uid {} but current euid is {}", + path.display(), + metadata.uid(), + expected_uid + ))); + } + Ok(metadata) +} + +#[cfg(unix)] +fn remove_stale_socket(socket_path: &Path, expected_uid: u32) -> Result<()> { + let metadata = match std::fs::symlink_metadata(socket_path) { + Ok(metadata) => metadata, + Err(err) if err.kind() == ErrorKind::NotFound => return Ok(()), + Err(err) => { + return Err(Error::execution(format!( + "failed to stat vm compute driver socket '{}': {err}", + socket_path.display() + ))); + } + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Error::execution(format!( + "vm compute driver socket '{}' is a symlink; refusing to remove it", + socket_path.display() + ))); + } + if metadata.uid() != expected_uid { + return Err(Error::execution(format!( + "vm compute driver socket '{}' is owned by uid {} but current euid is {}", + socket_path.display(), + metadata.uid(), + expected_uid + ))); + } + if !file_type.is_socket() { + return Err(Error::execution(format!( + "vm compute driver socket path '{}' exists but is not a Unix socket", + socket_path.display() + ))); + } + std::fs::remove_file(socket_path).map_err(|err| { + Error::execution(format!( + "failed to remove stale vm compute driver socket '{}': {err}", + socket_path.display() + )) + }) } #[cfg(unix)] @@ -278,24 +422,7 @@ pub async fn spawn( let driver_bin = resolve_compute_driver_bin(vm_config)?; let socket_path = compute_driver_socket_path(vm_config); let guest_tls_paths = compute_driver_guest_tls_paths(config, vm_config)?; - if let Some(parent) = socket_path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - Error::execution(format!( - "failed to create vm compute driver socket dir '{}': {e}", - parent.display() - )) - })?; - } - match std::fs::remove_file(&socket_path) { - Ok(()) => {} - Err(err) if err.kind() == ErrorKind::NotFound => {} - Err(err) => { - return Err(Error::execution(format!( - "failed to remove stale vm compute driver socket '{}': {err}", - socket_path.display() - ))); - } - } + prepare_compute_driver_socket_path(vm_config, &socket_path)?; let mut command = Command::new(&driver_bin); command.kill_on_drop(true); @@ -303,6 +430,9 @@ pub async fn spawn( command.stdout(Stdio::inherit()); command.stderr(Stdio::inherit()); command.arg("--bind-socket").arg(&socket_path); + command + .arg("--expected-peer-pid") + .arg(std::process::id().to_string()); command.arg("--log-level").arg(&config.log_level); command .arg("--openshell-endpoint") @@ -356,7 +486,7 @@ pub async fn spawn( #[cfg(unix)] async fn wait_for_compute_driver( - socket_path: &std::path::Path, + socket_path: &Path, child: &mut tokio::process::Child, ) -> Result { let mut last_error: Option = None; @@ -395,7 +525,7 @@ async fn wait_for_compute_driver( } #[cfg(unix)] -async fn connect_compute_driver(socket_path: &std::path::Path) -> Result { +async fn connect_compute_driver(socket_path: &Path) -> Result { let socket_path = socket_path.to_path_buf(); let display_path = socket_path.clone(); Endpoint::from_static("http://[::]:50051") @@ -415,11 +545,13 @@ async fn connect_compute_driver(socket_path: &std::path::Path) -> Result` | `OPENSHELL_DRIVER_DIR` | Search a custom directory for `openshell-driver-vm`. | -| `--vm-driver-state-dir ` | `OPENSHELL_VM_DRIVER_STATE_DIR` | Store VM rootfs, console logs, runtime state, and image-rootfs cache under this directory. | +| `--vm-driver-state-dir ` | `OPENSHELL_VM_DRIVER_STATE_DIR` | Store VM rootfs, console logs, runtime state, image-rootfs cache, and the private `run/compute-driver.sock` socket under this directory. | | `--vm-driver-vcpus ` | `OPENSHELL_VM_DRIVER_VCPUS` | Set the default vCPU count for VM sandboxes. | | `--vm-driver-mem-mib ` | `OPENSHELL_VM_DRIVER_MEM_MIB` | Set the default memory allocation for VM sandboxes in MiB. | | `--vm-krun-log-level ` | `OPENSHELL_VM_KRUN_LOG_LEVEL` | Set the libkrun log level for VM helper processes. | | `--vm-tls-ca`, `--vm-tls-cert`, `--vm-tls-key` | `OPENSHELL_VM_TLS_CA`, `OPENSHELL_VM_TLS_CERT`, `OPENSHELL_VM_TLS_KEY` | Copy sandbox client TLS materials into VM guests for mTLS callback to the gateway. | +The gateway starts `openshell-driver-vm` over a private Unix socket and passes its process ID so the driver can reject unexpected local clients. The driver's standalone TCP listener is disabled unless `--allow-unauthenticated-tcp` is set for local development. + ## Kubernetes Driver Kubernetes-backed sandboxes run as pods in the configured sandbox namespace. Use Kubernetes for shared clusters, remote compute, GPU scheduling, and operator-managed environments. From 52097f2d4c89f612e93c4dbc7babc46268729279 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Thu, 7 May 2026 21:57:55 -0700 Subject: [PATCH 004/157] ci(release): run package release canaries (#1256) --- .github/workflows/release-canary.yml | 292 +++------------------------ install-dev.sh | 97 +++++++-- 2 files changed, 99 insertions(+), 290 deletions(-) diff --git a/.github/workflows/release-canary.yml b/.github/workflows/release-canary.yml index defe6f32a..415a4f597 100644 --- a/.github/workflows/release-canary.yml +++ b/.github/workflows/release-canary.yml @@ -2,293 +2,47 @@ name: Release Canary on: workflow_dispatch: - inputs: - tag: - description: "Release tag to test (e.g. dev, v1.2.3)" - required: true - type: string workflow_run: - workflows: ["Release Dev", "Release Tag"] + workflows: ["Release Dev"] types: [completed] permissions: contents: read - packages: read defaults: run: shell: bash jobs: - # --------------------------------------------------------------------------- - # Verify the default install path (no OPENSHELL_VERSION) resolves to latest - # --------------------------------------------------------------------------- - install-default: - name: Install default (${{ matrix.arch }}) - if: >- - github.event.workflow_run.conclusion == 'success' - && github.event.workflow_run.name == 'Release Tag' - strategy: - fail-fast: false - matrix: - include: - - arch: amd64 - runner: linux-amd64-cpu8 - - arch: arm64 - runner: linux-arm64-cpu8 - runs-on: ${{ matrix.runner }} - timeout-minutes: 10 - container: - image: ghcr.io/nvidia/openshell/ci:latest - credentials: - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - steps: - - name: Install CLI (default / latest) - run: | - set -euo pipefail - curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/main/install.sh | sh - - - name: Verify CLI installation - run: | - set -euo pipefail - command -v openshell - ACTUAL="$(openshell --version)" - echo "Installed: $ACTUAL" - # This job only runs after Release Tag, so the triggering tag - # should match the latest release the default installer resolves to. - TAG="${{ github.event.workflow_run.head_branch }}" - if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - EXPECTED="${TAG#v}" - if [[ "$ACTUAL" != *"$EXPECTED"* ]]; then - echo "::error::Version mismatch: expected '$EXPECTED' in '$ACTUAL'" - exit 1 - fi - echo "Version check passed: found $EXPECTED in output" - fi - - install-dev: - name: Install Debian package (${{ matrix.arch }}) + macos: + name: macOS Homebrew if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }} - strategy: - fail-fast: false - matrix: - include: - - arch: amd64 - runner: linux-amd64-cpu8 - - arch: arm64 - runner: linux-arm64-cpu8 - runs-on: ${{ matrix.runner }} - timeout-minutes: 10 - container: - image: ghcr.io/nvidia/openshell/ci:latest - credentials: - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + runs-on: macos-latest-xlarge + timeout-minutes: 20 steps: - - name: Determine release tag - id: release + - name: Install dev and check status run: | - set -euo pipefail - if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then - echo "tag=${{ inputs.tag }}" >> "$GITHUB_OUTPUT" - else - WORKFLOW_NAME="${{ github.event.workflow_run.name }}" - if [ "$WORKFLOW_NAME" = "Release Dev" ]; then - echo "tag=dev" >> "$GITHUB_OUTPUT" - elif [ "$WORKFLOW_NAME" = "Release Tag" ]; then - TAG="${{ github.event.workflow_run.head_branch }}" - if [ -z "$TAG" ]; then - echo "::error::Could not determine release tag from workflow_run" - exit 1 - fi - echo "tag=${TAG}" >> "$GITHUB_OUTPUT" - else - echo "::error::Unexpected triggering workflow: ${WORKFLOW_NAME}" - exit 1 - fi - fi + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install-dev.sh | sh + openshell status - - name: Install Debian package - run: | - set -euo pipefail - curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/main/install-dev.sh \ - | OPENSHELL_VERSION=${{ steps.release.outputs.tag }} sh - - - name: Verify gateway and VM driver versions - run: | - set -euo pipefail - command -v openshell-gateway - test -x /usr/libexec/openshell/openshell-driver-vm - - GATEWAY_ACTUAL="$(openshell-gateway --version)" - DRIVER_ACTUAL="$(/usr/libexec/openshell/openshell-driver-vm --version)" - echo "Gateway: ${GATEWAY_ACTUAL}" - echo "Driver: ${DRIVER_ACTUAL}" - - TAG="${{ steps.release.outputs.tag }}" - if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - EXPECTED="${TAG#v}" - for actual in "$GATEWAY_ACTUAL" "$DRIVER_ACTUAL"; do - if [[ "$actual" != *"$EXPECTED"* ]]; then - echo "::error::Version mismatch: expected '$EXPECTED' in '$actual'" - exit 1 - fi - done - echo "Version check passed: found $EXPECTED in both binaries" - else - echo "Non-release tag ($TAG), skipping version check" - fi - - canary: - name: Canary package gateway (${{ matrix.arch }}) + ubuntu: + name: Ubuntu Docker if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }} - strategy: - fail-fast: false - matrix: - include: - - arch: amd64 - runner: linux-amd64-cpu8 - - arch: arm64 - runner: linux-arm64-cpu8 - runs-on: ${{ matrix.runner }} - timeout-minutes: 30 - env: - OPENSHELL_REGISTRY_TOKEN: ${{ secrets.GITHUB_TOKEN }} - OPENSHELL_CANARY_PORT: "17670" + runs-on: ubuntu-latest + timeout-minutes: 20 steps: - - uses: actions/checkout@v6 - - - name: Determine release tag - id: release + - name: Ensure Docker run: | - set -euo pipefail - if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then - echo "tag=${{ inputs.tag }}" >> "$GITHUB_OUTPUT" - else - WORKFLOW_NAME="${{ github.event.workflow_run.name }}" - if [ "$WORKFLOW_NAME" = "Release Dev" ]; then - echo "tag=dev" >> "$GITHUB_OUTPUT" - elif [ "$WORKFLOW_NAME" = "Release Tag" ]; then - TAG="${{ github.event.workflow_run.head_branch }}" - if [ -z "$TAG" ]; then - echo "::error::Could not determine release tag from workflow_run" - exit 1 - fi - echo "tag=${TAG}" >> "$GITHUB_OUTPUT" - else - echo "::error::Unexpected triggering workflow: ${WORKFLOW_NAME}" - exit 1 - fi + if ! command -v docker >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y docker.io fi + sudo systemctl start docker || sudo service docker start + mkdir -p "${HOME}/.config/openshell" + printf 'OPENSHELL_DRIVERS=docker\n' > "${HOME}/.config/openshell/gateway.env" + docker info - - name: Install Debian package + - name: Install dev and check status run: | - set -euo pipefail - curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/main/install-dev.sh \ - | OPENSHELL_VERSION=${{ steps.release.outputs.tag }} sh - - - name: Verify package binaries - run: | - set -euo pipefail - command -v openshell - command -v openshell-gateway - test -x /usr/libexec/openshell/openshell-driver-vm - - CLI_ACTUAL="$(openshell --version)" - GATEWAY_ACTUAL="$(openshell-gateway --version)" - DRIVER_ACTUAL="$(/usr/libexec/openshell/openshell-driver-vm --version)" - echo "CLI: ${CLI_ACTUAL}" - echo "Gateway: ${GATEWAY_ACTUAL}" - echo "Driver: ${DRIVER_ACTUAL}" - - TAG="${{ steps.release.outputs.tag }}" - if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - EXPECTED="${TAG#v}" - for actual in "$CLI_ACTUAL" "$GATEWAY_ACTUAL" "$DRIVER_ACTUAL"; do - if [[ "$actual" != *"$EXPECTED"* ]]; then - echo "::error::Version mismatch: expected '$EXPECTED' in '$actual'" - exit 1 - fi - done - echo "Version check passed: found $EXPECTED in package binaries" - else - echo "Non-release tag ($TAG), skipping version check" - fi - - - name: Start packaged gateway - run: | - set -euo pipefail - - # The CLI no longer owns gateway lifecycle. In CI we start the - # gateway binary installed by the Debian package directly, using the - # Docker driver so this canary can launch a real sandbox. - systemctl --user stop openshell-gateway >/dev/null 2>&1 || true - - STATE_DIR="$(mktemp -d)" - LOG="${STATE_DIR}/openshell-gateway.log" - echo "GATEWAY_LOG=${LOG}" >> "$GITHUB_ENV" - echo "GATEWAY_STATE_DIR=${STATE_DIR}" >> "$GITHUB_ENV" - - OPENSHELL_BIND_ADDRESS=0.0.0.0 \ - OPENSHELL_SERVER_PORT="${OPENSHELL_CANARY_PORT}" \ - OPENSHELL_DISABLE_TLS=true \ - OPENSHELL_DISABLE_GATEWAY_AUTH=true \ - OPENSHELL_DRIVERS=docker \ - OPENSHELL_DB_URL="sqlite:${STATE_DIR}/openshell.db?mode=rwc" \ - OPENSHELL_GRPC_ENDPOINT="http://host.openshell.internal:${OPENSHELL_CANARY_PORT}" \ - OPENSHELL_SSH_GATEWAY_HOST=127.0.0.1 \ - OPENSHELL_SSH_GATEWAY_PORT="${OPENSHELL_CANARY_PORT}" \ - OPENSHELL_SANDBOX_NAMESPACE="canary-${{ matrix.arch }}-${{ github.run_id }}" \ - nohup openshell-gateway >"${LOG}" 2>&1 & - PID=$! - echo "GATEWAY_PID=${PID}" >> "$GITHUB_ENV" - - for _ in $(seq 1 60); do - if curl -fsS "http://127.0.0.1:${OPENSHELL_CANARY_PORT}/healthz" >/dev/null; then - break - fi - if ! kill -0 "$PID" 2>/dev/null; then - echo "::error::openshell-gateway exited before becoming healthy" - cat "$LOG" - exit 1 - fi - sleep 1 - done - - curl -fsS "http://127.0.0.1:${OPENSHELL_CANARY_PORT}/healthz" - openshell gateway remove local >/dev/null 2>&1 || true - openshell gateway add "http://127.0.0.1:${OPENSHELL_CANARY_PORT}" --local --name local - - - name: Run canary test - run: | - set -euo pipefail - - echo "Creating sandbox and running 'echo hello world'..." - OUTPUT=$(openshell sandbox create --no-keep --no-tty -- echo "hello world" 2>&1) || { - EXIT_CODE=$? - echo "::error::openshell sandbox create failed with exit code ${EXIT_CODE}" - echo "$OUTPUT" - exit $EXIT_CODE - } - - echo "$OUTPUT" - - if echo "$OUTPUT" | grep -q "hello world"; then - echo "Canary test passed: 'hello world' found in output" - else - echo "::error::Canary test failed: 'hello world' not found in output" - exit 1 - fi - - - name: Stop packaged gateway - if: always() - run: | - set -euo pipefail - if [ -n "${GATEWAY_PID:-}" ]; then - kill "$GATEWAY_PID" >/dev/null 2>&1 || true - fi - if [ "${{ job.status }}" != "success" ] && [ -n "${GATEWAY_LOG:-}" ] && [ -f "$GATEWAY_LOG" ]; then - echo "Gateway log:" - cat "$GATEWAY_LOG" - fi + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install-dev.sh | sh + openshell status diff --git a/install-dev.sh b/install-dev.sh index edacd2b8d..a55fcbce8 100755 --- a/install-dev.sh +++ b/install-dev.sh @@ -227,26 +227,10 @@ find_deb_asset() { _arch="$2" awk -v arch="$_arch" ' - { - name = $2 - sub("^\\*", "", name) - - if (name == "openshell-dev-" arch ".deb") { - selected = name - found = 1 - exit - } - - if (fallback == "" && name ~ "^openshell_.*_" arch "\\.deb$") { - fallback = name - } - } - END { - if (found) { - print selected - } else if (fallback != "") { - print fallback - } + $2 ~ "^\\*?openshell[-_].*[-_]" arch "\\.deb$" { + sub("^\\*", "", $2) + print $2 + exit } ' "$_checksums" } @@ -313,6 +297,19 @@ patch_homebrew_formula() { sed 's/entitlements\.write <<~XML/entitlements.atomic_write <<~XML/' "$_formula_file" >"$_patched_file" mv "$_patched_file" "$_formula_file" fi + + if ! grep -q 'OPENSHELL_DRIVERS:' "$_formula_file"; then + info "patching Homebrew formula to use VM driver..." + awk ' + { + print + if ($0 ~ /^[[:space:]]*environment_variables\(/) { + print " OPENSHELL_DRIVERS: \"vm\"," + } + } + ' "$_formula_file" >"$_patched_file" + mv "$_patched_file" "$_formula_file" + fi } start_user_gateway() { @@ -329,8 +326,54 @@ start_user_gateway() { as_target_user systemctl --user restart openshell-gateway as_target_user systemctl --user is-active --quiet openshell-gateway + wait_for_local_gateway_listener info "registering local gateway as ${TARGET_USER}..." register_local_gateway + wait_for_local_gateway_status +} + +wait_for_local_gateway_listener() { + _timeout="${OPENSHELL_INSTALL_GATEWAY_TIMEOUT:-30}" + _elapsed=0 + _last_output="" + _probe_url="http://127.0.0.1:${LOCAL_GATEWAY_PORT}/" + + info "waiting for local gateway listener to become reachable..." + while [ "$_elapsed" -lt "$_timeout" ]; do + if _last_output="$(as_target_user curl -sS --max-time 2 -o /dev/null "$_probe_url" 2>&1)"; then + info "local gateway listener is reachable" + return 0 + fi + sleep 1 + _elapsed=$((_elapsed + 1)) + done + + printf '%s\n' "$_last_output" >&2 + error "local gateway listener did not become reachable at ${_probe_url} within ${_timeout}s" +} + +wait_for_local_gateway_status() { + _timeout="${OPENSHELL_INSTALL_GATEWAY_TIMEOUT:-30}" + _elapsed=0 + _status_output="" + _register_bin="${OPENSHELL_REGISTER_BIN:-openshell}" + + info "waiting for openshell status to report connected..." + while [ "$_elapsed" -lt "$_timeout" ]; do + if _status_output="$(as_target_user env NO_COLOR=1 "$_register_bin" status 2>&1)"; then + case "$_status_output" in + *"Version:"*) + info "openshell status reports connected" + return 0 + ;; + esac + fi + sleep 1 + _elapsed=$((_elapsed + 1)) + done + + printf '%s\n' "$_status_output" >&2 + error "openshell status did not report connected within ${_timeout}s" } remove_local_gateway_registration() { @@ -354,7 +397,7 @@ register_local_gateway() { _register_bin="${OPENSHELL_REGISTER_BIN:-openshell}" if _add_output="$(as_target_user "$_register_bin" gateway add "http://127.0.0.1:${LOCAL_GATEWAY_PORT}" --local --name local 2>&1)"; then - [ -z "$_add_output" ] || printf '%s\n' "$_add_output" >&2 + [ -z "$_add_output" ] || print_gateway_add_output "$_add_output" return 0 else _add_status=$? @@ -373,6 +416,16 @@ register_local_gateway() { esac } +print_gateway_add_output() { + printf '%s\n' "$1" | while IFS= read -r _line; do + case "$_line" in + *"Gateway is not reachable at http://127.0.0.1:${LOCAL_GATEWAY_PORT}"*) ;; + *"Verify the gateway is running and the endpoint is correct."*) ;; + *) printf '%s\n' "$_line" >&2 ;; + esac + done +} + install_linux_deb() { check_linux_platform @@ -466,8 +519,10 @@ install_macos_homebrew() { OPENSHELL_REGISTER_BIN="${_brew_prefix}/bin/openshell" fi + wait_for_local_gateway_listener info "registering local gateway as ${TARGET_USER}..." register_local_gateway + wait_for_local_gateway_status } main() { From 645b880513bf9e025162822d051cc3c39ee069b0 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Fri, 8 May 2026 07:45:30 -0700 Subject: [PATCH 005/157] feat(install): add rpm dev installer support (#1262) Signed-off-by: Drew Newberry --- .github/workflows/release-canary.yml | 21 +++ crates/openshell-driver-vm/README.md | 10 +- install-dev.sh | 188 ++++++++++++++++++++++++--- 3 files changed, 199 insertions(+), 20 deletions(-) diff --git a/.github/workflows/release-canary.yml b/.github/workflows/release-canary.yml index 415a4f597..5e895efc7 100644 --- a/.github/workflows/release-canary.yml +++ b/.github/workflows/release-canary.yml @@ -46,3 +46,24 @@ jobs: run: | curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install-dev.sh | sh openshell status + + fedora: + name: Fedora RPM + if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }} + runs-on: linux-amd64-cpu8 + timeout-minutes: 20 + container: + image: fedora:latest + options: --privileged + steps: + - name: Ensure Podman + run: | + dnf install -y curl podman + mkdir -p "${HOME}/.config/openshell" + printf 'OPENSHELL_DRIVERS=podman\n' > "${HOME}/.config/openshell/gateway.env" + podman info + + - name: Install dev and check status + run: | + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install-dev.sh | sh + openshell status diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index 0a11ceb0a..046554c59 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -180,12 +180,16 @@ The VM guest's serial console is appended to `//console.l - runtime tarballs: the rolling `vm-runtime` release, rebuilt on demand by `release-vm-kernel.yml` -On Linux amd64 and arm64, `install-dev.sh` installs the Debian package from the -selected `OPENSHELL_VERSION` release tag. That package includes -`openshell-gateway` and `openshell-driver-vm`, but leaves +On Debian-family Linux amd64 and arm64 systems, `install-dev.sh` installs the +Debian package from the selected `OPENSHELL_VERSION` release tag. That package +includes `openshell-gateway` and `openshell-driver-vm`, but leaves `OPENSHELL_DRIVERS` unset so the gateway uses its normal runtime auto-detection. Set `OPENSHELL_DRIVERS=vm` to force the VM driver. +On RPM-family Linux x86_64 and aarch64 systems, `install-dev.sh` installs the +`openshell` and `openshell-gateway` RPM packages from the selected release tag. +The RPM gateway package is configured for the Podman driver. + On Apple Silicon macOS, `install-dev.sh` stages the generated `openshell.rb` formula from the selected release in the `nvidia/openshell` Homebrew tap. Homebrew installs `openshell`, `openshell-gateway`, and diff --git a/install-dev.sh b/install-dev.sh index a55fcbce8..87234a4eb 100755 --- a/install-dev.sh +++ b/install-dev.sh @@ -4,9 +4,9 @@ # # Install the OpenShell development build from a GitHub release. # -# Linux keeps the Debian package install path. Apple Silicon macOS installs the -# generated Homebrew formula from the selected release, so Homebrew owns the -# binary layout and launchd service lifecycle. +# Linux installs either the Debian or RPM packages from the selected release. +# Apple Silicon macOS installs the generated Homebrew formula, so Homebrew owns +# the binary layout and launchd service lifecycle. # set -e @@ -52,7 +52,8 @@ NOTES: This installs the selected release from: ${GITHUB_URL}/releases/tag/${RELEASE_TAG} - Linux installs the Debian package on amd64 and arm64. + Linux installs the Debian package on amd64/arm64 or the RPM packages on + x86_64/aarch64, depending on the host package manager. macOS installs the release Homebrew formula on Apple Silicon and starts a brew services-backed local gateway. EOF @@ -84,10 +85,10 @@ download_release_asset() { return 0 fi - # GitHub normalizes `~` to `.` in release asset names, while the checksum file - # still records the Debian package filename with `~dev` for correct version - # ordering. Download the normalized asset but verify it against the checksum - # entry for the original package filename. + # GitHub normalizes `~` to `.` in release asset names, while checksum files + # can still record package filenames with `~dev` for correct version ordering. + # Download the normalized asset but verify it against the checksum entry for + # the original package filename. _normalized="$(printf '%s' "$_filename" | tr '~' '.')" if [ "$_normalized" != "$_filename" ]; then if download "${GITHUB_URL}/releases/download/${_tag}/${_normalized}" "$_output"; then @@ -186,7 +187,25 @@ detect_platform() { esac } -check_linux_platform() { +linux_package_method() { + if has_cmd dpkg; then + echo "deb" + elif has_cmd rpm; then + echo "rpm" + else + error "Linux dev installs require either dpkg or rpm" + fi +} + +set_linux_target_runtime_dir() { + if [ "$(id -u)" -eq "$TARGET_UID" ] && [ -n "${XDG_RUNTIME_DIR:-}" ]; then + TARGET_RUNTIME_DIR="$XDG_RUNTIME_DIR" + else + TARGET_RUNTIME_DIR="/run/user/${TARGET_UID}" + fi +} + +check_linux_deb_platform() { require_cmd dpkg } @@ -222,6 +241,30 @@ get_deb_arch() { esac } +get_rpm_arch() { + if has_cmd rpm; then + _arch="$(rpm --eval '%{_arch}' 2>/dev/null || true)" + else + _arch="" + fi + + if [ -z "$_arch" ]; then + _arch="$(uname -m)" + fi + + case "$_arch" in + x86_64|amd64) + echo "x86_64" + ;; + aarch64|arm64) + echo "aarch64" + ;; + *) + error "no dev RPM package is published for architecture: ${_arch}" + ;; + esac +} + find_deb_asset() { _checksums="$1" _arch="$2" @@ -235,6 +278,50 @@ find_deb_asset() { ' "$_checksums" } +find_rpm_asset() { + _checksums="$1" + _arch="$2" + _package="$3" + + case "$_package" in + openshell) + _dev_name="openshell-dev-${_arch}.rpm" + _fallback_re="^openshell-[0-9].*\\.${_arch}\\.rpm$" + ;; + openshell-gateway) + _dev_name="openshell-gateway-dev-${_arch}.rpm" + _fallback_re="^openshell-gateway-[0-9].*\\.${_arch}\\.rpm$" + ;; + *) + error "unknown RPM package selector: ${_package}" + ;; + esac + + awk -v dev_name="$_dev_name" -v fallback_re="$_fallback_re" ' + { + name = $2 + sub("^\\*", "", name) + + if (name == dev_name) { + selected = name + found = 1 + exit + } + + if (fallback == "" && name ~ fallback_re) { + fallback = name + } + } + END { + if (found) { + print selected + } else if (fallback != "") { + print fallback + } + } + ' "$_checksums" +} + verify_checksum() { _archive="$1" _checksums="$2" @@ -271,6 +358,21 @@ install_deb_package() { fi } +install_rpm_packages() { + if has_cmd dnf; then + as_root dnf install -y "$@" + elif has_cmd yum; then + as_root yum install -y "$@" + elif has_cmd zypper; then + as_root zypper --non-interactive install --allow-unsigned-rpm "$@" + elif has_cmd rpm; then + warn "installing with rpm directly; dependencies must already be installed" + as_root rpm -Uvh --replacepkgs "$@" + else + error "'dnf', 'yum', 'zypper', or 'rpm' is required to install RPM packages" + fi +} + homebrew_formula_path() { _tap="$1" _formula="$2" @@ -427,13 +529,8 @@ print_gateway_add_output() { } install_linux_deb() { - check_linux_platform - - if [ "$(id -u)" -eq "$TARGET_UID" ] && [ -n "${XDG_RUNTIME_DIR:-}" ]; then - TARGET_RUNTIME_DIR="$XDG_RUNTIME_DIR" - else - TARGET_RUNTIME_DIR="/run/user/${TARGET_UID}" - fi + check_linux_deb_platform + set_linux_target_runtime_dir _arch="$(get_deb_arch)" _tmpdir="$(mktemp -d)" @@ -471,6 +568,53 @@ install_linux_deb() { start_user_gateway } +install_linux_rpm() { + require_cmd rpm + set_linux_target_runtime_dir + + _arch="$(get_rpm_arch)" + _tmpdir="$(mktemp -d)" + chmod 0755 "$_tmpdir" + trap 'rm -rf "$_tmpdir"' EXIT + + _checksums_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${CHECKSUMS_NAME}" + info "downloading ${RELEASE_TAG} release checksums..." + download "$_checksums_url" "${_tmpdir}/${CHECKSUMS_NAME}" || { + error "failed to download ${_checksums_url}" + } + + _rpm_file="$(find_rpm_asset "${_tmpdir}/${CHECKSUMS_NAME}" "$_arch" openshell)" + if [ -z "$_rpm_file" ]; then + error "no dev openshell RPM package found for architecture: ${_arch}" + fi + + _gateway_rpm_file="$(find_rpm_asset "${_tmpdir}/${CHECKSUMS_NAME}" "$_arch" openshell-gateway)" + if [ -z "$_gateway_rpm_file" ]; then + error "no dev openshell-gateway RPM package found for architecture: ${_arch}" + fi + + info "selected ${_rpm_file} and ${_gateway_rpm_file}" + + for _package_file in "$_rpm_file" "$_gateway_rpm_file"; do + _package_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${_package_file}" + _package_path="${_tmpdir}/${_package_file}" + + info "downloading ${_package_file}..." + download_release_asset "$RELEASE_TAG" "$_package_file" "$_package_path" || { + error "failed to download ${_package_url}" + } + chmod 0644 "$_package_path" + + info "verifying checksum for ${_package_file}..." + verify_checksum "$_package_path" "${_tmpdir}/${CHECKSUMS_NAME}" "$_package_file" + done + + info "installing ${_rpm_file} and ${_gateway_rpm_file}..." + install_rpm_packages "${_tmpdir}/${_rpm_file}" "${_tmpdir}/${_gateway_rpm_file}" + info "installed ${APP_NAME} RPM packages from ${RELEASE_TAG}" + start_user_gateway +} + install_macos_homebrew() { check_macos_platform @@ -548,7 +692,17 @@ main() { case "$PLATFORM" in linux) - install_linux_deb + case "$(linux_package_method)" in + deb) + install_linux_deb + ;; + rpm) + install_linux_rpm + ;; + *) + error "unsupported Linux package method" + ;; + esac ;; darwin) install_macos_homebrew From 1f35abbefbacb8b1e2547a8574bf5651de1bc879 Mon Sep 17 00:00:00 2001 From: Mrunal Patel Date: Fri, 8 May 2026 09:39:48 -0700 Subject: [PATCH 006/157] feat(sandbox): add Kubernetes user namespace isolation (hostUsers: false) (#983) Add opt-in support for Kubernetes user namespace isolation on sandbox pods. When enabled, container UID 0 maps to an unprivileged host UID and capabilities become namespaced, providing defense-in-depth for the supervisor process. Configuration is two-layered: a cluster-wide default via OPENSHELL_ENABLE_USER_NAMESPACES (default false) and a per-sandbox override via the new `user_namespaces` field on SandboxTemplate. When user namespaces are active, the pod security context is extended with SETUID, SETGID, and DAC_READ_SEARCH capabilities to match the bounding-set requirements inside a user namespace. Introduces SandboxPodParams struct to replace long argument lists on sandbox_to_k8s_spec and sandbox_template_to_k8s. Validated end-to-end on OCP 4.22 (K8s 1.35.3, CRI-O 1.35, RHEL CoreOS, kernel 5.14) with full SSH tunnel and non-identity UID mapping. --- crates/openshell-core/src/config.rs | 9 + .../openshell-driver-kubernetes/src/config.rs | 1 + .../openshell-driver-kubernetes/src/driver.rs | 567 +++++++++++------- .../openshell-driver-kubernetes/src/main.rs | 4 + crates/openshell-server/src/cli.rs | 7 + crates/openshell-server/src/compute/mod.rs | 55 ++ crates/openshell-server/src/lib.rs | 7 +- .../helm/openshell/templates/statefulset.yaml | 4 + deploy/helm/openshell/values.yaml | 6 + docs/security/best-practices.mdx | 14 + e2e/rust/tests/user_namespaces.rs | 190 ++++++ proto/openshell.proto | 6 + 12 files changed, 637 insertions(+), 233 deletions(-) create mode 100644 e2e/rust/tests/user_namespaces.rs diff --git a/crates/openshell-core/src/config.rs b/crates/openshell-core/src/config.rs index 3f0d34b0f..a2a973011 100644 --- a/crates/openshell-core/src/config.rs +++ b/crates/openshell-core/src/config.rs @@ -297,6 +297,14 @@ pub struct Config { /// allowing them to reach services running on the Docker host. #[serde(default)] pub host_gateway_ip: String, + + /// Enable Kubernetes user namespace isolation (`hostUsers: false`) for + /// sandbox pods. When enabled, container UID 0 maps to an unprivileged + /// host UID and capabilities become namespaced. Requires Kubernetes 1.33+ + /// with user namespace support available (beta through 1.35, GA in 1.36+), + /// plus a supporting container runtime and Linux 5.12+. + #[serde(default)] + pub enable_user_namespaces: bool, } /// TLS configuration. @@ -410,6 +418,7 @@ impl Config { ssh_session_ttl_secs: default_ssh_session_ttl_secs(), client_tls_secret_name: String::new(), host_gateway_ip: String::new(), + enable_user_namespaces: false, } } diff --git a/crates/openshell-driver-kubernetes/src/config.rs b/crates/openshell-driver-kubernetes/src/config.rs index 838262c77..cbf55423d 100644 --- a/crates/openshell-driver-kubernetes/src/config.rs +++ b/crates/openshell-driver-kubernetes/src/config.rs @@ -19,4 +19,5 @@ pub struct KubernetesComputeConfig { pub ssh_handshake_skew_secs: u64, pub client_tls_secret_name: String, pub host_gateway_ip: String, + pub enable_user_namespaces: bool, } diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 6c855f63e..668a18d8c 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -308,21 +308,22 @@ impl KubernetesComputeDriver { labels: Some(sandbox_labels(sandbox)), ..Default::default() }; - obj.data = sandbox_to_k8s_spec( - sandbox.spec.as_ref(), - &self.config.default_image, - &self.config.image_pull_policy, - &self.config.supervisor_image, - &self.config.supervisor_image_pull_policy, - &sandbox.id, - &sandbox.name, - &self.config.grpc_endpoint, - self.ssh_socket_path(), - self.ssh_handshake_secret(), - self.ssh_handshake_skew_secs(), - &self.config.client_tls_secret_name, - &self.config.host_gateway_ip, - ); + let params = SandboxPodParams { + default_image: &self.config.default_image, + image_pull_policy: &self.config.image_pull_policy, + supervisor_image: &self.config.supervisor_image, + supervisor_image_pull_policy: &self.config.supervisor_image_pull_policy, + sandbox_id: &sandbox.id, + sandbox_name: &sandbox.name, + grpc_endpoint: &self.config.grpc_endpoint, + ssh_socket_path: self.ssh_socket_path(), + ssh_handshake_secret: self.ssh_handshake_secret(), + ssh_handshake_skew_secs: self.ssh_handshake_skew_secs(), + client_tls_secret_name: &self.config.client_tls_secret_name, + host_gateway_ip: &self.config.host_gateway_ip, + enable_user_namespaces: self.config.enable_user_namespaces, + }; + obj.data = sandbox_to_k8s_spec(sandbox.spec.as_ref(), ¶ms); let api = self.api(); match tokio::time::timeout(KUBE_API_TIMEOUT, api.create(&PostParams::default(), &obj)).await @@ -926,21 +927,27 @@ fn default_workspace_volume_claim_templates() -> serde_json::Value { }]) } -#[allow(clippy::too_many_arguments)] +/// Parameters shared by `sandbox_to_k8s_spec` and `sandbox_template_to_k8s`. +#[derive(Default)] +struct SandboxPodParams<'a> { + default_image: &'a str, + image_pull_policy: &'a str, + supervisor_image: &'a str, + supervisor_image_pull_policy: &'a str, + sandbox_id: &'a str, + sandbox_name: &'a str, + grpc_endpoint: &'a str, + ssh_socket_path: &'a str, + ssh_handshake_secret: &'a str, + ssh_handshake_skew_secs: u64, + client_tls_secret_name: &'a str, + host_gateway_ip: &'a str, + enable_user_namespaces: bool, +} + fn sandbox_to_k8s_spec( spec: Option<&SandboxSpec>, - default_image: &str, - image_pull_policy: &str, - supervisor_image: &str, - supervisor_image_pull_policy: &str, - sandbox_id: &str, - sandbox_name: &str, - grpc_endpoint: &str, - ssh_socket_path: &str, - ssh_handshake_secret: &str, - ssh_handshake_skew_secs: u64, - client_tls_secret_name: &str, - host_gateway_ip: &str, + params: &SandboxPodParams<'_>, ) -> serde_json::Value { let mut root = serde_json::Map::new(); @@ -971,20 +978,9 @@ fn sandbox_to_k8s_spec( sandbox_template_to_k8s( template, spec.gpu, - default_image, - image_pull_policy, - supervisor_image, - supervisor_image_pull_policy, - sandbox_id, - sandbox_name, - grpc_endpoint, - ssh_socket_path, - ssh_handshake_secret, - ssh_handshake_skew_secs, &spec.environment, - client_tls_secret_name, - host_gateway_ip, inject_workspace, + params, ), ); if !template.agent_socket_path.is_empty() { @@ -1019,20 +1015,9 @@ fn sandbox_to_k8s_spec( sandbox_template_to_k8s( &SandboxTemplate::default(), spec.as_ref().is_some_and(|s| s.gpu), - default_image, - image_pull_policy, - supervisor_image, - supervisor_image_pull_policy, - sandbox_id, - sandbox_name, - grpc_endpoint, - ssh_socket_path, - ssh_handshake_secret, - ssh_handshake_skew_secs, spec_env, - client_tls_secret_name, - host_gateway_ip, inject_workspace, + params, ), ); } @@ -1042,24 +1027,12 @@ fn sandbox_to_k8s_spec( ) } -#[allow(clippy::too_many_arguments)] fn sandbox_template_to_k8s( template: &SandboxTemplate, gpu: bool, - default_image: &str, - image_pull_policy: &str, - supervisor_image: &str, - supervisor_image_pull_policy: &str, - sandbox_id: &str, - sandbox_name: &str, - grpc_endpoint: &str, - ssh_socket_path: &str, - ssh_handshake_secret: &str, - ssh_handshake_skew_secs: u64, spec_environment: &std::collections::HashMap, - client_tls_secret_name: &str, - host_gateway_ip: &str, inject_workspace: bool, + params: &SandboxPodParams<'_>, ) -> serde_json::Value { let mut metadata = serde_json::Map::new(); if !template.labels.is_empty() { @@ -1077,20 +1050,34 @@ fn sandbox_template_to_k8s( ); } + // Per-sandbox platform_config.host_users overrides the cluster-wide default. + let use_user_namespaces = platform_config_bool(template, "host_users") + .map_or(params.enable_user_namespaces, |host_users| !host_users); + + if use_user_namespaces { + spec.insert("hostUsers".to_string(), serde_json::json!(false)); + if gpu { + warn!( + "GPU sandbox with user namespaces enabled — \ + NVIDIA device plugin compatibility is unverified" + ); + } + } + let mut container = serde_json::Map::new(); container.insert("name".to_string(), serde_json::json!("agent")); // Use template image if provided, otherwise fall back to default let image = if template.image.is_empty() { - default_image + params.default_image } else { &template.image }; if !image.is_empty() { container.insert("image".to_string(), serde_json::json!(image)); - if !image_pull_policy.is_empty() { + if !params.image_pull_policy.is_empty() { container.insert( "imagePullPolicy".to_string(), - serde_json::json!(image_pull_policy), + serde_json::json!(params.image_pull_policy), ); } } @@ -1100,34 +1087,36 @@ fn sandbox_template_to_k8s( None, &template.environment, spec_environment, - sandbox_id, - sandbox_name, - grpc_endpoint, - ssh_socket_path, - ssh_handshake_secret, - ssh_handshake_skew_secs, - !client_tls_secret_name.is_empty(), + params.sandbox_id, + params.sandbox_name, + params.grpc_endpoint, + params.ssh_socket_path, + params.ssh_handshake_secret, + params.ssh_handshake_skew_secs, + !params.client_tls_secret_name.is_empty(), ); container.insert("env".to_string(), serde_json::Value::Array(env)); - // The sandbox process needs SYS_ADMIN (for seccomp filter installation and - // network namespace creation), NET_ADMIN (for network namespace veth setup), - // SYS_PTRACE (for the CONNECT proxy to read /proc//fd/ of sandbox-user - // processes to resolve binary identity for network policy enforcement), - // and SYSLOG (for reading /dev/kmsg to surface bypass detection diagnostics). - // This mirrors the capabilities used by `mise run sandbox`. + let mut capabilities: Vec<&str> = vec!["SYS_ADMIN", "NET_ADMIN", "SYS_PTRACE", "SYSLOG"]; + if use_user_namespaces { + // In a user namespace the bounding set is reset. SETUID/SETGID are + // needed for the supervisor to drop privileges to the sandbox user. + // DAC_READ_SEARCH is needed for cross-UID /proc//fd/ access + // for process identity resolution in network policy enforcement. + capabilities.extend(["SETUID", "SETGID", "DAC_READ_SEARCH"]); + } container.insert( "securityContext".to_string(), serde_json::json!({ "capabilities": { - "add": ["SYS_ADMIN", "NET_ADMIN", "SYS_PTRACE", "SYSLOG"] + "add": capabilities } }), ); // Mount client TLS secret for mTLS to the server. - if !client_tls_secret_name.is_empty() { + if !params.client_tls_secret_name.is_empty() { container.insert( "volumeMounts".to_string(), serde_json::json!([{ @@ -1148,22 +1137,22 @@ fn sandbox_template_to_k8s( // Add TLS secret volume. Mode 0400 (owner-read) prevents the // unprivileged sandbox user from reading the mTLS private key. - if !client_tls_secret_name.is_empty() { + if !params.client_tls_secret_name.is_empty() { spec.insert( "volumes".to_string(), serde_json::json!([{ "name": "openshell-client-tls", - "secret": { "secretName": client_tls_secret_name, "defaultMode": 256 } + "secret": { "secretName": params.client_tls_secret_name, "defaultMode": 256 } }]), ); } // Add hostAliases so sandbox pods can reach the Docker host. - if !host_gateway_ip.is_empty() { + if !params.host_gateway_ip.is_empty() { spec.insert( "hostAliases".to_string(), serde_json::json!([{ - "ip": host_gateway_ip, + "ip": params.host_gateway_ip, "hostnames": ["host.docker.internal", "host.openshell.internal"] }]), ); @@ -1179,13 +1168,17 @@ fn sandbox_template_to_k8s( // Side-load the supervisor binary via an init container that copies it // from the supervisor image into a shared emptyDir volume. - apply_supervisor_sideload(&mut result, supervisor_image, supervisor_image_pull_policy); + apply_supervisor_sideload( + &mut result, + params.supervisor_image, + params.supervisor_image_pull_policy, + ); // Inject workspace persistence (init container + PVC volume mount) so // that /sandbox data survives pod rescheduling. Skipped when the user // provides custom volumeClaimTemplates to avoid conflicts. if inject_workspace { - apply_workspace_persistence(&mut result, image, image_pull_policy); + apply_workspace_persistence(&mut result, image, params.image_pull_policy); } result @@ -1346,6 +1339,15 @@ fn platform_config_string(template: &SandboxTemplate, key: &str) -> Option Option { + let config = template.platform_config.as_ref()?; + let value = config.fields.get(key)?; + match value.kind.as_ref() { + Some(prost_types::value::Kind::BoolValue(b)) => Some(*b), + _ => None, + } +} + /// Extract a nested Struct value from the template's `platform_config`, /// converting it to `serde_json::Value`. fn platform_config_struct(template: &SandboxTemplate, key: &str) -> Option { @@ -1645,24 +1647,16 @@ mod tests { #[test] fn gpu_sandbox_adds_runtime_class_and_gpu_limit() { - let pod_template = sandbox_template_to_k8s( - &SandboxTemplate::default(), - true, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + true, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; assert_eq!( pod_template["spec"]["runtimeClassName"], @@ -1689,24 +1683,16 @@ mod tests { ..SandboxTemplate::default() }; - let pod_template = sandbox_template_to_k8s( - &template, - true, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + true, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; assert_eq!( pod_template["spec"]["runtimeClassName"], @@ -1729,24 +1715,16 @@ mod tests { ..SandboxTemplate::default() }; - let pod_template = sandbox_template_to_k8s( - &template, - false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; assert_eq!( pod_template["spec"]["runtimeClassName"], @@ -1765,24 +1743,16 @@ mod tests { ..SandboxTemplate::default() }; - let pod_template = sandbox_template_to_k8s( - &template, - true, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + true, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; let limits = &pod_template["spec"]["containers"][0]["resources"]["limits"]; assert_eq!(limits["cpu"], serde_json::json!("2")); @@ -1794,24 +1764,19 @@ mod tests { #[test] fn host_aliases_injected_when_gateway_ip_set() { - let pod_template = sandbox_template_to_k8s( - &SandboxTemplate::default(), - false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "172.17.0.1", - true, - ); + let pod_template = { + let params = SandboxPodParams { + host_gateway_ip: "172.17.0.1", + ..Default::default() + }; + sandbox_template_to_k8s( + &SandboxTemplate::default(), + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; let host_aliases = pod_template["spec"]["hostAliases"] .as_array() @@ -1827,24 +1792,16 @@ mod tests { #[test] fn host_aliases_not_injected_when_gateway_ip_empty() { - let pod_template = sandbox_template_to_k8s( - &SandboxTemplate::default(), - false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; assert!( pod_template["spec"]["hostAliases"].is_null(), @@ -1855,24 +1812,19 @@ mod tests { #[test] fn tls_secret_volume_uses_restrictive_default_mode() { let template = SandboxTemplate::default(); - let pod_template = sandbox_template_to_k8s( - &template, - false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, - &std::collections::HashMap::new(), - "my-tls-secret", - "", - true, - ); + let pod_template = { + let params = SandboxPodParams { + client_tls_secret_name: "my-tls-secret", + ..Default::default() + }; + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; let volumes = pod_template["spec"]["volumes"] .as_array() @@ -2000,23 +1952,13 @@ mod tests { #[test] fn workspace_persistence_skipped_when_inject_workspace_false() { + let params = SandboxPodParams::default(); let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), false, - "openshell/sandbox:latest", - "", - "openshell/supervisor:latest", - "", - "sandbox-id", - "sandbox-name", - "https://gateway.example.com", - "0.0.0.0:2222", - "secret", - 300, &std::collections::HashMap::new(), - "", - "", false, // user provided custom VCTs + ¶ms, ); // Only the supervisor init container should be present — no workspace init container @@ -2039,4 +1981,175 @@ mod tests { "workspace mount must NOT be present when inject_workspace is false" ); } + + // ----------------------------------------------------------------------- + // User namespace tests + // ----------------------------------------------------------------------- + + fn default_template_to_k8s(enable_user_namespaces: bool) -> serde_json::Value { + let params = SandboxPodParams { + enable_user_namespaces, + ..Default::default() + }; + sandbox_template_to_k8s( + &SandboxTemplate::default(), + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + } + + #[test] + fn user_namespaces_disabled_by_default() { + let pod_template = default_template_to_k8s(false); + assert!( + pod_template["spec"]["hostUsers"].is_null(), + "hostUsers must not be set when user namespaces are disabled" + ); + let caps = pod_template["spec"]["containers"][0]["securityContext"]["capabilities"]["add"] + .as_array() + .unwrap(); + assert_eq!(caps.len(), 4); + assert!(!caps.contains(&serde_json::json!("SETUID"))); + } + + #[test] + fn user_namespaces_enabled_by_cluster_default() { + let pod_template = default_template_to_k8s(true); + assert_eq!( + pod_template["spec"]["hostUsers"], + serde_json::json!(false), + "hostUsers must be false when user namespaces are enabled" + ); + } + + #[test] + fn user_namespaces_adds_extra_capabilities() { + let pod_template = default_template_to_k8s(true); + let caps = pod_template["spec"]["containers"][0]["securityContext"]["capabilities"]["add"] + .as_array() + .unwrap(); + assert!(caps.contains(&serde_json::json!("SYS_ADMIN"))); + assert!(caps.contains(&serde_json::json!("NET_ADMIN"))); + assert!(caps.contains(&serde_json::json!("SYS_PTRACE"))); + assert!(caps.contains(&serde_json::json!("SYSLOG"))); + assert!(caps.contains(&serde_json::json!("SETUID"))); + assert!(caps.contains(&serde_json::json!("SETGID"))); + assert!(caps.contains(&serde_json::json!("DAC_READ_SEARCH"))); + assert_eq!(caps.len(), 7); + } + + #[test] + fn user_namespaces_per_sandbox_override_enables() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "host_users".to_string(), + Value { + kind: Some(Kind::BoolValue(false)), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + let params = SandboxPodParams::default(); // cluster default is off + let pod_template = sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ); + + assert_eq!( + pod_template["spec"]["hostUsers"], + serde_json::json!(false), + "per-sandbox host_users: false must enable user namespaces" + ); + let caps = pod_template["spec"]["containers"][0]["securityContext"]["capabilities"]["add"] + .as_array() + .unwrap(); + assert!(caps.contains(&serde_json::json!("SETUID"))); + } + + #[test] + fn user_namespaces_per_sandbox_override_disables() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "host_users".to_string(), + Value { + kind: Some(Kind::BoolValue(true)), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + let params = SandboxPodParams { + enable_user_namespaces: true, // cluster default is on + ..Default::default() + }; + let pod_template = sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ); + + assert!( + pod_template["spec"]["hostUsers"].is_null(), + "per-sandbox host_users: true must disable user namespaces even when cluster default is on" + ); + let caps = pod_template["spec"]["containers"][0]["securityContext"]["capabilities"]["add"] + .as_array() + .unwrap(); + assert_eq!( + caps.len(), + 4, + "extra capabilities must not be added when user namespaces are disabled" + ); + } + + #[test] + fn platform_config_bool_extracts_value() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "my_bool".to_string(), + Value { + kind: Some(Kind::BoolValue(true)), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + assert_eq!(platform_config_bool(&template, "my_bool"), Some(true)); + assert_eq!(platform_config_bool(&template, "missing"), None); + } + + #[test] + fn platform_config_bool_returns_none_for_non_bool() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "a_string".to_string(), + Value { + kind: Some(Kind::StringValue("hello".to_string())), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + assert_eq!(platform_config_bool(&template, "a_string"), None); + } } diff --git a/crates/openshell-driver-kubernetes/src/main.rs b/crates/openshell-driver-kubernetes/src/main.rs index 26d323f56..9e39e9d28 100644 --- a/crates/openshell-driver-kubernetes/src/main.rs +++ b/crates/openshell-driver-kubernetes/src/main.rs @@ -63,6 +63,9 @@ struct Args { #[arg(long, env = "OPENSHELL_SUPERVISOR_IMAGE_PULL_POLICY")] supervisor_image_pull_policy: Option, + + #[arg(long, env = "OPENSHELL_ENABLE_USER_NAMESPACES")] + enable_user_namespaces: bool, } #[tokio::main] @@ -88,6 +91,7 @@ async fn main() -> Result<()> { ssh_handshake_skew_secs: args.ssh_handshake_skew_secs, client_tls_secret_name: args.client_tls_secret_name.unwrap_or_default(), host_gateway_ip: args.host_gateway_ip.unwrap_or_default(), + enable_user_namespaces: args.enable_user_namespaces, }) .await .into_diagnostic()?; diff --git a/crates/openshell-server/src/cli.rs b/crates/openshell-server/src/cli.rs index ccc08cf2b..08180bc5e 100644 --- a/crates/openshell-server/src/cli.rs +++ b/crates/openshell-server/src/cli.rs @@ -220,6 +220,11 @@ struct Args { )] docker_network_name: String, + /// Enable Kubernetes user namespace isolation (hostUsers: false) for + /// sandbox pods. + #[arg(long, env = "OPENSHELL_ENABLE_USER_NAMESPACES")] + enable_user_namespaces: bool, + /// Disable TLS entirely — listen on plaintext HTTP. /// Use this when the gateway sits behind a reverse proxy or tunnel /// (e.g. Cloudflare Tunnel) that terminates TLS at the edge. @@ -404,6 +409,8 @@ async fn run_from_args(args: Args) -> Result<()> { }); } + config.enable_user_namespaces = args.enable_user_namespaces; + let vm_config = VmComputeConfig { state_dir: args.vm_driver_state_dir, driver_dir: args.driver_dir, diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 2d6351637..be8ebe2ba 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -1230,6 +1230,19 @@ fn build_platform_config(template: &SandboxTemplate) -> Option/exe` (the kernel-trusted executable path). diff --git a/e2e/rust/tests/user_namespaces.rs b/e2e/rust/tests/user_namespaces.rs new file mode 100644 index 000000000..9aa714767 --- /dev/null +++ b/e2e/rust/tests/user_namespaces.rs @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![cfg(feature = "e2e")] + +//! E2E test: verify Kubernetes user namespace pod spec generation. +//! +//! Enables `OPENSHELL_ENABLE_USER_NAMESPACES` on the gateway, triggers sandbox +//! creation, and inspects the resulting pod spec to confirm: +//! 1. `spec.hostUsers` is `false` +//! 2. The container security context includes the extra capabilities +//! (SETUID, SETGID, DAC_READ_SEARCH) required for user namespace operation +//! +//! The sandbox pod may fail to start in Docker-in-Docker dev clusters where the +//! filesystem does not support ID-mapped mounts. The test inspects the pod spec +//! regardless of runtime success. + +use std::process::Stdio; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use openshell_e2e::harness::binary::openshell_cmd; +use tokio::process::Child; + +async fn kubectl(args: &[&str]) -> Result { + let output = tokio::process::Command::new("docker") + .args(["exec", "openshell-cluster-openshell", "kubectl"]) + .args(args) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .await + .map_err(|e| format!("failed to run kubectl: {e}"))?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + + if !output.status.success() { + return Err(format!("kubectl {args:?} failed: {stdout}{stderr}")); + } + Ok(stdout) +} + +async fn set_user_namespaces(enable: bool) -> Result<(), String> { + let env_arg = if enable { + "OPENSHELL_ENABLE_USER_NAMESPACES=true" + } else { + "OPENSHELL_ENABLE_USER_NAMESPACES-" + }; + + kubectl(&[ + "set", "env", "statefulset/openshell", + "-n", "openshell", env_arg, + ]).await?; + + kubectl(&[ + "rollout", "status", "statefulset/openshell", + "-n", "openshell", "--timeout=120s", + ]).await?; + + // Give the gateway time to fully initialize after rollout. + tokio::time::sleep(Duration::from_secs(5)).await; + + Ok(()) +} + +async fn delete_sandbox(name: &str) { + let _ = kubectl(&["delete", "sandbox", name, "-n", "openshell"]).await; +} + +fn unique_sandbox_name() -> String { + let suffix = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + format!("userns-e2e-{suffix}") +} + +async fn stop_child(child: &mut Child) { + let _ = child.kill().await; + let _ = child.wait().await; +} + +async fn wait_for_sandbox(name: &str, timeout_secs: u64) -> Result<(), String> { + let deadline = tokio::time::Instant::now() + Duration::from_secs(timeout_secs); + while tokio::time::Instant::now() < deadline { + if let Ok(n) = kubectl(&[ + "get", "sandbox", name, "-n", "openshell", + "-o", "jsonpath={.metadata.name}", + ]).await { + if !n.trim().is_empty() { + return Ok(()); + } + } + tokio::time::sleep(Duration::from_secs(2)).await; + } + Err(format!("sandbox {name} did not appear within {timeout_secs}s")) +} + +/// Find a sandbox pod by its sandbox CRD name. The CRD controller creates a +/// pod with the same name as the Sandbox resource. +async fn wait_for_sandbox_pod(name: &str, timeout_secs: u64) -> Result<(), String> { + let deadline = tokio::time::Instant::now() + Duration::from_secs(timeout_secs); + while tokio::time::Instant::now() < deadline { + if let Ok(n) = kubectl(&[ + "get", "pod", name, "-n", "openshell", + "-o", "jsonpath={.metadata.name}", + ]).await { + if !n.trim().is_empty() { + return Ok(()); + } + } + tokio::time::sleep(Duration::from_secs(2)).await; + } + Err(format!("sandbox pod {name} did not appear within {timeout_secs}s")) +} + +#[tokio::test] +async fn sandbox_pod_spec_has_user_namespace_fields() { + // Enable user namespaces on the gateway. + set_user_namespaces(true) + .await + .expect("failed to enable user namespaces on gateway"); + + let sandbox_name = unique_sandbox_name(); + + // Start sandbox creation in the background. The pod may never become + // ready in DinD environments, so we spawn the CLI and inspect the pod + // spec independently. + let mut cmd = openshell_cmd(); + cmd.arg("sandbox").arg("create") + .arg("--name").arg(&sandbox_name) + .arg("--").arg("sleep").arg("infinity"); + cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); + + let mut child = cmd.spawn().expect("failed to spawn openshell create"); + + if let Err(e) = wait_for_sandbox(&sandbox_name, 60).await { + stop_child(&mut child).await; + delete_sandbox(&sandbox_name).await; + set_user_namespaces(false).await.ok(); + panic!("{e}"); + } + + // Wait for the pod to be created (the CRD controller creates it). + if let Err(e) = wait_for_sandbox_pod(&sandbox_name, 60).await { + stop_child(&mut child).await; + delete_sandbox(&sandbox_name).await; + set_user_namespaces(false).await.ok(); + panic!("{e}"); + } + + // Inspect the pod spec for hostUsers. + let host_users = kubectl(&[ + "get", "pod", &sandbox_name, "-n", "openshell", + "-o", "jsonpath={.spec.hostUsers}", + ]).await; + + // Inspect capabilities on the agent container. + let caps = kubectl(&[ + "get", "pod", &sandbox_name, "-n", "openshell", + "-o", "jsonpath={.spec.containers[?(@.name=='agent')].securityContext.capabilities.add}", + ]).await; + + // Clean up. + stop_child(&mut child).await; + delete_sandbox(&sandbox_name).await; + set_user_namespaces(false).await.ok(); + + // Assert hostUsers is false. + let host_users_val = host_users.expect("failed to get hostUsers from pod spec"); + assert_eq!( + host_users_val.trim(), "false", + "sandbox pod must have spec.hostUsers=false when user namespaces are enabled" + ); + + // Assert extra capabilities are present. + let caps_val = caps.expect("failed to get capabilities from pod spec"); + for cap in ["SETUID", "SETGID", "DAC_READ_SEARCH"] { + assert!( + caps_val.contains(cap), + "sandbox pod must include {cap} in capabilities when user namespaces are enabled, got: {caps_val}" + ); + } + for cap in ["SYS_ADMIN", "NET_ADMIN", "SYS_PTRACE", "SYSLOG"] { + assert!( + caps_val.contains(cap), + "sandbox pod must include {cap} in capabilities, got: {caps_val}" + ); + } +} diff --git a/proto/openshell.proto b/proto/openshell.proto index a4a18ce82..c6a9bab89 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -248,6 +248,12 @@ message SandboxTemplate { google.protobuf.Struct resources = 7; // Optional platform-specific volume claim templates. google.protobuf.Struct volume_claim_templates = 9; + // Enable Kubernetes user namespace isolation (hostUsers: false). + // When true, container UID 0 maps to a non-root host UID and capabilities + // become namespaced. Requires Kubernetes 1.33+ with user namespace support + // available (beta through 1.35, GA in 1.36+) and a supporting runtime. + // When unset, the cluster-wide default is used. + optional bool user_namespaces = 10; } // User-facing sandbox status derived by the gateway from compute-driver observations. From a4efc0b738cbaeeb34273a2dbdb64376fbf9d384 Mon Sep 17 00:00:00 2001 From: Taylor Mutch Date: Fri, 8 May 2026 09:49:07 -0700 Subject: [PATCH 007/157] feat(server): add generate-certs subcommand; replace alpine PKI hook (#1257) --- .agents/skills/helm-dev-environment/SKILL.md | 3 +- Cargo.lock | 3 + architecture/gateway.md | 26 + crates/openshell-server/Cargo.toml | 5 + crates/openshell-server/src/certgen.rs | 525 ++++++++++++++++++ crates/openshell-server/src/cli.rs | 131 ++++- crates/openshell-server/src/lib.rs | 1 + deploy/helm/openshell/README.md | 18 + deploy/helm/openshell/templates/certgen.yaml | 109 ++++ deploy/helm/openshell/templates/pki-hook.yaml | 191 ------- deploy/helm/openshell/values.yaml | 39 +- 11 files changed, 819 insertions(+), 232 deletions(-) create mode 100644 crates/openshell-server/src/certgen.rs create mode 100644 deploy/helm/openshell/templates/certgen.yaml delete mode 100644 deploy/helm/openshell/templates/pki-hook.yaml diff --git a/.agents/skills/helm-dev-environment/SKILL.md b/.agents/skills/helm-dev-environment/SKILL.md index 18d8c241e..623efb2e6 100644 --- a/.agents/skills/helm-dev-environment/SKILL.md +++ b/.agents/skills/helm-dev-environment/SKILL.md @@ -57,7 +57,8 @@ mise run helm:skaffold:run ``` Both commands build the `gateway` and `supervisor` images and deploy the OpenShell Helm -chart. The `pkiInitJob` hook runs on first install to generate mTLS secrets. Envoy Gateway opt-in; see the Optional Add-ons section below. +chart. The `pkiInitJob` hook (a pre-install Job that runs `openshell-gateway generate-certs`) +generates mTLS secrets on first install. Envoy Gateway opt-in; see the Optional Add-ons section below. The gateway Service uses ClusterIP. Access is via Envoy Gateway (port `8080`) or `kubectl port-forward`. diff --git a/Cargo.lock b/Cargo.lock index 28d86ea93..808956cd9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3630,10 +3630,13 @@ dependencies = [ "hyper-util", "ipnet", "jsonwebtoken 9.3.1", + "k8s-openapi", + "kube", "metrics", "metrics-exporter-prometheus", "miette", "nix", + "openshell-bootstrap", "openshell-core", "openshell-driver-docker", "openshell-driver-kubernetes", diff --git a/architecture/gateway.md b/architecture/gateway.md index f36878cf1..d89706e64 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -132,6 +132,32 @@ The same relay pattern backs interactive SSH, command execution, and file sync. The gateway tracks live sessions in memory and persists session records so tokens can expire or be revoked. +## PKI Bootstrap + +`openshell-gateway generate-certs` is the one place mTLS materials are +created. Both deployment paths use it: + +| Output mode | Selector | Layout | +|---|---|---| +| Kubernetes Secrets | (default) `--namespace`, `--server-secret-name`, `--client-secret-name` | Two `kubernetes.io/tls` Secrets with `tls.crt` / `tls.key` / `ca.crt`. | +| Filesystem | `--output-dir ` | `/{ca.crt, ca.key, server/tls.{crt,key}, client/tls.{crt,key}}`. Also copies client materials to `$XDG_CONFIG_HOME/openshell/gateways/openshell/mtls/` for CLI auto-discovery. | + +On Kubernetes, the Helm chart runs the command via a pre-install/pre-upgrade +hook Job using the gateway image itself — no separate cert-generation image, +no extra mirror burden in air-gapped environments. On the RPM gateway, the +same command runs from the systemd unit's `ExecStartPre` to bootstrap PKI +into the user's state directory on first start. + +Both modes share the same idempotency contract: all targets present → skip; +partial state → fail with a recovery hint; nothing present → generate and +write. This guards mTLS continuity across restarts and upgrades while still +recovering cleanly if an operator deletes everything and starts over. + +Operators who manage PKI externally (cert-manager, an enterprise CA, or +pre-created Secrets) disable the Helm hook via `pkiInitJob.enabled=false`. +The chart also ships a `certManager.*` path that produces equivalent Secrets +through cert-manager `Issuer`/`Certificate` resources. + ## Operational Constraints - Gateway TLS and client certificate distribution are deployment concerns owned diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index fab20186c..9cba99045 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -15,6 +15,7 @@ name = "openshell-gateway" path = "src/main.rs" [dependencies] +openshell-bootstrap = { path = "../openshell-bootstrap" } openshell-core = { path = "../openshell-core" } openshell-driver-docker = { path = "../openshell-driver-docker" } openshell-driver-kubernetes = { path = "../openshell-driver-kubernetes" } @@ -24,6 +25,10 @@ openshell-policy = { path = "../openshell-policy" } openshell-providers = { path = "../openshell-providers" } openshell-router = { path = "../openshell-router" } +# Kubernetes client (used by the `generate-certs` subcommand) +kube = { workspace = true } +k8s-openapi = { workspace = true } + # Async runtime tokio = { workspace = true } diff --git a/crates/openshell-server/src/certgen.rs b/crates/openshell-server/src/certgen.rs new file mode 100644 index 000000000..b9e4d7bd5 --- /dev/null +++ b/crates/openshell-server/src/certgen.rs @@ -0,0 +1,525 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! `generate-certs` subcommand: bootstrap mTLS PKI for the gateway. +//! +//! Two output modes, dispatched by the presence of `--output-dir`: +//! +//! - **Kubernetes mode** (default): create two `kubernetes.io/tls` Secrets +//! in the supplied namespace. Used by the Helm pre-install hook. Requires +//! `--namespace`, `--server-secret-name`, `--client-secret-name`. +//! - **Local mode** (`--output-dir `): write PEMs to a filesystem layout +//! matching `deploy/rpm/init-pki.sh`. Used by the RPM systemd unit's +//! `ExecStartPre`. Also copies client materials to +//! `$XDG_CONFIG_HOME/openshell/gateways/openshell/mtls/` so the local CLI +//! picks them up automatically. +//! +//! Both modes share the same idempotency contract: all targets present → +//! skip; partial state → error with a recovery hint; nothing present → +//! generate and write. + +use clap::Args; +use k8s_openapi::ByteString; +use k8s_openapi::api::core::v1::Secret; +use kube::Client; +use kube::api::{Api, ObjectMeta, PostParams}; +use miette::{IntoDiagnostic, Result, WrapErr}; +use openshell_bootstrap::pki::{PkiBundle, generate_pki}; +use openshell_core::paths::{create_dir_restricted, set_file_owner_only}; +use std::collections::BTreeMap; +use std::path::{Path, PathBuf}; +use tracing::{info, warn}; +use tracing_subscriber::EnvFilter; + +#[derive(Args, Debug)] +pub struct CertgenArgs { + /// Write PEMs to a filesystem directory instead of Kubernetes Secrets. + /// When set, the kube-related flags are not required. + #[arg(long, value_name = "DIR")] + output_dir: Option, + + /// Kubernetes namespace to create Secrets in. + /// Default comes from `POD_NAMESPACE`, which the Helm hook injects via + /// the downward API. + #[arg(long, env = "POD_NAMESPACE", required_unless_present = "output_dir")] + namespace: Option, + + /// Name of the server TLS Secret (`kubernetes.io/tls`) to create. + #[arg(long, required_unless_present = "output_dir")] + server_secret_name: Option, + + /// Name of the client TLS Secret (`kubernetes.io/tls`) to create. + #[arg(long, required_unless_present = "output_dir")] + client_secret_name: Option, + + /// Extra Subject Alternative Name for the server certificate. Repeatable. + /// Auto-detected as an IP address or DNS name. + #[arg(long = "server-san", value_name = "SAN")] + server_sans: Vec, + + /// Print the generated PEM materials to stdout instead of writing them. + /// For local debugging. + #[arg(long)] + dry_run: bool, +} + +pub async fn run(args: CertgenArgs) -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) + .init(); + + let bundle = generate_pki(&args.server_sans)?; + + if args.dry_run { + print_bundle(&bundle); + return Ok(()); + } + + if let Some(dir) = args.output_dir.as_deref() { + run_local(dir, &bundle) + } else { + run_kubernetes(&args, &bundle).await + } +} + +// ─────────────────────────── Kubernetes mode ─────────────────────────── + +#[derive(Debug, PartialEq, Eq)] +enum K8sAction { + SkipExists, + PartialState, + Create, +} + +fn decide_k8s(server_exists: bool, client_exists: bool) -> K8sAction { + match (server_exists, client_exists) { + (true, true) => K8sAction::SkipExists, + (false, false) => K8sAction::Create, + _ => K8sAction::PartialState, + } +} + +async fn run_kubernetes(args: &CertgenArgs, bundle: &PkiBundle) -> Result<()> { + let namespace = args + .namespace + .as_deref() + .ok_or_else(|| miette::miette!("--namespace is required (or set POD_NAMESPACE)"))?; + let server_name = args + .server_secret_name + .as_deref() + .ok_or_else(|| miette::miette!("--server-secret-name is required"))?; + let client_name = args + .client_secret_name + .as_deref() + .ok_or_else(|| miette::miette!("--client-secret-name is required"))?; + + let client = Client::try_default() + .await + .into_diagnostic() + .wrap_err("failed to construct in-cluster Kubernetes client")?; + let api: Api = Api::namespaced(client, namespace); + + let server_exists = api + .get_opt(server_name) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to read secret {server_name}"))? + .is_some(); + let client_exists = api + .get_opt(client_name) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to read secret {client_name}"))? + .is_some(); + + match decide_k8s(server_exists, client_exists) { + K8sAction::SkipExists => { + info!( + namespace = %namespace, + server = %server_name, + client = %client_name, + "PKI secrets already exist, skipping." + ); + return Ok(()); + } + K8sAction::PartialState => { + return Err(miette::miette!( + "partial PKI state in namespace {namespace}: exactly one of \ + {server_name} / {client_name} exists. Recover with: \ + kubectl delete secret -n {namespace} {server_name} {client_name}", + )); + } + K8sAction::Create => {} + } + + let server_secret = tls_secret( + server_name, + &bundle.server_cert_pem, + &bundle.server_key_pem, + &bundle.ca_cert_pem, + ); + let client_secret = tls_secret( + client_name, + &bundle.client_cert_pem, + &bundle.client_key_pem, + &bundle.ca_cert_pem, + ); + + api.create(&PostParams::default(), &server_secret) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to create secret {server_name}"))?; + api.create(&PostParams::default(), &client_secret) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to create secret {client_name}"))?; + + info!( + namespace = %namespace, + server = %server_name, + client = %client_name, + "PKI secrets created." + ); + Ok(()) +} + +fn tls_secret(name: &str, crt_pem: &str, key_pem: &str, ca_pem: &str) -> Secret { + let mut data = BTreeMap::new(); + data.insert( + "tls.crt".to_string(), + ByteString(crt_pem.as_bytes().to_vec()), + ); + data.insert( + "tls.key".to_string(), + ByteString(key_pem.as_bytes().to_vec()), + ); + data.insert("ca.crt".to_string(), ByteString(ca_pem.as_bytes().to_vec())); + Secret { + metadata: ObjectMeta { + name: Some(name.to_string()), + ..Default::default() + }, + type_: Some("kubernetes.io/tls".to_string()), + data: Some(data), + ..Default::default() + } +} + +// ─────────────────────────────── Local mode ─────────────────────────────── + +#[derive(Debug, PartialEq, Eq)] +enum LocalAction { + Skip, + PartialState, + Create, +} + +/// Layout under `` matches `deploy/rpm/init-pki.sh`: +/// +/// ```text +/// /ca.crt +/// /ca.key +/// /server/tls.crt +/// /server/tls.key +/// /client/tls.crt +/// /client/tls.key +/// ``` +struct LocalPaths { + ca_crt: PathBuf, + ca_key: PathBuf, + server_dir: PathBuf, + server_crt: PathBuf, + server_key: PathBuf, + client_dir: PathBuf, + client_crt: PathBuf, + client_key: PathBuf, +} + +impl LocalPaths { + fn resolve(dir: &Path) -> Self { + let server_dir = dir.join("server"); + let client_dir = dir.join("client"); + Self { + ca_crt: dir.join("ca.crt"), + ca_key: dir.join("ca.key"), + server_crt: server_dir.join("tls.crt"), + server_key: server_dir.join("tls.key"), + server_dir, + client_crt: client_dir.join("tls.crt"), + client_key: client_dir.join("tls.key"), + client_dir, + } + } + + fn all_files(&self) -> [&Path; 6] { + [ + &self.ca_crt, + &self.ca_key, + &self.server_crt, + &self.server_key, + &self.client_crt, + &self.client_key, + ] + } + + fn existence_count(&self) -> usize { + self.all_files().iter().filter(|p| p.exists()).count() + } +} + +fn decide_local(present: usize) -> LocalAction { + match present { + 6 => LocalAction::Skip, + 0 => LocalAction::Create, + _ => LocalAction::PartialState, + } +} + +fn run_local(dir: &Path, bundle: &PkiBundle) -> Result<()> { + let paths = LocalPaths::resolve(dir); + + match decide_local(paths.existence_count()) { + LocalAction::Skip => { + info!(dir = %dir.display(), "PKI files already exist, skipping."); + } + LocalAction::PartialState => { + return Err(miette::miette!( + "partial PKI state in {dir}: some files exist but not all. \ + Recover with: rm -rf {dir} (the gateway will regenerate on next start)", + dir = dir.display(), + )); + } + LocalAction::Create => { + write_local_bundle(dir, bundle, &paths)?; + info!(dir = %dir.display(), "PKI files created."); + } + } + + // Always make sure the CLI auto-discovery copy is in place. This + // self-heals the case where the operator wiped ~/.config/openshell but + // left the gateway state directory intact. + if let Err(e) = openshell_bootstrap::mtls::store_pki_bundle("openshell", bundle) { + warn!(error = %e, "failed to copy client mTLS materials for CLI auto-discovery"); + } + + Ok(()) +} + +fn write_local_bundle(dir: &Path, bundle: &PkiBundle, paths: &LocalPaths) -> Result<()> { + // Stage to a sibling tmp dir so individual renames into the final layout + // are atomic on the same filesystem. + let temp = sibling_temp_dir(dir); + if temp.exists() { + std::fs::remove_dir_all(&temp) + .into_diagnostic() + .wrap_err_with(|| format!("failed to remove stale {}", temp.display()))?; + } + + let temp_server = temp.join("server"); + let temp_client = temp.join("client"); + create_dir_restricted(&temp)?; + create_dir_restricted(&temp_server)?; + create_dir_restricted(&temp_client)?; + + write_pem(&temp.join("ca.crt"), &bundle.ca_cert_pem, false)?; + write_pem(&temp.join("ca.key"), &bundle.ca_key_pem, true)?; + write_pem(&temp_server.join("tls.crt"), &bundle.server_cert_pem, false)?; + write_pem(&temp_server.join("tls.key"), &bundle.server_key_pem, true)?; + write_pem(&temp_client.join("tls.crt"), &bundle.client_cert_pem, false)?; + write_pem(&temp_client.join("tls.key"), &bundle.client_key_pem, true)?; + + // Final destination (might not exist yet on first run). + create_dir_restricted(dir)?; + create_dir_restricted(&paths.server_dir)?; + create_dir_restricted(&paths.client_dir)?; + + let renames: [(PathBuf, &Path); 6] = [ + (temp.join("ca.crt"), paths.ca_crt.as_path()), + (temp.join("ca.key"), paths.ca_key.as_path()), + (temp_server.join("tls.crt"), paths.server_crt.as_path()), + (temp_server.join("tls.key"), paths.server_key.as_path()), + (temp_client.join("tls.crt"), paths.client_crt.as_path()), + (temp_client.join("tls.key"), paths.client_key.as_path()), + ]; + for (from, to) in &renames { + std::fs::rename(from, to) + .into_diagnostic() + .wrap_err_with(|| format!("failed to move {} -> {}", from.display(), to.display()))?; + } + + let _ = std::fs::remove_dir_all(&temp); + Ok(()) +} + +fn write_pem(path: &Path, contents: &str, owner_only: bool) -> Result<()> { + std::fs::write(path, contents) + .into_diagnostic() + .wrap_err_with(|| format!("failed to write {}", path.display()))?; + if owner_only { + set_file_owner_only(path)?; + } + Ok(()) +} + +fn sibling_temp_dir(dir: &Path) -> PathBuf { + // Use a sibling so std::fs::rename succeeds (same filesystem). + let mut name = dir + .file_name() + .map(std::ffi::OsStr::to_os_string) + .unwrap_or_default(); + name.push(".certgen.tmp"); + dir.with_file_name(name) +} + +// ────────────────────────────── Shared utility ───────────────────────────── + +fn print_bundle(bundle: &PkiBundle) { + println!("# CA certificate\n{}", bundle.ca_cert_pem); + println!("# Server certificate\n{}", bundle.server_cert_pem); + println!("# Server key\n{}", bundle.server_key_pem); + println!("# Client certificate\n{}", bundle.client_cert_pem); + println!("# Client key\n{}", bundle.client_key_pem); +} + +#[cfg(test)] +mod tests { + use super::{ + K8sAction, LocalAction, LocalPaths, decide_k8s, decide_local, sibling_temp_dir, tls_secret, + write_local_bundle, + }; + use openshell_bootstrap::pki::generate_pki; + use std::path::Path; + + // ── Kubernetes-mode decision ── + + #[test] + fn decide_k8s_skip_when_both_exist() { + assert_eq!(decide_k8s(true, true), K8sAction::SkipExists); + } + + #[test] + fn decide_k8s_create_when_neither_exists() { + assert_eq!(decide_k8s(false, false), K8sAction::Create); + } + + #[test] + fn decide_k8s_partial_when_only_server_exists() { + assert_eq!(decide_k8s(true, false), K8sAction::PartialState); + } + + #[test] + fn decide_k8s_partial_when_only_client_exists() { + assert_eq!(decide_k8s(false, true), K8sAction::PartialState); + } + + #[test] + fn tls_secret_has_kubernetes_io_tls_type_and_three_keys() { + let s = tls_secret("foo", "CRT-PEM", "KEY-PEM", "CA-PEM"); + assert_eq!(s.metadata.name.as_deref(), Some("foo")); + assert_eq!(s.type_.as_deref(), Some("kubernetes.io/tls")); + let data = s.data.expect("data set"); + assert_eq!(data.len(), 3); + assert_eq!(data["tls.crt"].0, b"CRT-PEM"); + assert_eq!(data["tls.key"].0, b"KEY-PEM"); + assert_eq!(data["ca.crt"].0, b"CA-PEM"); + } + + // ── Local-mode decision ── + + #[test] + fn decide_local_skip_when_all_six_present() { + assert_eq!(decide_local(6), LocalAction::Skip); + } + + #[test] + fn decide_local_create_when_none_present() { + assert_eq!(decide_local(0), LocalAction::Create); + } + + #[test] + fn decide_local_partial_for_any_count_in_between() { + for n in 1..=5 { + assert_eq!(decide_local(n), LocalAction::PartialState, "n = {n}"); + } + } + + // ── Local-mode layout & writes ── + + #[test] + fn local_paths_resolve_matches_init_pki_layout() { + let p = LocalPaths::resolve(Path::new("/tmp/openshell/tls")); + assert_eq!(p.ca_crt, Path::new("/tmp/openshell/tls/ca.crt")); + assert_eq!(p.ca_key, Path::new("/tmp/openshell/tls/ca.key")); + assert_eq!(p.server_crt, Path::new("/tmp/openshell/tls/server/tls.crt")); + assert_eq!(p.server_key, Path::new("/tmp/openshell/tls/server/tls.key")); + assert_eq!(p.client_crt, Path::new("/tmp/openshell/tls/client/tls.crt")); + assert_eq!(p.client_key, Path::new("/tmp/openshell/tls/client/tls.key")); + } + + #[test] + fn sibling_temp_dir_is_adjacent_to_target() { + assert_eq!( + sibling_temp_dir(Path::new("/var/lib/openshell/tls")), + Path::new("/var/lib/openshell/tls.certgen.tmp") + ); + } + + #[test] + fn write_local_bundle_writes_six_files_and_removes_temp() { + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + + write_local_bundle(&dir, &bundle, &paths).expect("write_local_bundle"); + + for f in paths.all_files() { + assert!(f.is_file(), "missing {}", f.display()); + } + assert!( + !sibling_temp_dir(&dir).exists(), + "temp dir should be cleaned up" + ); + + // Spot-check contents. + let ca = std::fs::read_to_string(&paths.ca_crt).unwrap(); + assert!(ca.contains("BEGIN CERTIFICATE")); + let server_key = std::fs::read_to_string(&paths.server_key).unwrap(); + assert!(server_key.contains("BEGIN PRIVATE KEY")); + } + + #[cfg(unix)] + #[test] + fn write_local_bundle_sets_owner_only_on_keys() { + use std::os::unix::fs::PermissionsExt; + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + + write_local_bundle(&dir, &bundle, &paths).expect("write_local_bundle"); + + for key in [&paths.ca_key, &paths.server_key, &paths.client_key] { + let mode = std::fs::metadata(key).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o600, "key {} has mode {:o}", key.display(), mode); + } + } + + #[test] + fn write_local_bundle_recovers_from_stale_temp_dir() { + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let stale = sibling_temp_dir(&dir); + std::fs::create_dir_all(&stale).unwrap(); + std::fs::write(stale.join("garbage"), "stale").unwrap(); + + let bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + write_local_bundle(&dir, &bundle, &paths).expect("write_local_bundle"); + + assert!(paths.ca_crt.is_file()); + assert!(!stale.exists(), "stale temp dir should be removed"); + } +} diff --git a/crates/openshell-server/src/cli.rs b/crates/openshell-server/src/cli.rs index 08180bc5e..534e3da37 100644 --- a/crates/openshell-server/src/cli.rs +++ b/crates/openshell-server/src/cli.rs @@ -15,14 +15,34 @@ use std::path::PathBuf; use tracing::info; use tracing_subscriber::EnvFilter; +use crate::certgen; use crate::compute::{DockerComputeConfig, VmComputeConfig}; use crate::{run_server, tracing_bus::TracingLogBus}; /// `OpenShell` gateway process - gRPC and HTTP server with protocol multiplexing. +/// +/// Top-level CLI. When invoked without a subcommand the binary runs the +/// gateway server using `RunArgs`. The `generate-certs` subcommand is used by +/// the Helm pre-install hook to bootstrap mTLS Secrets. #[derive(Parser, Debug)] #[command(version = openshell_core::VERSION)] #[command(about = "OpenShell gRPC/HTTP server", long_about = None)] -struct Args { +struct Cli { + #[command(subcommand)] + command: Option, + + #[command(flatten)] + run: RunArgs, +} + +#[derive(clap::Subcommand, Debug)] +enum Commands { + /// Generate mTLS PKI and write Kubernetes Secrets (Helm pre-install hook). + GenerateCerts(certgen::CertgenArgs), +} + +#[derive(clap::Args, Debug)] +struct RunArgs { /// IP address to bind the server, health, and metrics listeners to. #[arg(long, default_value = "127.0.0.1", env = "OPENSHELL_BIND_ADDRESS")] bind_address: IpAddr, @@ -58,8 +78,12 @@ struct Args { tls_client_ca: Option, /// Database URL for persistence. - #[arg(long, env = "OPENSHELL_DB_URL", required = true)] - db_url: String, + /// + /// Required when running the gateway. Validated at the call site rather + /// than as a clap-level requirement so the `generate-certs` subcommand + /// (which does not need a database) can run without it. + #[arg(long, env = "OPENSHELL_DB_URL")] + db_url: Option, /// Compute drivers configured for this gateway. /// @@ -284,7 +308,7 @@ struct Args { } pub fn command() -> Command { - Args::command() + Cli::command() .name("openshell-gateway") .bin_name("openshell-gateway") } @@ -294,12 +318,15 @@ pub async fn run_cli() -> Result<()> { .install_default() .map_err(|e| miette::miette!("failed to install rustls crypto provider: {e:?}"))?; - let args = Args::from_arg_matches(&command().get_matches()).expect("clap validated args"); + let cli = Cli::from_arg_matches(&command().get_matches()).expect("clap validated args"); - Box::pin(run_from_args(args)).await + match cli.command { + Some(Commands::GenerateCerts(args)) => certgen::run(args).await, + None => Box::pin(run_from_args(cli.run)).await, + } } -async fn run_from_args(args: Args) -> Result<()> { +async fn run_from_args(args: RunArgs) -> Result<()> { let tracing_log_bus = TracingLogBus::new(); tracing_log_bus.install_subscriber( EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)), @@ -331,6 +358,10 @@ async fn run_from_args(args: Args) -> Result<()> { }) }; + let db_url = args + .db_url + .ok_or_else(|| miette::miette!("--db-url is required (or set OPENSHELL_DB_URL)"))?; + let mut config = openshell_core::Config::new(tls) .with_bind_address(bind) .with_log_level(&args.log_level); @@ -364,7 +395,7 @@ async fn run_from_args(args: Args) -> Result<()> { } config = config - .with_database_url(args.db_url) + .with_database_url(db_url) .with_compute_drivers(args.drivers) .with_sandbox_namespace(args.sandbox_namespace) .with_ssh_gateway_host(args.ssh_gateway_host) @@ -451,7 +482,7 @@ fn parse_compute_driver(value: &str) -> std::result::Result at the clap level so subcommand parsing + // does not require it. The Run path validates it inside + // run_from_args. This test asserts the parse step succeeds with no + // --db-url, mirroring what the runtime check sees. + let _lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _g = EnvVarGuard::remove("OPENSHELL_DB_URL"); + + let cli = Cli::try_parse_from(["openshell-gateway"]).expect("parses without --db-url"); + assert!(cli.command.is_none()); + assert!(cli.run.db_url.is_none()); } } diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index 823eca51d..eaca911e4 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -20,6 +20,7 @@ //! [`compute::vm`]; keep this file driver-agnostic going forward. mod auth; +pub mod certgen; pub mod cli; mod compute; mod grpc; diff --git a/deploy/helm/openshell/README.md b/deploy/helm/openshell/README.md index ee7565f29..cc856731d 100644 --- a/deploy/helm/openshell/README.md +++ b/deploy/helm/openshell/README.md @@ -52,3 +52,21 @@ See [`values.yaml`](values.yaml) for configurable values. Selected overlays: - [`ci/values-gateway.yaml`](ci/values-gateway.yaml) — gateway-only configuration - [`ci/values-cert-manager.yaml`](ci/values-cert-manager.yaml) — cert-manager integration - [`ci/values-keycloak.yaml`](ci/values-keycloak.yaml) — Keycloak OIDC integration + +## PKI bootstrap + +By default, a pre-install/pre-upgrade hook Job runs `openshell-gateway generate-certs` +to create the gateway's server and client mTLS Secrets. The Job uses the gateway image +itself, so air-gapped environments only need to mirror that one image (no separate +openssl/alpine sidecar). + +The Job is idempotent: + +- Both target Secrets exist → log and exit 0. +- Exactly one exists → fail with `kubectl delete secret -n ` recovery hint. +- Neither exists → generate a CA, server cert, and client cert; POST both `kubernetes.io/tls` + Secrets (`tls.crt`, `tls.key`, `ca.crt`). + +Disable with `--set pkiInitJob.enabled=false` when bringing your own PKI (cert-manager, +external CA, or pre-created Secrets). See `certManager.*` in `values.yaml` for the +cert-manager alternative. diff --git a/deploy/helm/openshell/templates/certgen.yaml b/deploy/helm/openshell/templates/certgen.yaml new file mode 100644 index 000000000..d8136d581 --- /dev/null +++ b/deploy/helm/openshell/templates/certgen.yaml @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +{{- if and .Values.pkiInitJob.enabled .Values.certManager.enabled }} +{{- fail "pkiInitJob.enabled and certManager.enabled cannot both be true; disable one to avoid conflicting PKI sources." }} +{{- end }} +{{- if .Values.pkiInitJob.enabled }} +{{- $hookName := printf "%s-certgen" (include "openshell.fullname" .) }} +{{- $ns := .Release.Namespace }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ $hookName }} + namespace: {{ $ns }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + annotations: + helm.sh/hook: pre-install,pre-upgrade + helm.sh/hook-weight: "-30" + helm.sh/hook-delete-policy: before-hook-creation +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: {{ $hookName }} + namespace: {{ $ns }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + annotations: + helm.sh/hook: pre-install,pre-upgrade + helm.sh/hook-weight: "-30" + helm.sh/hook-delete-policy: before-hook-creation +rules: + - apiGroups: [""] + resources: ["secrets"] + verbs: ["get", "create"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: {{ $hookName }} + namespace: {{ $ns }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + annotations: + helm.sh/hook: pre-install,pre-upgrade + helm.sh/hook-weight: "-30" + helm.sh/hook-delete-policy: before-hook-creation +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: {{ $hookName }} +subjects: + - kind: ServiceAccount + name: {{ $hookName }} + namespace: {{ $ns }} +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ $hookName }} + namespace: {{ $ns }} + labels: + {{- include "openshell.labels" . | nindent 4 }} + annotations: + helm.sh/hook: pre-install,pre-upgrade + helm.sh/hook-weight: "-20" + helm.sh/hook-delete-policy: before-hook-creation,hook-succeeded +spec: + backoffLimit: 3 + activeDeadlineSeconds: 120 + ttlSecondsAfterFinished: 300 + template: + metadata: + labels: + {{- include "openshell.selectorLabels" . | nindent 8 }} + spec: + restartPolicy: OnFailure + serviceAccountName: {{ $hookName }} + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + containers: + - name: certgen + image: {{ include "openshell.image" . | quote }} + imagePullPolicy: {{ .Values.image.pullPolicy }} + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL + env: + - name: POD_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + command: ["openshell-gateway"] + args: + - generate-certs + - --server-secret-name={{ .Values.server.tls.certSecretName }} + - --client-secret-name={{ .Values.server.tls.clientTlsSecretName }} + {{- range .Values.pkiInitJob.serverDnsNames }} + - --server-san={{ . }} + {{- end }} + {{- range .Values.pkiInitJob.serverIpAddresses }} + - --server-san={{ . }} + {{- end }} +{{- end }} diff --git a/deploy/helm/openshell/templates/pki-hook.yaml b/deploy/helm/openshell/templates/pki-hook.yaml deleted file mode 100644 index c5e83c734..000000000 --- a/deploy/helm/openshell/templates/pki-hook.yaml +++ /dev/null @@ -1,191 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -{{- if and .Values.pkiInitJob.enabled .Values.certManager.enabled }} -{{- fail "pkiInitJob.enabled and certManager.enabled cannot both be true; disable one to avoid conflicting PKI sources." }} -{{- end }} -{{- if .Values.pkiInitJob.enabled }} -{{- $hookName := printf "%s-pki-hook" (include "openshell.fullname" .) }} -{{- $ns := .Release.Namespace }} -{{- $serverSecret := .Values.server.tls.certSecretName }} -{{- $clientSecret := .Values.server.tls.clientTlsSecretName }} -{{- $sanParts := list }} -{{- range .Values.pkiInitJob.serverDnsNames }}{{- $sanParts = append $sanParts (printf "DNS:%s" .) }}{{- end }} -{{- range .Values.pkiInitJob.serverIpAddresses }}{{- $sanParts = append $sanParts (printf "IP:%s" .) }}{{- end }} -{{- $serverSans := join "," $sanParts }} -apiVersion: v1 -kind: ServiceAccount -metadata: - name: {{ $hookName }} - namespace: {{ $ns }} - labels: - {{- include "openshell.labels" . | nindent 4 }} - annotations: - helm.sh/hook: pre-install,pre-upgrade - helm.sh/hook-weight: "-30" - helm.sh/hook-delete-policy: before-hook-creation ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: Role -metadata: - name: {{ $hookName }} - namespace: {{ $ns }} - labels: - {{- include "openshell.labels" . | nindent 4 }} - annotations: - helm.sh/hook: pre-install,pre-upgrade - helm.sh/hook-weight: "-30" - helm.sh/hook-delete-policy: before-hook-creation -rules: - - apiGroups: [""] - resources: ["secrets"] - verbs: ["get", "create"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: RoleBinding -metadata: - name: {{ $hookName }} - namespace: {{ $ns }} - labels: - {{- include "openshell.labels" . | nindent 4 }} - annotations: - helm.sh/hook: pre-install,pre-upgrade - helm.sh/hook-weight: "-30" - helm.sh/hook-delete-policy: before-hook-creation -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: Role - name: {{ $hookName }} -subjects: - - kind: ServiceAccount - name: {{ $hookName }} - namespace: {{ $ns }} ---- -apiVersion: batch/v1 -kind: Job -metadata: - name: {{ $hookName }} - namespace: {{ $ns }} - labels: - {{- include "openshell.labels" . | nindent 4 }} - annotations: - helm.sh/hook: pre-install,pre-upgrade - helm.sh/hook-weight: "-20" - helm.sh/hook-delete-policy: before-hook-creation,hook-succeeded -spec: - backoffLimit: 3 - activeDeadlineSeconds: 120 - ttlSecondsAfterFinished: 300 - template: - metadata: - labels: - {{- include "openshell.selectorLabels" . | nindent 8 }} - spec: - restartPolicy: OnFailure - serviceAccountName: {{ $hookName }} - containers: - - name: pki-gen - image: {{ .Values.pkiInitJob.image.repository }}:{{ .Values.pkiInitJob.image.tag }} - imagePullPolicy: {{ .Values.pkiInitJob.image.pullPolicy }} - securityContext: - allowPrivilegeEscalation: false - capabilities: - drop: - - ALL - env: - - name: NAMESPACE - valueFrom: - fieldRef: - fieldPath: metadata.namespace - - name: SERVER_SECRET - value: {{ $serverSecret | quote }} - - name: CLIENT_SECRET - value: {{ $clientSecret | quote }} - - name: CA_DAYS - value: {{ .Values.pkiInitJob.caValidityDays | quote }} - - name: CERT_DAYS - value: {{ .Values.pkiInitJob.certValidityDays | quote }} - - name: SERVER_SANS - value: {{ $serverSans | quote }} - command: - - /bin/sh - - -c - - | - set -eu - apk add --no-cache openssl curl >/dev/null 2>&1 - - TOKEN=$(cat /var/run/secrets/kubernetes.io/serviceaccount/token) - K8S_CA=/var/run/secrets/kubernetes.io/serviceaccount/ca.crt - API=https://kubernetes.default.svc - - # Idempotency: skip only when both TLS secrets already exist. - # Checking one is insufficient — a partial cleanup can leave one half - # of the pair behind, which would cause mTLS to fail at runtime. - HTTP_SERVER=$(curl -s -o /dev/null -w "%{http_code}" \ - -H "Authorization: Bearer $TOKEN" --cacert "$K8S_CA" \ - "$API/api/v1/namespaces/$NAMESPACE/secrets/$SERVER_SECRET") - HTTP_CLIENT=$(curl -s -o /dev/null -w "%{http_code}" \ - -H "Authorization: Bearer $TOKEN" --cacert "$K8S_CA" \ - "$API/api/v1/namespaces/$NAMESPACE/secrets/$CLIENT_SECRET") - if [ "$HTTP_SERVER" = "200" ] && [ "$HTTP_CLIENT" = "200" ]; then - echo "PKI secrets already exist, skipping." - exit 0 - fi - if [ "$HTTP_SERVER" = "200" ] || [ "$HTTP_CLIENT" = "200" ]; then - echo "ERROR: partial PKI state — one secret exists but not both." >&2 - echo "To recover: kubectl delete secret -n $NAMESPACE $SERVER_SECRET $CLIENT_SECRET" >&2 - exit 1 - fi - - cd /tmp - - # CA (ECDSA P-256) - openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 -out ca.key 2>/dev/null - openssl req -new -x509 -sha256 -key ca.key -out ca.crt \ - -days "$CA_DAYS" -subj "/O=openshell/CN=openshell-ca" \ - -addext "basicConstraints=critical,CA:TRUE,pathlen:0" \ - -addext "keyUsage=critical,keyCertSign,cRLSign" - - # Server cert (ECDSA P-256) - printf "[ext]\nsubjectAltName=%s\nextendedKeyUsage=serverAuth\nkeyUsage=digitalSignature,keyEncipherment\n" \ - "$SERVER_SANS" > server.ext - openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 -out server.key 2>/dev/null - openssl req -new -sha256 -key server.key -out server.csr -subj "/CN=openshell-server" - openssl x509 -req -sha256 -in server.csr -CA ca.crt -CAkey ca.key \ - -CAcreateserial -days "$CERT_DAYS" -extensions ext -extfile server.ext -out server.crt - - # Client cert (ECDSA P-256) - printf "[ext]\nextendedKeyUsage=clientAuth\nkeyUsage=digitalSignature,keyEncipherment\n" \ - > client.ext - openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 -out client.key 2>/dev/null - openssl req -new -sha256 -key client.key -out client.csr -subj "/CN=openshell-client" - openssl x509 -req -sha256 -in client.csr -CA ca.crt -CAkey ca.key \ - -CAcreateserial -days "$CERT_DAYS" -extensions ext -extfile client.ext -out client.crt - - CA_B64=$(base64 -w0 ca.crt) - SERVER_CRT_B64=$(base64 -w0 server.crt) - SERVER_KEY_B64=$(base64 -w0 server.key) - CLIENT_CRT_B64=$(base64 -w0 client.crt) - CLIENT_KEY_B64=$(base64 -w0 client.key) - - # Create server TLS secret - printf '{"apiVersion":"v1","kind":"Secret","metadata":{"name":"%s","namespace":"%s"},"type":"kubernetes.io/tls","data":{"tls.crt":"%s","tls.key":"%s","ca.crt":"%s"}}\n' \ - "$SERVER_SECRET" "$NAMESPACE" \ - "$SERVER_CRT_B64" "$SERVER_KEY_B64" "$CA_B64" > server-secret.json - curl -sf -X POST \ - -H "Authorization: Bearer $TOKEN" -H "Content-Type: application/json" \ - --cacert "$K8S_CA" "$API/api/v1/namespaces/$NAMESPACE/secrets" \ - -d @server-secret.json - - # Create client TLS secret - printf '{"apiVersion":"v1","kind":"Secret","metadata":{"name":"%s","namespace":"%s"},"type":"kubernetes.io/tls","data":{"tls.crt":"%s","tls.key":"%s","ca.crt":"%s"}}\n' \ - "$CLIENT_SECRET" "$NAMESPACE" \ - "$CLIENT_CRT_B64" "$CLIENT_KEY_B64" "$CA_B64" > client-secret.json - curl -sf -X POST \ - -H "Authorization: Bearer $TOKEN" -H "Content-Type: application/json" \ - --cacert "$K8S_CA" "$API/api/v1/namespaces/$NAMESPACE/secrets" \ - -d @client-secret.json - - rm -f *.key *.csr *.crt *.ext *.srl *.json - echo "PKI secrets created." -{{- end }} diff --git a/deploy/helm/openshell/values.yaml b/deploy/helm/openshell/values.yaml index 89449012a..02389925a 100644 --- a/deploy/helm/openshell/values.yaml +++ b/deploy/helm/openshell/values.yaml @@ -160,32 +160,23 @@ sshHandshake: value: "" # PKI bootstrap via a pre-install/pre-upgrade hook Job. -# Generates a self-signed CA, server TLS secret, and client TLS secret using -# openssl (ECDSA P-256) inside the cluster. Key material is written directly to -# K8s Secrets and never appears in Helm release history. Idempotent: existing -# secrets are left untouched on upgrade. -# Air-gapped environments should override pkiInitJob.image with an image that has -# openssl and curl pre-installed (the default alpine image fetches them at runtime). +# Runs `openshell-gateway generate-certs` to create the server and client TLS +# Secrets in-cluster. Key material is written directly to K8s Secrets and +# never appears in Helm release history. Idempotent: existing secrets are +# left untouched on upgrade. Reuses the gateway image — no extra image to +# mirror in air-gapped environments. +# +# The server certificate already includes the built-in cluster SANs +# (`openshell`, `openshell.openshell.svc`, the cluster.local FQDN, `localhost`, +# `host.docker.internal`, and `127.0.0.1`) baked into the gateway binary. The +# lists below are *additional* SANs appended on top — typically a public +# hostname or load-balancer IP for remote deployments. pkiInitJob: enabled: true - image: - repository: alpine - tag: "3" - pullPolicy: IfNotPresent - # Days until the CA certificate expires. - caValidityDays: 3650 - # Days until server and client certificates expire. - certValidityDays: 3650 - # DNS SANs for the server certificate. - serverDnsNames: - - openshell - - openshell.openshell.svc - - openshell.openshell.svc.cluster.local - - localhost - - host.docker.internal - # IP SANs for the server certificate. - serverIpAddresses: - - 127.0.0.1 + # Extra DNS SANs to append to the server certificate. + serverDnsNames: [] + # Extra IP SANs to append to the server certificate. + serverIpAddresses: [] # cert-manager Certificate/Issuer resources (requires cert-manager CRDs in-cluster). # Uses namespaced Issuers only (no ClusterIssuer). Does not install cert-manager itself. From b74d24bcf62e210b62e0a4a0e78cec3a6cb5a2ce Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Fri, 8 May 2026 10:16:04 -0700 Subject: [PATCH 008/157] fix(docs): constrain landing terminal height (#1269) Signed-off-by: Drew Newberry --- fern/main.css | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/fern/main.css b/fern/main.css index aa20a6263..781acd15f 100644 --- a/fern/main.css +++ b/fern/main.css @@ -1004,19 +1004,31 @@ header[role="banner"] > div[class*="bg-(color:--accent)"][class*="text-(color:-- background: #27c93f; } .nc-term-body { + display: grid; + grid-template-rows: repeat(2, 1.8em); padding: 16px 20px; color: #d4d4d8; + overflow-x: auto; +} +.nc-term-body > div { + min-width: max-content; + white-space: nowrap; } .nc-term-body .nc-ps { color: #76b900; user-select: none; } .nc-swap { - display: inline-grid; - vertical-align: baseline; + display: inline-block; + position: relative; + min-width: 12ch; + height: 1.8em; + overflow: hidden; + vertical-align: top; } .nc-swap > span { - grid-area: 1 / 1; + position: absolute; + inset: 0 auto auto 0; white-space: nowrap; opacity: 0; animation: nc-cycle 12s ease-in-out infinite; From 3cfc915bfe49bd721f5b306b53650c32acde06b4 Mon Sep 17 00:00:00 2001 From: jtoelke2 <149006449+jtoelke2@users.noreply.github.com> Date: Fri, 8 May 2026 12:20:42 -0500 Subject: [PATCH 009/157] ci(os-132): remove stale remote buildx mode (#1267) Signed-off-by: Jonas Toelke --- .github/actions/setup-buildx/action.yml | 32 +++---------------------- .github/workflows/ci-image.yml | 1 - .github/workflows/docker-build.yml | 1 - .github/workflows/driver-vm-macos.yml | 2 -- .github/workflows/release-dev.yml | 6 ----- .github/workflows/release-tag.yml | 6 ----- 6 files changed, 3 insertions(+), 45 deletions(-) diff --git a/.github/actions/setup-buildx/action.yml b/.github/actions/setup-buildx/action.yml index 41210035f..3e1d54521 100644 --- a/.github/actions/setup-buildx/action.yml +++ b/.github/actions/setup-buildx/action.yml @@ -1,23 +1,11 @@ name: Setup Docker Buildx description: > - Create a Docker Buildx builder. Two modes: - * driver=remote (default) — multi-arch builder against in-cluster BuildKit - pods. Requires EKS connectivity. Behaviour unchanged from prior versions. - * driver=local — single-node buildx on the local docker-container driver. - Pair with cache-to/cache-from=type=gha on build steps for persistence. - Works on nv-gha-runners; no EKS needed. + Create a Docker Buildx builder on the local docker-container driver. Pair + with cache-to/cache-from=type=gha on build steps for persistence. Works on + nv-gha-runners; no EKS BuildKit service is required. Cleanup is automatic when the job finishes (docker/setup-buildx-action default). inputs: - driver: - description: "buildx driver: 'remote' or 'local'" - default: remote - amd64-endpoint: - description: BuildKit endpoint for linux/amd64 (remote driver only) - default: tcp://buildkit-amd64.buildkit:1234 - arm64-endpoint: - description: BuildKit endpoint for linux/arm64 (remote driver only) - default: tcp://buildkit-arm64.buildkit:1234 name: description: Builder instance name default: openshell @@ -34,21 +22,7 @@ inputs: runs: using: composite steps: - - name: Set up Docker Buildx (remote) - if: inputs.driver == 'remote' - uses: docker/setup-buildx-action@v3 - with: - name: ${{ inputs.name }} - driver: remote - endpoint: ${{ inputs.amd64-endpoint }} - platforms: linux/amd64 - append: | - - endpoint: ${{ inputs.arm64-endpoint }} - platforms: linux/arm64 - buildkitd-config: ${{ inputs.buildkitd-config }} - - name: Set up Docker Buildx (local) - if: inputs.driver == 'local' uses: docker/setup-buildx-action@v3 with: name: ${{ inputs.name }} diff --git a/.github/workflows/ci-image.yml b/.github/workflows/ci-image.yml index db98022d5..327ce0733 100644 --- a/.github/workflows/ci-image.yml +++ b/.github/workflows/ci-image.yml @@ -56,7 +56,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx with: - driver: local buildkitd-config: ${{ steps.buildkit.outputs.config }} - name: Build and push CI image diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 42d991b60..7447b1e42 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -204,7 +204,6 @@ jobs: - name: Set up buildx (local driver) uses: ./.github/actions/setup-buildx with: - driver: local buildkitd-config: /etc/buildkit/buildkitd.toml - name: Download Rust binary artifact diff --git a/.github/workflows/driver-vm-macos.yml b/.github/workflows/driver-vm-macos.yml index 5b2bac927..ed875e4ba 100644 --- a/.github/workflows/driver-vm-macos.yml +++ b/.github/workflows/driver-vm-macos.yml @@ -151,8 +151,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Install zstd run: apt-get update && apt-get install -y --no-install-recommends zstd && rm -rf /var/lib/apt/lists/* diff --git a/.github/workflows/release-dev.yml b/.github/workflows/release-dev.yml index 7b9bb2f92..520a51c65 100644 --- a/.github/workflows/release-dev.yml +++ b/.github/workflows/release-dev.yml @@ -179,8 +179,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Mark workspace safe for git run: git config --global --add safe.directory "$GITHUB_WORKSPACE" @@ -349,8 +347,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Build macOS binary via Docker run: | @@ -498,8 +494,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Build macOS binary via Docker run: | diff --git a/.github/workflows/release-tag.yml b/.github/workflows/release-tag.yml index 60966b3b6..19d9df47f 100644 --- a/.github/workflows/release-tag.yml +++ b/.github/workflows/release-tag.yml @@ -210,8 +210,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Mark workspace safe for git run: git config --global --add safe.directory "$GITHUB_WORKSPACE" @@ -382,8 +380,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Build macOS binary via Docker run: | @@ -617,8 +613,6 @@ jobs: - name: Set up Docker Buildx uses: ./.github/actions/setup-buildx - with: - driver: local - name: Build macOS binary via Docker run: | From 1d3b741ee38bbb8cee701e1760740e494398ccf7 Mon Sep 17 00:00:00 2001 From: "John T. Myers" <9696606+johntmyers@users.noreply.github.com> Date: Fri, 8 May 2026 11:14:44 -0700 Subject: [PATCH 010/157] feat(providers): support sandbox provider attach lifecycle (#1242) * feat(providers): support sandbox provider attach lifecycle Closes #1171 Adds sandbox provider list, attach, and detach API/CLI support while keeping provider policy and credential resolution derived from current sandbox attachments. * fix(providers): refresh sandbox provider credentials Adds provider environment revisions and generation-scoped sandbox credential snapshots so future SSH and exec launches pick up provider attach, detach, and credential updates without mutating already-running processes. Also blocks provider deletion while attached to prevent stale sandbox provider references. * fix(providers): serialize sandbox object mutations * test(providers): cover sandbox provider attach lifecycle * test(providers): accept versioned credential placeholders --- crates/openshell-cli/src/main.rs | 77 +++ crates/openshell-cli/src/run.rs | 203 +++++++- .../tests/ensure_providers_integration.rs | 43 +- .../openshell-cli/tests/mtls_integration.rs | 27 ++ .../tests/provider_commands_integration.rs | 237 +++++++++- .../sandbox_create_lifecycle_integration.rs | 44 +- .../sandbox_name_fallback_integration.rs | 42 +- crates/openshell-sandbox/src/grpc_client.rs | 15 +- crates/openshell-sandbox/src/lib.rs | 123 +++-- .../src/provider_credentials.rs | 143 ++++++ crates/openshell-sandbox/src/proxy.rs | 9 +- crates/openshell-sandbox/src/secrets.rs | 30 +- crates/openshell-sandbox/src/ssh.rs | 24 +- crates/openshell-server/src/auth/authz.rs | 29 +- crates/openshell-server/src/compute/mod.rs | 10 + crates/openshell-server/src/grpc/mod.rs | 55 ++- crates/openshell-server/src/grpc/policy.rs | 419 ++++++++++++++++- crates/openshell-server/src/grpc/provider.rs | 88 +++- crates/openshell-server/src/grpc/sandbox.rs | 438 +++++++++++++++++- .../tests/auth_endpoint_integration.rs | 30 ++ .../tests/edge_tunnel_auth.rs | 27 ++ .../tests/multiplex_integration.rs | 27 ++ .../tests/multiplex_tls_integration.rs | 27 ++ .../tests/supervisor_relay_integration.rs | 18 + .../tests/ws_tunnel_integration.rs | 27 ++ e2e/python/test_sandbox_providers.py | 100 +++- e2e/rust/tests/provider_auto_create.rs | 12 +- proto/openshell.proto | 55 +++ proto/sandbox.proto | 3 + 29 files changed, 2222 insertions(+), 160 deletions(-) create mode 100644 crates/openshell-sandbox/src/provider_credentials.rs diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index cd14568ef..25fa07cf2 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1260,6 +1260,45 @@ enum SandboxCommands { #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] name: Option, }, + + /// Manage providers attached to a sandbox. + #[command(subcommand)] + Provider(SandboxProviderCommands), +} + +#[derive(Subcommand, Debug)] +enum SandboxProviderCommands { + /// List providers attached to a sandbox. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + List { + /// Sandbox name (defaults to last-used sandbox). + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: Option, + }, + + /// Attach a provider to a sandbox. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Attach { + /// Sandbox name. + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: String, + + /// Provider name to attach. + #[arg(add = ArgValueCompleter::new(completers::complete_provider_names))] + provider: String, + }, + + /// Detach a provider from a sandbox. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Detach { + /// Sandbox name. + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: String, + + /// Provider name to detach. + #[arg(add = ArgValueCompleter::new(completers::complete_provider_names))] + provider: String, + }, } #[derive(Subcommand, Debug)] @@ -2385,6 +2424,20 @@ async fn main() -> Result<()> { let name = resolve_sandbox_name(name, &ctx.name)?; run::print_ssh_config(&ctx.name, &name); } + SandboxCommands::Provider(command) => match command { + SandboxProviderCommands::List { name } => { + let name = resolve_sandbox_name(name, &ctx.name)?; + run::sandbox_provider_list(endpoint, &name, &tls).await?; + } + SandboxProviderCommands::Attach { name, provider } => { + run::sandbox_provider_attach(endpoint, &name, &provider, &tls) + .await?; + } + SandboxProviderCommands::Detach { name, provider } => { + run::sandbox_provider_detach(endpoint, &name, &provider, &tls) + .await?; + } + }, } } } @@ -2721,6 +2774,30 @@ mod tests { ); } + #[test] + fn sandbox_provider_subcommands_parse() { + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "provider", + "attach", + "work-sandbox", + "work-github", + ]) + .expect("sandbox provider attach should parse"); + + let Some(Commands::Sandbox { + command: + Some(SandboxCommands::Provider(SandboxProviderCommands::Attach { name, provider })), + }) = cli.command + else { + panic!("expected sandbox provider attach command"); + }; + + assert_eq!(name, "work-sandbox"); + assert_eq!(provider, "work-github"); + } + #[test] fn completions_policy_flag_falls_back_to_file_paths() { let temp = tempfile::tempdir().expect("failed to create tempdir"); diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 102bc87ab..dd1b6721d 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -25,18 +25,19 @@ use openshell_bootstrap::{ }; use openshell_core::proto::ProviderProfileCategory; use openshell_core::proto::{ - ApproveAllDraftChunksRequest, ApproveDraftChunkRequest, ClearDraftChunksRequest, - CreateProviderRequest, CreateSandboxRequest, DeleteProviderProfileRequest, - DeleteProviderRequest, DeleteSandboxRequest, ExecSandboxRequest, GetClusterInferenceRequest, + ApproveAllDraftChunksRequest, ApproveDraftChunkRequest, AttachSandboxProviderRequest, + ClearDraftChunksRequest, CreateProviderRequest, CreateSandboxRequest, + DeleteProviderProfileRequest, DeleteProviderRequest, DeleteSandboxRequest, + DetachSandboxProviderRequest, ExecSandboxRequest, GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, - ListSandboxPoliciesRequest, ListSandboxesRequest, PolicySource, PolicyStatus, Provider, - ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, - Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, SetClusterInferenceRequest, - SettingScope, SettingValue, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, - exec_sandbox_event, setting_value, + ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest, PolicySource, + PolicyStatus, Provider, ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, + RejectDraftChunkRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, + SetClusterInferenceRequest, SettingScope, SettingValue, UpdateConfigRequest, + UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; @@ -2512,6 +2513,143 @@ pub async fn sandbox_list( Ok(()) } +pub async fn sandbox_provider_list(server: &str, name: &str, tls: &TlsOptions) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .list_sandbox_providers(ListSandboxProvidersRequest { + sandbox_name: name.to_string(), + }) + .await + .into_diagnostic()?; + let providers = response.into_inner().providers; + + if providers.is_empty() { + println!("No providers attached to sandbox {name}."); + return Ok(()); + } + + print_provider_attachment_table(&providers); + Ok(()) +} + +pub async fn sandbox_provider_attach( + server: &str, + name: &str, + provider: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .attach_sandbox_provider(AttachSandboxProviderRequest { + sandbox_name: name.to_string(), + provider_name: provider.to_string(), + }) + .await + .into_diagnostic()? + .into_inner(); + + if response.attached { + println!( + "{} Attached provider {} to sandbox {}", + "✓".green().bold(), + provider, + name + ); + } else { + println!("Provider {provider} is already attached to sandbox {name}."); + } + Ok(()) +} + +pub async fn sandbox_provider_detach( + server: &str, + name: &str, + provider: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .detach_sandbox_provider(DetachSandboxProviderRequest { + sandbox_name: name.to_string(), + provider_name: provider.to_string(), + }) + .await + .into_diagnostic()? + .into_inner(); + + if response.detached { + println!( + "{} Detached provider {} from sandbox {}", + "✓".green().bold(), + provider, + name + ); + } else { + println!("Provider {provider} was not attached to sandbox {name}."); + } + Ok(()) +} + +fn print_provider_attachment_table(providers: &[Provider]) { + print!("{}", format_provider_attachment_table(providers, true)); +} + +fn format_provider_attachment_table(providers: &[Provider], color: bool) -> String { + use std::fmt::Write as _; + + let name_width = providers + .iter() + .map(|provider| provider.object_name().len()) + .max() + .unwrap_or(4) + .max(4); + let type_width = providers + .iter() + .map(|provider| provider.r#type.len()) + .max() + .unwrap_or(4) + .max(4); + + let name_header = if color { + "NAME".bold().to_string() + } else { + "NAME".to_string() + }; + let type_header = if color { + "TYPE".bold().to_string() + } else { + "TYPE".to_string() + }; + let credential_keys_header = if color { + "CREDENTIAL_KEYS".bold().to_string() + } else { + "CREDENTIAL_KEYS".to_string() + }; + let config_keys_header = if color { + "CONFIG_KEYS".bold().to_string() + } else { + "CONFIG_KEYS".to_string() + }; + + let mut output = String::new(); + let _ = writeln!( + output, + "{name_header: String { mod tests { use super::{ TlsOptions, dockerfile_sources_supported_for_gateway, format_gateway_select_header, - format_gateway_select_items, gateway_add, gateway_auth_label, gateway_env_override_warning, - gateway_select_with, gateway_type_label, git_sync_files, http_health_check, - image_requests_gpu, inferred_provider_type, parse_cli_setting_value, - parse_credential_pairs, plaintext_gateway_is_remote, provisioning_timeout_message, - ready_false_condition_message, resolve_from, sandbox_should_persist, + format_gateway_select_items, format_provider_attachment_table, gateway_add, + gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, + git_sync_files, http_health_check, image_requests_gpu, inferred_provider_type, + parse_cli_setting_value, parse_credential_pairs, plaintext_gateway_is_remote, + provisioning_timeout_message, ready_false_condition_message, resolve_from, + sandbox_should_persist, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -5268,7 +5407,9 @@ mod tests { use std::thread; use openshell_bootstrap::GatewayMetadata; - use openshell_core::proto::{SandboxCondition, SandboxStatus}; + use openshell_core::proto::{ + Provider, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta, + }; struct EnvVarGuard { key: &'static str, @@ -5371,6 +5512,40 @@ mod tests { )); } + #[test] + fn provider_attachment_table_formats_provider_counts() { + let output = format_provider_attachment_table( + &[Provider { + metadata: Some(ObjectMeta { + name: "work-custom".to_string(), + ..Default::default() + }), + r#type: "custom-api".to_string(), + credentials: [ + ("CUSTOM_API_KEY".to_string(), "REDACTED".to_string()), + ("CUSTOM_API_SECRET".to_string(), "REDACTED".to_string()), + ] + .into_iter() + .collect(), + config: std::iter::once(( + "BASE_URL".to_string(), + "https://api.custom.example".to_string(), + )) + .collect(), + }], + false, + ); + + assert!(output.contains("NAME")); + assert!(output.contains("TYPE")); + assert!(output.contains("CREDENTIAL_KEYS")); + assert!(output.contains("CONFIG_KEYS")); + assert!(output.contains("work-custom")); + assert!(output.contains("custom-api")); + assert!(output.contains('2')); + assert!(output.contains('1')); + } + #[cfg(feature = "dev-settings")] #[test] fn parse_cli_setting_value_parses_bool_aliases() { diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index 15f620e8e..fec161c53 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -9,16 +9,18 @@ use openshell_cli::run; use openshell_cli::tls::TlsOptions; use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, - GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - Provider, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, + CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, + DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, + GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, + GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, + HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, Provider, ProviderResponse, RevokeSshSessionRequest, + RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, + SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, }; use openshell_core::{ObjectId, ObjectName}; use rcgen::{ @@ -153,6 +155,27 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(ListSandboxProvidersResponse::default())) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(AttachSandboxProviderResponse::default())) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(DetachSandboxProviderResponse::default())) + } + async fn delete_sandbox( &self, _request: tonic::Request, diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index 866048a81..e833e7af9 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -99,6 +99,33 @@ impl OpenShell for TestOpenShell { )) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ListSandboxProvidersResponse::default(), + )) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::DetachSandboxProviderResponse::default(), + )) + } + async fn delete_sandbox( &self, _request: tonic::Request, diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 55ed69500..3902bda34 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -5,16 +5,18 @@ use openshell_cli::run; use openshell_cli::tls::TlsOptions; use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, - GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - Provider, ProviderProfile, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, - SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, + CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, + DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, + GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, + GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, + HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, Provider, ProviderProfile, ProviderResponse, RevokeSshSessionRequest, + RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, + SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, }; use openshell_core::{ObjectId, ObjectName}; use rcgen::{ @@ -64,6 +66,23 @@ impl Drop for EnvVarGuard { struct ProviderState { providers: Arc>>, profiles: Arc>>, + sandbox_providers: Arc>>>, + sandbox_provider_requests: Arc>>, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +enum SandboxProviderRequestLog { + List { + sandbox_name: String, + }, + Attach { + sandbox_name: String, + provider_name: String, + }, + Detach { + sandbox_name: String, + provider_name: String, + }, } #[derive(Clone, Default)] @@ -104,6 +123,120 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + request: tonic::Request, + ) -> Result, Status> { + let sandbox_name = request.into_inner().sandbox_name; + self.state + .sandbox_provider_requests + .lock() + .await + .push(SandboxProviderRequestLog::List { + sandbox_name: sandbox_name.clone(), + }); + let provider_names = self + .state + .sandbox_providers + .lock() + .await + .get(&sandbox_name) + .cloned() + .unwrap_or_default(); + let providers_by_name = self.state.providers.lock().await; + let providers = provider_names + .iter() + .filter_map(|name| providers_by_name.get(name).cloned()) + .collect(); + Ok(Response::new(ListSandboxProvidersResponse { providers })) + } + + async fn attach_sandbox_provider( + &self, + request: tonic::Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .sandbox_provider_requests + .lock() + .await + .push(SandboxProviderRequestLog::Attach { + sandbox_name: request.sandbox_name.clone(), + provider_name: request.provider_name.clone(), + }); + if !self + .state + .providers + .lock() + .await + .contains_key(&request.provider_name) + { + return Err(Status::failed_precondition("provider not found")); + } + let mut sandbox_providers = self.state.sandbox_providers.lock().await; + let providers = sandbox_providers + .entry(request.sandbox_name.clone()) + .or_default(); + let attached = if providers.contains(&request.provider_name) { + false + } else { + providers.push(request.provider_name.clone()); + true + }; + let sandbox = openshell_core::proto::Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + name: request.sandbox_name, + ..Default::default() + }), + spec: Some(openshell_core::proto::SandboxSpec { + providers: providers.clone(), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(AttachSandboxProviderResponse { + sandbox: Some(sandbox), + attached, + })) + } + + async fn detach_sandbox_provider( + &self, + request: tonic::Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .sandbox_provider_requests + .lock() + .await + .push(SandboxProviderRequestLog::Detach { + sandbox_name: request.sandbox_name.clone(), + provider_name: request.provider_name.clone(), + }); + let mut sandbox_providers = self.state.sandbox_providers.lock().await; + let providers = sandbox_providers + .entry(request.sandbox_name.clone()) + .or_default(); + let before_len = providers.len(); + providers.retain(|name| name != &request.provider_name); + let detached = providers.len() != before_len; + let sandbox = openshell_core::proto::Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + name: request.sandbox_name, + ..Default::default() + }), + spec: Some(openshell_core::proto::SandboxSpec { + providers: providers.clone(), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(DetachSandboxProviderResponse { + sandbox: Some(sandbox), + detached, + })) + } + async fn delete_sandbox( &self, _request: tonic::Request, @@ -628,6 +761,90 @@ async fn provider_list_profiles_cli_uses_profile_browsing_rpc() { .expect("provider list-profiles"); } +#[tokio::test] +async fn sandbox_provider_cli_run_functions_wire_requests_and_idempotent_results() { + let ts = run_server().await; + + run::provider_create( + &ts.endpoint, + "work-github", + "github", + false, + &["GITHUB_TOKEN=ghp-test".to_string()], + &[], + &ts.tls, + ) + .await + .expect("provider create"); + + run::sandbox_provider_attach(&ts.endpoint, "dev-sandbox", "work-github", &ts.tls) + .await + .expect("sandbox provider attach"); + run::sandbox_provider_attach(&ts.endpoint, "dev-sandbox", "work-github", &ts.tls) + .await + .expect("sandbox provider attach is idempotent"); + run::sandbox_provider_list(&ts.endpoint, "dev-sandbox", &ts.tls) + .await + .expect("sandbox provider list"); + run::sandbox_provider_detach(&ts.endpoint, "dev-sandbox", "work-github", &ts.tls) + .await + .expect("sandbox provider detach"); + run::sandbox_provider_detach(&ts.endpoint, "dev-sandbox", "work-github", &ts.tls) + .await + .expect("sandbox provider detach is idempotent"); + + let requests = ts.state.sandbox_provider_requests.lock().await.clone(); + assert_eq!( + requests, + vec![ + SandboxProviderRequestLog::Attach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "work-github".to_string(), + }, + SandboxProviderRequestLog::Attach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "work-github".to_string(), + }, + SandboxProviderRequestLog::List { + sandbox_name: "dev-sandbox".to_string(), + }, + SandboxProviderRequestLog::Detach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "work-github".to_string(), + }, + SandboxProviderRequestLog::Detach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "work-github".to_string(), + }, + ] + ); + + let providers = ts.state.sandbox_providers.lock().await; + assert!(providers.get("dev-sandbox").is_none_or(Vec::is_empty)); +} + +#[tokio::test] +async fn sandbox_provider_attach_cli_surfaces_server_errors() { + let ts = run_server().await; + + let err = + run::sandbox_provider_attach(&ts.endpoint, "dev-sandbox", "missing-provider", &ts.tls) + .await + .expect_err("missing provider should fail"); + + assert!( + err.to_string().contains("provider not found"), + "unexpected error: {err}" + ); + assert_eq!( + ts.state.sandbox_provider_requests.lock().await.as_slice(), + [SandboxProviderRequestLog::Attach { + sandbox_name: "dev-sandbox".to_string(), + provider_name: "missing-provider".to_string(), + }] + ); +} + #[tokio::test] async fn provider_profile_cli_run_functions_support_custom_profiles() { let ts = run_server().await; diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index a8e359d54..eb28a18b3 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -6,16 +6,19 @@ use openshell_cli::run; use openshell_cli::tls::TlsOptions; use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, - GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - PlatformEvent, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, Sandbox, - SandboxPhase, SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, - UpdateProviderRequest, WatchSandboxRequest, sandbox_stream_event, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, + CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, + DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, + GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, + GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, + HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, PlatformEvent, ProviderResponse, RevokeSshSessionRequest, + RevokeSshSessionResponse, Sandbox, SandboxPhase, SandboxResponse, SandboxStreamEvent, + ServiceStatus, SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, + sandbox_stream_event, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -151,6 +154,27 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(ListSandboxProvidersResponse::default())) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(AttachSandboxProviderResponse::default())) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(DetachSandboxProviderResponse::default())) + } + async fn delete_sandbox( &self, request: tonic::Request, diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index ac1ff37c6..7e6ea68b8 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -6,15 +6,18 @@ use openshell_cli::run; use openshell_cli::tls::TlsOptions; use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, - GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, Sandbox, SandboxPolicy, SandboxResponse, SandboxStreamEvent, ServiceStatus, - SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, + CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, + DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, + GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, + GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, + HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, ProviderResponse, Sandbox, SandboxPolicy, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -129,6 +132,27 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(ListSandboxProvidersResponse::default())) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(AttachSandboxProviderResponse::default())) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(DetachSandboxProviderResponse::default())) + } + async fn delete_sandbox( &self, _request: tonic::Request, diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 44f372355..cc35f67b5 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -244,7 +244,7 @@ pub async fn sync_policy(endpoint: &str, sandbox: &str, policy: &ProtoSandboxPol pub async fn fetch_provider_environment( endpoint: &str, sandbox_id: &str, -) -> Result> { +) -> Result { debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Fetching provider environment"); let mut client = connect(endpoint).await?; @@ -256,7 +256,11 @@ pub async fn fetch_provider_environment( .await .into_diagnostic()?; - Ok(response.into_inner().environment) + let inner = response.into_inner(); + Ok(ProviderEnvironmentResult { + environment: inner.environment, + provider_env_revision: inner.provider_env_revision, + }) } /// A reusable gRPC client for the `OpenShell` service. @@ -279,6 +283,12 @@ pub struct SettingsPollResult { pub settings: HashMap, /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, + pub provider_env_revision: u64, +} + +pub struct ProviderEnvironmentResult { + pub environment: HashMap, + pub provider_env_revision: u64, } impl CachedOpenShellClient { @@ -315,6 +325,7 @@ impl CachedOpenShellClient { .unwrap_or(PolicySource::Unspecified), settings: inner.settings, global_policy_version: inner.global_policy_version, + provider_env_revision: inner.provider_env_revision, }) } diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 19424bd2b..abbf7eb65 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -17,6 +17,7 @@ pub mod opa; mod policy; mod process; pub mod procfs; +mod provider_credentials; pub mod proxy; mod sandbox; mod secrets; @@ -97,7 +98,6 @@ use crate::policy::{NetworkMode, NetworkPolicy, ProxyPolicy, SandboxPolicy}; use crate::proxy::ProxyHandle; #[cfg(target_os = "linux")] use crate::sandbox::linux::netns::NetworkNamespace; -use crate::secrets::SecretResolver; pub use process::{ProcessHandle, ProcessStatus}; pub use sandbox::apply_supervisor_startup_hardening; @@ -269,42 +269,46 @@ pub async fn run_sandbox( // Fetch provider environment variables from the server. // This is done after loading the policy so the sandbox can still start // even if provider env fetch fails (graceful degradation). - let provider_env = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { - match grpc_client::fetch_provider_environment(endpoint, id).await { - Ok(env) => { - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .state(StateId::Enabled, "loaded") - .message(format!( - "Fetched provider environment [env_count:{}]", - env.len() - )) - .build() - ); - env - } - Err(e) => { - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .state(StateId::Other, "degraded") - .message(format!( - "Failed to fetch provider environment, continuing without: {e}" - )) - .build() - ); - std::collections::HashMap::new() + let (provider_env_revision, provider_env) = + if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { + match grpc_client::fetch_provider_environment(endpoint, id).await { + Ok(result) => { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .message(format!( + "Fetched provider environment [env_count:{}]", + result.environment.len() + )) + .build() + ); + (result.provider_env_revision, result.environment) + } + Err(e) => { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "degraded") + .message(format!( + "Failed to fetch provider environment, continuing without: {e}" + )) + .build() + ); + (0, std::collections::HashMap::new()) + } } - } - } else { - std::collections::HashMap::new() - }; + } else { + (0, std::collections::HashMap::new()) + }; - let (provider_env, secret_resolver) = SecretResolver::from_provider_env(provider_env); - let secret_resolver = secret_resolver.map(Arc::new); + let provider_credentials = provider_credentials::ProviderCredentialState::from_environment( + provider_env_revision, + provider_env, + ); + let provider_env = provider_credentials.snapshot().child_env.clone(); // Create identity cache for SHA256 TOFU when OPA is active let identity_cache = opa_engine @@ -480,7 +484,7 @@ pub async fn run_sandbox( entrypoint_pid.clone(), tls_state, inference_ctx, - secret_resolver.clone(), + Some(provider_credentials.clone()), denial_tx, ) .await?; @@ -619,7 +623,7 @@ pub async fn run_sandbox( let proxy_url = ssh_proxy_url; let netns_fd = ssh_netns_fd; let ca_paths = ca_file_paths.clone(); - let provider_env_clone = provider_env.clone(); + let provider_credentials_clone = provider_credentials.clone(); let (ssh_ready_tx, ssh_ready_rx) = tokio::sync::oneshot::channel(); @@ -632,7 +636,7 @@ pub async fn run_sandbox( netns_fd, proxy_url, ca_paths, - provider_env_clone, + provider_credentials_clone, ) .await { @@ -796,6 +800,7 @@ pub async fn run_sandbox( let poll_engine = engine.clone(); let poll_ocsf_enabled = ocsf_enabled.clone(); let poll_pid = entrypoint_pid.clone(); + let poll_provider_credentials = provider_credentials.clone(); let poll_interval_secs: u64 = std::env::var("OPENSHELL_POLICY_POLL_INTERVAL_SECS") .ok() .and_then(|v| v.parse().ok()) @@ -809,6 +814,7 @@ pub async fn run_sandbox( &poll_pid, poll_interval_secs, &poll_ocsf_enabled, + poll_provider_credentials, ) .await { @@ -2152,6 +2158,7 @@ async fn run_policy_poll_loop( entrypoint_pid: &Arc, interval_secs: u64, ocsf_enabled: &std::sync::atomic::AtomicBool, + provider_credentials: provider_credentials::ProviderCredentialState, ) -> Result<()> { use crate::grpc_client::CachedOpenShellClient; use openshell_core::proto::PolicySource; @@ -2159,6 +2166,7 @@ async fn run_policy_poll_loop( let client = CachedOpenShellClient::connect(endpoint).await?; let mut current_config_revision: u64 = 0; + let mut current_provider_env_revision: u64 = provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); let mut current_settings: std::collections::HashMap< String, @@ -2193,7 +2201,8 @@ async fn run_policy_poll_loop( } }; - if result.config_revision == current_config_revision { + let provider_env_changed = result.provider_env_revision != current_provider_env_revision; + if result.config_revision == current_config_revision && !provider_env_changed { continue; } @@ -2209,12 +2218,46 @@ async fn run_policy_poll_loop( .unmapped("old_config_revision", serde_json::json!(current_config_revision)) .unmapped("new_config_revision", serde_json::json!(result.config_revision)) .unmapped("policy_changed", serde_json::json!(policy_changed)) + .unmapped("provider_env_changed", serde_json::json!(provider_env_changed)) .message(format!( - "Settings poll: config change detected [old_revision:{current_config_revision} new_revision:{} policy_changed:{policy_changed}]", + "Settings poll: config change detected [old_revision:{current_config_revision} new_revision:{} policy_changed:{policy_changed} provider_env_changed:{provider_env_changed}]", result.config_revision )) .build()); + if provider_env_changed { + match grpc_client::fetch_provider_environment(endpoint, sandbox_id).await { + Ok(env_result) => { + let env_count = provider_credentials.install_environment( + env_result.provider_env_revision, + env_result.environment, + ); + current_provider_env_revision = env_result.provider_env_revision; + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped( + "provider_env_revision", + serde_json::json!(current_provider_env_revision) + ) + .message(format!( + "Provider environment refreshed [revision:{current_provider_env_revision} env_count:{env_count}]" + )) + .build() + ); + } + Err(e) => { + warn!( + error = %e, + provider_env_revision = result.provider_env_revision, + "Settings poll: failed to refresh provider environment" + ); + } + } + } + // Only reload OPA when the policy payload actually changed. if policy_changed { let Some(policy) = result.policy.as_ref() else { diff --git a/crates/openshell-sandbox/src/provider_credentials.rs b/crates/openshell-sandbox/src/provider_credentials.rs new file mode 100644 index 000000000..bd28824ae --- /dev/null +++ b/crates/openshell-sandbox/src/provider_credentials.rs @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Runtime provider credential snapshots. + +use crate::secrets::SecretResolver; +use std::collections::{HashMap, VecDeque}; +use std::sync::{Arc, RwLock}; + +const MAX_RETAINED_CREDENTIAL_GENERATIONS: usize = 8; + +#[derive(Debug, Clone, Default)] +pub struct ProviderCredentialSnapshot { + pub revision: u64, + pub child_env: HashMap, +} + +#[derive(Debug)] +struct ProviderCredentialStateInner { + current: Arc, + generations: VecDeque>, + combined_resolver: Option>, +} + +#[derive(Debug, Clone)] +pub struct ProviderCredentialState { + inner: Arc>, +} + +impl ProviderCredentialState { + pub fn from_environment(revision: u64, env: HashMap) -> Self { + let (child_env, resolver) = SecretResolver::from_provider_env_for_revision(env, revision); + let snapshot = Arc::new(ProviderCredentialSnapshot { + revision, + child_env, + }); + let generations: VecDeque<_> = resolver.map(Arc::new).into_iter().collect(); + let combined_resolver = + SecretResolver::merge(generations.iter().map(Arc::as_ref)).map(Arc::new); + + Self { + inner: Arc::new(RwLock::new(ProviderCredentialStateInner { + current: snapshot, + generations, + combined_resolver, + })), + } + } + + pub fn snapshot(&self) -> Arc { + self.inner + .read() + .expect("provider credential state poisoned") + .current + .clone() + } + + pub fn resolver(&self) -> Option> { + self.inner + .read() + .expect("provider credential state poisoned") + .combined_resolver + .clone() + } + + pub fn install_environment(&self, revision: u64, env: HashMap) -> usize { + let (child_env, resolver) = SecretResolver::from_provider_env_for_revision(env, revision); + let mut inner = self + .inner + .write() + .expect("provider credential state poisoned"); + + inner.current = Arc::new(ProviderCredentialSnapshot { + revision, + child_env, + }); + + if let Some(resolver) = resolver { + inner.generations.push_back(Arc::new(resolver)); + while inner.generations.len() > MAX_RETAINED_CREDENTIAL_GENERATIONS { + inner.generations.pop_front(); + } + } + inner.combined_resolver = + SecretResolver::merge(inner.generations.iter().map(Arc::as_ref)).map(Arc::new); + inner.current.child_env.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn snapshots_use_revision_scoped_placeholders() { + let state = ProviderCredentialState::from_environment( + 10, + HashMap::from([("GITHUB_TOKEN".to_string(), "old".to_string())]), + ); + let first = state.snapshot(); + assert_eq!( + first.child_env.get("GITHUB_TOKEN").map(String::as_str), + Some("openshell:resolve:env:v10_GITHUB_TOKEN") + ); + + state.install_environment( + 11, + HashMap::from([("GITHUB_TOKEN".to_string(), "new".to_string())]), + ); + let second = state.snapshot(); + assert_eq!( + second.child_env.get("GITHUB_TOKEN").map(String::as_str), + Some("openshell:resolve:env:v11_GITHUB_TOKEN") + ); + + let resolver = state.resolver().expect("resolver"); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:v10_GITHUB_TOKEN"), + Some("old") + ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:v11_GITHUB_TOKEN"), + Some("new") + ); + } + + #[test] + fn empty_refresh_removes_env_from_new_snapshots_but_retains_old_resolver() { + let state = ProviderCredentialState::from_environment( + 10, + HashMap::from([("GITHUB_TOKEN".to_string(), "old".to_string())]), + ); + + state.install_environment(11, HashMap::new()); + + assert!(state.snapshot().child_env.is_empty()); + let resolver = state.resolver().expect("old resolver retained"); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:v10_GITHUB_TOKEN"), + Some("old") + ); + } +} diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 5344374ac..179576d82 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -8,6 +8,7 @@ use crate::identity::BinaryIdentityCache; use crate::l7::tls::ProxyTlsState; use crate::opa::{NetworkAction, OpaEngine, PolicyGenerationGuard}; use crate::policy::ProxyPolicy; +use crate::provider_credentials::ProviderCredentialState; use crate::secrets::{SecretResolver, rewrite_header_line}; use miette::{IntoDiagnostic, Result}; use openshell_core::net::{is_always_blocked_ip, is_internal_ip}; @@ -147,7 +148,7 @@ impl ProxyHandle { /// The proxy uses OPA for network decisions with process-identity binding /// via `/proc/net/tcp`. All connections are evaluated through OPA policy. #[allow(clippy::too_many_arguments)] - pub async fn start_with_bind_addr( + pub(crate) async fn start_with_bind_addr( policy: &ProxyPolicy, bind_addr: Option, opa_engine: Arc, @@ -155,7 +156,7 @@ impl ProxyHandle { entrypoint_pid: Arc, tls_state: Option>, inference_ctx: Option>, - secret_resolver: Option>, + provider_credentials: Option, denial_tx: Option>, ) -> Result { // Use override bind_addr, fall back to policy http_addr, then default @@ -194,7 +195,9 @@ impl ProxyHandle { let spid = entrypoint_pid.clone(); let tls = tls_state.clone(); let inf = inference_ctx.clone(); - let resolver = secret_resolver.clone(); + let resolver = provider_credentials + .as_ref() + .and_then(ProviderCredentialState::resolver); let dtx = denial_tx.clone(); tokio::spawn(async move { if let Err(err) = handle_tcp_connection( diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index 63e253e50..d645e1482 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -70,8 +70,16 @@ pub struct SecretResolver { } impl SecretResolver { + #[cfg_attr(not(test), allow(dead_code))] pub(crate) fn from_provider_env( provider_env: HashMap, + ) -> (HashMap, Option) { + Self::from_provider_env_for_revision(provider_env, 0) + } + + pub(crate) fn from_provider_env_for_revision( + provider_env: HashMap, + revision: u64, ) -> (HashMap, Option) { if provider_env.is_empty() { return (HashMap::new(), None); @@ -81,7 +89,7 @@ impl SecretResolver { let mut by_placeholder = HashMap::with_capacity(provider_env.len()); for (key, value) in provider_env { - let placeholder = placeholder_for_env_key(&key); + let placeholder = placeholder_for_env_key_for_revision(&key, revision); child_env.insert(key, placeholder.clone()); by_placeholder.insert(placeholder, value); } @@ -89,6 +97,18 @@ impl SecretResolver { (child_env, Some(Self { by_placeholder })) } + pub(crate) fn merge<'a>(resolvers: impl IntoIterator) -> Option { + let mut by_placeholder = HashMap::new(); + for resolver in resolvers { + by_placeholder.extend(resolver.by_placeholder.clone()); + } + if by_placeholder.is_empty() { + None + } else { + Some(Self { by_placeholder }) + } + } + /// Resolve a placeholder string to the real secret value. /// /// Returns `None` if the placeholder is unknown or the resolved value @@ -178,6 +198,14 @@ pub fn placeholder_for_env_key(key: &str) -> String { format!("{PLACEHOLDER_PREFIX}{key}") } +pub fn placeholder_for_env_key_for_revision(key: &str, revision: u64) -> String { + if revision == 0 { + placeholder_for_env_key(key) + } else { + format!("{PLACEHOLDER_PREFIX}v{revision}_{key}") + } +} + // --------------------------------------------------------------------------- // Secret validation (F1 — CWE-113) // --------------------------------------------------------------------------- diff --git a/crates/openshell-sandbox/src/ssh.rs b/crates/openshell-sandbox/src/ssh.rs index 9434d0a16..355fdc037 100644 --- a/crates/openshell-sandbox/src/ssh.rs +++ b/crates/openshell-sandbox/src/ssh.rs @@ -6,6 +6,7 @@ use crate::child_env; use crate::policy::SandboxPolicy; use crate::process::drop_privileges; +use crate::provider_credentials::ProviderCredentialState; use crate::sandbox; #[cfg(target_os = "linux")] use crate::{register_managed_child, unregister_managed_child}; @@ -105,7 +106,7 @@ pub async fn run_ssh_server( netns_fd: Option, proxy_url: Option, ca_file_paths: Option<(PathBuf, PathBuf)>, - provider_env: HashMap, + provider_credentials: ProviderCredentialState, ) -> Result<()> { let (listener, config, ca_paths) = match ssh_server_init(&listen_path, &ca_file_paths) { Ok(v) => { @@ -129,7 +130,7 @@ pub async fn run_ssh_server( let workdir = workdir.clone(); let proxy_url = proxy_url.clone(); let ca_paths = ca_paths.clone(); - let provider_env = provider_env.clone(); + let provider_credentials = provider_credentials.clone(); tokio::spawn(async move { if let Err(err) = handle_connection( @@ -140,7 +141,7 @@ pub async fn run_ssh_server( netns_fd, proxy_url, ca_paths, - provider_env, + provider_credentials, ) .await { @@ -166,7 +167,7 @@ async fn handle_connection( netns_fd: Option, proxy_url: Option, ca_file_paths: Option>, - provider_env: HashMap, + provider_credentials: ProviderCredentialState, ) -> Result<()> { // Access is gated by the Unix-socket filesystem permissions (root-only), // not by an application-level preface. The supervisor bridges the @@ -188,7 +189,7 @@ async fn handle_connection( netns_fd, proxy_url, ca_file_paths, - provider_env, + provider_credentials, ); russh::server::run_stream(config, stream, handler) .await @@ -215,7 +216,7 @@ struct SshHandler { netns_fd: Option, proxy_url: Option, ca_file_paths: Option>, - provider_env: HashMap, + provider_credentials: ProviderCredentialState, channels: HashMap, } @@ -226,7 +227,7 @@ impl SshHandler { netns_fd: Option, proxy_url: Option, ca_file_paths: Option>, - provider_env: HashMap, + provider_credentials: ProviderCredentialState, ) -> Self { Self { policy, @@ -234,7 +235,7 @@ impl SshHandler { netns_fd, proxy_url, ca_file_paths, - provider_env, + provider_credentials, channels: HashMap::new(), } } @@ -456,7 +457,7 @@ impl russh::server::Handler for SshHandler { self.netns_fd, self.proxy_url.clone(), self.ca_file_paths.clone(), - &self.provider_env, + &self.provider_credentials.snapshot().child_env, )?; let state = self.channels.get_mut(&channel).ok_or_else(|| { anyhow::anyhow!("subsystem_request on unknown channel {channel:?}") @@ -533,6 +534,7 @@ impl SshHandler { handle: Handle, command: Option, ) -> anyhow::Result<()> { + let provider_snapshot = self.provider_credentials.snapshot(); let state = self .channels .get_mut(&channel) @@ -550,7 +552,7 @@ impl SshHandler { self.netns_fd, self.proxy_url.clone(), self.ca_file_paths.clone(), - &self.provider_env, + &provider_snapshot.child_env, )?; state.pty_master = Some(pty_master); state.input_sender = Some(input_sender); @@ -567,7 +569,7 @@ impl SshHandler { self.netns_fd, self.proxy_url.clone(), self.ca_file_paths.clone(), - &self.provider_env, + &provider_snapshot.child_env, )?; state.input_sender = Some(input_sender); } diff --git a/crates/openshell-server/src/auth/authz.rs b/crates/openshell-server/src/auth/authz.rs index 05ac19354..7e69b1cd8 100644 --- a/crates/openshell-server/src/auth/authz.rs +++ b/crates/openshell-server/src/auth/authz.rs @@ -41,6 +41,10 @@ const SCOPED_METHODS: &[(&str, &str)] = &[ // sandbox:read ("/openshell.v1.OpenShell/GetSandbox", "sandbox:read"), ("/openshell.v1.OpenShell/ListSandboxes", "sandbox:read"), + ( + "/openshell.v1.OpenShell/ListSandboxProviders", + "sandbox:read", + ), ("/openshell.v1.OpenShell/WatchSandbox", "sandbox:read"), ("/openshell.v1.OpenShell/GetSandboxLogs", "sandbox:read"), ( @@ -57,6 +61,14 @@ const SCOPED_METHODS: &[(&str, &str)] = &[ ("/openshell.v1.OpenShell/ExecSandbox", "sandbox:write"), ("/openshell.v1.OpenShell/CreateSshSession", "sandbox:write"), ("/openshell.v1.OpenShell/RevokeSshSession", "sandbox:write"), + ( + "/openshell.v1.OpenShell/AttachSandboxProvider", + "sandbox:write", + ), + ( + "/openshell.v1.OpenShell/DetachSandboxProvider", + "sandbox:write", + ), // provider:read ("/openshell.v1.OpenShell/GetProvider", "provider:read"), ("/openshell.v1.OpenShell/ListProviders", "provider:read"), @@ -398,11 +410,26 @@ mod tests { .check(&id, "/openshell.v1.OpenShell/ListSandboxes") .is_ok() ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/ListSandboxProviders") + .is_ok() + ); assert!( policy .check(&id, "/openshell.v1.OpenShell/CreateSandbox") .is_ok() ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/AttachSandboxProvider") + .is_ok() + ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/DetachSandboxProvider") + .is_ok() + ); } #[test] @@ -415,7 +442,7 @@ mod tests { .is_ok() ); let err = policy - .check(&id, "/openshell.v1.OpenShell/CreateSandbox") + .check(&id, "/openshell.v1.OpenShell/AttachSandboxProvider") .unwrap_err(); assert_eq!(err.code(), tonic::Code::PermissionDenied); assert!(err.message().contains("sandbox:write")); diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index be8ebe2ba..d2fd34011 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -275,6 +275,16 @@ impl ComputeRuntime { }) } + /// Serializes sandbox object read-modify-write operations within this + /// gateway process. + /// + /// This is a temporary single-gateway guard for full-object sandbox writes. + /// It is not HA-safe; replace it with DB-backed CAS/resource-version writes + /// tracked by #1255 before enabling multiple gateway writers. + pub(crate) async fn sandbox_sync_guard(&self) -> tokio::sync::OwnedMutexGuard<()> { + self.sync_lock.clone().lock_owned().await + } + pub async fn new_docker( config: openshell_core::Config, docker_config: DockerComputeConfig, diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index 87af948ed..ebb8b1021 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -10,27 +10,29 @@ mod validation; use openshell_core::proto::{ ApproveAllDraftChunksRequest, ApproveAllDraftChunksResponse, ApproveDraftChunkRequest, - ApproveDraftChunkResponse, ClearDraftChunksRequest, ClearDraftChunksResponse, - CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderProfileRequest, DeleteProviderProfileResponse, DeleteProviderRequest, - DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, EditDraftChunkRequest, - EditDraftChunkResponse, ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, - GetDraftHistoryRequest, GetDraftHistoryResponse, GetDraftPolicyRequest, GetDraftPolicyResponse, - GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderProfileRequest, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxLogsRequest, + ApproveDraftChunkResponse, AttachSandboxProviderRequest, AttachSandboxProviderResponse, + ClearDraftChunksRequest, ClearDraftChunksResponse, CreateProviderRequest, CreateSandboxRequest, + CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderProfileRequest, + DeleteProviderProfileResponse, DeleteProviderRequest, DeleteProviderResponse, + DeleteSandboxRequest, DeleteSandboxResponse, DetachSandboxProviderRequest, + DetachSandboxProviderResponse, EditDraftChunkRequest, EditDraftChunkResponse, ExecSandboxEvent, + ExecSandboxRequest, GatewayMessage, GetDraftHistoryRequest, GetDraftHistoryResponse, + GetDraftPolicyRequest, GetDraftPolicyResponse, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderProfileRequest, GetProviderRequest, + GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxLogsRequest, GetSandboxLogsResponse, GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ImportProviderProfilesRequest, ImportProviderProfilesResponse, LintProviderProfilesRequest, LintProviderProfilesResponse, ListProviderProfilesRequest, ListProviderProfilesResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, ListSandboxesRequest, - ListSandboxesResponse, ProviderProfileResponse, ProviderResponse, PushSandboxLogsRequest, - PushSandboxLogsResponse, RejectDraftChunkRequest, RejectDraftChunkResponse, RelayFrame, - ReportPolicyStatusRequest, ReportPolicyStatusResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, SupervisorMessage, - UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, UpdateConfigResponse, - UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, + ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, ListSandboxProvidersRequest, + ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderProfileResponse, ProviderResponse, PushSandboxLogsRequest, PushSandboxLogsResponse, + RejectDraftChunkRequest, RejectDraftChunkResponse, RelayFrame, ReportPolicyStatusRequest, + ReportPolicyStatusResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, + SupervisorMessage, UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, + UpdateConfigResponse, UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, }; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; @@ -199,6 +201,27 @@ impl OpenShell for OpenShellService { sandbox::handle_list_sandboxes(&self.state, request).await } + async fn list_sandbox_providers( + &self, + request: Request, + ) -> Result, Status> { + sandbox::handle_list_sandbox_providers(&self.state, request).await + } + + async fn attach_sandbox_provider( + &self, + request: Request, + ) -> Result, Status> { + sandbox::handle_attach_sandbox_provider(&self.state, request).await + } + + async fn detach_sandbox_provider( + &self, + request: Request, + ) -> Result, Status> { + sandbox::handle_detach_sandbox_provider(&self.state, request).await + } + async fn delete_sandbox( &self, request: Request, diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 2c62c930a..d5a47bcba 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -10,7 +10,7 @@ #![allow(clippy::cast_precision_loss)] // f64->f32 for confidence scores #![allow(clippy::items_after_statements)] // DB_PORTS const inside function -use crate::persistence::{DraftChunkRecord, ObjectId, ObjectName, PolicyRecord, Store}; +use crate::persistence::{DraftChunkRecord, ObjectId, ObjectName, ObjectType, PolicyRecord, Store}; use crate::policy_store::PolicyStoreExt; use crate::{ServerState, auth::oidc}; use openshell_core::proto::policy_merge_operation; @@ -472,6 +472,8 @@ pub(super) async fn handle_get_sandbox_config( let settings = merge_effective_settings(&global_settings, &sandbox_settings)?; let config_revision = compute_config_revision(policy.as_ref(), &settings, policy_source); + let provider_env_revision = + compute_provider_env_revision(state.store.as_ref(), &sandbox_provider_names).await?; Ok(Response::new(GetSandboxConfigResponse { policy, @@ -481,9 +483,52 @@ pub(super) async fn handle_get_sandbox_config( config_revision, policy_source: policy_source.into(), global_policy_version, + provider_env_revision, })) } +pub(super) async fn compute_provider_env_revision( + store: &Store, + provider_names: &[String], +) -> Result { + let mut hasher = Sha256::new(); + hasher.update(b"openshell-provider-env-revision-v1"); + + for provider_name in provider_names { + hasher.update(provider_name.as_bytes()); + match store + .get_by_name(Provider::object_type(), provider_name) + .await + .map_err(|e| { + Status::internal(format!("fetch provider '{provider_name}' failed: {e}")) + })? { + Some(record) => { + hasher.update(record.id.as_bytes()); + hasher.update(record.updated_at_ms.to_le_bytes()); + + let provider = Provider::decode(record.payload.as_slice()).map_err(|e| { + Status::internal(format!("decode provider '{provider_name}' failed: {e}")) + })?; + hasher.update(provider.r#type.as_bytes()); + + let mut credential_keys: Vec<_> = provider.credentials.keys().collect(); + credential_keys.sort(); + for key in credential_keys { + hasher.update(key.as_bytes()); + } + } + None => { + hasher.update(b"missing"); + } + } + } + + let digest = hasher.finalize(); + Ok(u64::from_le_bytes(digest[..8].try_into().map_err( + |_| Status::internal("provider env revision digest too short"), + )?)) +} + async fn profile_provider_policy_layers( store: &Store, provider_names: &[String], @@ -571,19 +616,24 @@ pub(super) async fn handle_get_sandbox_provider_environment( .spec .ok_or_else(|| Status::internal("sandbox has no spec"))?; + let provider_names = spec.providers; + let provider_env_revision = + compute_provider_env_revision(state.store.as_ref(), &provider_names).await?; let environment = - super::provider::resolve_provider_environment(state.store.as_ref(), &spec.providers) + super::provider::resolve_provider_environment(state.store.as_ref(), &provider_names) .await?; info!( sandbox_id = %sandbox_id, - provider_count = spec.providers.len(), + provider_count = provider_names.len(), env_count = environment.len(), + provider_env_revision, "GetSandboxProviderEnvironment request completed successfully" ); Ok(Response::new(GetSandboxProviderEnvironmentResponse { environment, + provider_env_revision, })) } @@ -950,19 +1000,32 @@ pub(super) async fn handle_update_config( validate_static_fields_unchanged(baseline_policy, &new_policy)?; validate_policy_safety(&new_policy)?; } else { - let mut sandbox = sandbox; - if let Some(ref mut spec) = sandbox.spec { - spec.policy = Some(new_policy.clone()); - } - state + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; + let mut sandbox = state .store - .put_message(&sandbox) + .get_message::(&sandbox_id) .await - .map_err(|e| Status::internal(format!("backfill spec.policy failed: {e}")))?; - info!( - sandbox_id = %sandbox_id, - "UpdateConfig: backfilled spec.policy from sandbox-discovered policy" - ); + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + let spec = sandbox + .spec + .as_mut() + .ok_or_else(|| Status::internal("sandbox has no spec"))?; + if let Some(baseline_policy) = spec.policy.as_ref() { + validate_static_fields_unchanged(baseline_policy, &new_policy)?; + validate_policy_safety(&new_policy)?; + } else { + spec.policy = Some(new_policy.clone()); + state + .store + .put_message(&sandbox) + .await + .map_err(|e| Status::internal(format!("backfill spec.policy failed: {e}")))?; + info!( + sandbox_id = %sandbox_id, + "UpdateConfig: backfilled spec.policy from sandbox-discovered policy" + ); + } } let latest = state @@ -1159,6 +1222,7 @@ pub(super) async fn handle_report_policy_status( .store .supersede_older_policies(&req.sandbox_id, version) .await; + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; if let Ok(Some(mut sandbox)) = state.store.get_message::(&req.sandbox_id).await { sandbox.current_policy_version = req.version; let _ = state.store.put_message(&sandbox).await; @@ -3322,6 +3386,333 @@ mod tests { assert_eq!(v2_env.get("GITHUB_TOKEN"), Some(&"ghp-test".to_string())); } + #[tokio::test] + async fn provider_env_revision_changes_when_attached_provider_record_changes() { + use openshell_core::proto::GetSandboxProviderEnvironmentRequest; + use std::time::Duration; + + let state = test_server_state().await; + let mut provider = test_provider("work-github", "github"); + state.store.put_message(&provider).await.unwrap(); + state + .store + .put_message(&test_sandbox( + "sb-provider-revision", + "provider-revision", + test_policy_with_rule("sandbox_only", "sandbox.example.com"), + vec!["work-github".to_string()], + )) + .await + .unwrap(); + + let first = handle_get_sandbox_provider_environment( + &state, + Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-provider-revision".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + tokio::time::sleep(Duration::from_millis(2)).await; + provider + .credentials + .insert("GITHUB_TOKEN".to_string(), "rotated".to_string()); + state.store.put_message(&provider).await.unwrap(); + + let second = handle_get_sandbox_provider_environment( + &state, + Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-provider-revision".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + assert_ne!( + first.provider_env_revision, second.provider_env_revision, + "provider object updates must trigger sandbox credential refresh" + ); + assert_eq!( + second.environment.get("GITHUB_TOKEN"), + Some(&"rotated".to_string()) + ); + } + + #[tokio::test] + async fn sandbox_config_and_provider_env_follow_attached_provider_lifecycle() { + use crate::grpc::sandbox::{ + handle_attach_sandbox_provider, handle_detach_sandbox_provider, + }; + use openshell_core::proto::{ + AttachSandboxProviderRequest, DetachSandboxProviderRequest, + GetSandboxProviderEnvironmentRequest, + }; + + let state = test_server_state().await; + enable_providers_v2(&state).await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox( + "sb-attach-lifecycle", + "attach-lifecycle", + test_policy_with_rule("sandbox_only", "sandbox.example.com"), + Vec::new(), + )) + .await + .unwrap(); + + let baseline_policy = get_sandbox_policy(&state, "sb-attach-lifecycle").await; + assert!( + !baseline_policy + .network_policies + .contains_key("_provider_work_github") + ); + let baseline_env = handle_get_sandbox_provider_environment( + &state, + Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-attach-lifecycle".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "attach-lifecycle".to_string(), + provider_name: "work-github".to_string(), + }), + ) + .await + .unwrap(); + + let attached_policy = get_sandbox_policy(&state, "sb-attach-lifecycle").await; + assert!( + attached_policy + .network_policies + .contains_key("_provider_work_github") + ); + + let attached_env = handle_get_sandbox_provider_environment( + &state, + Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-attach-lifecycle".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_ne!( + baseline_env.provider_env_revision, + attached_env.provider_env_revision + ); + assert_eq!( + attached_env.environment.get("GITHUB_TOKEN"), + Some(&"ghp-test".to_string()) + ); + + handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "attach-lifecycle".to_string(), + provider_name: "work-github".to_string(), + }), + ) + .await + .unwrap(); + + let detached_policy = get_sandbox_policy(&state, "sb-attach-lifecycle").await; + assert!( + !detached_policy + .network_policies + .contains_key("_provider_work_github") + ); + + let detached_env = handle_get_sandbox_provider_environment( + &state, + Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-attach-lifecycle".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_ne!( + attached_env.provider_env_revision, + detached_env.provider_env_revision + ); + assert!(!detached_env.environment.contains_key("GITHUB_TOKEN")); + } + + #[tokio::test] + #[allow(deprecated)] + async fn custom_imported_profile_policy_and_env_follow_attach_detach_lifecycle() { + use crate::grpc::provider::handle_import_provider_profiles; + use crate::grpc::sandbox::{ + handle_attach_sandbox_provider, handle_detach_sandbox_provider, + }; + use openshell_core::proto::{ + AttachSandboxProviderRequest, DetachSandboxProviderRequest, + GetSandboxProviderEnvironmentRequest, ImportProviderProfilesRequest, NetworkBinary, + ProviderProfile, ProviderProfileCategory, ProviderProfileCredential, + ProviderProfileImportItem, + }; + + let state = test_server_state().await; + enable_providers_v2(&state).await; + handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + source: "custom-api.yaml".to_string(), + profile: Some(ProviderProfile { + id: "custom-api".to_string(), + display_name: "Custom API".to_string(), + description: String::new(), + category: ProviderProfileCategory::Other as i32, + credentials: vec![ProviderProfileCredential { + name: "api_key".to_string(), + env_vars: vec!["CUSTOM_API_KEY".to_string()], + auth_style: "bearer".to_string(), + header_name: "authorization".to_string(), + required: true, + ..Default::default() + }], + endpoints: vec![NetworkEndpoint { + host: "api.custom.example".to_string(), + port: 443, + protocol: "rest".to_string(), + rules: vec![L7Rule { + allow: Some(openshell_core::proto::L7Allow { + method: "GET".to_string(), + path: "/v1/**".to_string(), + ..Default::default() + }), + }], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/custom".to_string(), + harness: true, + }], + inference_capable: false, + }), + }], + }), + ) + .await + .unwrap(); + + let mut provider = test_provider("work-custom", "custom-api"); + provider.credentials = + std::iter::once(("CUSTOM_API_KEY".to_string(), "custom-secret".to_string())).collect(); + state.store.put_message(&provider).await.unwrap(); + state + .store + .put_message(&test_sandbox( + "sb-custom-attach-lifecycle", + "custom-attach-lifecycle", + test_policy_with_rule("sandbox_only", "sandbox.example.com"), + Vec::new(), + )) + .await + .unwrap(); + + let baseline_policy = get_sandbox_policy(&state, "sb-custom-attach-lifecycle").await; + assert!( + !baseline_policy + .network_policies + .contains_key("_provider_work_custom") + ); + let baseline_env = handle_get_sandbox_provider_environment( + &state, + Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-custom-attach-lifecycle".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "custom-attach-lifecycle".to_string(), + provider_name: "work-custom".to_string(), + }), + ) + .await + .unwrap(); + + let attached_policy = get_sandbox_policy(&state, "sb-custom-attach-lifecycle").await; + let custom_rule = attached_policy + .network_policies + .get("_provider_work_custom") + .expect("custom provider rule should be composed after attach"); + assert_eq!(custom_rule.endpoints[0].host, "api.custom.example"); + assert_eq!(custom_rule.endpoints[0].protocol, "rest"); + assert_eq!(custom_rule.endpoints[0].rules.len(), 1); + assert_eq!(custom_rule.binaries[0].path, "/usr/bin/custom"); + + let attached_env = handle_get_sandbox_provider_environment( + &state, + Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-custom-attach-lifecycle".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_ne!( + baseline_env.provider_env_revision, + attached_env.provider_env_revision + ); + assert_eq!( + attached_env.environment.get("CUSTOM_API_KEY"), + Some(&"custom-secret".to_string()) + ); + + handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "custom-attach-lifecycle".to_string(), + provider_name: "work-custom".to_string(), + }), + ) + .await + .unwrap(); + + let detached_policy = get_sandbox_policy(&state, "sb-custom-attach-lifecycle").await; + assert!( + !detached_policy + .network_policies + .contains_key("_provider_work_custom") + ); + let detached_env = handle_get_sandbox_provider_environment( + &state, + Request::new(GetSandboxProviderEnvironmentRequest { + sandbox_id: "sb-custom-attach-lifecycle".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_ne!( + attached_env.provider_env_revision, + detached_env.provider_env_revision + ); + assert!(!detached_env.environment.contains_key("CUSTOM_API_KEY")); + } + #[tokio::test] async fn global_policy_suppresses_provider_profile_layers_when_v2_enabled() { use openshell_core::proto::{ diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 2f4893073..2ed4d439d 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -6,7 +6,7 @@ #![allow(clippy::result_large_err)] // gRPC handlers return Result, Status> use crate::persistence::{ObjectName, ObjectType, Store, generate_name}; -use openshell_core::proto::Provider; +use openshell_core::proto::{Provider, Sandbox}; use prost::Message; use tonic::Status; use tracing::warn; @@ -175,12 +175,57 @@ pub(super) async fn delete_provider_record(store: &Store, name: &str) -> Result< return Err(Status::invalid_argument("name is required")); } + let blocking_sandboxes = sandboxes_using_provider(store, name).await?; + if !blocking_sandboxes.is_empty() { + return Err(Status::failed_precondition(format!( + "provider '{name}' is attached to sandbox(es): {}", + blocking_sandboxes.join(", ") + ))); + } + store .delete_by_name(Provider::object_type(), name) .await .map_err(|e| Status::internal(format!("delete provider failed: {e}"))) } +async fn sandboxes_using_provider( + store: &Store, + provider_name: &str, +) -> Result, Status> { + let mut blocking = Vec::new(); + let mut offset = 0; + loop { + let records = store + .list(Sandbox::object_type(), 1000, offset) + .await + .map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?; + if records.is_empty() { + break; + } + offset = offset + .checked_add( + u32::try_from(records.len()) + .map_err(|_| Status::internal("sandbox page size exceeded u32"))?, + ) + .ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?; + + for record in records { + let sandbox = Sandbox::decode(record.payload.as_slice()) + .map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?; + let Some(spec) = sandbox.spec.as_ref() else { + continue; + }; + if spec.providers.iter().any(|name| name == provider_name) { + blocking.push(sandbox.object_name().to_string()); + } + } + } + blocking.sort(); + blocking.dedup(); + Ok(blocking) +} + /// Merge an incoming map into an existing map. /// /// - If `incoming` is empty, return `existing` unchanged (no-op). @@ -278,8 +323,8 @@ use openshell_core::proto::{ ImportProviderProfilesRequest, ImportProviderProfilesResponse, LintProviderProfilesRequest, LintProviderProfilesResponse, ListProviderProfilesRequest, ListProviderProfilesResponse, ListProvidersRequest, ListProvidersResponse, ProviderProfile, ProviderProfileDiagnostic, - ProviderProfileImportItem, ProviderProfileResponse, ProviderResponse, Sandbox, - StoredProviderProfile, UpdateProviderRequest, + ProviderProfileImportItem, ProviderProfileResponse, ProviderResponse, StoredProviderProfile, + UpdateProviderRequest, }; use openshell_providers::{ ProfileValidationDiagnostic, ProviderTypeProfile, default_profiles, get_default_profile, @@ -1414,6 +1459,43 @@ mod tests { assert_eq!(missing.code(), Code::NotFound); } + #[tokio::test] + async fn delete_provider_rejects_attached_provider() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + create_provider_record(&store, provider_with_values("gitlab-local", "gitlab")) + .await + .unwrap(); + store + .put_message(&Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sandbox-id".to_string(), + name: "attached-sandbox".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + }), + spec: Some(SandboxSpec { + providers: vec!["gitlab-local".to_string()], + ..Default::default() + }), + ..Default::default() + }) + .await + .unwrap(); + + let err = delete_provider_record(&store, "gitlab-local") + .await + .unwrap_err(); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!( + err.message().contains("attached-sandbox"), + "error should identify blocking sandbox: {}", + err.message() + ); + } + #[tokio::test] async fn provider_validation_errors() { let store = Store::connect("sqlite::memory:?cache=shared") diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index ed1b4cdfc..65ac69acb 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -13,11 +13,13 @@ use crate::ServerState; use crate::persistence::{ObjectType, generate_name}; use futures::future; use openshell_core::proto::{ - CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteSandboxRequest, - DeleteSandboxResponse, ExecSandboxEvent, ExecSandboxExit, ExecSandboxRequest, - ExecSandboxStderr, ExecSandboxStdout, GetSandboxRequest, ListSandboxesRequest, - ListSandboxesResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, WatchSandboxRequest, + AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateSandboxRequest, + CreateSshSessionRequest, CreateSshSessionResponse, DeleteSandboxRequest, DeleteSandboxResponse, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, ExecSandboxExit, + ExecSandboxRequest, ExecSandboxStderr, ExecSandboxStdout, GetSandboxRequest, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, Provider, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, WatchSandboxRequest, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; use prost::Message; @@ -31,12 +33,12 @@ use tracing::{info, warn}; use russh::ChannelMsg; use russh::client::AuthResult; -use super::provider::is_valid_env_key; +use super::provider::{get_provider_record, is_valid_env_key}; use super::validation::{ level_matches, source_matches, validate_exec_request_fields, validate_policy_safety, validate_sandbox_spec, }; -use super::{MAX_PAGE_SIZE, clamp_limit, current_time_ms}; +use super::{MAX_PAGE_SIZE, MAX_PROVIDERS, clamp_limit, current_time_ms}; // --------------------------------------------------------------------------- // Sandbox lifecycle handlers @@ -66,7 +68,7 @@ pub(super) async fn handle_create_sandbox( for name in &spec.providers { state .store - .get_message_by_name::(name) + .get_message_by_name::(name) .await .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? .ok_or_else(|| Status::failed_precondition(format!("provider '{name}' not found")))?; @@ -192,6 +194,130 @@ pub(super) async fn handle_list_sandboxes( Ok(Response::new(ListSandboxesResponse { sandboxes })) } +pub(super) async fn handle_list_sandbox_providers( + state: &Arc, + request: Request, +) -> Result, Status> { + let sandbox = sandbox_by_name(state, &request.into_inner().sandbox_name).await?; + let providers = providers_for_sandbox(state, &sandbox).await?; + Ok(Response::new(ListSandboxProvidersResponse { providers })) +} + +pub(super) async fn handle_attach_sandbox_provider( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + if request.provider_name.is_empty() { + return Err(Status::invalid_argument("provider_name is required")); + } + + get_provider_record(state.store.as_ref(), &request.provider_name) + .await + .map_err(|err| { + if err.code() == tonic::Code::NotFound { + Status::failed_precondition(format!( + "provider '{}' not found", + request.provider_name + )) + } else { + err + } + })?; + + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; + let mut sandbox = sandbox_by_name(state, &request.sandbox_name).await?; + let sandbox_name = sandbox + .metadata + .as_ref() + .map_or_else(String::new, |metadata| metadata.name.clone()); + let spec = sandbox + .spec + .as_mut() + .ok_or_else(|| Status::failed_precondition("sandbox spec is missing"))?; + + dedupe_provider_names(&mut spec.providers); + let attached = if spec + .providers + .iter() + .any(|name| name == &request.provider_name) + { + false + } else { + if spec.providers.len() >= MAX_PROVIDERS { + return Err(Status::invalid_argument(format!( + "providers list exceeds maximum ({MAX_PROVIDERS})" + ))); + } + spec.providers.push(request.provider_name.clone()); + true + }; + validate_sandbox_spec(&sandbox_name, spec)?; + + state + .store + .put_message(&sandbox) + .await + .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + + info!( + sandbox_name = %request.sandbox_name, + provider_name = %request.provider_name, + attached, + "AttachSandboxProvider request completed successfully" + ); + + Ok(Response::new(AttachSandboxProviderResponse { + sandbox: Some(sandbox), + attached, + })) +} + +pub(super) async fn handle_detach_sandbox_provider( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + if request.provider_name.is_empty() { + return Err(Status::invalid_argument("provider_name is required")); + } + + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; + let mut sandbox = sandbox_by_name(state, &request.sandbox_name).await?; + let sandbox_name = sandbox + .metadata + .as_ref() + .map_or_else(String::new, |metadata| metadata.name.clone()); + let spec = sandbox + .spec + .as_mut() + .ok_or_else(|| Status::failed_precondition("sandbox spec is missing"))?; + + let before_len = spec.providers.len(); + spec.providers.retain(|name| name != &request.provider_name); + let detached = spec.providers.len() != before_len; + dedupe_provider_names(&mut spec.providers); + validate_sandbox_spec(&sandbox_name, spec)?; + + state + .store + .put_message(&sandbox) + .await + .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + + info!( + sandbox_name = %request.sandbox_name, + provider_name = %request.provider_name, + detached, + "DetachSandboxProvider request completed successfully" + ); + + Ok(Response::new(DetachSandboxProviderResponse { + sandbox: Some(sandbox), + detached, + })) +} + pub(super) async fn handle_delete_sandbox( state: &Arc, request: Request, @@ -206,6 +332,56 @@ pub(super) async fn handle_delete_sandbox( Ok(Response::new(DeleteSandboxResponse { deleted })) } +async fn sandbox_by_name(state: &Arc, name: &str) -> Result { + if name.is_empty() { + return Err(Status::invalid_argument("sandbox_name is required")); + } + + state + .store + .get_message_by_name::(name) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found")) +} + +async fn providers_for_sandbox( + state: &Arc, + sandbox: &Sandbox, +) -> Result, Status> { + let provider_names = sandbox + .spec + .as_ref() + .map(|spec| spec.providers.as_slice()) + .ok_or_else(|| Status::failed_precondition("sandbox spec is missing"))?; + + let mut providers = Vec::with_capacity(provider_names.len()); + for name in provider_names { + let provider = get_provider_record(state.store.as_ref(), name) + .await + .map_err(|err| { + if err.code() == tonic::Code::NotFound { + Status::failed_precondition(format!("provider '{name}' not found")) + } else { + err + } + })?; + providers.push(provider); + } + Ok(providers) +} + +fn dedupe_provider_names(provider_names: &mut Vec) { + let mut index = 0; + while index < provider_names.len() { + if provider_names[..index].contains(&provider_names[index]) { + provider_names.remove(index); + } else { + index += 1; + } + } +} + // --------------------------------------------------------------------------- // Watch handler // --------------------------------------------------------------------------- @@ -938,6 +1114,15 @@ async fn run_exec_with_russh( #[cfg(test)] mod tests { use super::*; + use crate::compute::new_test_runtime; + use crate::persistence::Store; + use crate::sandbox_index::SandboxIndex; + use crate::sandbox_watch::SandboxWatchBus; + use crate::supervisor_session::SupervisorSessionRegistry; + use crate::tracing_bus::TracingLogBus; + use openshell_core::Config; + use openshell_core::proto::datamodel::v1::ObjectMeta; + use std::collections::HashMap; // ---- shell_escape ---- @@ -1066,4 +1251,241 @@ mod tests { ); } } + + async fn test_server_state() -> Arc { + let store = Arc::new(Store::connect("sqlite::memory:").await.unwrap()); + let compute = new_test_runtime(store.clone()).await; + Arc::new(ServerState::new( + Config::new(None) + .with_database_url("sqlite::memory:") + .with_ssh_handshake_secret("test-secret"), + store, + compute, + SandboxIndex::new(), + SandboxWatchBus::new(), + TracingLogBus::new(), + Arc::new(SupervisorSessionRegistry::new()), + None, + )) + } + + fn test_provider(name: &str, provider_type: &str) -> Provider { + Provider { + metadata: Some(ObjectMeta { + id: format!("provider-{name}"), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + }), + r#type: provider_type.to_string(), + credentials: std::iter::once(("TOKEN".to_string(), "secret".to_string())).collect(), + config: HashMap::new(), + } + } + + fn test_sandbox(name: &str, providers: Vec) -> Sandbox { + Sandbox { + metadata: Some(ObjectMeta { + id: format!("sandbox-{name}"), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: std::iter::once(("team".to_string(), "agents".to_string())).collect(), + }), + spec: Some(openshell_core::proto::SandboxSpec { + log_level: "debug".to_string(), + policy: Some(openshell_core::proto::SandboxPolicy::default()), + providers, + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + current_policy_version: 7, + ..Default::default() + } + } + + #[tokio::test] + async fn attach_sandbox_provider_persists_current_provider_list() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + let response = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "work-github".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.attached); + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let spec = sandbox.spec.unwrap(); + assert_eq!(spec.providers, vec!["work-github"]); + assert_eq!(spec.log_level, "debug"); + assert_eq!(sandbox.phase, SandboxPhase::Ready as i32); + assert_eq!(sandbox.current_policy_version, 7); + } + + #[tokio::test] + async fn attach_sandbox_provider_is_idempotent_and_avoids_duplicates() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox( + "work", + vec!["work-github".to_string(), "work-github".to_string()], + )) + .await + .unwrap(); + + let response = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "work-github".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(!response.attached); + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers, vec!["work-github"]); + } + + #[tokio::test] + async fn detach_sandbox_provider_is_idempotent_and_removes_all_matches() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox( + "work", + vec![ + "work-github".to_string(), + "other".to_string(), + "work-github".to_string(), + ], + )) + .await + .unwrap(); + + let response = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "work-github".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.detached); + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers, vec!["other"]); + + let response = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "work-github".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(!response.detached); + } + + #[tokio::test] + async fn list_sandbox_providers_returns_attached_provider_records() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", vec!["work-github".to_string()])) + .await + .unwrap(); + + let response = handle_list_sandbox_providers( + &state, + Request::new(ListSandboxProvidersRequest { + sandbox_name: "work".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + assert_eq!(response.providers.len(), 1); + assert_eq!(response.providers[0].r#type, "github"); + assert_eq!( + response.providers[0].credentials.get("TOKEN"), + Some(&"REDACTED".to_string()) + ); + } + + #[tokio::test] + async fn attach_sandbox_provider_validates_provider_exists() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "missing".to_string(), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + } } diff --git a/crates/openshell-server/tests/auth_endpoint_integration.rs b/crates/openshell-server/tests/auth_endpoint_integration.rs index 7b16ee991..c66f2ad6b 100644 --- a/crates/openshell-server/tests/auth_endpoint_integration.rs +++ b/crates/openshell-server/tests/auth_endpoint_integration.rs @@ -426,6 +426,36 @@ impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { )) } + async fn list_sandbox_providers( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> + { + Ok(tonic::Response::new( + openshell_core::proto::ListSandboxProvidersResponse::default(), + )) + } + + async fn attach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> + { + Ok(tonic::Response::new( + openshell_core::proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> + { + Ok(tonic::Response::new( + openshell_core::proto::DetachSandboxProviderResponse::default(), + )) + } + async fn delete_sandbox( &self, _: tonic::Request, diff --git a/crates/openshell-server/tests/edge_tunnel_auth.rs b/crates/openshell-server/tests/edge_tunnel_auth.rs index 39df0819f..706967d1f 100644 --- a/crates/openshell-server/tests/edge_tunnel_auth.rs +++ b/crates/openshell-server/tests/edge_tunnel_auth.rs @@ -106,6 +106,33 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ListSandboxProvidersResponse::default(), + )) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::DetachSandboxProviderResponse::default(), + )) + } + async fn delete_sandbox( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/multiplex_integration.rs b/crates/openshell-server/tests/multiplex_integration.rs index 49c6f9c92..d5631319d 100644 --- a/crates/openshell-server/tests/multiplex_integration.rs +++ b/crates/openshell-server/tests/multiplex_integration.rs @@ -64,6 +64,33 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ListSandboxProvidersResponse::default(), + )) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::DetachSandboxProviderResponse::default(), + )) + } + async fn delete_sandbox( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/multiplex_tls_integration.rs b/crates/openshell-server/tests/multiplex_tls_integration.rs index d6a244e49..c4f68eaf4 100644 --- a/crates/openshell-server/tests/multiplex_tls_integration.rs +++ b/crates/openshell-server/tests/multiplex_tls_integration.rs @@ -77,6 +77,33 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ListSandboxProvidersResponse::default(), + )) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::DetachSandboxProviderResponse::default(), + )) + } + async fn delete_sandbox( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/supervisor_relay_integration.rs b/crates/openshell-server/tests/supervisor_relay_integration.rs index f8519cdc7..8f5cac03a 100644 --- a/crates/openshell-server/tests/supervisor_relay_integration.rs +++ b/crates/openshell-server/tests/supervisor_relay_integration.rs @@ -111,6 +111,24 @@ impl OpenShell for RelayGateway { ) -> Result, Status> { Err(Status::unimplemented("unused")) } + async fn list_sandbox_providers( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn attach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn detach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } async fn delete_sandbox( &self, _: tonic::Request, diff --git a/crates/openshell-server/tests/ws_tunnel_integration.rs b/crates/openshell-server/tests/ws_tunnel_integration.rs index 173a7225d..8212b1085 100644 --- a/crates/openshell-server/tests/ws_tunnel_integration.rs +++ b/crates/openshell-server/tests/ws_tunnel_integration.rs @@ -100,6 +100,33 @@ impl OpenShell for TestOpenShell { Ok(Response::new(ListSandboxesResponse::default())) } + async fn list_sandbox_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::ListSandboxProvidersResponse::default(), + )) + } + + async fn attach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::DetachSandboxProviderResponse::default(), + )) + } + async fn delete_sandbox( &self, _request: tonic::Request, diff --git a/e2e/python/test_sandbox_providers.py b/e2e/python/test_sandbox_providers.py index 0e9349dc0..337e3db8b 100644 --- a/e2e/python/test_sandbox_providers.py +++ b/e2e/python/test_sandbox_providers.py @@ -11,6 +11,7 @@ from __future__ import annotations +import time from contextlib import contextmanager from typing import TYPE_CHECKING @@ -30,6 +31,17 @@ # --------------------------------------------------------------------------- +def _is_placeholder_for_env_key(value: str, key: str) -> bool: + """Return true when value is an OpenShell credential placeholder for key.""" + prefix = "openshell:resolve:env:" + if value == f"{prefix}{key}": + return True + token = value.removeprefix(prefix) + if token == value: + return False + return token.startswith("v") and token.endswith(f"_{key}") + + def _default_policy() -> sandbox_pb2.SandboxPolicy: """Build a sandbox policy with standard filesystem/process/landlock settings.""" return sandbox_pb2.SandboxPolicy( @@ -117,7 +129,7 @@ def read_env_var() -> str: result = sb.exec_python(read_env_var) assert result.exit_code == 0, result.stderr value = result.stdout.strip() - assert value == "openshell:resolve:env:ANTHROPIC_API_KEY" + assert _is_placeholder_for_env_key(value, "ANTHROPIC_API_KEY") assert value != "sk-e2e-test-key-12345" @@ -150,10 +162,9 @@ def read_generic_env_vars() -> str: with sandbox(spec=spec, delete_on_exit=True) as sb: result = sb.exec_python(read_generic_env_vars) assert result.exit_code == 0, result.stderr - assert ( - result.stdout.strip() - == "openshell:resolve:env:CUSTOM_SERVICE_TOKEN|openshell:resolve:env:CUSTOM_SERVICE_URL" - ) + token, url = result.stdout.strip().split("|") + assert _is_placeholder_for_env_key(token, "CUSTOM_SERVICE_TOKEN") + assert _is_placeholder_for_env_key(url, "CUSTOM_SERVICE_URL") def test_nvidia_provider_injects_nvidia_api_key_env_var( @@ -180,7 +191,84 @@ def read_nvidia_key() -> str: with sandbox(spec=spec, delete_on_exit=True) as sb: result = sb.exec_python(read_nvidia_key) assert result.exit_code == 0, result.stderr - assert result.stdout.strip() == "openshell:resolve:env:NVIDIA_API_KEY" + assert _is_placeholder_for_env_key( + result.stdout.strip(), "NVIDIA_API_KEY" + ) + + +def test_attach_detach_updates_credentials_for_later_exec_launches( + sandbox: Callable[..., Sandbox], + sandbox_client: SandboxClient, +) -> None: + """Later exec launches see provider attach/detach credential changes.""" + stub = sandbox_client._stub + provider_name = "e2e-test-attach-detach-env" + + with provider( + stub, + name=provider_name, + provider_type="generic", + credentials={"CUSTOM_ATTACH_TOKEN": "token-attach-detach"}, + ): + spec = datamodel_pb2.SandboxSpec(policy=_default_policy(), providers=[]) + + def read_attach_token() -> str: + import os + + return os.environ.get("CUSTOM_ATTACH_TOKEN", "NOT_SET") + + def exec_token(sb: Sandbox) -> str: + result = sb.exec_python(read_attach_token) + assert result.exit_code == 0, result.stderr + return result.stdout.strip() + + def wait_for_token(sb: Sandbox, expected: str) -> None: + deadline = time.monotonic() + 35 + last = None + while time.monotonic() < deadline: + last = exec_token(sb) + if expected == "NOT_SET": + matched = last == expected + else: + matched = _is_placeholder_for_env_key(last, "CUSTOM_ATTACH_TOKEN") + if matched: + return + time.sleep(2) + pytest.fail(f"expected {expected!r}, last exec saw {last!r}") + + with sandbox(spec=spec, delete_on_exit=True) as sb: + assert exec_token(sb) == "NOT_SET" + + try: + stub.AttachSandboxProvider( + openshell_pb2.AttachSandboxProviderRequest( + sandbox_name=sb.sandbox.name, + provider_name=provider_name, + ) + ) + wait_for_token( + sb, + "openshell:resolve:env:CUSTOM_ATTACH_TOKEN", + ) + + stub.DetachSandboxProvider( + openshell_pb2.DetachSandboxProviderRequest( + sandbox_name=sb.sandbox.name, + provider_name=provider_name, + ) + ) + wait_for_token(sb, "NOT_SET") + finally: + try: + stub.DetachSandboxProvider( + openshell_pb2.DetachSandboxProviderRequest( + sandbox_name=sb.sandbox.name, + provider_name=provider_name, + ) + ) + except grpc.RpcError as exc: + if exc.code() != grpc.StatusCode.NOT_FOUND: + raise # =========================================================================== diff --git a/e2e/rust/tests/provider_auto_create.rs b/e2e/rust/tests/provider_auto_create.rs index 46ccb7999..c678c46c4 100644 --- a/e2e/rust/tests/provider_auto_create.rs +++ b/e2e/rust/tests/provider_auto_create.rs @@ -24,9 +24,17 @@ use openshell_e2e::harness::binary::openshell_cmd; use openshell_e2e::harness::output::{extract_field, strip_ansi}; const TEST_API_KEY: &str = "sk-e2e-auto-provider-test-key"; -const TEST_API_KEY_PLACEHOLDER: &str = "openshell:resolve:env:ANTHROPIC_API_KEY"; static CLAUDE_PROVIDER_LOCK: Mutex<()> = Mutex::new(()); +fn contains_placeholder_for_env_key(output: &str, key: &str) -> bool { + let legacy = format!("openshell:resolve:env:{key}"); + let revision_prefix = "openshell:resolve:env:v"; + let revision_suffix = format!("_{key}"); + output.split_whitespace().any(|token| { + token == legacy || (token.starts_with(revision_prefix) && token.ends_with(&revision_suffix)) + }) +} + /// Helper: delete a provider by name, ignoring errors. async fn delete_provider(name: &str) { let mut cmd = openshell_cmd(); @@ -123,7 +131,7 @@ async fn auto_created_provider_credential_available_in_sandbox() { ); assert!( - clean.contains(TEST_API_KEY_PLACEHOLDER), + contains_placeholder_for_env_key(&clean, "ANTHROPIC_API_KEY"), "sandbox should have placeholder ANTHROPIC_API_KEY in its environment:\n{clean}" ); diff --git a/proto/openshell.proto b/proto/openshell.proto index c6a9bab89..b0291254a 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -30,6 +30,18 @@ service OpenShell { // List sandboxes. rpc ListSandboxes(ListSandboxesRequest) returns (ListSandboxesResponse); + // List provider records attached to a sandbox. + rpc ListSandboxProviders(ListSandboxProvidersRequest) + returns (ListSandboxProvidersResponse); + + // Attach a provider record to an existing sandbox. + rpc AttachSandboxProvider(AttachSandboxProviderRequest) + returns (AttachSandboxProviderResponse); + + // Detach a provider record from an existing sandbox. + rpc DetachSandboxProvider(DetachSandboxProviderRequest) + returns (DetachSandboxProviderResponse); + // Delete a sandbox by name. rpc DeleteSandbox(DeleteSandboxRequest) returns (DeleteSandboxResponse); @@ -339,6 +351,28 @@ message ListSandboxesRequest { string label_selector = 3; } +// List providers attached to a sandbox request. +message ListSandboxProvidersRequest { + // Sandbox name (canonical lookup key). + string sandbox_name = 1; +} + +// Attach provider to sandbox request. +message AttachSandboxProviderRequest { + // Sandbox name (canonical lookup key). + string sandbox_name = 1; + // Provider name to attach. + string provider_name = 2; +} + +// Detach provider from sandbox request. +message DetachSandboxProviderRequest { + // Sandbox name (canonical lookup key). + string sandbox_name = 1; + // Provider name to detach. + string provider_name = 2; +} + // Delete sandbox request. message DeleteSandboxRequest { // Sandbox name (canonical lookup key). @@ -355,6 +389,25 @@ message ListSandboxesResponse { repeated Sandbox sandboxes = 1; } +// List providers attached to a sandbox response. +message ListSandboxProvidersResponse { + repeated openshell.datamodel.v1.Provider providers = 1; +} + +// Attach provider to sandbox response. +message AttachSandboxProviderResponse { + Sandbox sandbox = 1; + // True when the provider was newly attached. False means it was already attached. + bool attached = 2; +} + +// Detach provider from sandbox response. +message DetachSandboxProviderResponse { + Sandbox sandbox = 1; + // True when the provider was removed. False means it was not attached. + bool detached = 2; +} + // Delete sandbox response. message DeleteSandboxResponse { bool deleted = 1; @@ -713,6 +766,8 @@ message GetSandboxProviderEnvironmentRequest { message GetSandboxProviderEnvironmentResponse { // Provider credential environment variables. map environment = 1; + // Fingerprint for the provider credential inputs that produced environment. + uint64 provider_env_revision = 2; } // --------------------------------------------------------------------------- diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 2ea8659da..f7df5945e 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -257,4 +257,7 @@ message GetSandboxConfigResponse { // When policy_source is GLOBAL, the version of the global policy revision. // Zero when no global policy is active or when policy_source is SANDBOX. uint32 global_policy_version = 7; + // Fingerprint for provider credential inputs attached to this sandbox. + // Changes when attached provider names or attached provider records change. + uint64 provider_env_revision = 8; } From 31f03456a948198b32b4be8c38c2acf11f9509eb Mon Sep 17 00:00:00 2001 From: jtoelke2 <149006449+jtoelke2@users.noreply.github.com> Date: Fri, 8 May 2026 14:45:22 -0500 Subject: [PATCH 011/157] ci(os-132): remove obsolete shadow workflows (#1273) Signed-off-by: Jonas Toelke --- .github/workflows/docker-build.yml | 2 +- ...native-build.yml => rust-native-build.yml} | 55 +------------- .github/workflows/shadow-docker-build.yml | 35 --------- .github/workflows/shadow-shared-cpu-spike.yml | 73 ------------------- 4 files changed, 4 insertions(+), 161 deletions(-) rename .github/workflows/{shadow-rust-native-build.yml => rust-native-build.yml} (83%) delete mode 100644 .github/workflows/shadow-docker-build.yml delete mode 100644 .github/workflows/shadow-shared-cpu-spike.yml diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 7447b1e42..450d6b5c5 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -150,7 +150,7 @@ jobs: strategy: fail-fast: false matrix: ${{ fromJSON(needs.resolve.outputs.matrix) }} - uses: ./.github/workflows/shadow-rust-native-build.yml + uses: ./.github/workflows/rust-native-build.yml with: component: ${{ needs.resolve.outputs.binary_component }} arch: ${{ matrix.arch }} diff --git a/.github/workflows/shadow-rust-native-build.yml b/.github/workflows/rust-native-build.yml similarity index 83% rename from .github/workflows/shadow-rust-native-build.yml rename to .github/workflows/rust-native-build.yml index 7113d260f..40feb1dd0 100644 --- a/.github/workflows/shadow-rust-native-build.yml +++ b/.github/workflows/rust-native-build.yml @@ -3,8 +3,8 @@ name: Rust Native Build (openshell-gateway / openshell-sandbox) -# OS-128 Phase 4: build Rust binaries natively per Linux architecture before -# the Docker image build consumes them as prebuilt artifacts. +# Build Rust binaries natively per Linux architecture before the Docker image +# build consumes them as prebuilt artifacts. on: workflow_call: @@ -47,55 +47,6 @@ on: required: false type: string default: "" - workflow_dispatch: - inputs: - component: - description: "Binary component to build" - required: true - type: choice - default: gateway - options: - - gateway - - sandbox - arch: - description: "Linux architecture to build" - required: true - type: choice - default: amd64 - options: - - amd64 - - arm64 - cargo-version: - description: "Cargo version override" - required: false - type: string - default: "" - features: - description: "Cargo features to enable" - required: false - type: string - default: "openshell-core/dev-settings" - retention-days: - description: "Artifact retention period" - required: false - type: number - default: 5 - artifact-name: - description: "Artifact name override" - required: false - type: string - default: "" - checkout-ref: - description: "Git ref to check out for build inputs (defaults to the workflow SHA)" - required: false - type: string - default: "" - image-tag: - description: "Supervisor image tag to bake into gateway binaries" - required: false - type: string - default: "" - permissions: contents: read packages: read @@ -123,7 +74,7 @@ jobs: FEATURES: ${{ inputs.features }} # Partition the GHA sccache cache per (component, arch). Without this, # concurrent jobs collide on the same cache key and later-starting - # writers hit 409 Conflict (PR #961 fix for shadow-shared-cpu-spike). + # writers hit 409 Conflict. SCCACHE_GHA_VERSION: ${{ inputs.component }}-${{ inputs.arch }} container: image: ghcr.io/nvidia/openshell/ci:latest diff --git a/.github/workflows/shadow-docker-build.yml b/.github/workflows/shadow-docker-build.yml deleted file mode 100644 index 3c7642ab3..000000000 --- a/.github/workflows/shadow-docker-build.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Shadow Docker Build - -# OS-128 Phase 4: manual non-publishing exercise of the production Docker -# image workflow. This stays off main's push surface because the image path is -# not a required signal while the prebuilt-binary rollout is being measured. - -on: - workflow_dispatch: - inputs: - platform: - description: "Target platform(s)" - required: false - type: string - default: "linux/amd64,linux/arm64" - -permissions: - contents: read - packages: write - -jobs: - gateway: - uses: ./.github/workflows/docker-build.yml - with: - component: gateway - platform: ${{ inputs.platform }} - push: false - secrets: inherit - - supervisor: - uses: ./.github/workflows/docker-build.yml - with: - component: supervisor - platform: ${{ inputs.platform }} - push: false - secrets: inherit diff --git a/.github/workflows/shadow-shared-cpu-spike.yml b/.github/workflows/shadow-shared-cpu-spike.yml deleted file mode 100644 index 5a072c8e1..000000000 --- a/.github/workflows/shadow-shared-cpu-spike.yml +++ /dev/null @@ -1,73 +0,0 @@ -name: Shadow — Shared CPU Spike - -# OS-49 Phase 2 / PR 1 — non-blocking spike. Runs `cargo check` on the -# nv-gha-runners shared CPU pool (`linux-{amd64,arm64}-cpu8`) with a -# GHA-backed sccache. -# -# Plan, decision thresholds, and results live in the Linear doc attached -# to OS-126 ("OS-126 — Shared CPU spike plan & results"). Dispatch this -# workflow 4–5 times after merge and record numbers there. - -on: - workflow_dispatch: - -permissions: - contents: read - packages: read - -env: - CARGO_TERM_COLOR: always - CARGO_INCREMENTAL: "0" - MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # Wire sccache (already RUSTC_WRAPPER in mise.toml) to the GHA cache - # backend instead of the EKS memcached used by ARC. - SCCACHE_GHA_ENABLED: "true" - -jobs: - rust-check: - name: cargo check (${{ matrix.runner }}) - strategy: - fail-fast: false - matrix: - runner: [linux-amd64-cpu8, linux-arm64-cpu8] - runs-on: ${{ matrix.runner }} - env: - # Partition the GHA sccache cache per-arch. Without this, matrix jobs - # collide on the same cache key, and the later-starting job's writes - # fail with 409 Conflict while the earlier one's writes land — leaving - # subsequent runs with asymmetric and mostly-empty caches. - # (Run 1 on 2026-04-24 showed amd64 with 0/1062 writes succeeding - # while arm64 got 575/1064.) - SCCACHE_GHA_VERSION: ${{ matrix.runner }} - container: - image: ghcr.io/nvidia/openshell/ci:latest - credentials: - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - timeout-minutes: 60 - steps: - - uses: actions/checkout@v6 - - - name: Install tools - run: mise install - - - name: Configure GHA sccache backend - # Exposes ACTIONS_CACHE_URL / ACTIONS_RUNTIME_TOKEN to subsequent steps - # so sccache (wrapped around rustc via RUSTC_WRAPPER in mise.toml) can - # initialize the GHA cache. Without this, sccache fails at startup with - # "cache url for ghac not found". The action also installs its own - # sccache binary; harmless since mise's sccache remains on PATH. - uses: mozilla-actions/sccache-action@9e7fa8a12102821edf02ca5dbea1acd0f89a2696 # v0.0.10 - - - name: Cache Rust target and registry - uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2 - with: - shared-key: shadow-shared-cpu-spike-${{ matrix.runner }} - cache-directories: .cache/sccache - - - name: cargo check - run: mise x -- cargo check --workspace --all-targets - - - name: sccache stats - if: always() - run: mise x -- sccache --show-stats From daa2a362d515080a9e570c20d028ce415b89e564 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Fri, 8 May 2026 13:02:12 -0700 Subject: [PATCH 012/157] fix(packaging): enable mTLS for local packages (#1271) --- .github/workflows/branch-checks.yml | 2 +- .github/workflows/release-dev.yml | 1 + .github/workflows/release-tag.yml | 1 + architecture/build.md | 7 + crates/openshell-cli/src/run.rs | 158 +++++++++++++++++--- crates/openshell-driver-docker/src/lib.rs | 10 +- crates/openshell-driver-docker/src/tests.rs | 16 ++ crates/openshell-server/src/certgen.rs | 55 +++++-- deploy/deb/openshell-gateway.service | 17 ++- deploy/docker/Dockerfile.gateway-macos | 1 + docs/about/installation.mdx | 4 + docs/reference/gateway-auth.mdx | 4 + install-dev.sh | 26 +++- mise.lock | 1 + python/openshell/release_formula_test.py | 9 ++ tasks/scripts/package-deb-install.sh | 4 +- tasks/scripts/release.py | 49 +++++- 17 files changed, 308 insertions(+), 57 deletions(-) diff --git a/.github/workflows/branch-checks.yml b/.github/workflows/branch-checks.yml index abbcef423..f7bc6ad1f 100644 --- a/.github/workflows/branch-checks.yml +++ b/.github/workflows/branch-checks.yml @@ -128,7 +128,7 @@ jobs: stats_bin="${SCCACHE_PATH:-sccache}" "$stats_bin" --show-stats status=$? - if [[ $status -ne 0 ]]; then + if [ "$status" -ne 0 ]; then echo "::warning::sccache stats unavailable (exit $status)" fi exit 0 diff --git a/.github/workflows/release-dev.yml b/.github/workflows/release-dev.yml index 520a51c65..f1df71b3f 100644 --- a/.github/workflows/release-dev.yml +++ b/.github/workflows/release-dev.yml @@ -501,6 +501,7 @@ jobs: docker buildx build \ --file deploy/docker/Dockerfile.gateway-macos \ --build-arg OPENSHELL_CARGO_VERSION="${{ needs.compute-versions.outputs.cargo_version }}" \ + --build-arg OPENSHELL_IMAGE_TAG=dev \ --build-arg CARGO_TARGET_CACHE_SCOPE="${{ github.sha }}" \ --target binary \ --output type=local,dest=out/ \ diff --git a/.github/workflows/release-tag.yml b/.github/workflows/release-tag.yml index 19d9df47f..18bf74db5 100644 --- a/.github/workflows/release-tag.yml +++ b/.github/workflows/release-tag.yml @@ -620,6 +620,7 @@ jobs: docker buildx build \ --file deploy/docker/Dockerfile.gateway-macos \ --build-arg OPENSHELL_CARGO_VERSION="${{ needs.compute-versions.outputs.cargo_version }}" \ + --build-arg OPENSHELL_IMAGE_TAG="${{ needs.compute-versions.outputs.semver }}" \ --build-arg CARGO_TARGET_CACHE_SCOPE="${{ github.sha }}" \ --target binary \ --output type=local,dest=out/ \ diff --git a/architecture/build.md b/architecture/build.md index 266575efb..cfe13c4b1 100644 --- a/architecture/build.md +++ b/architecture/build.md @@ -27,6 +27,13 @@ target architecture, stages them under `deploy/docker/.build/`, and then uses Buildx to publish per-architecture images and multi-architecture tags. Gateway image builds bake the corresponding supervisor image tag into the gateway binary so Docker sandboxes do not depend on `:latest` by default. +Package formulas also pin Docker supervisor extraction to the matching release +image tag so standalone gateway binaries do not infer image tags from package +versions. +The Homebrew service keeps gateway TLS under the Homebrew state directory but +mirrors Docker sandbox client TLS into `$HOME/.local/state/openshell/homebrew/tls` +at service start, because Docker Desktop bind mounts must use paths visible to +the macOS user's shared home directory. Local image work should use `mise` tasks rather than direct Docker commands so the same staging and tagging assumptions are used locally and in CI. diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index dd1b6721d..fc30b03d6 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -643,6 +643,60 @@ fn mtls_certs_exist_for_endpoint(name: &str, endpoint: &str) -> bool { }) } +fn package_managed_tls_dirs() -> Vec { + if let Some(path) = std::env::var_os("OPENSHELL_LOCAL_TLS_DIR") { + return vec![PathBuf::from(path)]; + } + + let mut dirs = Vec::new(); + + if cfg!(target_os = "macos") { + dirs.push(PathBuf::from("/opt/homebrew/var/openshell/tls")); + dirs.push(PathBuf::from("/usr/local/var/openshell/tls")); + } + + let state_dir = std::env::var_os("XDG_STATE_HOME") + .map(PathBuf::from) + .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".local/state"))); + if let Some(state_dir) = state_dir { + dirs.push(state_dir.join("openshell/tls")); + } + + dirs +} + +fn import_local_package_mtls_bundle(name: &str) -> Result> { + for dir in package_managed_tls_dirs() { + let ca = dir.join("ca.crt"); + let cert = dir.join("client/tls.crt"); + let key = dir.join("client/tls.key"); + if !(ca.is_file() && cert.is_file() && key.is_file()) { + continue; + } + + let bundle = openshell_bootstrap::pki::PkiBundle { + ca_cert_pem: std::fs::read_to_string(&ca) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read {}", ca.display()))?, + ca_key_pem: String::new(), + server_cert_pem: String::new(), + server_key_pem: String::new(), + client_cert_pem: std::fs::read_to_string(&cert) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read {}", cert.display()))?, + client_key_pem: std::fs::read_to_string(&key) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read {}", key.display()))?, + }; + openshell_bootstrap::mtls::store_pki_bundle(name, &bundle) + .wrap_err_with(|| format!("failed to store mTLS bundle for gateway '{name}'"))?; + + return Ok(Some(dir)); + } + + Ok(None) +} + fn plaintext_gateway_is_remote(endpoint: &str, remote: Option<&str>, local: bool) -> bool { if local { return false; @@ -924,16 +978,13 @@ pub async fn gateway_add( // Verify the gateway is reachable. let tls = TlsOptions::default(); - match http_health_check(&endpoint, &tls).await { - Ok(Some(status)) if status.is_success() => {} - _ => { - eprintln!( - "{} Gateway is not reachable at {endpoint}", - "⚠".yellow().bold(), - ); - if !has_mtls_certs { - eprintln!(" Verify the gateway is running and the endpoint is correct."); - } + if !gateway_reachable(&endpoint, &tls).await { + eprintln!( + "{} Gateway is not reachable at {endpoint}", + "⚠".yellow().bold(), + ); + if !has_mtls_certs { + eprintln!(" Verify the gateway is running and the endpoint is correct."); } } @@ -951,7 +1002,13 @@ pub async fn gateway_add( if remote.is_some() || local { // mTLS gateway (remote or local). - let certs_on_disk = mtls_certs_exist_for_endpoint(name, &endpoint); + let imported_mtls_dir = if local { + import_local_package_mtls_bundle(name)? + } else { + None + }; + let certs_on_disk = + imported_mtls_dir.is_some() || mtls_certs_exist_for_endpoint(name, &endpoint); if !certs_on_disk { return Err(miette::miette!( "mTLS certificates for gateway '{name}' were not found.\n\ @@ -984,14 +1041,11 @@ pub async fn gateway_add( // Verify the gateway is reachable over mTLS. let tls = TlsOptions::default().with_gateway_name(name); - match http_health_check(&endpoint, &tls).await { - Ok(Some(status)) if status.is_success() => {} - _ => { - eprintln!( - "{} Gateway is not reachable at {endpoint}. Verify the gateway is running.", - "⚠".yellow().bold(), - ); - } + if !gateway_reachable(&endpoint, &tls).await { + eprintln!( + "{} Gateway is not reachable at {endpoint}. Verify the gateway is running.", + "⚠".yellow().bold(), + ); } eprintln!( @@ -1252,6 +1306,16 @@ async fn http_health_check(server: &str, tls: &TlsOptions) -> Result bool { + if let Ok(mut client) = grpc_client(server, tls).await + && client.health(HealthRequest {}).await.is_ok() + { + return true; + } + + matches!(http_health_check(server, tls).await, Ok(Some(status)) if status.is_success()) +} + fn remove_gateway_registration(name: &str) { if let Err(err) = openshell_bootstrap::edge_token::remove_edge_token(name) { tracing::debug!("failed to remove edge token: {err}"); @@ -5391,10 +5455,10 @@ mod tests { TlsOptions, dockerfile_sources_supported_for_gateway, format_gateway_select_header, format_gateway_select_items, format_provider_attachment_table, gateway_add, gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, - git_sync_files, http_health_check, image_requests_gpu, inferred_provider_type, - parse_cli_setting_value, parse_credential_pairs, plaintext_gateway_is_remote, - provisioning_timeout_message, ready_false_condition_message, resolve_from, - sandbox_should_persist, + git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, + inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, + parse_credential_pairs, plaintext_gateway_is_remote, provisioning_timeout_message, + ready_false_condition_message, resolve_from, sandbox_should_persist, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -5402,7 +5466,7 @@ mod tests { use std::fs; use std::io::{Read, Write}; use std::net::TcpListener; - use std::path::Path; + use std::path::{Path, PathBuf}; use std::process::Command; use std::thread; @@ -6023,6 +6087,52 @@ mod tests { assert_eq!(gateway_auth_label(&gateway), "mtls"); } + #[test] + fn package_managed_tls_dirs_respects_override() { + let _guard = TEST_ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _tls_dir = EnvVarGuard::set("OPENSHELL_LOCAL_TLS_DIR", "/tmp/openshell-test-tls"); + + assert_eq!( + package_managed_tls_dirs(), + vec![PathBuf::from("/tmp/openshell-test-tls")], + ); + } + + #[test] + fn import_local_package_mtls_bundle_copies_client_materials() { + let tmpdir = tempfile::tempdir().expect("create tmpdir"); + let package_tls = tmpdir.path().join("package-tls"); + fs::create_dir_all(package_tls.join("client")).expect("create package tls dir"); + fs::write(package_tls.join("ca.crt"), "ca").expect("write ca"); + fs::write(package_tls.join("client/tls.crt"), "client cert").expect("write cert"); + fs::write(package_tls.join("client/tls.key"), "client key").expect("write key"); + + with_tmp_xdg(tmpdir.path(), || { + let _tls_dir = EnvVarGuard::set( + "OPENSHELL_LOCAL_TLS_DIR", + package_tls.to_str().expect("temp path should be utf-8"), + ); + + let imported = + import_local_package_mtls_bundle("openshell").expect("import local bundle"); + + assert_eq!(imported.as_deref(), Some(package_tls.as_path())); + + let mtls = tmpdir.path().join("openshell/gateways/openshell/mtls"); + assert_eq!(fs::read_to_string(mtls.join("ca.crt")).unwrap(), "ca"); + assert_eq!( + fs::read_to_string(mtls.join("tls.crt")).unwrap(), + "client cert", + ); + assert_eq!( + fs::read_to_string(mtls.join("tls.key")).unwrap(), + "client key", + ); + }); + } + #[test] fn plaintext_gateway_locality_infers_loopback_endpoints_as_local() { assert!(!plaintext_gateway_is_remote( diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index a864a3eb6..db197685d 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -89,7 +89,7 @@ pub fn default_docker_supervisor_image() -> String { /// fallback covers image build wrappers that already tag the gateway and /// supervisor together. Standalone release binaries also patch the Cargo /// package version, so use it when it has been set to a real release value. -fn default_docker_supervisor_image_tag() -> &'static str { +fn default_docker_supervisor_image_tag() -> String { resolve_default_docker_supervisor_image_tag( option_env!("OPENSHELL_IMAGE_TAG"), option_env!("IMAGE_TAG"), @@ -101,8 +101,8 @@ fn resolve_default_docker_supervisor_image_tag( openshell_image_tag: Option<&'static str>, image_tag: Option<&'static str>, cargo_pkg_version: &'static str, -) -> &'static str { - openshell_image_tag +) -> String { + let tag = openshell_image_tag .filter(|tag| !tag.is_empty()) .or_else(|| image_tag.filter(|tag| !tag.is_empty())) .unwrap_or_else(|| { @@ -111,7 +111,9 @@ fn resolve_default_docker_supervisor_image_tag( } else { cargo_pkg_version } - }) + }); + + tag.replace('+', "-") } /// Queried by the Docker driver to decide when a sandbox's supervisor diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index e41f2688e..41c9a5901 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -722,6 +722,22 @@ fn docker_supervisor_image_tag_prefers_explicit_build_tags() { ); } +#[test] +fn docker_supervisor_image_tag_sanitizes_build_metadata_for_docker() { + assert_eq!( + resolve_default_docker_supervisor_image_tag(None, None, "0.0.37-dev.156+g1d3b741ee"), + "0.0.37-dev.156-g1d3b741ee", + ); + assert_eq!( + resolve_default_docker_supervisor_image_tag( + Some("0.0.37-dev.156+g1d3b741ee"), + None, + "0.0.0", + ), + "0.0.37-dev.156-g1d3b741ee", + ); +} + #[test] fn supervisor_cache_path_namespaces_by_digest_under_openshell_data_dir() { let base = PathBuf::from("/var/cache/share"); diff --git a/crates/openshell-server/src/certgen.rs b/crates/openshell-server/src/certgen.rs index b9e4d7bd5..f7dcc0803 100644 --- a/crates/openshell-server/src/certgen.rs +++ b/crates/openshell-server/src/certgen.rs @@ -70,16 +70,16 @@ pub async fn run(args: CertgenArgs) -> Result<()> { ) .init(); - let bundle = generate_pki(&args.server_sans)?; - if args.dry_run { + let bundle = generate_pki(&args.server_sans)?; print_bundle(&bundle); return Ok(()); } if let Some(dir) = args.output_dir.as_deref() { - run_local(dir, &bundle) + run_local(dir, &args.server_sans) } else { + let bundle = generate_pki(&args.server_sans)?; run_kubernetes(&args, &bundle).await } } @@ -277,12 +277,13 @@ fn decide_local(present: usize) -> LocalAction { } } -fn run_local(dir: &Path, bundle: &PkiBundle) -> Result<()> { +fn run_local(dir: &Path, server_sans: &[String]) -> Result<()> { let paths = LocalPaths::resolve(dir); - match decide_local(paths.existence_count()) { + let bundle = match decide_local(paths.existence_count()) { LocalAction::Skip => { info!(dir = %dir.display(), "PKI files already exist, skipping."); + read_local_bundle(&paths)? } LocalAction::PartialState => { return Err(miette::miette!( @@ -292,21 +293,40 @@ fn run_local(dir: &Path, bundle: &PkiBundle) -> Result<()> { )); } LocalAction::Create => { - write_local_bundle(dir, bundle, &paths)?; + let bundle = generate_pki(server_sans)?; + write_local_bundle(dir, &bundle, &paths)?; info!(dir = %dir.display(), "PKI files created."); + bundle } - } + }; // Always make sure the CLI auto-discovery copy is in place. This // self-heals the case where the operator wiped ~/.config/openshell but // left the gateway state directory intact. - if let Err(e) = openshell_bootstrap::mtls::store_pki_bundle("openshell", bundle) { + if let Err(e) = openshell_bootstrap::mtls::store_pki_bundle("openshell", &bundle) { warn!(error = %e, "failed to copy client mTLS materials for CLI auto-discovery"); } Ok(()) } +fn read_local_bundle(paths: &LocalPaths) -> Result { + Ok(PkiBundle { + ca_cert_pem: read_pem(&paths.ca_crt)?, + ca_key_pem: read_pem(&paths.ca_key)?, + server_cert_pem: read_pem(&paths.server_crt)?, + server_key_pem: read_pem(&paths.server_key)?, + client_cert_pem: read_pem(&paths.client_crt)?, + client_key_pem: read_pem(&paths.client_key)?, + }) +} + +fn read_pem(path: &Path) -> Result { + std::fs::read_to_string(path) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read {}", path.display())) +} + fn write_local_bundle(dir: &Path, bundle: &PkiBundle, paths: &LocalPaths) -> Result<()> { // Stage to a sibling tmp dir so individual renames into the final layout // are atomic on the same filesystem. @@ -386,8 +406,8 @@ fn print_bundle(bundle: &PkiBundle) { #[cfg(test)] mod tests { use super::{ - K8sAction, LocalAction, LocalPaths, decide_k8s, decide_local, sibling_temp_dir, tls_secret, - write_local_bundle, + K8sAction, LocalAction, LocalPaths, decide_k8s, decide_local, read_local_bundle, + sibling_temp_dir, tls_secret, write_local_bundle, }; use openshell_bootstrap::pki::generate_pki; use std::path::Path; @@ -490,6 +510,21 @@ mod tests { assert!(server_key.contains("BEGIN PRIVATE KEY")); } + #[test] + fn read_local_bundle_uses_existing_files() { + let parent = tempfile::tempdir().expect("tempdir"); + let dir = parent.path().join("tls"); + let bundle = generate_pki(&[]).expect("generate_pki"); + let paths = LocalPaths::resolve(&dir); + + write_local_bundle(&dir, &bundle, &paths).expect("write_local_bundle"); + + let read = read_local_bundle(&paths).expect("read_local_bundle"); + assert_eq!(read.ca_cert_pem, bundle.ca_cert_pem); + assert_eq!(read.client_cert_pem, bundle.client_cert_pem); + assert_eq!(read.client_key_pem, bundle.client_key_pem); + } + #[cfg(unix)] #[test] fn write_local_bundle_sets_owner_only_on_keys() { diff --git a/deploy/deb/openshell-gateway.service b/deploy/deb/openshell-gateway.service index 26e3c07be..9de94da22 100644 --- a/deploy/deb/openshell-gateway.service +++ b/deploy/deb/openshell-gateway.service @@ -9,14 +9,25 @@ StateDirectory=openshell/gateway # %S resolves to $XDG_STATE_HOME for user services. Environment=OPENSHELL_BIND_ADDRESS=127.0.0.1 Environment=OPENSHELL_SERVER_PORT=17670 -Environment=OPENSHELL_DISABLE_TLS=true -Environment=OPENSHELL_DISABLE_GATEWAY_AUTH=true +Environment=OPENSHELL_TLS_CERT=%S/openshell/tls/server/tls.crt +Environment=OPENSHELL_TLS_KEY=%S/openshell/tls/server/tls.key +Environment=OPENSHELL_TLS_CLIENT_CA=%S/openshell/tls/ca.crt Environment=OPENSHELL_DB_URL=sqlite:%S/openshell/gateway/openshell.db -Environment=OPENSHELL_GRPC_ENDPOINT=http://127.0.0.1:17670 +Environment=OPENSHELL_GRPC_ENDPOINT=https://127.0.0.1:17670 Environment=OPENSHELL_SSH_GATEWAY_HOST=127.0.0.1 Environment=OPENSHELL_SSH_GATEWAY_PORT=17670 Environment=OPENSHELL_VM_DRIVER_STATE_DIR=%S/openshell/vm-driver +Environment=OPENSHELL_VM_TLS_CA=%S/openshell/tls/ca.crt +Environment=OPENSHELL_VM_TLS_CERT=%S/openshell/tls/client/tls.crt +Environment=OPENSHELL_VM_TLS_KEY=%S/openshell/tls/client/tls.key +Environment=OPENSHELL_DOCKER_TLS_CA=%S/openshell/tls/ca.crt +Environment=OPENSHELL_DOCKER_TLS_CERT=%S/openshell/tls/client/tls.crt +Environment=OPENSHELL_DOCKER_TLS_KEY=%S/openshell/tls/client/tls.key +Environment=OPENSHELL_PODMAN_TLS_CA=%S/openshell/tls/ca.crt +Environment=OPENSHELL_PODMAN_TLS_CERT=%S/openshell/tls/client/tls.crt +Environment=OPENSHELL_PODMAN_TLS_KEY=%S/openshell/tls/client/tls.key EnvironmentFile=-%h/.config/openshell/gateway.env +ExecStartPre=/usr/bin/openshell-gateway generate-certs --output-dir %S/openshell/tls --server-san host.openshell.internal ExecStart=/usr/bin/openshell-gateway Restart=on-failure RestartSec=5s diff --git a/deploy/docker/Dockerfile.gateway-macos b/deploy/docker/Dockerfile.gateway-macos index 4cae2f0e7..27d2ffbba 100644 --- a/deploy/docker/Dockerfile.gateway-macos +++ b/deploy/docker/Dockerfile.gateway-macos @@ -94,6 +94,7 @@ RUN touch crates/openshell-core/src/lib.rs \ proto/*.proto ARG OPENSHELL_CARGO_VERSION +ARG OPENSHELL_IMAGE_TAG RUN --mount=type=cache,id=cargo-registry-gateway-macos,sharing=locked,target=/root/.cargo/registry \ --mount=type=cache,id=cargo-git-gateway-macos,sharing=locked,target=/root/.cargo/git \ --mount=type=cache,id=cargo-target-gateway-macos-${CARGO_TARGET_CACHE_SCOPE},sharing=locked,target=/build/target \ diff --git a/docs/about/installation.mdx b/docs/about/installation.mdx index 378439758..25bf96480 100644 --- a/docs/about/installation.mdx +++ b/docs/about/installation.mdx @@ -38,6 +38,8 @@ For detailed driver behavior, refer to [Sandbox Compute Drivers](/reference/sand On macOS, the install script uses Homebrew. The Homebrew package installs the `openshell` CLI, the gateway binary, and a Homebrew-managed gateway service. +The Homebrew service listens on `https://127.0.0.1:17670` and generates a local mTLS bundle on install. The CLI reads the client bundle from `~/.config/openshell/gateways/openshell/mtls/`. + The installer starts the service for you. Use Homebrew service commands when you need to inspect, restart, or stop the gateway service: ```shell @@ -51,6 +53,8 @@ On Fedora and RHEL, the install script uses RPM packages. The RPM installs the ` On Debian and Ubuntu, the install script uses a Debian package. The Debian package installs the `openshell` CLI, the `openshell-gateway` daemon, VM sandbox support, and a systemd user service. +The Debian user service listens on `https://127.0.0.1:17670` and generates a local mTLS bundle before the gateway starts. The CLI reads the client bundle from `~/.config/openshell/gateways/openshell/mtls/`. + The installer starts the service for you. Use systemd user commands when you need to inspect, restart, or stop the gateway service: ```shell diff --git a/docs/reference/gateway-auth.mdx b/docs/reference/gateway-auth.mdx index bd2f27f4e..e95ce854b 100644 --- a/docs/reference/gateway-auth.mdx +++ b/docs/reference/gateway-auth.mdx @@ -38,6 +38,10 @@ Set these environment variables before starting the gateway: For local access, the server certificate must be valid for the endpoint the CLI uses. Include `localhost` and `127.0.0.1` in the certificate SANs when users connect to a local gateway through loopback. +Package-managed local gateways on Homebrew and Debian generate this bundle automatically for the `openshell` gateway name and use `https://127.0.0.1:17670` by default. +When you register a package-managed local gateway with `openshell gateway add https://127.0.0.1:17670 --local --name openshell`, the CLI refreshes its mTLS bundle from the package-managed TLS directory. +On Homebrew, the gateway service also mirrors the Docker sandbox client bundle into `$HOME/.local/state/openshell/homebrew/tls` before startup so Docker Desktop can bind-mount the files into sandbox containers. + The CLI loads its mTLS bundle from `~/.config/openshell/gateways//mtls/`: | File | Purpose | diff --git a/install-dev.sh b/install-dev.sh index 87234a4eb..660795949 100755 --- a/install-dev.sh +++ b/install-dev.sh @@ -420,7 +420,7 @@ start_user_gateway() { if ! as_target_user systemctl --user daemon-reload; then info "could not reach the user systemd manager for ${TARGET_USER}" info "restart the gateway later with: systemctl --user enable openshell-gateway && systemctl --user restart openshell-gateway" - info "then register it with: openshell gateway add http://127.0.0.1:17670 --local --name local" + info "then register it with: openshell gateway add https://127.0.0.1:17670 --local --name openshell" return 0 fi @@ -438,11 +438,14 @@ wait_for_local_gateway_listener() { _timeout="${OPENSHELL_INSTALL_GATEWAY_TIMEOUT:-30}" _elapsed=0 _last_output="" - _probe_url="http://127.0.0.1:${LOCAL_GATEWAY_PORT}/" + _probe_url="https://127.0.0.1:${LOCAL_GATEWAY_PORT}/" + _mtls_dir="${TARGET_HOME}/.config/openshell/gateways/openshell/mtls" info "waiting for local gateway listener to become reachable..." while [ "$_elapsed" -lt "$_timeout" ]; do - if _last_output="$(as_target_user curl -sS --max-time 2 -o /dev/null "$_probe_url" 2>&1)"; then + if [ ! -f "${_mtls_dir}/ca.crt" ] || [ ! -f "${_mtls_dir}/tls.crt" ] || [ ! -f "${_mtls_dir}/tls.key" ]; then + _last_output="mTLS client bundle is not ready under ${_mtls_dir}" + elif _last_output="$(as_target_user curl -sS --max-time 2 --cacert "${_mtls_dir}/ca.crt" --cert "${_mtls_dir}/tls.crt" --key "${_mtls_dir}/tls.key" -o /dev/null "$_probe_url" 2>&1)"; then info "local gateway listener is reachable" return 0 fi @@ -488,8 +491,15 @@ remove_local_gateway_registration() { as_target_user sh -c ' config_dir=$1 rm -rf "${config_dir}/gateways/local" + mkdir -p "${config_dir}/gateways/openshell" + rm -f \ + "${config_dir}/gateways/openshell/metadata.json" \ + "${config_dir}/gateways/openshell/edge_token" \ + "${config_dir}/gateways/openshell/cf_token" \ + "${config_dir}/gateways/openshell/oidc_token.json" active="${config_dir}/active_gateway" - if [ "$(cat "$active" 2>/dev/null || true)" = "local" ]; then + active_name="$(cat "$active" 2>/dev/null || true)" + if [ "$active_name" = "local" ] || [ "$active_name" = "openshell" ]; then rm -f "$active" fi ' sh "$_config_dir" @@ -498,7 +508,7 @@ remove_local_gateway_registration() { register_local_gateway() { _register_bin="${OPENSHELL_REGISTER_BIN:-openshell}" - if _add_output="$(as_target_user "$_register_bin" gateway add "http://127.0.0.1:${LOCAL_GATEWAY_PORT}" --local --name local 2>&1)"; then + if _add_output="$(as_target_user "$_register_bin" gateway add "https://127.0.0.1:${LOCAL_GATEWAY_PORT}" --local --name openshell 2>&1)"; then [ -z "$_add_output" ] || print_gateway_add_output "$_add_output" return 0 else @@ -509,7 +519,7 @@ register_local_gateway() { *"already exists"*) info "local gateway already exists; removing and re-adding it..." remove_local_gateway_registration - as_target_user "$_register_bin" gateway add "http://127.0.0.1:${LOCAL_GATEWAY_PORT}" --local --name local + as_target_user "$_register_bin" gateway add "https://127.0.0.1:${LOCAL_GATEWAY_PORT}" --local --name openshell ;; *) printf '%s\n' "$_add_output" >&2 @@ -521,7 +531,7 @@ register_local_gateway() { print_gateway_add_output() { printf '%s\n' "$1" | while IFS= read -r _line; do case "$_line" in - *"Gateway is not reachable at http://127.0.0.1:${LOCAL_GATEWAY_PORT}"*) ;; + *"Gateway is not reachable at https://127.0.0.1:${LOCAL_GATEWAY_PORT}"*) ;; *"Verify the gateway is running and the endpoint is correct."*) ;; *) printf '%s\n' "$_line" >&2 ;; esac @@ -654,7 +664,7 @@ install_macos_homebrew() { if ! as_target_user brew services restart "$_formula_ref"; then warn "could not restart the OpenShell Homebrew service" info "restart it later with: brew services restart ${_formula_ref}" - info "then register it with: openshell gateway add http://127.0.0.1:${LOCAL_GATEWAY_PORT} --local --name local" + info "then register it with: openshell gateway add https://127.0.0.1:${LOCAL_GATEWAY_PORT} --local --name openshell" return 0 fi diff --git a/mise.lock b/mise.lock index 6ab204f6e..5110b03c9 100644 --- a/mise.lock +++ b/mise.lock @@ -287,6 +287,7 @@ version = "2.19.0" backend = "aqua:GoogleContainerTools/skaffold" [tools.skaffold."platforms.linux-arm64"] +checksum = "blake3:c62b62077ac47abb7f7d184836d37f467c8e82b47e47b5dce570e15db4bb30fe" url = "https://storage.googleapis.com/skaffold/releases/v2.19.0/skaffold-linux-arm64" [tools.skaffold."platforms.linux-arm64-musl"] diff --git a/python/openshell/release_formula_test.py b/python/openshell/release_formula_test.py index 56abd7af6..8eb02cada 100644 --- a/python/openshell/release_formula_test.py +++ b/python/openshell/release_formula_test.py @@ -53,5 +53,14 @@ def test_generate_homebrew_formula_uses_tagged_macos_driver_asset_without_defaul assert 'sha256 "' + "b" * 64 + '"' in formula assert "OPENSHELL_DRIVERS" not in formula assert 'OPENSHELL_DRIVER_DIR: "#{opt_libexec}"' in formula + assert ( + 'OPENSHELL_DOCKER_SUPERVISOR_IMAGE: "ghcr.io/nvidia/openshell/supervisor:0.0.10"' + ) in formula + assert 'run opt_libexec/"openshell-gateway-homebrew-service"' in formula + assert ( + 'docker_tls_dir="${OPENSHELL_DOCKER_TLS_DIR:-${HOME}/.local/state/openshell/homebrew/tls}"' + ) in formula + assert 'export OPENSHELL_DOCKER_TLS_CA="${docker_tls_dir}/ca.crt"' in formula + assert 'OPENSHELL_DOCKER_TLS_CA: "#{var}/openshell/tls/ca.crt"' not in formula assert "entitlements.atomic_write" in formula assert "brew services restart openshell" in formula diff --git a/tasks/scripts/package-deb-install.sh b/tasks/scripts/package-deb-install.sh index 735ce70a9..b6e473067 100755 --- a/tasks/scripts/package-deb-install.sh +++ b/tasks/scripts/package-deb-install.sh @@ -28,8 +28,8 @@ cd "$repo_root" VERSION="${OPENSHELL_DEB_VERSION:-0.0.0-local}" OUTPUT_DIR="${OPENSHELL_OUTPUT_DIR:-artifacts}" ARCH="${OPENSHELL_DEB_ARCH:-$(dpkg --print-architecture 2>/dev/null || uname -m)}" -GATEWAY_NAME="local" -GATEWAY_ENDPOINT="http://127.0.0.1:17670" +GATEWAY_NAME="openshell" +GATEWAY_ENDPOINT="https://127.0.0.1:17670" remove_existing_gateway_registration() { local config_home="${XDG_CONFIG_HOME:-${HOME}/.config}" diff --git a/tasks/scripts/release.py b/tasks/scripts/release.py index 0a6893121..cdb525e7b 100644 --- a/tasks/scripts/release.py +++ b/tasks/scripts/release.py @@ -214,6 +214,11 @@ def _asset_url(release_tag: str, filename: str) -> str: return f"{GITHUB_RELEASE_DOWNLOADS}/{release_tag}/{filename}" +def _homebrew_supervisor_image(release_tag: str) -> str: + image_tag = "dev" if release_tag == "dev" else release_tag.removeprefix("v") + return f"ghcr.io/nvidia/openshell/supervisor:{image_tag}" + + def render_homebrew_formula( *, release_tag: str, @@ -225,6 +230,7 @@ def render_homebrew_formula( raise ValueError(f"release tag contains unsupported characters: {release_tag}") version = release_tag.removeprefix("v") + docker_supervisor_image = _homebrew_supervisor_image(release_tag) return f"""# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # @@ -263,12 +269,37 @@ def install resource("openshell-driver-vm").stage do libexec.install "openshell-driver-vm" end + + (libexec/"openshell-gateway-homebrew-service").write <<~SH + #!/bin/sh + set -eu + + if [ -z "${{HOME:-}}" ]; then + echo "HOME must be set for Docker TLS bind mounts" >&2 + exit 1 + fi + + docker_tls_dir="${{OPENSHELL_DOCKER_TLS_DIR:-${{HOME}}/.local/state/openshell/homebrew/tls}}" + mkdir -p "${{docker_tls_dir}}/client" + chmod 700 "${{docker_tls_dir}}" "${{docker_tls_dir}}/client" + /usr/bin/install -m 0644 "#{{var}}/openshell/tls/ca.crt" "${{docker_tls_dir}}/ca.crt" + /usr/bin/install -m 0644 "#{{var}}/openshell/tls/client/tls.crt" "${{docker_tls_dir}}/client/tls.crt" + /usr/bin/install -m 0600 "#{{var}}/openshell/tls/client/tls.key" "${{docker_tls_dir}}/client/tls.key" + + export OPENSHELL_DOCKER_TLS_CA="${{docker_tls_dir}}/ca.crt" + export OPENSHELL_DOCKER_TLS_CERT="${{docker_tls_dir}}/client/tls.crt" + export OPENSHELL_DOCKER_TLS_KEY="${{docker_tls_dir}}/client/tls.key" + + exec "#{{opt_bin}}/openshell-gateway" + SH + chmod 0755, libexec/"openshell-gateway-homebrew-service" end def post_install (var/"openshell/gateway").mkpath (var/"openshell/vm-driver").mkpath (var/"log/openshell").mkpath + system bin/"openshell-gateway", "generate-certs", "--output-dir", var/"openshell/tls", "--server-san", "host.openshell.internal" entitlements = var/"openshell/openshell-driver-vm.entitlements.plist" entitlements.atomic_write <<~XML @@ -286,17 +317,25 @@ def post_install end service do - run opt_bin/"openshell-gateway" + run opt_libexec/"openshell-gateway-homebrew-service" environment_variables( OPENSHELL_BIND_ADDRESS: "127.0.0.1", OPENSHELL_SERVER_PORT: "{LOCAL_GATEWAY_PORT}", - OPENSHELL_DISABLE_TLS: "true", - OPENSHELL_DISABLE_GATEWAY_AUTH: "true", + OPENSHELL_TLS_CERT: "#{{var}}/openshell/tls/server/tls.crt", + OPENSHELL_TLS_KEY: "#{{var}}/openshell/tls/server/tls.key", + OPENSHELL_TLS_CLIENT_CA: "#{{var}}/openshell/tls/ca.crt", OPENSHELL_DB_URL: "sqlite:#{{var}}/openshell/gateway/openshell.db", - OPENSHELL_GRPC_ENDPOINT: "http://127.0.0.1:{LOCAL_GATEWAY_PORT}", + OPENSHELL_GRPC_ENDPOINT: "https://127.0.0.1:{LOCAL_GATEWAY_PORT}", OPENSHELL_SSH_GATEWAY_HOST: "127.0.0.1", OPENSHELL_SSH_GATEWAY_PORT: "{LOCAL_GATEWAY_PORT}", OPENSHELL_VM_DRIVER_STATE_DIR: "#{{var}}/openshell/vm-driver", + OPENSHELL_VM_TLS_CA: "#{{var}}/openshell/tls/ca.crt", + OPENSHELL_VM_TLS_CERT: "#{{var}}/openshell/tls/client/tls.crt", + OPENSHELL_VM_TLS_KEY: "#{{var}}/openshell/tls/client/tls.key", + OPENSHELL_DOCKER_SUPERVISOR_IMAGE: "{docker_supervisor_image}", + OPENSHELL_PODMAN_TLS_CA: "#{{var}}/openshell/tls/ca.crt", + OPENSHELL_PODMAN_TLS_CERT: "#{{var}}/openshell/tls/client/tls.crt", + OPENSHELL_PODMAN_TLS_KEY: "#{{var}}/openshell/tls/client/tls.key", OPENSHELL_DRIVER_DIR: "#{{opt_libexec}}", ) keep_alive successful_exit: false @@ -310,7 +349,7 @@ def caveats brew services restart openshell Register it with the OpenShell CLI: - openshell gateway add http://127.0.0.1:{LOCAL_GATEWAY_PORT} --local --name local + openshell gateway add https://127.0.0.1:{LOCAL_GATEWAY_PORT} --local --name openshell EOS end From eec949dd5edcb01570fc1225b1f23cbd5451c2f9 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Fri, 8 May 2026 13:49:22 -0700 Subject: [PATCH 013/157] fix(installer): stop forcing Homebrew VM driver (#1277) --- install-dev.sh | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/install-dev.sh b/install-dev.sh index 660795949..48f935959 100755 --- a/install-dev.sh +++ b/install-dev.sh @@ -400,18 +400,6 @@ patch_homebrew_formula() { mv "$_patched_file" "$_formula_file" fi - if ! grep -q 'OPENSHELL_DRIVERS:' "$_formula_file"; then - info "patching Homebrew formula to use VM driver..." - awk ' - { - print - if ($0 ~ /^[[:space:]]*environment_variables\(/) { - print " OPENSHELL_DRIVERS: \"vm\"," - } - } - ' "$_formula_file" >"$_patched_file" - mv "$_patched_file" "$_formula_file" - fi } start_user_gateway() { From 316c788eac7485aff05e9f80ad068471e211eb8c Mon Sep 17 00:00:00 2001 From: Taylor Mutch Date: Fri, 8 May 2026 14:10:19 -0700 Subject: [PATCH 014/157] fix(helm): derive grpcEndpoint from chart context (#1241) * fix(helm): derive grpcEndpoint from chart context The chart hardcoded server.grpcEndpoint to https://openshell.openshell.svc.cluster.local:8080, which only matched the in-cluster Service DNS for the standard release name and namespace. A new helper now builds ://..svc.cluster.local: from chart context, picking the scheme from server.disableTls. An explicit server.grpcEndpoint override is passed through verbatim. * chore(scripts): validate k3d cluster name length early helm-k3s-local.sh derives the cluster name from the current branch suffix. Long branch names produced names exceeding k3d's 32-char cap and failed deep inside k3d cluster create with a confusing validation error. cmd_create now bails out before invoking docker/k3d with a copy-pasteable HELM_K3S_CLUSTER_NAME override hint. Status, start, stop, delete, and help remain unaffected so an over-long derived name does not block diagnostics. --- deploy/helm/openshell/templates/_helpers.tpl | 16 ++++++++++++++++ deploy/helm/openshell/templates/statefulset.yaml | 2 +- deploy/helm/openshell/values.yaml | 8 ++++++-- deploy/kube/manifests/openshell-helmchart.yaml | 1 - examples/gateway-deploy-connect.md | 3 +-- tasks/scripts/helm-k3s-local.sh | 12 ++++++++++++ 6 files changed, 36 insertions(+), 6 deletions(-) diff --git a/deploy/helm/openshell/templates/_helpers.tpl b/deploy/helm/openshell/templates/_helpers.tpl index 09159340d..93eff90a9 100644 --- a/deploy/helm/openshell/templates/_helpers.tpl +++ b/deploy/helm/openshell/templates/_helpers.tpl @@ -81,3 +81,19 @@ Namespaced Issuer (selfSigned) for cert-manager CA bootstrap. {{- define "openshell.issuerSelfSigned" -}} {{- printf "%s-selfsigned" (include "openshell.fullname" .) | trunc 63 | trimSuffix "-" }} {{- end }} + +{{/* +gRPC endpoint sandbox pods use to call back into the gateway. An explicit +.Values.server.grpcEndpoint is used verbatim. Otherwise it is derived from +the in-cluster Service DNS, release namespace, service port, and disableTls +flag — so the default value works for any release name or namespace without +override. +*/}} +{{- define "openshell.grpcEndpoint" -}} +{{- if .Values.server.grpcEndpoint -}} +{{- .Values.server.grpcEndpoint -}} +{{- else -}} +{{- $scheme := ternary "http" "https" (default false .Values.server.disableTls) -}} +{{- printf "%s://%s.%s.svc.cluster.local:%d" $scheme (include "openshell.fullname" .) .Release.Namespace (int .Values.service.port) -}} +{{- end -}} +{{- end }} diff --git a/deploy/helm/openshell/templates/statefulset.yaml b/deploy/helm/openshell/templates/statefulset.yaml index b91757ad4..2d3f731af 100644 --- a/deploy/helm/openshell/templates/statefulset.yaml +++ b/deploy/helm/openshell/templates/statefulset.yaml @@ -77,7 +77,7 @@ spec: value: {{ .Values.supervisor.image.pullPolicy | quote }} {{- end }} - name: OPENSHELL_GRPC_ENDPOINT - value: {{ if .Values.server.disableTls }}{{ .Values.server.grpcEndpoint | replace "https://" "http://" | quote }}{{ else }}{{ .Values.server.grpcEndpoint | quote }}{{ end }} + value: {{ include "openshell.grpcEndpoint" . | quote }} {{- if .Values.server.sshGatewayHost }} - name: OPENSHELL_SSH_GATEWAY_HOST value: {{ .Values.server.sshGatewayHost | quote }} diff --git a/deploy/helm/openshell/values.yaml b/deploy/helm/openshell/values.yaml index 02389925a..7630554f2 100644 --- a/deploy/helm/openshell/values.yaml +++ b/deploy/helm/openshell/values.yaml @@ -86,8 +86,12 @@ server: # (Always for :latest, IfNotPresent otherwise). Set to "Always" for dev # clusters so new images are picked up without manual eviction. sandboxImagePullPolicy: "" - # gRPC endpoint for sandboxes to callback to OpenShell (must be reachable from pods) - grpcEndpoint: "https://openshell.openshell.svc.cluster.local:8080" + # gRPC endpoint sandboxes call back into the gateway. Leave empty to derive + # it from the chart fullname, release namespace, service port, and + # disableTls flag (i.e. ://..svc.cluster.local:). + # Override only when sandboxes must reach the gateway via a different + # hostname (e.g. an external ingress or a host alias). + grpcEndpoint: "" # Public host/port returned to CLI clients for SSH proxy CONNECT requests. # For local clusters the default 127.0.0.1:8080 is correct; for remote # clusters these should be set to the externally reachable host and port. diff --git a/deploy/kube/manifests/openshell-helmchart.yaml b/deploy/kube/manifests/openshell-helmchart.yaml index eba79364c..ea4e370dc 100644 --- a/deploy/kube/manifests/openshell-helmchart.yaml +++ b/deploy/kube/manifests/openshell-helmchart.yaml @@ -35,7 +35,6 @@ spec: dbUrl: __DB_URL__ sshGatewayHost: __SSH_GATEWAY_HOST__ sshGatewayPort: __SSH_GATEWAY_PORT__ - grpcEndpoint: "https://openshell.openshell.svc.cluster.local:8080" hostGatewayIP: __HOST_GATEWAY_IP__ disableGatewayAuth: __DISABLE_GATEWAY_AUTH__ disableTls: __DISABLE_TLS__ diff --git a/examples/gateway-deploy-connect.md b/examples/gateway-deploy-connect.md index 27f690334..37ed37bf2 100644 --- a/examples/gateway-deploy-connect.md +++ b/examples/gateway-deploy-connect.md @@ -16,8 +16,7 @@ kubectl create namespace openshell helm upgrade --install openshell deploy/helm/openshell \ --namespace openshell \ --set server.disableTls=true \ - --set service.type=ClusterIP \ - --set server.grpcEndpoint=http://openshell.openshell.svc.cluster.local:8080 + --set service.type=ClusterIP ``` For local evaluation, forward the service and register the forwarded endpoint: diff --git a/tasks/scripts/helm-k3s-local.sh b/tasks/scripts/helm-k3s-local.sh index 3f268c2dc..d4f802c0f 100755 --- a/tasks/scripts/helm-k3s-local.sh +++ b/tasks/scripts/helm-k3s-local.sh @@ -20,6 +20,9 @@ ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" _branch="$(git -C "${ROOT}" rev-parse --abbrev-ref HEAD 2>/dev/null)" || _branch="" _suffix="$(printf '%s' "${_branch##*/}" | tr '[:upper:]' '[:lower:]' | tr -cs 'a-z0-9' '-' | sed 's/-*$//')" CLUSTER_NAME="${HELM_K3S_CLUSTER_NAME:-openshell-dev${_suffix:+-${_suffix}}}" +# k3d caps cluster names at 32 chars; validated in cmd_create so the operator +# gets an actionable hint instead of a deep-stack k3d validation error. +K3D_CLUSTER_NAME_MAX=32 # Host port forwarded to port 80 via the k3d load balancer. # Used by Envoy Gateway's LoadBalancer service (values-gateway.yaml). HOST_LB_PORT="${HELM_K3S_LB_HOST_PORT:-8080}" @@ -154,6 +157,15 @@ cmd_create() { require_docker require_k3d + if (( ${#CLUSTER_NAME} > K3D_CLUSTER_NAME_MAX )); then + cat >&2 < Date: Fri, 8 May 2026 14:46:57 -0700 Subject: [PATCH 015/157] fix(e2e): isolate kubernetes user namespace test (#1276) --- e2e/rust/Cargo.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/e2e/rust/Cargo.toml b/e2e/rust/Cargo.toml index 57bc1ff68..0da2e417b 100644 --- a/e2e/rust/Cargo.toml +++ b/e2e/rust/Cargo.toml @@ -19,6 +19,7 @@ publish = false e2e = [] e2e-docker = ["e2e"] e2e-docker-gpu = ["e2e-docker"] +e2e-kubernetes = ["e2e"] [[test]] name = "custom_image" @@ -40,6 +41,11 @@ name = "gateway_resume" path = "tests/gateway_resume.rs" required-features = ["e2e-docker"] +[[test]] +name = "user_namespaces" +path = "tests/user_namespaces.rs" +required-features = ["e2e-kubernetes"] + [dependencies] tokio = { version = "1.43", features = ["full"] } tempfile = "3" From 7ad823ea826c35c10c6cc0c58f9e05ae1f130026 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Fri, 8 May 2026 15:14:54 -0700 Subject: [PATCH 016/157] fix(install): register local gateway before probing listener (#1280) Signed-off-by: Drew Newberry --- install-dev.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/install-dev.sh b/install-dev.sh index 48f935959..e48152862 100755 --- a/install-dev.sh +++ b/install-dev.sh @@ -416,9 +416,9 @@ start_user_gateway() { as_target_user systemctl --user restart openshell-gateway as_target_user systemctl --user is-active --quiet openshell-gateway - wait_for_local_gateway_listener info "registering local gateway as ${TARGET_USER}..." register_local_gateway + wait_for_local_gateway_listener wait_for_local_gateway_status } @@ -661,9 +661,9 @@ install_macos_homebrew() { OPENSHELL_REGISTER_BIN="${_brew_prefix}/bin/openshell" fi - wait_for_local_gateway_listener info "registering local gateway as ${TARGET_USER}..." register_local_gateway + wait_for_local_gateway_listener wait_for_local_gateway_status } From 40417981ec3c0b2a986c08693ed07cb63e10b039 Mon Sep 17 00:00:00 2001 From: Saurabh Agarwal Date: Fri, 8 May 2026 19:15:01 -0400 Subject: [PATCH 017/157] fix(helm): derive sandboxNamespace from Release.Namespace instead of hardcoding (#1282) Signed-off-by: Saurabh Agarwal --- deploy/helm/openshell/templates/_helpers.tpl | 9 ++++ .../openshell/templates/networkpolicy.yaml | 2 +- .../helm/openshell/templates/statefulset.yaml | 2 +- .../tests/sandbox_namespace_test.yaml | 50 +++++++++++++++++++ deploy/helm/openshell/values.yaml | 4 +- 5 files changed, 64 insertions(+), 3 deletions(-) create mode 100644 deploy/helm/openshell/tests/sandbox_namespace_test.yaml diff --git a/deploy/helm/openshell/templates/_helpers.tpl b/deploy/helm/openshell/templates/_helpers.tpl index 93eff90a9..c0c8562c1 100644 --- a/deploy/helm/openshell/templates/_helpers.tpl +++ b/deploy/helm/openshell/templates/_helpers.tpl @@ -82,6 +82,15 @@ Namespaced Issuer (selfSigned) for cert-manager CA bootstrap. {{- printf "%s-selfsigned" (include "openshell.fullname" .) | trunc 63 | trimSuffix "-" }} {{- end }} +{{/* +Namespace where sandbox pods are created. An explicit +.Values.server.sandboxNamespace is used verbatim. Otherwise it defaults to +.Release.Namespace so `helm install -n my-ns` works without extra overrides. +*/}} +{{- define "openshell.sandboxNamespace" -}} +{{- .Values.server.sandboxNamespace | default .Release.Namespace -}} +{{- end }} + {{/* gRPC endpoint sandbox pods use to call back into the gateway. An explicit .Values.server.grpcEndpoint is used verbatim. Otherwise it is derived from diff --git a/deploy/helm/openshell/templates/networkpolicy.yaml b/deploy/helm/openshell/templates/networkpolicy.yaml index 3e0b6f504..e85571e5f 100644 --- a/deploy/helm/openshell/templates/networkpolicy.yaml +++ b/deploy/helm/openshell/templates/networkpolicy.yaml @@ -11,7 +11,7 @@ apiVersion: networking.k8s.io/v1 kind: NetworkPolicy metadata: name: {{ include "openshell.fullname" . }}-sandbox-ssh - namespace: {{ .Values.server.sandboxNamespace }} + namespace: {{ include "openshell.sandboxNamespace" . }} labels: {{- include "openshell.labels" . | nindent 4 }} spec: diff --git a/deploy/helm/openshell/templates/statefulset.yaml b/deploy/helm/openshell/templates/statefulset.yaml index 2d3f731af..b28e99c66 100644 --- a/deploy/helm/openshell/templates/statefulset.yaml +++ b/deploy/helm/openshell/templates/statefulset.yaml @@ -63,7 +63,7 @@ spec: - {{ .Values.server.dbUrl | quote }} env: - name: OPENSHELL_SANDBOX_NAMESPACE - value: {{ .Values.server.sandboxNamespace | quote }} + value: {{ include "openshell.sandboxNamespace" . | quote }} - name: OPENSHELL_SANDBOX_IMAGE value: {{ .Values.server.sandboxImage | quote }} {{- if .Values.server.sandboxImagePullPolicy }} diff --git a/deploy/helm/openshell/tests/sandbox_namespace_test.yaml b/deploy/helm/openshell/tests/sandbox_namespace_test.yaml new file mode 100644 index 000000000..a128cd440 --- /dev/null +++ b/deploy/helm/openshell/tests/sandbox_namespace_test.yaml @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +suite: sandboxNamespace defaulting +templates: + - templates/statefulset.yaml + - templates/networkpolicy.yaml +release: + name: openshell + namespace: my-namespace + +tests: + - it: defaults OPENSHELL_SANDBOX_NAMESPACE to release namespace + template: templates/statefulset.yaml + asserts: + - contains: + path: spec.template.spec.containers[0].env + content: + name: OPENSHELL_SANDBOX_NAMESPACE + value: "my-namespace" + + - it: uses explicit sandboxNamespace when set + template: templates/statefulset.yaml + set: + server.sandboxNamespace: other-ns + asserts: + - contains: + path: spec.template.spec.containers[0].env + content: + name: OPENSHELL_SANDBOX_NAMESPACE + value: "other-ns" + + - it: defaults NetworkPolicy namespace to release namespace + template: templates/networkpolicy.yaml + set: + networkPolicy.enabled: true + asserts: + - equal: + path: metadata.namespace + value: my-namespace + + - it: uses explicit sandboxNamespace for NetworkPolicy + template: templates/networkpolicy.yaml + set: + networkPolicy.enabled: true + server.sandboxNamespace: other-ns + asserts: + - equal: + path: metadata.namespace + value: other-ns diff --git a/deploy/helm/openshell/values.yaml b/deploy/helm/openshell/values.yaml index 7630554f2..389efc132 100644 --- a/deploy/helm/openshell/values.yaml +++ b/deploy/helm/openshell/values.yaml @@ -79,7 +79,9 @@ affinity: {} # Server configuration server: logLevel: info - sandboxNamespace: openshell + # Namespace where sandbox pods are created. Defaults to the Helm release + # namespace (.Release.Namespace) when left empty. + sandboxNamespace: "" dbUrl: "sqlite:/var/openshell/openshell.db" sandboxImage: "ghcr.io/nvidia/openshell-community/sandboxes/base:latest" # Kubernetes imagePullPolicy for sandbox pods. Empty = Kubernetes default From 529be37fd447827bce842cc6b35cd172e4ff9e72 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Fri, 8 May 2026 17:04:15 -0700 Subject: [PATCH 018/157] chore(installer): promote package install script (#1261) --- .github/workflows/release-canary.yml | 16 +- .github/workflows/test-install.yml | 51 -- crates/openshell-driver-vm/README.md | 6 +- e2e/install/bash_test.sh | 80 --- e2e/install/fish_test.fish | 152 ----- e2e/install/helpers.sh | 100 ---- e2e/install/sh_test.sh | 105 ---- e2e/install/zsh_test.sh | 80 --- install-dev.sh | 714 ----------------------- install.sh | 820 ++++++++++++++++++++------- 10 files changed, 635 insertions(+), 1489 deletions(-) delete mode 100644 .github/workflows/test-install.yml delete mode 100755 e2e/install/bash_test.sh delete mode 100755 e2e/install/fish_test.fish delete mode 100644 e2e/install/helpers.sh delete mode 100755 e2e/install/sh_test.sh delete mode 100755 e2e/install/zsh_test.sh delete mode 100755 install-dev.sh diff --git a/.github/workflows/release-canary.yml b/.github/workflows/release-canary.yml index 5e895efc7..61f8a8a1e 100644 --- a/.github/workflows/release-canary.yml +++ b/.github/workflows/release-canary.yml @@ -20,9 +20,13 @@ jobs: runs-on: macos-latest-xlarge timeout-minutes: 20 steps: - - name: Install dev and check status + - name: Ensure VM driver run: | - curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install-dev.sh | sh + launchctl setenv OPENSHELL_DRIVERS vm + + - name: Install and check status + run: | + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install.sh | sh openshell status ubuntu: @@ -42,9 +46,9 @@ jobs: printf 'OPENSHELL_DRIVERS=docker\n' > "${HOME}/.config/openshell/gateway.env" docker info - - name: Install dev and check status + - name: Install and check status run: | - curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install-dev.sh | sh + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install.sh | sh openshell status fedora: @@ -63,7 +67,7 @@ jobs: printf 'OPENSHELL_DRIVERS=podman\n' > "${HOME}/.config/openshell/gateway.env" podman info - - name: Install dev and check status + - name: Install and check status run: | - curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install-dev.sh | sh + curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/${{ github.event.workflow_run.head_sha || github.sha }}/install.sh | sh openshell status diff --git a/.github/workflows/test-install.yml b/.github/workflows/test-install.yml deleted file mode 100644 index 06b1e007f..000000000 --- a/.github/workflows/test-install.yml +++ /dev/null @@ -1,51 +0,0 @@ -name: Test Install Script - -on: - pull_request: - paths: - - 'install.sh' - - 'e2e/install/**' - - '.github/workflows/test-install.yml' - push: - branches: [main] - paths: - - 'install.sh' - - 'e2e/install/**' - - '.github/workflows/test-install.yml' - workflow_dispatch: - -permissions: - contents: read - -jobs: - test-install: - name: install.sh (${{ matrix.shell }}) - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - include: - - shell: sh - test: e2e/install/sh_test.sh - run: sh e2e/install/sh_test.sh - - shell: bash - test: e2e/install/bash_test.sh - run: bash e2e/install/bash_test.sh - - shell: zsh - test: e2e/install/zsh_test.sh - run: zsh e2e/install/zsh_test.sh - install: zsh - - shell: fish - test: e2e/install/fish_test.fish - run: fish e2e/install/fish_test.fish - install: fish - - steps: - - uses: actions/checkout@v6 - - - name: Install ${{ matrix.shell }} - if: matrix.install - run: sudo apt-get update && sudo apt-get install -y ${{ matrix.install }} - - - name: Run tests (${{ matrix.shell }}) - run: ${{ matrix.run }} diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index 046554c59..49f2ef005 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -180,17 +180,17 @@ The VM guest's serial console is appended to `//console.l - runtime tarballs: the rolling `vm-runtime` release, rebuilt on demand by `release-vm-kernel.yml` -On Debian-family Linux amd64 and arm64 systems, `install-dev.sh` installs the +On Debian-family Linux amd64 and arm64 systems, `install.sh` installs the Debian package from the selected `OPENSHELL_VERSION` release tag. That package includes `openshell-gateway` and `openshell-driver-vm`, but leaves `OPENSHELL_DRIVERS` unset so the gateway uses its normal runtime auto-detection. Set `OPENSHELL_DRIVERS=vm` to force the VM driver. -On RPM-family Linux x86_64 and aarch64 systems, `install-dev.sh` installs the +On RPM-family Linux x86_64 and aarch64 systems, `install.sh` installs the `openshell` and `openshell-gateway` RPM packages from the selected release tag. The RPM gateway package is configured for the Podman driver. -On Apple Silicon macOS, `install-dev.sh` stages the generated `openshell.rb` +On Apple Silicon macOS, `install.sh` stages the generated `openshell.rb` formula from the selected release in the `nvidia/openshell` Homebrew tap. Homebrew installs `openshell`, `openshell-gateway`, and `openshell-driver-vm`, ad-hoc signs the driver with the Hypervisor entitlement diff --git a/e2e/install/bash_test.sh b/e2e/install/bash_test.sh deleted file mode 100755 index 2b4db1caf..000000000 --- a/e2e/install/bash_test.sh +++ /dev/null @@ -1,80 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Bash e2e tests for install.sh. -# -# Downloads the latest release for real and validates: -# - Binary is installed to the correct directory -# - Binary is executable and runs -# - PATH guidance shows the correct export command for bash -# -set -euo pipefail - -. "$(dirname "$0")/helpers.sh" - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - -test_binary_installed() { - printf 'TEST: binary exists in install directory\n' - - if [ -f "$INSTALL_DIR/openshell" ]; then - pass "openshell binary exists at $INSTALL_DIR/openshell" - else - fail "openshell binary exists" "not found at $INSTALL_DIR/openshell" - fi -} - -test_binary_executable() { - printf 'TEST: binary is executable\n' - - if [ -x "$INSTALL_DIR/openshell" ]; then - pass "openshell binary is executable" - else - fail "openshell binary is executable" "$INSTALL_DIR/openshell is not executable" - fi -} - -test_binary_runs() { - printf 'TEST: binary runs successfully\n' - - if _version="$("$INSTALL_DIR/openshell" --version 2>/dev/null)"; then - pass "openshell --version succeeds: $_version" - else - fail "openshell --version succeeds" "exit code: $?" - fi -} - -test_guidance_shows_export_path() { - printf 'TEST: guidance shows export PATH for bash users\n' - - assert_output_contains "$INSTALL_OUTPUT" 'export PATH="' "shows export PATH command" - assert_output_not_contains "$INSTALL_OUTPUT" "fish_add_path" "does not show fish command" -} - -test_guidance_mentions_not_on_path() { - printf 'TEST: guidance mentions install dir is not on PATH\n' - - assert_output_contains "$INSTALL_OUTPUT" "is not on your PATH" "mentions PATH issue" - assert_output_contains "$INSTALL_OUTPUT" "$INSTALL_DIR" "includes install dir in guidance" -} - -# --------------------------------------------------------------------------- -# Runner -# --------------------------------------------------------------------------- - -printf '=== install.sh e2e tests: bash ===\n\n' - -printf 'Installing openshell...\n' -SHELL="/bin/bash" run_install -printf 'Done.\n\n' - -test_binary_installed; echo "" -test_binary_executable; echo "" -test_binary_runs; echo "" -test_guidance_shows_export_path; echo "" -test_guidance_mentions_not_on_path - -print_summary diff --git a/e2e/install/fish_test.fish b/e2e/install/fish_test.fish deleted file mode 100755 index 101760715..000000000 --- a/e2e/install/fish_test.fish +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env fish -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Fish e2e tests for install.sh. -# -# Downloads the latest release for real and validates: -# - Binary is installed to the correct directory -# - Binary is executable and runs -# - PATH guidance shows fish_add_path (not export PATH) - -set -g PASS 0 -set -g FAIL 0 - -# Resolve paths relative to this script -set -g SCRIPT_DIR (builtin cd (dirname (status filename)) && pwd) -set -g REPO_ROOT (builtin cd "$SCRIPT_DIR/../.." && pwd) -set -g INSTALL_SCRIPT "$REPO_ROOT/install.sh" - -# Set by run_install -set -g INSTALL_DIR "" -set -g INSTALL_OUTPUT "" - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -function pass - set -g PASS (math $PASS + 1) - printf ' PASS: %s\n' $argv[1] -end - -function fail - set -g FAIL (math $FAIL + 1) - printf ' FAIL: %s\n' $argv[1] >&2 - if test (count $argv) -gt 1 - printf ' %s\n' $argv[2] >&2 - end -end - -function assert_output_contains - set -l output $argv[1] - set -l pattern $argv[2] - set -l label $argv[3] - - if string match -q -- "*$pattern*" "$output" - pass "$label" - else - fail "$label" "expected '$pattern' in output" - end -end - -function assert_output_not_contains - set -l output $argv[1] - set -l pattern $argv[2] - set -l label $argv[3] - - if string match -q -- "*$pattern*" "$output" - fail "$label" "unexpected '$pattern' found in output" - else - pass "$label" - end -end - -function run_install - set -g INSTALL_DIR (mktemp -d)/bin - - set -g INSTALL_OUTPUT (OPENSHELL_INSTALL_DIR="$INSTALL_DIR" \ - SHELL="/usr/bin/fish" \ - PATH="/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin" \ - sh "$INSTALL_SCRIPT" 2>&1) - - if test $status -ne 0 - printf 'install.sh failed:\n%s\n' "$INSTALL_OUTPUT" >&2 - return 1 - end -end - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - -function test_binary_installed - printf 'TEST: binary exists in install directory\n' - - if test -f "$INSTALL_DIR/openshell" - pass "openshell binary exists at $INSTALL_DIR/openshell" - else - fail "openshell binary exists" "not found at $INSTALL_DIR/openshell" - end -end - -function test_binary_executable - printf 'TEST: binary is executable\n' - - if test -x "$INSTALL_DIR/openshell" - pass "openshell binary is executable" - else - fail "openshell binary is executable" "$INSTALL_DIR/openshell is not executable" - end -end - -function test_binary_runs - printf 'TEST: binary runs successfully\n' - - set -l version_output ("$INSTALL_DIR/openshell" --version 2>/dev/null) - if test $status -eq 0 - pass "openshell --version succeeds: $version_output" - else - fail "openshell --version succeeds" "exit code: $status" - end -end - -function test_guidance_shows_fish_add_path - printf 'TEST: guidance shows fish_add_path for fish users\n' - - assert_output_contains "$INSTALL_OUTPUT" "fish_add_path" "shows fish_add_path command" - assert_output_not_contains "$INSTALL_OUTPUT" 'export PATH="' "does not show POSIX export" -end - -function test_guidance_mentions_not_on_path - printf 'TEST: guidance mentions install dir is not on PATH\n' - - assert_output_contains "$INSTALL_OUTPUT" "is not on your PATH" "mentions PATH issue" - assert_output_contains "$INSTALL_OUTPUT" "$INSTALL_DIR" "includes install dir in guidance" -end - -# --------------------------------------------------------------------------- -# Runner -# --------------------------------------------------------------------------- - -printf '=== install.sh e2e tests: fish ===\n\n' - -printf 'Installing openshell...\n' -run_install -printf 'Done.\n\n' - -test_binary_installed -echo "" -test_binary_executable -echo "" -test_binary_runs -echo "" -test_guidance_shows_fish_add_path -echo "" -test_guidance_mentions_not_on_path - -printf '\n=== Results: %d passed, %d failed ===\n' $PASS $FAIL - -if test $FAIL -gt 0 - exit 1 -end diff --git a/e2e/install/helpers.sh b/e2e/install/helpers.sh deleted file mode 100644 index ff5f66376..000000000 --- a/e2e/install/helpers.sh +++ /dev/null @@ -1,100 +0,0 @@ -#!/bin/sh -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Shared test helpers for install.sh e2e tests. -# Sourced by each per-shell test file (except fish, which has its own helpers). -# -# Provides: -# - pass / fail / print_summary -# - assert_output_contains / assert_output_not_contains -# - run_install (runs the real install.sh to a temp dir, captures output) -# - REPO_ROOT / INSTALL_SCRIPT paths -# - INSTALL_DIR / INSTALL_OUTPUT (set after run_install) - -HELPERS_DIR="$(cd "$(dirname "$0")" && pwd)" -REPO_ROOT="$(cd "$HELPERS_DIR/../.." && pwd)" -INSTALL_SCRIPT="$REPO_ROOT/install.sh" - -_PASS=0 -_FAIL=0 - -# Set by run_install -INSTALL_DIR="" -INSTALL_OUTPUT="" - -# --------------------------------------------------------------------------- -# Assertions -# --------------------------------------------------------------------------- - -pass() { - _PASS=$((_PASS + 1)) - printf ' PASS: %s\n' "$1" -} - -fail() { - _FAIL=$((_FAIL + 1)) - printf ' FAIL: %s\n' "$1" >&2 - if [ -n "${2:-}" ]; then - printf ' %s\n' "$2" >&2 - fi -} - -assert_output_contains() { - _aoc_output="$1" - _aoc_pattern="$2" - _aoc_label="$3" - - if printf '%s' "$_aoc_output" | grep -qF "$_aoc_pattern"; then - pass "$_aoc_label" - else - fail "$_aoc_label" "expected '$_aoc_pattern' in output" - fi -} - -assert_output_not_contains() { - _aonc_output="$1" - _aonc_pattern="$2" - _aonc_label="$3" - - if printf '%s' "$_aonc_output" | grep -qF "$_aonc_pattern"; then - fail "$_aonc_label" "unexpected '$_aonc_pattern' found in output" - else - pass "$_aonc_label" - fi -} - -# --------------------------------------------------------------------------- -# Install runner -# --------------------------------------------------------------------------- - -# Run the real install.sh, installing to a temp directory with the install -# dir removed from PATH so we always get PATH guidance output. -# -# Sets INSTALL_DIR and INSTALL_OUTPUT for subsequent assertions. -# The SHELL variable is passed through so tests can control which shell -# guidance is shown. -# -# Usage: -# SHELL="/bin/bash" run_install -run_install() { - INSTALL_DIR="$(mktemp -d)/bin" - - # Remove the install dir from PATH (it won't be there, but be explicit). - # Keep a minimal PATH so curl/tar/install are available. - INSTALL_OUTPUT="$(OPENSHELL_INSTALL_DIR="$INSTALL_DIR" \ - PATH="/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin" \ - sh "$INSTALL_SCRIPT" 2>&1)" || { - printf 'install.sh failed:\n%s\n' "$INSTALL_OUTPUT" >&2 - return 1 - } -} - -# --------------------------------------------------------------------------- -# Summary -# --------------------------------------------------------------------------- - -print_summary() { - printf '\n=== Results: %d passed, %d failed ===\n' "$_PASS" "$_FAIL" - [ "$_FAIL" -eq 0 ] -} diff --git a/e2e/install/sh_test.sh b/e2e/install/sh_test.sh deleted file mode 100755 index 320c00efb..000000000 --- a/e2e/install/sh_test.sh +++ /dev/null @@ -1,105 +0,0 @@ -#!/bin/sh -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# POSIX sh e2e tests for install.sh. -# -# Downloads the latest release for real and validates: -# - Binary is installed to the correct directory -# - Binary is executable and runs -# - PATH guidance shows the correct export command for sh -# - No rc files or env scripts are created -# -set -eu - -. "$(dirname "$0")/helpers.sh" - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - -test_binary_installed() { - printf 'TEST: binary exists in install directory\n' - - if [ -f "$INSTALL_DIR/openshell" ]; then - pass "openshell binary exists at $INSTALL_DIR/openshell" - else - fail "openshell binary exists" "not found at $INSTALL_DIR/openshell" - fi -} - -test_binary_executable() { - printf 'TEST: binary is executable\n' - - if [ -x "$INSTALL_DIR/openshell" ]; then - pass "openshell binary is executable" - else - fail "openshell binary is executable" "$INSTALL_DIR/openshell is not executable" - fi -} - -test_binary_runs() { - printf 'TEST: binary runs successfully\n' - - if _version="$("$INSTALL_DIR/openshell" --version 2>/dev/null)"; then - pass "openshell --version succeeds: $_version" - else - fail "openshell --version succeeds" "exit code: $?" - fi -} - -test_guidance_shows_export_path() { - printf 'TEST: guidance shows export PATH for sh users\n' - - assert_output_contains "$INSTALL_OUTPUT" 'export PATH="' "shows export PATH command" - assert_output_not_contains "$INSTALL_OUTPUT" "fish_add_path" "does not show fish command" -} - -test_guidance_mentions_not_on_path() { - printf 'TEST: guidance mentions install dir is not on PATH\n' - - assert_output_contains "$INSTALL_OUTPUT" "is not on your PATH" "mentions PATH issue" - assert_output_contains "$INSTALL_OUTPUT" "$INSTALL_DIR" "includes install dir in guidance" -} - -test_guidance_mentions_restart() { - printf 'TEST: guidance tells user to restart shell\n' - - assert_output_contains "$INSTALL_OUTPUT" "restart your shell" "mentions shell restart" -} - -test_no_env_scripts_created() { - printf 'TEST: no env scripts are created in install dir\n' - - if [ -f "$INSTALL_DIR/env" ]; then - fail "no env script created" "found $INSTALL_DIR/env" - else - pass "no env script created" - fi - - if [ -f "$INSTALL_DIR/env.fish" ]; then - fail "no env.fish script created" "found $INSTALL_DIR/env.fish" - else - pass "no env.fish script created" - fi -} - -# --------------------------------------------------------------------------- -# Runner -# --------------------------------------------------------------------------- - -printf '=== install.sh e2e tests: sh ===\n\n' - -printf 'Installing openshell...\n' -SHELL="/bin/sh" run_install -printf 'Done.\n\n' - -test_binary_installed; echo "" -test_binary_executable; echo "" -test_binary_runs; echo "" -test_guidance_shows_export_path; echo "" -test_guidance_mentions_not_on_path; echo "" -test_guidance_mentions_restart; echo "" -test_no_env_scripts_created - -print_summary diff --git a/e2e/install/zsh_test.sh b/e2e/install/zsh_test.sh deleted file mode 100755 index 621d35f8e..000000000 --- a/e2e/install/zsh_test.sh +++ /dev/null @@ -1,80 +0,0 @@ -#!/bin/zsh -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Zsh e2e tests for install.sh. -# -# Downloads the latest release for real and validates: -# - Binary is installed to the correct directory -# - Binary is executable and runs -# - PATH guidance shows the correct export command for zsh -# -set -eu - -. "$(dirname "$0")/helpers.sh" - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - -test_binary_installed() { - printf 'TEST: binary exists in install directory\n' - - if [ -f "$INSTALL_DIR/openshell" ]; then - pass "openshell binary exists at $INSTALL_DIR/openshell" - else - fail "openshell binary exists" "not found at $INSTALL_DIR/openshell" - fi -} - -test_binary_executable() { - printf 'TEST: binary is executable\n' - - if [ -x "$INSTALL_DIR/openshell" ]; then - pass "openshell binary is executable" - else - fail "openshell binary is executable" "$INSTALL_DIR/openshell is not executable" - fi -} - -test_binary_runs() { - printf 'TEST: binary runs successfully\n' - - if _version="$("$INSTALL_DIR/openshell" --version 2>/dev/null)"; then - pass "openshell --version succeeds: $_version" - else - fail "openshell --version succeeds" "exit code: $?" - fi -} - -test_guidance_shows_export_path() { - printf 'TEST: guidance shows export PATH for zsh users\n' - - assert_output_contains "$INSTALL_OUTPUT" 'export PATH="' "shows export PATH command" - assert_output_not_contains "$INSTALL_OUTPUT" "fish_add_path" "does not show fish command" -} - -test_guidance_mentions_not_on_path() { - printf 'TEST: guidance mentions install dir is not on PATH\n' - - assert_output_contains "$INSTALL_OUTPUT" "is not on your PATH" "mentions PATH issue" - assert_output_contains "$INSTALL_OUTPUT" "$INSTALL_DIR" "includes install dir in guidance" -} - -# --------------------------------------------------------------------------- -# Runner -# --------------------------------------------------------------------------- - -printf '=== install.sh e2e tests: zsh ===\n\n' - -printf 'Installing openshell...\n' -SHELL="/bin/zsh" run_install -printf 'Done.\n\n' - -test_binary_installed; echo "" -test_binary_executable; echo "" -test_binary_runs; echo "" -test_guidance_shows_export_path; echo "" -test_guidance_mentions_not_on_path - -print_summary diff --git a/install-dev.sh b/install-dev.sh deleted file mode 100755 index e48152862..000000000 --- a/install-dev.sh +++ /dev/null @@ -1,714 +0,0 @@ -#!/bin/sh -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Install the OpenShell development build from a GitHub release. -# -# Linux installs either the Debian or RPM packages from the selected release. -# Apple Silicon macOS installs the generated Homebrew formula, so Homebrew owns -# the binary layout and launchd service lifecycle. -# -set -e - -APP_NAME="openshell" -REPO="NVIDIA/OpenShell" -GITHUB_URL="https://github.com/${REPO}" -RELEASE_TAG="${OPENSHELL_VERSION:-dev}" -CHECKSUMS_NAME="openshell-checksums-sha256.txt" -LOCAL_GATEWAY_PORT="17670" -HOMEBREW_TAP="nvidia/openshell" -HOMEBREW_FORMULA_NAME="openshell" - -info() { - printf '%s: %s\n' "$APP_NAME" "$*" >&2 -} - -warn() { - printf '%s: warning: %s\n' "$APP_NAME" "$*" >&2 -} - -error() { - printf '%s: error: %s\n' "$APP_NAME" "$*" >&2 - exit 1 -} - -usage() { - cat </dev/null 2>&1 -} - -require_cmd() { - if ! has_cmd "$1"; then - error "'$1' is required" - fi -} - -download() { - _url="$1" - _output="$2" - curl -fLsS --retry 3 --max-redirs 5 -o "$_output" "$_url" -} - -download_release_asset() { - _tag="$1" - _filename="$2" - _output="$3" - - if curl -fLs --retry 3 --max-redirs 5 -o "$_output" \ - "${GITHUB_URL}/releases/download/${_tag}/${_filename}"; then - return 0 - fi - - # GitHub normalizes `~` to `.` in release asset names, while checksum files - # can still record package filenames with `~dev` for correct version ordering. - # Download the normalized asset but verify it against the checksum entry for - # the original package filename. - _normalized="$(printf '%s' "$_filename" | tr '~' '.')" - if [ "$_normalized" != "$_filename" ]; then - if download "${GITHUB_URL}/releases/download/${_tag}/${_normalized}" "$_output"; then - info "using GitHub-normalized asset name ${_normalized}" - return 0 - fi - fi - - return 1 -} - -as_root() { - if [ "$(id -u)" -eq 0 ]; then - "$@" - elif has_cmd sudo; then - sudo "$@" - else - error "this installer needs root privileges; rerun as root or install sudo" - fi -} - -target_user() { - if [ "$(id -u)" -eq 0 ] && [ -n "${SUDO_USER:-}" ] && [ "${SUDO_USER}" != "root" ]; then - echo "$SUDO_USER" - else - id -un - fi -} - -user_home() { - _user="$1" - if has_cmd getent; then - _home="$(getent passwd "$_user" | awk -F: '{ print $6 }')" - if [ -n "$_home" ]; then - echo "$_home" - return 0 - fi - fi - - if [ "$(uname -s)" = "Darwin" ] && has_cmd dscl; then - _home="$(dscl . -read "/Users/${_user}" NFSHomeDirectory 2>/dev/null | awk '{ print $2 }')" - if [ -n "$_home" ]; then - echo "$_home" - return 0 - fi - fi - - if [ "$(id -un)" = "$_user" ]; then - echo "${HOME:-}" - return 0 - fi - - if [ "$(uname -s)" = "Darwin" ]; then - echo "/Users/${_user}" - return 0 - fi - - echo "/home/${_user}" -} - -as_target_user() { - if [ "${PLATFORM:-}" = "darwin" ]; then - if [ "$(id -u)" -eq "$TARGET_UID" ]; then - env HOME="$TARGET_HOME" "$@" - elif has_cmd sudo; then - sudo -u "$TARGET_USER" env HOME="$TARGET_HOME" "$@" - else - error "cannot run commands as ${TARGET_USER}; install sudo or run as ${TARGET_USER}" - fi - return - fi - - _bus="unix:path=${TARGET_RUNTIME_DIR}/bus" - if [ "$(id -u)" -eq "$TARGET_UID" ]; then - env HOME="$TARGET_HOME" XDG_RUNTIME_DIR="$TARGET_RUNTIME_DIR" DBUS_SESSION_BUS_ADDRESS="$_bus" "$@" - elif has_cmd sudo; then - sudo -u "$TARGET_USER" env HOME="$TARGET_HOME" XDG_RUNTIME_DIR="$TARGET_RUNTIME_DIR" DBUS_SESSION_BUS_ADDRESS="$_bus" "$@" - elif has_cmd runuser; then - runuser -u "$TARGET_USER" -- env HOME="$TARGET_HOME" XDG_RUNTIME_DIR="$TARGET_RUNTIME_DIR" DBUS_SESSION_BUS_ADDRESS="$_bus" "$@" - else - error "cannot run user service commands as ${TARGET_USER}; install sudo or run as ${TARGET_USER}" - fi -} - -detect_platform() { - case "$(uname -s)" in - Linux) - echo "linux" - ;; - Darwin) - echo "darwin" - ;; - *) - error "unsupported OS: $(uname -s); dev builds support Linux and macOS" - ;; - esac -} - -linux_package_method() { - if has_cmd dpkg; then - echo "deb" - elif has_cmd rpm; then - echo "rpm" - else - error "Linux dev installs require either dpkg or rpm" - fi -} - -set_linux_target_runtime_dir() { - if [ "$(id -u)" -eq "$TARGET_UID" ] && [ -n "${XDG_RUNTIME_DIR:-}" ]; then - TARGET_RUNTIME_DIR="$XDG_RUNTIME_DIR" - else - TARGET_RUNTIME_DIR="/run/user/${TARGET_UID}" - fi -} - -check_linux_deb_platform() { - require_cmd dpkg -} - -check_macos_platform() { - _arch="$(uname -m)" - - case "$_arch" in - arm64|aarch64) - ;; - x86_64|amd64) - error "Intel macOS is not supported because no x86_64-apple-darwin dev assets are published" - ;; - *) - error "no macOS dev build is published for architecture: ${_arch}" - ;; - esac - - if ! as_target_user brew --version >/dev/null 2>&1; then - error "Homebrew is required for macOS dev installs; install it from https://brew.sh" - fi -} - -get_deb_arch() { - _arch="$(dpkg --print-architecture)" - - case "$_arch" in - amd64|arm64) - echo "$_arch" - ;; - *) - error "no dev Debian package is published for architecture: ${_arch}" - ;; - esac -} - -get_rpm_arch() { - if has_cmd rpm; then - _arch="$(rpm --eval '%{_arch}' 2>/dev/null || true)" - else - _arch="" - fi - - if [ -z "$_arch" ]; then - _arch="$(uname -m)" - fi - - case "$_arch" in - x86_64|amd64) - echo "x86_64" - ;; - aarch64|arm64) - echo "aarch64" - ;; - *) - error "no dev RPM package is published for architecture: ${_arch}" - ;; - esac -} - -find_deb_asset() { - _checksums="$1" - _arch="$2" - - awk -v arch="$_arch" ' - $2 ~ "^\\*?openshell[-_].*[-_]" arch "\\.deb$" { - sub("^\\*", "", $2) - print $2 - exit - } - ' "$_checksums" -} - -find_rpm_asset() { - _checksums="$1" - _arch="$2" - _package="$3" - - case "$_package" in - openshell) - _dev_name="openshell-dev-${_arch}.rpm" - _fallback_re="^openshell-[0-9].*\\.${_arch}\\.rpm$" - ;; - openshell-gateway) - _dev_name="openshell-gateway-dev-${_arch}.rpm" - _fallback_re="^openshell-gateway-[0-9].*\\.${_arch}\\.rpm$" - ;; - *) - error "unknown RPM package selector: ${_package}" - ;; - esac - - awk -v dev_name="$_dev_name" -v fallback_re="$_fallback_re" ' - { - name = $2 - sub("^\\*", "", name) - - if (name == dev_name) { - selected = name - found = 1 - exit - } - - if (fallback == "" && name ~ fallback_re) { - fallback = name - } - } - END { - if (found) { - print selected - } else if (fallback != "") { - print fallback - } - } - ' "$_checksums" -} - -verify_checksum() { - _archive="$1" - _checksums="$2" - _filename="$3" - - if has_cmd sha256sum; then - _expected="$(awk -v name="$_filename" '($2 == name || $2 == "*" name) { print $1; exit }' "$_checksums")" - [ -n "$_expected" ] || error "no checksum entry found for ${_filename}" - echo "$_expected $_archive" | sha256sum -c --quiet - elif has_cmd shasum; then - _expected="$(awk -v name="$_filename" '($2 == name || $2 == "*" name) { print $1; exit }' "$_checksums")" - [ -n "$_expected" ] || error "no checksum entry found for ${_filename}" - echo "$_expected $_archive" | shasum -a 256 -c --quiet - else - error "neither 'sha256sum' nor 'shasum' found; cannot verify download integrity" - fi -} - -install_deb_package() { - _deb_path="$1" - - if has_cmd apt-get; then - as_root env DEBIAN_FRONTEND=noninteractive apt-get install -y \ - -o Dpkg::Options::=--force-confdef \ - -o Dpkg::Options::=--force-confnew \ - "$_deb_path" - elif has_cmd apt; then - as_root env DEBIAN_FRONTEND=noninteractive apt install -y \ - -o Dpkg::Options::=--force-confdef \ - -o Dpkg::Options::=--force-confnew \ - "$_deb_path" - else - as_root dpkg --force-confdef --force-confnew -i "$_deb_path" - fi -} - -install_rpm_packages() { - if has_cmd dnf; then - as_root dnf install -y "$@" - elif has_cmd yum; then - as_root yum install -y "$@" - elif has_cmd zypper; then - as_root zypper --non-interactive install --allow-unsigned-rpm "$@" - elif has_cmd rpm; then - warn "installing with rpm directly; dependencies must already be installed" - as_root rpm -Uvh --replacepkgs "$@" - else - error "'dnf', 'yum', 'zypper', or 'rpm' is required to install RPM packages" - fi -} - -homebrew_formula_path() { - _tap="$1" - _formula="$2" - - if ! as_target_user brew tap-info "$_tap" >/dev/null 2>&1; then - info "creating local Homebrew tap ${_tap}..." - as_target_user brew tap-new --no-git "$_tap" >/dev/null - fi - - _tap_dir="$(as_target_user brew --repository "$_tap" 2>/dev/null || true)" - [ -n "$_tap_dir" ] || error "could not locate Homebrew tap ${_tap}" - - _formula_dir="${_tap_dir}/Formula" - as_target_user mkdir -p "$_formula_dir" - printf '%s/%s.rb\n' "$_formula_dir" "$_formula" -} - -patch_homebrew_formula() { - _formula_file="$1" - _patched_file="${_formula_file}.patched" - - if grep -q 'entitlements.write <<~XML' "$_formula_file"; then - info "patching Homebrew formula for idempotent postinstall..." - sed 's/entitlements\.write <<~XML/entitlements.atomic_write <<~XML/' "$_formula_file" >"$_patched_file" - mv "$_patched_file" "$_formula_file" - fi - -} - -start_user_gateway() { - info "restarting openshell-gateway user service as ${TARGET_USER}..." - - if ! as_target_user systemctl --user daemon-reload; then - info "could not reach the user systemd manager for ${TARGET_USER}" - info "restart the gateway later with: systemctl --user enable openshell-gateway && systemctl --user restart openshell-gateway" - info "then register it with: openshell gateway add https://127.0.0.1:17670 --local --name openshell" - return 0 - fi - - as_target_user systemctl --user enable openshell-gateway - as_target_user systemctl --user restart openshell-gateway - as_target_user systemctl --user is-active --quiet openshell-gateway - - info "registering local gateway as ${TARGET_USER}..." - register_local_gateway - wait_for_local_gateway_listener - wait_for_local_gateway_status -} - -wait_for_local_gateway_listener() { - _timeout="${OPENSHELL_INSTALL_GATEWAY_TIMEOUT:-30}" - _elapsed=0 - _last_output="" - _probe_url="https://127.0.0.1:${LOCAL_GATEWAY_PORT}/" - _mtls_dir="${TARGET_HOME}/.config/openshell/gateways/openshell/mtls" - - info "waiting for local gateway listener to become reachable..." - while [ "$_elapsed" -lt "$_timeout" ]; do - if [ ! -f "${_mtls_dir}/ca.crt" ] || [ ! -f "${_mtls_dir}/tls.crt" ] || [ ! -f "${_mtls_dir}/tls.key" ]; then - _last_output="mTLS client bundle is not ready under ${_mtls_dir}" - elif _last_output="$(as_target_user curl -sS --max-time 2 --cacert "${_mtls_dir}/ca.crt" --cert "${_mtls_dir}/tls.crt" --key "${_mtls_dir}/tls.key" -o /dev/null "$_probe_url" 2>&1)"; then - info "local gateway listener is reachable" - return 0 - fi - sleep 1 - _elapsed=$((_elapsed + 1)) - done - - printf '%s\n' "$_last_output" >&2 - error "local gateway listener did not become reachable at ${_probe_url} within ${_timeout}s" -} - -wait_for_local_gateway_status() { - _timeout="${OPENSHELL_INSTALL_GATEWAY_TIMEOUT:-30}" - _elapsed=0 - _status_output="" - _register_bin="${OPENSHELL_REGISTER_BIN:-openshell}" - - info "waiting for openshell status to report connected..." - while [ "$_elapsed" -lt "$_timeout" ]; do - if _status_output="$(as_target_user env NO_COLOR=1 "$_register_bin" status 2>&1)"; then - case "$_status_output" in - *"Version:"*) - info "openshell status reports connected" - return 0 - ;; - esac - fi - sleep 1 - _elapsed=$((_elapsed + 1)) - done - - printf '%s\n' "$_status_output" >&2 - error "openshell status did not report connected within ${_timeout}s" -} - -remove_local_gateway_registration() { - [ -n "$TARGET_HOME" ] || error "cannot resolve home directory for ${TARGET_USER}" - _config_dir="${TARGET_HOME}/.config/openshell" - - # The install-dev gateway is a user service. Replace the CLI registration - # directly instead of asking `gateway destroy` to tear down Docker resources. - # shellcheck disable=SC2016 - as_target_user sh -c ' - config_dir=$1 - rm -rf "${config_dir}/gateways/local" - mkdir -p "${config_dir}/gateways/openshell" - rm -f \ - "${config_dir}/gateways/openshell/metadata.json" \ - "${config_dir}/gateways/openshell/edge_token" \ - "${config_dir}/gateways/openshell/cf_token" \ - "${config_dir}/gateways/openshell/oidc_token.json" - active="${config_dir}/active_gateway" - active_name="$(cat "$active" 2>/dev/null || true)" - if [ "$active_name" = "local" ] || [ "$active_name" = "openshell" ]; then - rm -f "$active" - fi - ' sh "$_config_dir" -} - -register_local_gateway() { - _register_bin="${OPENSHELL_REGISTER_BIN:-openshell}" - - if _add_output="$(as_target_user "$_register_bin" gateway add "https://127.0.0.1:${LOCAL_GATEWAY_PORT}" --local --name openshell 2>&1)"; then - [ -z "$_add_output" ] || print_gateway_add_output "$_add_output" - return 0 - else - _add_status=$? - fi - - case "$_add_output" in - *"already exists"*) - info "local gateway already exists; removing and re-adding it..." - remove_local_gateway_registration - as_target_user "$_register_bin" gateway add "https://127.0.0.1:${LOCAL_GATEWAY_PORT}" --local --name openshell - ;; - *) - printf '%s\n' "$_add_output" >&2 - return "$_add_status" - ;; - esac -} - -print_gateway_add_output() { - printf '%s\n' "$1" | while IFS= read -r _line; do - case "$_line" in - *"Gateway is not reachable at https://127.0.0.1:${LOCAL_GATEWAY_PORT}"*) ;; - *"Verify the gateway is running and the endpoint is correct."*) ;; - *) printf '%s\n' "$_line" >&2 ;; - esac - done -} - -install_linux_deb() { - check_linux_deb_platform - set_linux_target_runtime_dir - - _arch="$(get_deb_arch)" - _tmpdir="$(mktemp -d)" - chmod 0755 "$_tmpdir" - trap 'rm -rf "$_tmpdir"' EXIT - - _checksums_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${CHECKSUMS_NAME}" - info "downloading ${RELEASE_TAG} release checksums..." - download "$_checksums_url" "${_tmpdir}/${CHECKSUMS_NAME}" || { - error "failed to download ${_checksums_url}" - } - - _deb_file="$(find_deb_asset "${_tmpdir}/${CHECKSUMS_NAME}" "$_arch")" - if [ -z "$_deb_file" ]; then - error "no dev Debian package found for architecture: ${_arch}" - fi - - _deb_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${_deb_file}" - _deb_path="${_tmpdir}/${_deb_file}" - - info "selected ${_deb_file}" - - info "downloading ${_deb_file}..." - download_release_asset "$RELEASE_TAG" "$_deb_file" "$_deb_path" || { - error "failed to download ${_deb_url}" - } - chmod 0644 "$_deb_path" - - info "verifying checksum..." - verify_checksum "$_deb_path" "${_tmpdir}/${CHECKSUMS_NAME}" "$_deb_file" - - info "installing ${_deb_file}..." - install_deb_package "$_deb_path" - info "installed ${APP_NAME} package from ${RELEASE_TAG}" - start_user_gateway -} - -install_linux_rpm() { - require_cmd rpm - set_linux_target_runtime_dir - - _arch="$(get_rpm_arch)" - _tmpdir="$(mktemp -d)" - chmod 0755 "$_tmpdir" - trap 'rm -rf "$_tmpdir"' EXIT - - _checksums_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${CHECKSUMS_NAME}" - info "downloading ${RELEASE_TAG} release checksums..." - download "$_checksums_url" "${_tmpdir}/${CHECKSUMS_NAME}" || { - error "failed to download ${_checksums_url}" - } - - _rpm_file="$(find_rpm_asset "${_tmpdir}/${CHECKSUMS_NAME}" "$_arch" openshell)" - if [ -z "$_rpm_file" ]; then - error "no dev openshell RPM package found for architecture: ${_arch}" - fi - - _gateway_rpm_file="$(find_rpm_asset "${_tmpdir}/${CHECKSUMS_NAME}" "$_arch" openshell-gateway)" - if [ -z "$_gateway_rpm_file" ]; then - error "no dev openshell-gateway RPM package found for architecture: ${_arch}" - fi - - info "selected ${_rpm_file} and ${_gateway_rpm_file}" - - for _package_file in "$_rpm_file" "$_gateway_rpm_file"; do - _package_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${_package_file}" - _package_path="${_tmpdir}/${_package_file}" - - info "downloading ${_package_file}..." - download_release_asset "$RELEASE_TAG" "$_package_file" "$_package_path" || { - error "failed to download ${_package_url}" - } - chmod 0644 "$_package_path" - - info "verifying checksum for ${_package_file}..." - verify_checksum "$_package_path" "${_tmpdir}/${CHECKSUMS_NAME}" "$_package_file" - done - - info "installing ${_rpm_file} and ${_gateway_rpm_file}..." - install_rpm_packages "${_tmpdir}/${_rpm_file}" "${_tmpdir}/${_gateway_rpm_file}" - info "installed ${APP_NAME} RPM packages from ${RELEASE_TAG}" - start_user_gateway -} - -install_macos_homebrew() { - check_macos_platform - - _tmpdir="$(mktemp -d)" - chmod 0755 "$_tmpdir" - trap 'rm -rf "$_tmpdir"' EXIT - - _formula_file="${_tmpdir}/openshell.rb" - _formula_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/openshell.rb" - - info "downloading Homebrew formula from ${_formula_url}..." - download_release_asset "$RELEASE_TAG" "openshell.rb" "$_formula_file" || { - error "failed to download ${_formula_url}; the selected release may not include a Homebrew formula" - } - chmod 0644 "$_formula_file" - patch_homebrew_formula "$_formula_file" - - _tap_formula_file="$(homebrew_formula_path "$HOMEBREW_TAP" "$HOMEBREW_FORMULA_NAME")" - info "staging Homebrew formula in tap ${HOMEBREW_TAP}..." - cp "$_formula_file" "$_tap_formula_file" - chmod 0644 "$_tap_formula_file" - if [ "$(id -u)" -eq 0 ]; then - chown "$TARGET_USER" "$_tap_formula_file" 2>/dev/null || true - fi - - _formula_ref="${HOMEBREW_TAP}/${HOMEBREW_FORMULA_NAME}" - - if as_target_user brew list --formula openshell >/dev/null 2>&1; then - info "reinstalling OpenShell with Homebrew..." - as_target_user brew reinstall --formula "$_formula_ref" - else - info "installing OpenShell with Homebrew..." - as_target_user brew install --formula "$_formula_ref" - fi - - info "restarting OpenShell Homebrew service..." - if ! as_target_user brew services restart "$_formula_ref"; then - warn "could not restart the OpenShell Homebrew service" - info "restart it later with: brew services restart ${_formula_ref}" - info "then register it with: openshell gateway add https://127.0.0.1:${LOCAL_GATEWAY_PORT} --local --name openshell" - return 0 - fi - - _brew_prefix="$(as_target_user brew --prefix 2>/dev/null || true)" - if [ -n "$_brew_prefix" ] && [ -x "${_brew_prefix}/bin/openshell" ]; then - OPENSHELL_REGISTER_BIN="${_brew_prefix}/bin/openshell" - fi - - info "registering local gateway as ${TARGET_USER}..." - register_local_gateway - wait_for_local_gateway_listener - wait_for_local_gateway_status -} - -main() { - if [ "$#" -gt 0 ]; then - case "$1" in - --help) - usage - exit 0 - ;; - *) - error "unknown option: $1" - ;; - esac - fi - - require_cmd curl - PLATFORM="$(detect_platform)" - - TARGET_USER="$(target_user)" - TARGET_UID="$(id -u "$TARGET_USER" 2>/dev/null || true)" - [ -n "$TARGET_UID" ] || error "cannot resolve uid for ${TARGET_USER}" - TARGET_HOME="$(user_home "$TARGET_USER")" - - case "$PLATFORM" in - linux) - case "$(linux_package_method)" in - deb) - install_linux_deb - ;; - rpm) - install_linux_rpm - ;; - *) - error "unsupported Linux package method" - ;; - esac - ;; - darwin) - install_macos_homebrew - ;; - *) - error "unsupported platform: ${PLATFORM}" - ;; - esac -} - -main "$@" diff --git a/install.sh b/install.sh index 0ad6eee63..98d2d5d77 100755 --- a/install.sh +++ b/install.sh @@ -2,27 +2,22 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # -# Install the OpenShell CLI binary. +# Install OpenShell from a GitHub release. # -# Usage: -# curl -LsSf https://raw.githubusercontent.com/NVIDIA/OpenShell/main/install.sh | sh +# Linux installs either the Debian or RPM packages from the selected release. +# Apple Silicon macOS installs the generated Homebrew formula, so Homebrew owns +# the binary layout and launchd service lifecycle. # -# Or run directly: -# ./install.sh -# -# Environment variables: -# OPENSHELL_VERSION - Release tag to install (default: latest tagged release) -# OPENSHELL_INSTALL_DIR - Directory to install into (default: ~/.local/bin) -# -set -eu +set -e APP_NAME="openshell" REPO="NVIDIA/OpenShell" GITHUB_URL="https://github.com/${REPO}" - -# --------------------------------------------------------------------------- -# Logging -# --------------------------------------------------------------------------- +RELEASE_TAG="${OPENSHELL_VERSION:-}" +CHECKSUMS_NAME="openshell-checksums-sha256.txt" +LOCAL_GATEWAY_PORT="17670" +HOMEBREW_TAP="nvidia/openshell" +HOMEBREW_FORMULA_NAME="openshell" info() { printf '%s: %s\n' "$APP_NAME" "$*" >&2 @@ -37,284 +32,713 @@ error() { exit 1 } -# --------------------------------------------------------------------------- -# Usage -# --------------------------------------------------------------------------- - usage() { cat </dev/null 2>&1 } -check_downloader() { - if has_cmd curl; then - return 0 - elif has_cmd wget; then - return 0 - else - error "either 'curl' or 'wget' is required to download files" +require_cmd() { + if ! has_cmd "$1"; then + error "'$1' is required" fi } -# Download a URL to a file. Outputs nothing on success. download() { _url="$1" _output="$2" + curl -fLsS --retry 3 --max-redirs 5 -o "$_output" "$_url" +} - if has_cmd curl; then - curl -fLsS --retry 3 --max-redirs 5 -o "$_output" "$_url" - elif has_cmd wget; then - wget -q --tries=3 --max-redirect=5 -O "$_output" "$_url" +resolve_release_tag() { + if [ -n "${OPENSHELL_VERSION:-}" ]; then + echo "$OPENSHELL_VERSION" + return 0 fi + + info "resolving latest version..." + _latest_url="${GITHUB_URL}/releases/latest" + _resolved="$(curl -fLsS -o /dev/null -w '%{url_effective}' "$_latest_url")" || { + error "failed to resolve latest release from ${_latest_url}" + } + + case "$_resolved" in + https://github.com/${REPO}/releases/*) + ;; + *) + error "unexpected redirect target: ${_resolved} (expected https://github.com/${REPO}/releases/...)" + ;; + esac + + _version="${_resolved##*/}" + if [ -z "$_version" ] || [ "$_version" = "latest" ]; then + error "could not determine latest release version (resolved URL: ${_resolved})" + fi + + echo "$_version" } -# Follow a URL and print the final resolved URL (for detecting redirect targets). -resolve_redirect() { - _url="$1" +download_release_asset() { + _tag="$1" + _filename="$2" + _output="$3" - if has_cmd curl; then - curl -fLsS -o /dev/null -w '%{url_effective}' "$_url" - elif has_cmd wget; then - # wget --spider follows redirects; capture the final Location from stderr - wget --spider --max-redirect=10 "$_url" 2>&1 | sed -n 's/^.*Location: \([^ ]*\).*/\1/p' | tail -1 + if curl -fLs --retry 3 --max-redirs 5 -o "$_output" \ + "${GITHUB_URL}/releases/download/${_tag}/${_filename}"; then + return 0 + fi + + # GitHub normalizes `~` to `.` in release asset names, while checksum files + # can still record package filenames with `~dev` for correct version ordering. + # Download the normalized asset but verify it against the checksum entry for + # the original package filename. + _normalized="$(printf '%s' "$_filename" | tr '~' '.')" + if [ "$_normalized" != "$_filename" ]; then + if download "${GITHUB_URL}/releases/download/${_tag}/${_normalized}" "$_output"; then + info "using GitHub-normalized asset name ${_normalized}" + return 0 + fi fi + + return 1 } -# --------------------------------------------------------------------------- -# Platform detection -# --------------------------------------------------------------------------- +as_root() { + if [ "$(id -u)" -eq 0 ]; then + "$@" + elif has_cmd sudo; then + sudo "$@" + else + error "this installer needs root privileges; rerun as root or install sudo" + fi +} + +target_user() { + if [ "$(id -u)" -eq 0 ] && [ -n "${SUDO_USER:-}" ] && [ "${SUDO_USER}" != "root" ]; then + echo "$SUDO_USER" + else + id -un + fi +} + +user_home() { + _user="$1" + if has_cmd getent; then + _home="$(getent passwd "$_user" | awk -F: '{ print $6 }')" + if [ -n "$_home" ]; then + echo "$_home" + return 0 + fi + fi + + if [ "$(uname -s)" = "Darwin" ] && has_cmd dscl; then + _home="$(dscl . -read "/Users/${_user}" NFSHomeDirectory 2>/dev/null | awk '{ print $2 }')" + if [ -n "$_home" ]; then + echo "$_home" + return 0 + fi + fi + + if [ "$(id -un)" = "$_user" ]; then + echo "${HOME:-}" + return 0 + fi -get_os() { + if [ "$(uname -s)" = "Darwin" ]; then + echo "/Users/${_user}" + return 0 + fi + + echo "/home/${_user}" +} + +as_target_user() { + if [ "${PLATFORM:-}" = "darwin" ]; then + if [ "$(id -u)" -eq "$TARGET_UID" ]; then + env HOME="$TARGET_HOME" "$@" + elif has_cmd sudo; then + sudo -u "$TARGET_USER" env HOME="$TARGET_HOME" "$@" + else + error "cannot run commands as ${TARGET_USER}; install sudo or run as ${TARGET_USER}" + fi + return + fi + + _bus="unix:path=${TARGET_RUNTIME_DIR}/bus" + if [ "$(id -u)" -eq "$TARGET_UID" ]; then + env HOME="$TARGET_HOME" XDG_RUNTIME_DIR="$TARGET_RUNTIME_DIR" DBUS_SESSION_BUS_ADDRESS="$_bus" "$@" + elif has_cmd sudo; then + sudo -u "$TARGET_USER" env HOME="$TARGET_HOME" XDG_RUNTIME_DIR="$TARGET_RUNTIME_DIR" DBUS_SESSION_BUS_ADDRESS="$_bus" "$@" + elif has_cmd runuser; then + runuser -u "$TARGET_USER" -- env HOME="$TARGET_HOME" XDG_RUNTIME_DIR="$TARGET_RUNTIME_DIR" DBUS_SESSION_BUS_ADDRESS="$_bus" "$@" + else + error "cannot run user service commands as ${TARGET_USER}; install sudo or run as ${TARGET_USER}" + fi +} + +detect_platform() { case "$(uname -s)" in - Darwin) echo "apple-darwin" ;; - Linux) echo "unknown-linux-musl" ;; - *) error "unsupported OS: $(uname -s)" ;; + Linux) + echo "linux" + ;; + Darwin) + echo "darwin" + ;; + *) + error "unsupported OS: $(uname -s); this installer supports Linux and macOS" + ;; esac } -get_arch() { - case "$(uname -m)" in - x86_64|amd64) echo "x86_64" ;; - aarch64|arm64) echo "aarch64" ;; - *) error "unsupported architecture: $(uname -m)" ;; - esac +linux_package_method() { + if has_cmd dpkg; then + echo "deb" + elif has_cmd rpm; then + echo "rpm" + else + error "Linux installs require either dpkg or rpm" + fi +} + +set_linux_target_runtime_dir() { + if [ "$(id -u)" -eq "$TARGET_UID" ] && [ -n "${XDG_RUNTIME_DIR:-}" ]; then + TARGET_RUNTIME_DIR="$XDG_RUNTIME_DIR" + else + TARGET_RUNTIME_DIR="/run/user/${TARGET_UID}" + fi } -get_target() { - _arch="$(get_arch)" - _os="$(get_os)" - _target="${_arch}-${_os}" +check_linux_deb_platform() { + require_cmd dpkg +} + +check_macos_platform() { + _arch="$(uname -m)" - # Only these targets have published binaries. - case "$_target" in - x86_64-unknown-linux-musl|aarch64-unknown-linux-musl|aarch64-apple-darwin) ;; - x86_64-apple-darwin) error "macOS x86_64 is not supported; use Apple Silicon (aarch64) or Rosetta 2" ;; - *) error "no prebuilt binary for $_target" ;; + case "$_arch" in + arm64|aarch64) + ;; + x86_64|amd64) + error "Intel macOS is not supported because no x86_64-apple-darwin release assets are published" + ;; + *) + error "no macOS release build is published for architecture: ${_arch}" + ;; esac - echo "$_target" + if ! as_target_user brew --version >/dev/null 2>&1; then + error "Homebrew is required for macOS installs; install it from https://brew.sh" + fi } -# --------------------------------------------------------------------------- -# Version resolution -# --------------------------------------------------------------------------- +get_deb_arch() { + _arch="$(dpkg --print-architecture)" -resolve_version() { - if [ -n "${OPENSHELL_VERSION:-}" ]; then - echo "$OPENSHELL_VERSION" - return 0 + case "$_arch" in + amd64|arm64) + echo "$_arch" + ;; + *) + error "no Debian package is published for architecture: ${_arch}" + ;; + esac +} + +get_rpm_arch() { + if has_cmd rpm; then + _arch="$(rpm --eval '%{_arch}' 2>/dev/null || true)" + else + _arch="" fi - # Resolve "latest" by following the GitHub releases/latest redirect. - # GitHub redirects /releases/latest -> /releases/tag/ - info "resolving latest version..." - _latest_url="${GITHUB_URL}/releases/latest" - _resolved="$(resolve_redirect "$_latest_url")" || error "failed to resolve latest release from ${_latest_url}" + if [ -z "$_arch" ]; then + _arch="$(uname -m)" + fi - # Validate that the redirect stayed on the expected GitHub origin. - # A MITM or DNS hijack could redirect to an attacker-controlled domain, - # which would also serve a matching checksums file (making checksum - # verification useless). See: https://github.com/NVIDIA/OpenShell/issues/638 - case "$_resolved" in - https://github.com/${REPO}/releases/*) + case "$_arch" in + x86_64|amd64) + echo "x86_64" + ;; + aarch64|arm64) + echo "aarch64" ;; *) - error "unexpected redirect target: ${_resolved} (expected https://github.com/${REPO}/releases/...)" + error "no RPM package is published for architecture: ${_arch}" ;; esac +} - # Extract the tag from the resolved URL: .../releases/tag/v0.0.4 -> v0.0.4 - _version="${_resolved##*/}" +find_deb_asset() { + _checksums="$1" + _arch="$2" + + awk -v arch="$_arch" ' + $2 ~ "^\\*?openshell[-_].*[-_]" arch "\\.deb$" { + sub("^\\*", "", $2) + print $2 + exit + } + ' "$_checksums" +} - if [ -z "$_version" ] || [ "$_version" = "latest" ]; then - error "could not determine latest release version (resolved URL: ${_resolved})" +find_rpm_asset() { + _checksums="$1" + _arch="$2" + _package="$3" + + case "$_package" in + openshell) + _dev_name="openshell-dev-${_arch}.rpm" + _fallback_re="^openshell-[0-9].*\\.${_arch}\\.rpm$" + ;; + openshell-gateway) + _dev_name="openshell-gateway-dev-${_arch}.rpm" + _fallback_re="^openshell-gateway-[0-9].*\\.${_arch}\\.rpm$" + ;; + *) + error "unknown RPM package selector: ${_package}" + ;; + esac + + awk -v dev_name="$_dev_name" -v fallback_re="$_fallback_re" ' + { + name = $2 + sub("^\\*", "", name) + + if (name == dev_name) { + selected = name + found = 1 + exit + } + + if (fallback == "" && name ~ fallback_re) { + fallback = name + } + } + END { + if (found) { + print selected + } else if (fallback != "") { + print fallback + } + } + ' "$_checksums" +} + +verify_checksum() { + _archive="$1" + _checksums="$2" + _filename="$3" + + if has_cmd sha256sum; then + _expected="$(awk -v name="$_filename" '($2 == name || $2 == "*" name) { print $1; exit }' "$_checksums")" + [ -n "$_expected" ] || error "no checksum entry found for ${_filename}" + echo "$_expected $_archive" | sha256sum -c --quiet + elif has_cmd shasum; then + _expected="$(awk -v name="$_filename" '($2 == name || $2 == "*" name) { print $1; exit }' "$_checksums")" + [ -n "$_expected" ] || error "no checksum entry found for ${_filename}" + echo "$_expected $_archive" | shasum -a 256 -c --quiet + else + error "neither 'sha256sum' nor 'shasum' found; cannot verify download integrity" fi +} - echo "$_version" +install_deb_package() { + _deb_path="$1" + + if has_cmd apt-get; then + as_root env DEBIAN_FRONTEND=noninteractive apt-get install -y \ + -o Dpkg::Options::=--force-confdef \ + -o Dpkg::Options::=--force-confnew \ + "$_deb_path" + elif has_cmd apt; then + as_root env DEBIAN_FRONTEND=noninteractive apt install -y \ + -o Dpkg::Options::=--force-confdef \ + -o Dpkg::Options::=--force-confnew \ + "$_deb_path" + else + as_root dpkg --force-confdef --force-confnew -i "$_deb_path" + fi } -# --------------------------------------------------------------------------- -# Checksum verification -# --------------------------------------------------------------------------- +install_rpm_packages() { + if has_cmd dnf; then + as_root dnf install -y "$@" + elif has_cmd yum; then + as_root yum install -y "$@" + elif has_cmd zypper; then + as_root zypper --non-interactive install --allow-unsigned-rpm "$@" + elif has_cmd rpm; then + warn "installing with rpm directly; dependencies must already be installed" + as_root rpm -Uvh --replacepkgs "$@" + else + error "'dnf', 'yum', 'zypper', or 'rpm' is required to install RPM packages" + fi +} -verify_checksum() { - _vc_archive="$1" - _vc_checksums="$2" - _vc_filename="$3" +homebrew_formula_path() { + _tap="$1" + _formula="$2" - if ! has_cmd shasum && ! has_cmd sha256sum; then - error "neither 'shasum' nor 'sha256sum' found; cannot verify download integrity" + if ! as_target_user brew tap-info "$_tap" >/dev/null 2>&1; then + info "creating local Homebrew tap ${_tap}..." + as_target_user brew tap-new --no-git "$_tap" >/dev/null fi - _vc_expected="$(grep -F "$_vc_filename" "$_vc_checksums" | awk '{print $1}')" + _tap_dir="$(as_target_user brew --repository "$_tap" 2>/dev/null || true)" + [ -n "$_tap_dir" ] || error "could not locate Homebrew tap ${_tap}" + + _formula_dir="${_tap_dir}/Formula" + as_target_user mkdir -p "$_formula_dir" + printf '%s/%s.rb\n' "$_formula_dir" "$_formula" +} + +patch_homebrew_formula() { + _formula_file="$1" + _patched_file="${_formula_file}.patched" - if [ -z "$_vc_expected" ]; then - error "no checksum entry found for $_vc_filename in checksums file" + if grep -q 'entitlements.write <<~XML' "$_formula_file"; then + info "patching Homebrew formula for idempotent postinstall..." + sed 's/entitlements\.write <<~XML/entitlements.atomic_write <<~XML/' "$_formula_file" >"$_patched_file" + mv "$_patched_file" "$_formula_file" fi - if has_cmd shasum; then - echo "$_vc_expected $_vc_archive" | shasum -a 256 -c --quiet 2>/dev/null - elif has_cmd sha256sum; then - echo "$_vc_expected $_vc_archive" | sha256sum -c --quiet 2>/dev/null +} + +start_user_gateway() { + info "restarting openshell-gateway user service as ${TARGET_USER}..." + + if ! as_target_user systemctl --user daemon-reload; then + info "could not reach the user systemd manager for ${TARGET_USER}" + info "restart the gateway later with: systemctl --user enable openshell-gateway && systemctl --user restart openshell-gateway" + info "then register it with: openshell gateway add https://127.0.0.1:17670 --local --name openshell" + return 0 fi + + as_target_user systemctl --user enable openshell-gateway + as_target_user systemctl --user restart openshell-gateway + as_target_user systemctl --user is-active --quiet openshell-gateway + + info "registering local gateway as ${TARGET_USER}..." + register_local_gateway + wait_for_local_gateway_listener + wait_for_local_gateway_status } -# --------------------------------------------------------------------------- -# Install location -# --------------------------------------------------------------------------- +wait_for_local_gateway_listener() { + _timeout="${OPENSHELL_INSTALL_GATEWAY_TIMEOUT:-30}" + _elapsed=0 + _last_output="" + _probe_url="https://127.0.0.1:${LOCAL_GATEWAY_PORT}/" + _mtls_dir="${TARGET_HOME}/.config/openshell/gateways/openshell/mtls" + + info "waiting for local gateway listener to become reachable..." + while [ "$_elapsed" -lt "$_timeout" ]; do + if [ ! -f "${_mtls_dir}/ca.crt" ] || [ ! -f "${_mtls_dir}/tls.crt" ] || [ ! -f "${_mtls_dir}/tls.key" ]; then + _last_output="mTLS client bundle is not ready under ${_mtls_dir}" + elif _last_output="$(as_target_user curl -sS --max-time 2 --cacert "${_mtls_dir}/ca.crt" --cert "${_mtls_dir}/tls.crt" --key "${_mtls_dir}/tls.key" -o /dev/null "$_probe_url" 2>&1)"; then + info "local gateway listener is reachable" + return 0 + fi + sleep 1 + _elapsed=$((_elapsed + 1)) + done + + printf '%s\n' "$_last_output" >&2 + error "local gateway listener did not become reachable at ${_probe_url} within ${_timeout}s" +} + +wait_for_local_gateway_status() { + _timeout="${OPENSHELL_INSTALL_GATEWAY_TIMEOUT:-30}" + _elapsed=0 + _status_output="" + _register_bin="${OPENSHELL_REGISTER_BIN:-openshell}" + + info "waiting for openshell status to report connected..." + while [ "$_elapsed" -lt "$_timeout" ]; do + if _status_output="$(as_target_user env NO_COLOR=1 "$_register_bin" status 2>&1)"; then + case "$_status_output" in + *"Version:"*) + info "openshell status reports connected" + return 0 + ;; + esac + fi + sleep 1 + _elapsed=$((_elapsed + 1)) + done + + printf '%s\n' "$_status_output" >&2 + error "openshell status did not report connected within ${_timeout}s" +} + +remove_local_gateway_registration() { + [ -n "$TARGET_HOME" ] || error "cannot resolve home directory for ${TARGET_USER}" + _config_dir="${TARGET_HOME}/.config/openshell" + + # The install-dev gateway is a user service. Replace the CLI registration + # directly instead of asking `gateway destroy` to tear down Docker resources. + # shellcheck disable=SC2016 + as_target_user sh -c ' + config_dir=$1 + rm -rf "${config_dir}/gateways/local" + mkdir -p "${config_dir}/gateways/openshell" + rm -f \ + "${config_dir}/gateways/openshell/metadata.json" \ + "${config_dir}/gateways/openshell/edge_token" \ + "${config_dir}/gateways/openshell/cf_token" \ + "${config_dir}/gateways/openshell/oidc_token.json" + active="${config_dir}/active_gateway" + active_name="$(cat "$active" 2>/dev/null || true)" + if [ "$active_name" = "local" ] || [ "$active_name" = "openshell" ]; then + rm -f "$active" + fi + ' sh "$_config_dir" +} -get_install_dir() { - if [ -n "${OPENSHELL_INSTALL_DIR:-}" ]; then - echo "$OPENSHELL_INSTALL_DIR" +register_local_gateway() { + _register_bin="${OPENSHELL_REGISTER_BIN:-openshell}" + + if _add_output="$(as_target_user "$_register_bin" gateway add "https://127.0.0.1:${LOCAL_GATEWAY_PORT}" --local --name openshell 2>&1)"; then + [ -z "$_add_output" ] || print_gateway_add_output "$_add_output" + return 0 else - echo "${HOME}/.local/bin" + _add_status=$? fi -} -# Check if a directory is already on PATH. -is_on_path() { - _dir="$1" - case ":${PATH}:" in - *":${_dir}:"*) return 0 ;; - *) return 1 ;; + case "$_add_output" in + *"already exists"*) + info "local gateway already exists; removing and re-adding it..." + remove_local_gateway_registration + as_target_user "$_register_bin" gateway add "https://127.0.0.1:${LOCAL_GATEWAY_PORT}" --local --name openshell + ;; + *) + printf '%s\n' "$_add_output" >&2 + return "$_add_status" + ;; esac } -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - -main() { - # Parse CLI flags - for arg in "$@"; do - case "$arg" in - --help) - usage - exit 0 - ;; - *) - error "unknown option: $arg" - ;; +print_gateway_add_output() { + printf '%s\n' "$1" | while IFS= read -r _line; do + case "$_line" in + *"Gateway is not reachable at https://127.0.0.1:${LOCAL_GATEWAY_PORT}"*) ;; + *"Verify the gateway is running and the endpoint is correct."*) ;; + *) printf '%s\n' "$_line" >&2 ;; esac done +} - check_downloader - - _version="$(resolve_version)" - _target="$(get_target)" - _filename="${APP_NAME}-${_target}.tar.gz" - _download_url="${GITHUB_URL}/releases/download/${_version}/${_filename}" - _checksums_url="${GITHUB_URL}/releases/download/${_version}/${APP_NAME}-checksums-sha256.txt" - _install_dir="$(get_install_dir)" - - info "downloading ${APP_NAME} ${_version} (${_target})..." +install_linux_deb() { + check_linux_deb_platform + set_linux_target_runtime_dir + _arch="$(get_deb_arch)" _tmpdir="$(mktemp -d)" + chmod 0755 "$_tmpdir" trap 'rm -rf "$_tmpdir"' EXIT - if ! download "$_download_url" "${_tmpdir}/${_filename}"; then - error "failed to download ${_download_url}" + _checksums_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${CHECKSUMS_NAME}" + info "downloading ${RELEASE_TAG} release checksums..." + download "$_checksums_url" "${_tmpdir}/${CHECKSUMS_NAME}" || { + error "failed to download ${_checksums_url}" + } + + _deb_file="$(find_deb_asset "${_tmpdir}/${CHECKSUMS_NAME}" "$_arch")" + if [ -z "$_deb_file" ]; then + error "no Debian package found for architecture: ${_arch}" fi - # Verify checksum (mandatory — never skip) + _deb_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${_deb_file}" + _deb_path="${_tmpdir}/${_deb_file}" + + info "selected ${_deb_file}" + + info "downloading ${_deb_file}..." + download_release_asset "$RELEASE_TAG" "$_deb_file" "$_deb_path" || { + error "failed to download ${_deb_url}" + } + chmod 0644 "$_deb_path" + info "verifying checksum..." - if ! download "$_checksums_url" "${_tmpdir}/checksums.txt"; then - error "failed to download checksums file from ${_checksums_url}" + verify_checksum "$_deb_path" "${_tmpdir}/${CHECKSUMS_NAME}" "$_deb_file" + + info "installing ${_deb_file}..." + install_deb_package "$_deb_path" + info "installed ${APP_NAME} package from ${RELEASE_TAG}" + start_user_gateway +} + +install_linux_rpm() { + require_cmd rpm + set_linux_target_runtime_dir + + _arch="$(get_rpm_arch)" + _tmpdir="$(mktemp -d)" + chmod 0755 "$_tmpdir" + trap 'rm -rf "$_tmpdir"' EXIT + + _checksums_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${CHECKSUMS_NAME}" + info "downloading ${RELEASE_TAG} release checksums..." + download "$_checksums_url" "${_tmpdir}/${CHECKSUMS_NAME}" || { + error "failed to download ${_checksums_url}" + } + + _rpm_file="$(find_rpm_asset "${_tmpdir}/${CHECKSUMS_NAME}" "$_arch" openshell)" + if [ -z "$_rpm_file" ]; then + error "no openshell RPM package found for architecture: ${_arch}" fi - if ! verify_checksum "${_tmpdir}/${_filename}" "${_tmpdir}/checksums.txt" "$_filename"; then - error "checksum verification failed for ${_filename}" + + _gateway_rpm_file="$(find_rpm_asset "${_tmpdir}/${CHECKSUMS_NAME}" "$_arch" openshell-gateway)" + if [ -z "$_gateway_rpm_file" ]; then + error "no openshell-gateway RPM package found for architecture: ${_arch}" fi - # Extract - info "extracting..." - tar -xzf "${_tmpdir}/${_filename}" -C "${_tmpdir}" --no-same-owner --no-same-permissions "${APP_NAME}" + info "selected ${_rpm_file} and ${_gateway_rpm_file}" + + for _package_file in "$_rpm_file" "$_gateway_rpm_file"; do + _package_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/${_package_file}" + _package_path="${_tmpdir}/${_package_file}" + + info "downloading ${_package_file}..." + download_release_asset "$RELEASE_TAG" "$_package_file" "$_package_path" || { + error "failed to download ${_package_url}" + } + chmod 0644 "$_package_path" + + info "verifying checksum for ${_package_file}..." + verify_checksum "$_package_path" "${_tmpdir}/${CHECKSUMS_NAME}" "$_package_file" + done + + info "installing ${_rpm_file} and ${_gateway_rpm_file}..." + install_rpm_packages "${_tmpdir}/${_rpm_file}" "${_tmpdir}/${_gateway_rpm_file}" + info "installed ${APP_NAME} RPM packages from ${RELEASE_TAG}" + start_user_gateway +} + +install_macos_homebrew() { + check_macos_platform - # Install - mkdir -p "$_install_dir" 2>/dev/null || true + _tmpdir="$(mktemp -d)" + chmod 0755 "$_tmpdir" + trap 'rm -rf "$_tmpdir"' EXIT + + _formula_file="${_tmpdir}/openshell.rb" + _formula_url="${GITHUB_URL}/releases/download/${RELEASE_TAG}/openshell.rb" + + info "downloading Homebrew formula from ${_formula_url}..." + download_release_asset "$RELEASE_TAG" "openshell.rb" "$_formula_file" || { + error "failed to download ${_formula_url}; the selected release may not include a Homebrew formula" + } + chmod 0644 "$_formula_file" + patch_homebrew_formula "$_formula_file" + + _tap_formula_file="$(homebrew_formula_path "$HOMEBREW_TAP" "$HOMEBREW_FORMULA_NAME")" + info "staging Homebrew formula in tap ${HOMEBREW_TAP}..." + cp "$_formula_file" "$_tap_formula_file" + chmod 0644 "$_tap_formula_file" + if [ "$(id -u)" -eq 0 ]; then + chown "$TARGET_USER" "$_tap_formula_file" 2>/dev/null || true + fi - if [ -w "$_install_dir" ] || mkdir -p "$_install_dir" 2>/dev/null; then - install -m 755 "${_tmpdir}/${APP_NAME}" "${_install_dir}/${APP_NAME}" + _formula_ref="${HOMEBREW_TAP}/${HOMEBREW_FORMULA_NAME}" + + if as_target_user brew list --formula openshell >/dev/null 2>&1; then + info "reinstalling OpenShell with Homebrew..." + as_target_user brew reinstall --formula "$_formula_ref" else - info "elevated permissions required to install to ${_install_dir}" - sudo mkdir -p "$_install_dir" - sudo install -m 755 "${_tmpdir}/${APP_NAME}" "${_install_dir}/${APP_NAME}" - fi - - _installed_version="$("${_install_dir}/${APP_NAME}" --version 2>/dev/null || echo "${_version}")" - info "installed ${_installed_version} to ${_install_dir}/${APP_NAME}" - - # If the install directory isn't on PATH, print instructions - if ! is_on_path "$_install_dir"; then - echo "" - info "${_install_dir} is not on your PATH." - info "" - info "Add it by appending the following to your shell configuration file" - info "(e.g. ~/.bashrc, ~/.zshrc, or ~/.config/fish/config.fish):" - info "" - - _current_shell="$(basename "${SHELL:-sh}" 2>/dev/null || echo "sh")" - case "$_current_shell" in - fish) - info " fish_add_path ${_install_dir}" + info "installing OpenShell with Homebrew..." + as_target_user brew install --formula "$_formula_ref" + fi + + info "restarting OpenShell Homebrew service..." + if ! as_target_user brew services restart "$_formula_ref"; then + warn "could not restart the OpenShell Homebrew service" + info "restart it later with: brew services restart ${_formula_ref}" + info "then register it with: openshell gateway add https://127.0.0.1:${LOCAL_GATEWAY_PORT} --local --name openshell" + return 0 + fi + + _brew_prefix="$(as_target_user brew --prefix 2>/dev/null || true)" + if [ -n "$_brew_prefix" ] && [ -x "${_brew_prefix}/bin/openshell" ]; then + OPENSHELL_REGISTER_BIN="${_brew_prefix}/bin/openshell" + fi + + info "registering local gateway as ${TARGET_USER}..." + register_local_gateway + wait_for_local_gateway_listener + wait_for_local_gateway_status +} + +main() { + if [ "$#" -gt 0 ]; then + case "$1" in + --help) + usage + exit 0 ;; *) - info " export PATH=\"${_install_dir}:\$PATH\"" + error "unknown option: $1" ;; esac - - info "" - info "Then restart your shell or run the command above in your current session." fi + + require_cmd curl + RELEASE_TAG="$(resolve_release_tag)" + PLATFORM="$(detect_platform)" + + TARGET_USER="$(target_user)" + TARGET_UID="$(id -u "$TARGET_USER" 2>/dev/null || true)" + [ -n "$TARGET_UID" ] || error "cannot resolve uid for ${TARGET_USER}" + TARGET_HOME="$(user_home "$TARGET_USER")" + + case "$PLATFORM" in + linux) + case "$(linux_package_method)" in + deb) + install_linux_deb + ;; + rpm) + install_linux_rpm + ;; + *) + error "unsupported Linux package method" + ;; + esac + ;; + darwin) + install_macos_homebrew + ;; + *) + error "unsupported platform: ${PLATFORM}" + ;; + esac } main "$@" From 1c79b2131f487bdc9c6a74aeedf97bf58e6e2c88 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 8 May 2026 17:14:27 -0700 Subject: [PATCH 019/157] feat: agent-driven policy management MVP (#1151) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * docs(rfc): add agent-driven policy management * docs(rfc): switch policy MVP to local API * docs(rfc): clarify policy advisor skill and local logs * feat(sandbox): add agent-driven policy proposal loop * test(examples): add codex policy dogfood loop * refactor(examples): make policy demo agent-agnostic * refactor(examples): colocate policy validation harness * docs(examples): add policy demo env sample * docs(examples): use placeholder env example * feat(sandbox): wire policy.local denials to OCSF JSONL log Wires GET /v1/denials?last=N on the sandbox-local policy advisor API to read recent OCSF JSONL events from /var/log/openshell-ocsf.YYYY-MM-DD.log, filter to network/L7 denials (action_id=2, class_uid 4001/4002), and return a compact summary newest-first. Default limit is 10, capped at 100. Ran inside spawn_blocking so file I/O does not block the policy.local handler. Other cleanup: - POST /v1/proposals now uses the typed grpc_client wrapper instead of raw_client, so accepted/rejected counts surface to the agent uniformly. Wrapper return type extended to the response struct. - Drop the 'add_rule' snake_case alias in the proposal JSON; canonical form is camelCase 'addRule', matching the PolicyMergeOperation convention used elsewhere. - skills/policy_advisor.md updated to match: documents the now-real /v1/denials?last=10 endpoint and uses 'addRule' consistently. - skills.rs test asserts on the canonical 'addRule' phrase rather than the removed 'PolicyMergeOperation' substring. * feat(cli): show L7 protocol/method/path in rule get output format_endpoint() previously rendered only host:port, dropping protocol, access, and the L7 rules array. That made openshell rule get text output unable to distinguish a broad L4 grant from a method/path-scoped L7 REST rule -- exactly the distinction a developer needs at approval time. New rendering tags each endpoint with its enforcement layer and surfaces allow/deny rules: bare L4: api.example:443 [L4] L7 read-only: api.example:443 [L7 rest, access=read-only] L7 method/path: api.example:443 [L7 rest, allow PUT /v1/foo/bar] Pure display change: no proto, gateway, or behavior changes. Unit test covers all three rendering cases with synthetic fixtures. * refactor(examples): rewrite policy demo as Codex-default loop Re-shape examples/agent-driven-policy-management/ to be a single, clean end-to-end demonstration of the agent-driven policy loop. A Codex agent inside an OpenShell sandbox attempts a GitHub Contents API write, hits a structured 403 from the L7 proxy, reads the policy_advisor skill, drafts a narrow addRule proposal via http://policy.local/v1/proposals, the host auto-approves, the sandbox hot-reloads policy, and the agent's retry succeeds. Whole loop runs in roughly two minutes. Demo cleanup: - Drop .env file ceremony. Defaults resolve from gh: owner via 'gh api user --jq .login', repo defaults to 'openshell-policy-demo', token from gh auth token / GITHUB_TOKEN / GH_TOKEN. With gh auth login and codex login already done, 'bash demo.sh' Just Works. - Codex-specific. Bootstraps ~/.codex/auth.json from credentials injected by the OpenShell provider, runs codex exec --sandbox danger-full-access (OpenShell is the actual security boundary; bwrap nesting cannot create user namespaces inside the sandbox container). - Tighter narrative output: a single 'Preflight' step, a run summary banner before launch, an inline narration of what's happening inside the sandbox while we poll for the proposal (including the literal structured 403 body the agent acts on), and an OCSF trace at the end filtered to the three events that tell the story (DENY, RELOAD, ALLOW). - Replace Python heredoc templating with sed; uploads use the single-flag pattern (--upload "${PAYLOAD_DIR}:/sandbox") with files referenced at the basename-prefixed path that #952 / #1028 established. - README documents the trust model honestly: structured rule is the contract, agent rationale is a hint, prover validation badge in progress per RFC 0001. Move the deterministic no-LLM regression harness out of examples/ into e2e/policy-advisor/ -- it was a parallel demo, not an example. Same loop without the LLM, useful for iterating on the proxy and policy.local API. * style(sandbox,cli): apply rustfmt Whitespace-only fixups caught by mise run pre-commit. No functional change. * perf(examples): cap Codex reasoning at 'low' in policy demo The demo task is mechanical (one HTTP request, parse a structured 403, post a JSON proposal, retry). Codex's default high-effort reasoning roughly doubles the demo's wall time without improving outcomes; running at 'low' lands the same minimal L7 grant in roughly half the time. Override with DEMO_CODEX_REASONING=medium (or higher) to compare runs. * fix(sandbox): harden policy.local denials endpoint Three changes addressing review feedback before merging the agent-driven policy management MVP: - Distinguish "OCSF JSONL enabled, no denials" from "OCSF JSONL disabled, nothing to read." The endpoint now returns a `log_available` flag and an explanatory `note` when the log file is missing, so the in-sandbox agent can give the developer an accurate hint instead of a misleading empty list. - Stop echoing the OCSF `message` field in the per-denial summary. The proxy's denial messages can include the request path with query string (e.g., `?access_token=...`); the structured `host`/`port`/`method`/ `path`/`binary` fields carry everything the agent needs to draft a proposal, and `path` is sourced from `http_request.url.path` which already excludes the query string. - Cap `read_request_body` at a 15s timeout. Bounds slowloris-style stalls from a misbehaving in-sandbox process. The proxy listener only accepts loopback connections so practical impact is small, but this is cheap defense-in-depth. New tests cover the missing-log signal and the message-redaction guarantee. * fix(examples): redact tokens in agent log tail and validate DEMO_FILE_DIR Two small hardening passes on the policy management demo: - `fail()` now pipes the agent log tail through a redactor that masks the GitHub token and Codex credential triple before printing. Codex itself is well-behaved about not echoing the token, but a misbehaving tool call could leak it; this is a final safety net before the log hits the developer's terminal (and any clipboard or chat history that follows). - `validate_env` now regex-checks DEMO_FILE_DIR with the same allow-list the other path-shaped variables use. The value is interpolated through sed with `|` as the delimiter when rendering the agent task; rejecting unsupported characters keeps the templating predictable and stops a user-supplied value from breaking out into a shell context. * refactor(sandbox): centralize policy.local routes and skill path Addresses review feedback that the deny body's `next_steps` array and the route table could drift apart. The route paths and skill location now live as `pub const`s in `policy_local.rs` and feed both: - the dispatcher in `route_request` that matches against them - a new `agent_next_steps()` helper that builds the JSON the L7 deny body embeds `l7/rest.rs::deny_response_body` calls `policy_local::agent_next_steps()` instead of inlining the array, so adding or renaming a route is a one-line change in `policy_local.rs` and the agent contract follows automatically. * feat(sandbox): switch /v1/denials to shorthand log pass-through Previously /v1/denials parsed `/var/log/openshell-ocsf.*.log` (OCSF JSONL) and returned structured per-event objects. JSONL is opt-in via `ocsf_json_enabled`, so the endpoint returned an empty list with a "log not enabled" hint by default — agents had to navigate a setup step before the inspect-recent-denials guidance was useful. Switch to reading the shorthand log at `/var/log/openshell.*.log`, which is always-on and the same human-readable format `openshell logs` displays. The endpoint now returns raw shorthand lines (newest first) — the agent reads them directly, no field parsing. Tradeoffs: - Removes the JSONL-on-by-default debate: shorthand is already on, no defaults change. - Updating shorthand is a single-file change in this repo; no schema rev needed when we want to add fields. Implementation: - `read_recent_denial_lines` walks shorthand log files newest-first, filters lines with ` OCSF ` AND ` DENIED ` (the OCSF action label, uppercase, space-bounded). - `collect_shorthand_log_files` matches `openshell..log`; the trailing dot in `SHORTHAND_LOG_PREFIX = "openshell."` excludes `openshell-ocsf..log` so JSONL-on doesn't bleed into responses. - 4096-byte cap per surfaced line as defense against pathological inputs. - Skill doc updated to reflect that `/v1/denials` returns raw shorthand lines, not structured fields. Defense-in-depth on query-string secrets: - `redact_query_strings` strips `?` to `?[redacted]` from each surfaced line. The L7 relay path emits OCSF events using `redacted_target` (secret-placeholder redaction), but the FORWARD deny path in `proxy.rs` populates `OcsfUrl::new("http", host, path, port)` and `.message(...)` with the raw request path — query string included. Stripping queries at the consumer guards `/v1/denials` regardless of whether the upstream emit sites are tightened. The on-disk log is not rewritten by this change; that is a separate hardening task tracked for the FORWARD path emit sites in proxy.rs. - `truncate_at_char_boundary` is UTF-8 safe; redaction runs before truncation so a cut cannot slice mid-secret. Tests: - `recent_denials_returns_newest_first_from_shorthand_lines` covers the happy path with mixed allowed/denied/non-OCSF lines. - `recent_denials_skips_jsonl_log_files` confirms JSONL files don't surface even if present. - `recent_denials_truncates_pathological_lines` covers the cap. - `is_ocsf_denial_line_filters_correctly` covers the line-level filter. - `redact_query_strings_removes_query_from_url_token` and `redact_query_strings_removes_query_in_reason_tag` cover the redaction in both URL token and `[reason:...]` contexts. - `truncate_at_char_boundary_does_not_panic_on_multibyte` covers the UTF-8 safety. * chore(sandbox): align proto inits with main's L7 GraphQL additions Post-rebase fixups after #1083 (GraphQL L7 inspection) landed on main and introduced new fields on the proto types this branch constructs: - `crates/openshell-sandbox/src/l7/relay.rs`: two `deny_with_redacted_target` call sites (REST and GraphQL relay deny paths) now pass the `DenyResponseContext` argument that `rest::send_deny_response` expects. Both sites pass `host`, `port`, and `binary` from the existing `L7EvalContext`, matching the pattern used at the primary deny site. - `crates/openshell-sandbox/src/policy_local.rs`: `L7Allow`, `L7DenyRule`, and `NetworkEndpoint` proto initializers now populate the new GraphQL and path-scoping fields with empty defaults. Agent-authored proposals via `policy.local` target REST/SQL/L4 today; GraphQL operation matching is set on the gateway side or via direct YAML, so empty defaults are correct here. No behavior change. `cargo test -p openshell-sandbox --lib` (650 tests) and `cargo clippy -p openshell-sandbox --lib --tests -- -D warnings` clean. * feat(sandbox): gate agent policy proposals behind opt-in feature flag The agent-driven policy proposal surface delivered by this PR (skill install, `policy.local` API, `next_steps` array on L7 deny bodies) is now opt-in via the new `agent_policy_proposals_enabled` setting. Default false. Same shape as `providers_v2_enabled`: registered in `openshell-core::settings`, sandbox-level, hot-toggleable via the existing settings poll loop. Why: the surface is a novel agent-controlled mutation point in every sandbox. The per-proposal developer approval gate is a correctness control, but it doesn't address "should this sandbox have an agent-authoring API at all" — compliance teams may want that question closed. The flag is the second gate. Implementation: - New registry entry + `AGENT_POLICY_PROPOSALS_ENABLED_KEY` constant in `openshell-core::settings`. - `lib.rs`: process-wide `OnceLock>` mirroring the `OCSF_CTX` pattern. `agent_proposals_enabled()` is the single read point. - Initial settings fetch added to `run_sandbox` so skill install honors the flag at startup (not just on the poll loop's first tick). - Skill install in `run_sandbox` is gated on the flag. - `policy_local::route_request` returns `404 feature_disabled` for all routes when the flag is off — including the otherwise-public `current_policy` and `denials` routes. When the surface is off it's off entirely. - `policy_local::agent_next_steps` returns an empty array when the flag is off so deny bodies don't advertise routes that 404. - Poll loop updates the atomic on each tick, lazily installs the skill on a false→true transition (no claw-back on true→false; stale skill on disk is harmless because route + next_steps gate on the live atom). Tests: - Shared `test_helpers::ProposalsFlagGuard` mutex+atomic guard for the process-wide flag, used across `policy_local::tests` and `l7::rest::tests`. - New: `agent_next_steps_returns_empty_when_flag_off`, `agent_next_steps_returns_full_array_when_flag_on`, `route_request_returns_feature_disabled_when_flag_off`. - Updated existing tests that exercise the deny body or the route dispatcher to set the flag on first. - Full sandbox lib test suite: 653 pass, clippy clean. Demo and e2e: - `examples/agent-driven-policy-management/demo.sh` and `e2e/policy-advisor/test.sh` now snapshot the prior global value of the setting, set it to true before sandbox creation (so the supervisor's initial poll picks it up), and restore on exit (delete if previously unset, otherwise write the prior value back). Docs: - RFC 0001 MVP-implementation note documents the flag, default, and intended soft-launch posture. * test(policy-advisor): require proposal opt-in for e2e * refactor(sandbox): group policy poll loop state * test(e2e): isolate Kubernetes user namespace test --------- Co-authored-by: John Myers <9696606+johntmyers@users.noreply.github.com> --- crates/openshell-cli/src/run.rs | 109 +- crates/openshell-core/src/settings.rs | 15 + crates/openshell-sandbox/src/grpc_client.rs | 17 +- crates/openshell-sandbox/src/l7/relay.rs | 15 + crates/openshell-sandbox/src/l7/rest.rs | 198 ++- crates/openshell-sandbox/src/lib.rs | 234 +++- crates/openshell-sandbox/src/policy_local.rs | 1157 +++++++++++++++++ crates/openshell-sandbox/src/proxy.rs | 48 +- crates/openshell-sandbox/src/skills.rs | 63 + .../src/skills/policy_advisor.md | 100 ++ e2e/policy-advisor/README.md | 54 + e2e/policy-advisor/policy.template.yaml | 28 + e2e/policy-advisor/sandbox-runner.sh | 142 ++ e2e/policy-advisor/test.sh | 410 ++++++ .../agent-driven-policy-management/README.md | 87 ++ .../agent-task.md | 50 + .../agent-driven-policy-management/demo.sh | 442 +++++++ .../policy.template.yaml | 69 + .../sandbox-agent.sh | 82 ++ rfc/0001-agent-driven-policy-management.md | 723 ++++++++++ 20 files changed, 3986 insertions(+), 57 deletions(-) create mode 100644 crates/openshell-sandbox/src/policy_local.rs create mode 100644 crates/openshell-sandbox/src/skills.rs create mode 100644 crates/openshell-sandbox/src/skills/policy_advisor.md create mode 100644 e2e/policy-advisor/README.md create mode 100644 e2e/policy-advisor/policy.template.yaml create mode 100755 e2e/policy-advisor/sandbox-runner.sh create mode 100755 e2e/policy-advisor/test.sh create mode 100644 examples/agent-driven-policy-management/README.md create mode 100644 examples/agent-driven-policy-management/agent-task.md create mode 100755 examples/agent-driven-policy-management/demo.sh create mode 100644 examples/agent-driven-policy-management/policy.template.yaml create mode 100755 examples/agent-driven-policy-management/sandbox-agent.sh create mode 100644 rfc/0001-agent-driven-policy-management.md diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index fc30b03d6..165713b6e 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -5422,17 +5422,57 @@ pub async fn sandbox_draft_history(server: &str, name: &str, tls: &TlsOptions) - fn format_endpoints(rule: &openshell_core::proto::NetworkPolicyRule) -> String { rule.endpoints .iter() - .map(|e| { - if e.port > 0 { - format!("{}:{}", e.host, e.port) - } else { - e.host.clone() - } - }) + .map(format_endpoint) .collect::>() .join(", ") } +/// Render an endpoint as `host:port [layer, …allows…, …denies…]` so a reader +/// can tell L4-only access apart from a method/path-scoped L7 grant. The L7 +/// fields (`protocol: rest`, `rules`, `access`) materially change what gets +/// allowed; surfacing them in the default text output is what makes +/// `openshell rule get` useful for approval review. +fn format_endpoint(endpoint: &openshell_core::proto::NetworkEndpoint) -> String { + let host_port = if endpoint.port > 0 { + format!("{}:{}", endpoint.host, endpoint.port) + } else { + endpoint.host.clone() + }; + + let mut tags: Vec = Vec::new(); + let layer_tag = if endpoint.protocol.eq_ignore_ascii_case("rest") { + "L7 rest" + } else if endpoint.protocol.is_empty() { + "L4" + } else { + endpoint.protocol.as_str() + }; + tags.push(layer_tag.to_string()); + + if !endpoint.access.is_empty() { + tags.push(format!("access={}", endpoint.access)); + } + + for r in &endpoint.rules { + if let Some(allow) = &r.allow { + let method = non_empty_or(&allow.method, "*"); + let path = non_empty_or(&allow.path, "*"); + tags.push(format!("allow {method} {path}")); + } + } + for r in &endpoint.deny_rules { + let method = non_empty_or(&r.method, "*"); + let path = non_empty_or(&r.path, "*"); + tags.push(format!("deny {method} {path}")); + } + + format!("{host_port} [{}]", tags.join(", ")) +} + +fn non_empty_or<'a>(value: &'a str, fallback: &'a str) -> &'a str { + if value.is_empty() { fallback } else { value } +} + /// Format a millisecond timestamp into a readable string. fn format_timestamp_ms(ms: i64) -> String { if ms <= 0 { @@ -5452,10 +5492,11 @@ fn format_timestamp_ms(ms: i64) -> String { #[cfg(test)] mod tests { use super::{ - TlsOptions, dockerfile_sources_supported_for_gateway, format_gateway_select_header, - format_gateway_select_items, format_provider_attachment_table, gateway_add, - gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, - git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, + TlsOptions, dockerfile_sources_supported_for_gateway, format_endpoint, + format_gateway_select_header, format_gateway_select_items, + format_provider_attachment_table, gateway_add, gateway_auth_label, + gateway_env_override_warning, gateway_select_with, gateway_type_label, git_sync_files, + http_health_check, image_requests_gpu, import_local_package_mtls_bundle, inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, parse_credential_pairs, plaintext_gateway_is_remote, provisioning_timeout_message, ready_false_condition_message, resolve_from, sandbox_should_persist, @@ -6254,4 +6295,50 @@ mod tests { server.join().expect("server thread"); assert_eq!(status, Some(StatusCode::OK)); } + #[test] + fn format_endpoint_distinguishes_l4_from_l7_rest() { + use openshell_core::proto::{L7Allow, L7DenyRule, L7Rule, NetworkEndpoint}; + + let l4 = NetworkEndpoint { + host: "host.example.test".to_string(), + port: 443, + ..Default::default() + }; + assert_eq!(format_endpoint(&l4), "host.example.test:443 [L4]"); + + let l7_readonly = NetworkEndpoint { + host: "host.example.test".to_string(), + port: 443, + protocol: "rest".to_string(), + access: "read-only".to_string(), + ..Default::default() + }; + assert_eq!( + format_endpoint(&l7_readonly), + "host.example.test:443 [L7 rest, access=read-only]" + ); + + let l7_scoped = NetworkEndpoint { + host: "host.example.test".to_string(), + port: 443, + protocol: "rest".to_string(), + rules: vec![L7Rule { + allow: Some(L7Allow { + method: "PUT".to_string(), + path: "/v1/example/resource".to_string(), + ..Default::default() + }), + }], + deny_rules: vec![L7DenyRule { + method: "DELETE".to_string(), + path: "/v1/example/resource".to_string(), + ..Default::default() + }], + ..Default::default() + }; + assert_eq!( + format_endpoint(&l7_scoped), + "host.example.test:443 [L7 rest, allow PUT /v1/example/resource, deny DELETE /v1/example/resource]" + ); + } } diff --git a/crates/openshell-core/src/settings.rs b/crates/openshell-core/src/settings.rs index 2765ebeda..897317a5a 100644 --- a/crates/openshell-core/src/settings.rs +++ b/crates/openshell-core/src/settings.rs @@ -50,6 +50,15 @@ pub struct RegisteredSetting { /// 5. Add a unit test in this module's `tests` section to cover the new key. pub const PROVIDERS_V2_ENABLED_KEY: &str = "providers_v2_enabled"; +/// Sandbox-level opt-in for the agent-driven policy proposal surface. +/// +/// When true, the supervisor installs the `policy_advisor` skill, serves +/// the `policy.local` API routes, and includes `next_steps` in L7 deny +/// bodies. See `crates/openshell-sandbox/src/policy_local.rs`. Defaults to +/// false. Independent of the per-proposal developer approval gate, which +/// still applies when this flag is on. +pub const AGENT_POLICY_PROPOSALS_ENABLED_KEY: &str = "agent_policy_proposals_enabled"; + pub const REGISTERED_SETTINGS: &[RegisteredSetting] = &[ // Gateway-level opt-in for provider profile policy composition. Defaults // to false when unset. @@ -64,6 +73,12 @@ pub const REGISTERED_SETTINGS: &[RegisteredSetting] = &[ key: "ocsf_json_enabled", kind: SettingValueKind::Bool, }, + // Sandbox-level opt-in for the agent-driven policy proposal surface. + // See AGENT_POLICY_PROPOSALS_ENABLED_KEY for details. Defaults to false. + RegisteredSetting { + key: AGENT_POLICY_PROPOSALS_ENABLED_KEY, + kind: SettingValueKind::Bool, + }, // Test-only keys live behind the `dev-settings` feature flag so they // don't appear in production builds. #[cfg(feature = "dev-settings")] diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index cc35f67b5..1cb15f929 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -11,8 +11,8 @@ use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ DenialSummary, GetInferenceBundleRequest, GetInferenceBundleResponse, GetSandboxConfigRequest, GetSandboxProviderEnvironmentRequest, PolicySource, PolicyStatus, ReportPolicyStatusRequest, - SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, UpdateConfigRequest, - inference_client::InferenceClient, open_shell_client::OpenShellClient, + SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, + UpdateConfigRequest, inference_client::InferenceClient, open_shell_client::OpenShellClient, }; use tonic::service::interceptor::InterceptedService; use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity}; @@ -329,15 +329,20 @@ impl CachedOpenShellClient { }) } - /// Submit denial summaries for policy analysis. + /// Submit denial summaries and/or agent-authored proposals for policy analysis. + /// + /// Returns the gateway response so callers can surface accepted/rejected + /// counts and rejection reasons (e.g., the `policy.local` API forwards + /// these to the in-sandbox agent). pub async fn submit_policy_analysis( &self, sandbox_name: &str, summaries: Vec, proposed_chunks: Vec, analysis_mode: &str, - ) -> Result<()> { - self.client + ) -> Result { + let response = self + .client .clone() .submit_policy_analysis(SubmitPolicyAnalysisRequest { name: sandbox_name.to_string(), @@ -348,7 +353,7 @@ impl CachedOpenShellClient { .await .into_diagnostic()?; - Ok(()) + Ok(response.into_inner()) } /// Report policy load status back to the server. diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index d0599ea99..f099c3558 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -305,6 +305,11 @@ where &reason, client, Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), ) .await?; return Ok(()); @@ -584,6 +589,11 @@ where &reason, client, Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), ) .await?; return Ok(()); @@ -789,6 +799,11 @@ where &reason, client, Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), ) .await?; return Ok(()); diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 19acdbf32..85ae01290 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -72,10 +72,19 @@ impl L7Provider for RestProvider { reason: &str, client: &mut C, ) -> Result<()> { - send_deny_response(req, policy_name, reason, client, None).await + send_deny_response(req, policy_name, reason, client, None, None).await } } +/// Extra sandbox-side context included in agent-readable deny responses when +/// the relay has it available. +#[derive(Debug, Clone, Copy, Default)] +pub(crate) struct DenyResponseContext<'a> { + pub(crate) host: Option<&'a str>, + pub(crate) port: Option, + pub(crate) binary: Option<&'a str>, +} + impl RestProvider { /// Deny with a redacted target for the response body. pub(crate) async fn deny_with_redacted_target( @@ -85,8 +94,9 @@ impl RestProvider { reason: &str, client: &mut C, redacted_target: Option<&str>, + context: Option>, ) -> Result<()> { - send_deny_response(req, policy_name, reason, client, redacted_target).await + send_deny_response(req, policy_name, reason, client, redacted_target, context).await } } @@ -452,14 +462,9 @@ async fn send_deny_response( reason: &str, client: &mut C, redacted_target: Option<&str>, + context: Option>, ) -> Result<()> { - let target = redacted_target.unwrap_or(&req.target); - let body = serde_json::json!({ - "error": "policy_denied", - "policy": policy_name, - "rule": format!("{} {}", req.action, target), - "detail": reason - }); + let body = deny_response_body(req, policy_name, reason, redacted_target, context); let body_bytes = body.to_string(); let response = format!( "HTTP/1.1 403 Forbidden\r\n\ @@ -481,6 +486,74 @@ async fn send_deny_response( Ok(()) } +fn deny_response_body( + req: &L7Request, + policy_name: &str, + reason: &str, + redacted_target: Option<&str>, + context: Option>, +) -> serde_json::Value { + let target = redacted_target.unwrap_or(&req.target); + let context = context.unwrap_or_default(); + let host = non_empty(context.host); + let binary = non_empty(context.binary); + + let mut rule_missing = serde_json::Map::new(); + rule_missing.insert("type".to_string(), serde_json::json!("rest_allow")); + rule_missing.insert("layer".to_string(), serde_json::json!("l7")); + rule_missing.insert("method".to_string(), serde_json::json!(req.action)); + rule_missing.insert("path".to_string(), serde_json::json!(target)); + if let Some(host) = host { + rule_missing.insert("host".to_string(), serde_json::json!(host)); + } + if let Some(port) = context.port { + rule_missing.insert("port".to_string(), serde_json::json!(port)); + } + if let Some(binary) = binary { + rule_missing.insert("binary".to_string(), serde_json::json!(binary)); + } + + let mut body = serde_json::Map::new(); + body.insert("error".to_string(), serde_json::json!("policy_denied")); + body.insert("policy".to_string(), serde_json::json!(policy_name)); + body.insert( + "rule".to_string(), + serde_json::json!(format!("{} {}", req.action, target)), + ); + body.insert("detail".to_string(), serde_json::json!(reason)); + body.insert("layer".to_string(), serde_json::json!("l7")); + body.insert("protocol".to_string(), serde_json::json!("rest")); + body.insert("method".to_string(), serde_json::json!(req.action)); + body.insert("path".to_string(), serde_json::json!(target)); + if let Some(host) = host { + body.insert("host".to_string(), serde_json::json!(host)); + } + if let Some(port) = context.port { + body.insert("port".to_string(), serde_json::json!(port)); + } + if let Some(binary) = binary { + body.insert("binary".to_string(), serde_json::json!(binary)); + } + body.insert( + "rule_missing".to_string(), + serde_json::Value::Object(rule_missing), + ); + // `next_steps` is generated by the policy_local module so the wire URLs + // and the on-disk skill path stay in sync with the route table. Adding + // or renaming a route only requires touching the constants in that + // module; this side picks up the change automatically. + body.insert( + "next_steps".to_string(), + crate::policy_local::agent_next_steps(), + ); + + serde_json::Value::Object(body) +} + +fn non_empty(value: Option<&str>) -> Option<&str> { + value.map(str::trim).filter(|value| !value.is_empty()) +} + /// Parse Content-Length or Transfer-Encoding from HTTP headers. /// /// Per RFC 7230 Section 3.3.3, rejects requests containing both @@ -977,6 +1050,113 @@ mod tests { const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + #[test] + fn deny_response_body_is_agent_readable_and_redacted() { + // Agent-readable next_steps is gated on the proposals feature flag. + let _proposals = crate::test_helpers::ProposalsFlagGuard::set_blocking(true); + let req = L7Request { + action: "PUT".to_string(), + target: "/repos/NVIDIA/OpenShell/contents/README.md?access_token=secret-token" + .to_string(), + query_params: HashMap::new(), + raw_header: Vec::new(), + body_length: BodyLength::ContentLength(128), + }; + + let body = deny_response_body( + &req, + "github-readonly", + "no matching L7 allow rule", + Some("/repos/NVIDIA/OpenShell/contents/README.md"), + Some(DenyResponseContext { + host: Some("api.github.com"), + port: Some(443), + binary: Some("/usr/bin/gh"), + }), + ); + + assert_eq!(body["error"], "policy_denied"); + assert_eq!(body["policy"], "github-readonly"); + assert_eq!(body["layer"], "l7"); + assert_eq!(body["protocol"], "rest"); + assert_eq!(body["method"], "PUT"); + assert_eq!(body["host"], "api.github.com"); + assert_eq!(body["port"], 443); + assert_eq!(body["binary"], "/usr/bin/gh"); + assert_eq!(body["path"], "/repos/NVIDIA/OpenShell/contents/README.md"); + assert_eq!( + body["rule"], + "PUT /repos/NVIDIA/OpenShell/contents/README.md" + ); + assert_eq!(body["rule_missing"]["type"], "rest_allow"); + assert_eq!(body["rule_missing"]["layer"], "l7"); + assert_eq!(body["rule_missing"]["method"], "PUT"); + assert_eq!( + body["rule_missing"]["path"], + "/repos/NVIDIA/OpenShell/contents/README.md" + ); + assert_eq!(body["rule_missing"]["host"], "api.github.com"); + assert_eq!(body["rule_missing"]["port"], 443); + assert_eq!(body["rule_missing"]["binary"], "/usr/bin/gh"); + assert_eq!(body["next_steps"][0]["action"], "read_skill"); + assert_eq!( + body["next_steps"][0]["path"], + "/etc/openshell/skills/policy_advisor.md" + ); + assert_eq!(body["next_steps"][3]["body_type"], "PolicyMergeOperation"); + assert!( + !body.to_string().contains("secret-token"), + "deny body must not leak query params or credential values" + ); + } + + #[tokio::test] + async fn send_deny_response_writes_structured_json_403() { + // Agent-readable next_steps is gated on the proposals feature flag. + let _proposals = crate::test_helpers::ProposalsFlagGuard::set(true).await; + let (mut client, mut server) = tokio::io::duplex(4096); + let send = tokio::spawn(async move { + let req = L7Request { + action: "POST".to_string(), + target: "/user/repos".to_string(), + query_params: HashMap::new(), + raw_header: Vec::new(), + body_length: BodyLength::ContentLength(64), + }; + send_deny_response( + &req, + "github-readonly", + "no matching L7 allow rule", + &mut server, + None, + Some(DenyResponseContext { + host: Some("api.github.com"), + port: Some(443), + binary: Some("/usr/bin/gh"), + }), + ) + .await + .unwrap(); + }); + + let mut received = Vec::new(); + client.read_to_end(&mut received).await.unwrap(); + send.await.unwrap(); + + let response = String::from_utf8(received).unwrap(); + assert!(response.starts_with("HTTP/1.1 403 Forbidden")); + assert!(response.contains("Content-Type: application/json")); + assert!(response.contains("X-OpenShell-Policy: github-readonly")); + + let (_, body) = response.split_once("\r\n\r\n").unwrap(); + let body: serde_json::Value = serde_json::from_str(body).unwrap(); + assert_eq!(body["error"], "policy_denied"); + assert_eq!(body["method"], "POST"); + assert_eq!(body["path"], "/user/repos"); + assert_eq!(body["rule_missing"]["host"], "api.github.com"); + assert_eq!(body["next_steps"][2]["action"], "inspect_recent_denials"); + } + #[test] fn parse_content_length() { let headers = "POST /api HTTP/1.1\r\nHost: example.com\r\nContent-Length: 42\r\n\r\n"; diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index abbf7eb65..25a28af54 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -15,12 +15,14 @@ pub mod log_push; pub mod mechanistic_mapper; pub mod opa; mod policy; +mod policy_local; mod process; pub mod procfs; mod provider_credentials; pub mod proxy; mod sandbox; mod secrets; +mod skills; mod ssh; mod supervisor_session; @@ -88,6 +90,83 @@ pub(crate) fn ocsf_ctx() -> &'static SandboxContext { OCSF_CTX.get().unwrap_or(&OCSF_CTX_FALLBACK) } +/// Process-wide flag for the agent-driven policy proposal surface. +/// Set once during `run_sandbox()` startup and updated by the settings poll +/// loop when `agent_policy_proposals_enabled` changes. Read by the +/// `policy.local` route handler and the L7 deny body's `next_steps` builder +/// to gate the agent-controlled mutation surface. Exposed `pub(crate)` so +/// unit tests in sibling modules can flip the flag through a serialized +/// guard (see `policy_local::tests::ProposalsFlagGuard`). +pub(crate) static AGENT_PROPOSALS_ENABLED: OnceLock> = + OnceLock::new(); + +/// Read the current value of the agent proposals feature flag. +/// +/// Returns `false` if `run_sandbox()` has not initialized the flag (e.g. +/// during unit tests), matching the documented default for the setting. +pub(crate) fn agent_proposals_enabled() -> bool { + AGENT_PROPOSALS_ENABLED + .get() + .is_some_and(|flag| flag.load(Ordering::Relaxed)) +} + +/// Test-only helpers shared across sibling test modules. +#[cfg(test)] +pub(crate) mod test_helpers { + #![allow( + clippy::redundant_pub_crate, + reason = "intentional crate-private module" + )] + use std::sync::Arc; + use std::sync::LazyLock; + use std::sync::atomic::{AtomicBool, Ordering}; + use tokio::sync::MutexGuard; + + static PROPOSALS_FLAG_LOCK: LazyLock> = + LazyLock::new(|| tokio::sync::Mutex::new(())); + + /// Guard for tests that toggle the process-wide + /// `AGENT_PROPOSALS_ENABLED` flag. Acquires a process-wide async mutex, + /// swaps in the requested value, and restores the previous value on drop. + /// Hold the guard for the duration of any code that reads + /// `agent_proposals_enabled()`. + pub(crate) struct ProposalsFlagGuard { + prev: bool, + flag: Arc, + _lock: MutexGuard<'static, ()>, + } + + impl ProposalsFlagGuard { + pub(crate) async fn set(enabled: bool) -> Self { + let lock = PROPOSALS_FLAG_LOCK.lock().await; + Self::with_lock(enabled, lock) + } + + pub(crate) fn set_blocking(enabled: bool) -> Self { + let lock = PROPOSALS_FLAG_LOCK.blocking_lock(); + Self::with_lock(enabled, lock) + } + + fn with_lock(enabled: bool, lock: MutexGuard<'static, ()>) -> Self { + let flag = super::AGENT_PROPOSALS_ENABLED + .get_or_init(|| Arc::new(AtomicBool::new(false))) + .clone(); + let prev = flag.swap(enabled, Ordering::Relaxed); + Self { + prev, + flag, + _lock: lock, + } + } + } + + impl Drop for ProposalsFlagGuard { + fn drop(&mut self) { + self.flag.store(self.prev, Ordering::Relaxed); + } + } +} + use crate::identity::BinaryIdentityCache; use crate::l7::tls::{ CertCache, ProxyTlsState, SandboxCa, build_upstream_client_config, read_system_ca_bundle, @@ -260,6 +339,11 @@ pub async fn run_sandbox( policy_data, ) .await?; + let policy_local_ctx = Arc::new(policy_local::PolicyLocalContext::new( + retained_proto.clone(), + openshell_endpoint.clone(), + sandbox_name_for_agg.clone().or_else(|| sandbox_id.clone()), + )); // Validate that the required "sandbox" user exists in this image. // All sandbox images must include this user for privilege dropping. @@ -318,6 +402,49 @@ pub async fn run_sandbox( // Prepare filesystem: create and chown read_write directories prepare_filesystem(&policy)?; + // Initialize the agent-proposals feature flag. Default false until the + // initial settings fetch (or the poll loop) tells us otherwise. The flag + // gates the skill install, the policy.local route handler, and the L7 + // deny body's `next_steps` field — see `agent_proposals_enabled()`. + let proposals_enabled = Arc::new(std::sync::atomic::AtomicBool::new(false)); + if AGENT_PROPOSALS_ENABLED + .set(proposals_enabled.clone()) + .is_err() + { + debug!("agent proposals flag already initialized, keeping existing"); + } + + // Eagerly fetch the initial settings so skill install can honor the flag + // at startup rather than waiting for the poll loop's first tick. In + // offline/file-mode there is no gateway, so the flag stays false. + if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) + && let Ok(client) = grpc_client::CachedOpenShellClient::connect(endpoint).await + && let Ok(result) = client.poll_settings(id).await + { + let initial = extract_bool_setting( + &result.settings, + openshell_core::settings::AGENT_POLICY_PROPOSALS_ENABLED_KEY, + ) + .unwrap_or(false); + proposals_enabled.store(initial, Ordering::Relaxed); + } + + if agent_proposals_enabled() { + match skills::install_static_skills() { + Ok(installed) => { + info!( + path = %installed.policy_advisor.display(), + "Installed sandbox agent skill" + ); + } + Err(error) => { + warn!(error = %error, "Failed to install sandbox agent skill"); + } + } + } else { + debug!("agent_policy_proposals_enabled is false at startup; skipping skill install"); + } + // Generate ephemeral CA and TLS state for HTTPS L7 inspection. // The CA cert is written to disk so sandbox processes can trust it. let (tls_state, ca_file_paths) = if matches!(policy.network.mode, NetworkMode::Proxy) { @@ -485,6 +612,7 @@ pub async fn run_sandbox( tls_state, inference_ctx, Some(provider_credentials.clone()), + Some(policy_local_ctx.clone()), denial_tx, ) .await?; @@ -801,23 +929,24 @@ pub async fn run_sandbox( let poll_ocsf_enabled = ocsf_enabled.clone(); let poll_pid = entrypoint_pid.clone(); let poll_provider_credentials = provider_credentials.clone(); + let poll_policy_local = policy_local_ctx.clone(); let poll_interval_secs: u64 = std::env::var("OPENSHELL_POLICY_POLL_INTERVAL_SECS") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(10); + let poll_ctx = PolicyPollLoopContext { + endpoint: poll_endpoint, + sandbox_id: poll_id, + opa_engine: poll_engine, + entrypoint_pid: poll_pid, + interval_secs: poll_interval_secs, + ocsf_enabled: poll_ocsf_enabled, + provider_credentials: poll_provider_credentials, + policy_local_ctx: Some(poll_policy_local), + }; tokio::spawn(async move { - if let Err(e) = run_policy_poll_loop( - &poll_endpoint, - &poll_id, - &poll_engine, - &poll_pid, - poll_interval_secs, - &poll_ocsf_enabled, - poll_provider_credentials, - ) - .await - { + if let Err(e) = run_policy_poll_loop(poll_ctx).await { ocsf_emit!( AppLifecycleBuilder::new(ocsf_ctx()) .activity(ActivityId::Fail) @@ -2151,22 +2280,25 @@ async fn flush_proposals_to_gateway( /// /// When the entrypoint PID is available, policy reloads include symlink /// resolution for binary paths via the container filesystem. -async fn run_policy_poll_loop( - endpoint: &str, - sandbox_id: &str, - opa_engine: &Arc, - entrypoint_pid: &Arc, +struct PolicyPollLoopContext { + endpoint: String, + sandbox_id: String, + opa_engine: Arc, + entrypoint_pid: Arc, interval_secs: u64, - ocsf_enabled: &std::sync::atomic::AtomicBool, + ocsf_enabled: Arc, provider_credentials: provider_credentials::ProviderCredentialState, -) -> Result<()> { + policy_local_ctx: Option>, +} + +async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { use crate::grpc_client::CachedOpenShellClient; use openshell_core::proto::PolicySource; use std::sync::atomic::Ordering; - let client = CachedOpenShellClient::connect(endpoint).await?; + let client = CachedOpenShellClient::connect(&ctx.endpoint).await?; let mut current_config_revision: u64 = 0; - let mut current_provider_env_revision: u64 = provider_credentials.snapshot().revision; + let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); let mut current_settings: std::collections::HashMap< String, @@ -2174,7 +2306,7 @@ async fn run_policy_poll_loop( > = std::collections::HashMap::new(); // Initialize revision from the first poll. - match client.poll_settings(sandbox_id).await { + match client.poll_settings(&ctx.sandbox_id).await { Ok(result) => { current_config_revision = result.config_revision; current_policy_hash = result.policy_hash.clone(); @@ -2189,11 +2321,11 @@ async fn run_policy_poll_loop( } } - let interval = Duration::from_secs(interval_secs); + let interval = Duration::from_secs(ctx.interval_secs); loop { tokio::time::sleep(interval).await; - let result = match client.poll_settings(sandbox_id).await { + let result = match client.poll_settings(&ctx.sandbox_id).await { Ok(r) => r, Err(e) => { debug!(error = %e, "Settings poll: server unreachable, will retry"); @@ -2226,9 +2358,9 @@ async fn run_policy_poll_loop( .build()); if provider_env_changed { - match grpc_client::fetch_provider_environment(endpoint, sandbox_id).await { + match grpc_client::fetch_provider_environment(&ctx.endpoint, &ctx.sandbox_id).await { Ok(env_result) => { - let env_count = provider_credentials.install_environment( + let env_count = ctx.provider_credentials.install_environment( env_result.provider_env_revision, env_result.environment, ); @@ -2273,9 +2405,12 @@ async fn run_policy_poll_loop( continue; }; - let pid = entrypoint_pid.load(Ordering::Acquire); - match opa_engine.reload_from_proto_with_pid(policy, pid) { + let pid = ctx.entrypoint_pid.load(Ordering::Acquire); + match ctx.opa_engine.reload_from_proto_with_pid(policy, pid) { Ok(()) => { + if let Some(policy_local_ctx) = ctx.policy_local_ctx.as_ref() { + policy_local_ctx.set_current_policy(policy.clone()).await; + } if result.global_policy_version > 0 { ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) .severity(SeverityId::Informational) @@ -2306,7 +2441,7 @@ async fn run_policy_poll_loop( if result.version > 0 && result.policy_source == PolicySource::Sandbox && let Err(e) = client - .report_policy_status(sandbox_id, result.version, true, "") + .report_policy_status(&ctx.sandbox_id, result.version, true, "") .await { warn!(error = %e, "Failed to report policy load success"); @@ -2327,7 +2462,12 @@ async fn run_policy_poll_loop( if result.version > 0 && result.policy_source == PolicySource::Sandbox && let Err(report_err) = client - .report_policy_status(sandbox_id, result.version, false, &e.to_string()) + .report_policy_status( + &ctx.sandbox_id, + result.version, + false, + &e.to_string(), + ) .await { warn!(error = %report_err, "Failed to report policy load failure"); @@ -2338,11 +2478,45 @@ async fn run_policy_poll_loop( // Apply OCSF JSON toggle from the `ocsf_json_enabled` setting. let new_ocsf = extract_bool_setting(&result.settings, "ocsf_json_enabled").unwrap_or(false); - let prev_ocsf = ocsf_enabled.swap(new_ocsf, Ordering::Relaxed); + let prev_ocsf = ctx.ocsf_enabled.swap(new_ocsf, Ordering::Relaxed); if new_ocsf != prev_ocsf { info!(ocsf_json_enabled = new_ocsf, "OCSF JSONL logging toggled"); } + // Apply the agent-proposals feature toggle. On a false→true transition + // we lazily install the skill so a sandbox that started with the flag + // off picks up the surface without a recreate. We never uninstall on + // a true→false transition: stale skill content on disk is harmless + // because route_request and agent_next_steps both gate on the live + // atomic, so the agent that reads the skill will see 404s and an + // empty `next_steps` array regardless. + if let Some(flag) = AGENT_PROPOSALS_ENABLED.get() { + let new_proposals = extract_bool_setting( + &result.settings, + openshell_core::settings::AGENT_POLICY_PROPOSALS_ENABLED_KEY, + ) + .unwrap_or(false); + let prev_proposals = flag.swap(new_proposals, Ordering::Relaxed); + if new_proposals != prev_proposals { + info!( + agent_policy_proposals_enabled = new_proposals, + "agent-driven policy proposals toggled" + ); + if new_proposals && !prev_proposals { + match skills::install_static_skills() { + Ok(installed) => info!( + path = %installed.policy_advisor.display(), + "Installed sandbox agent skill on toggle-on" + ), + Err(error) => warn!( + error = %error, + "Failed to install sandbox agent skill on toggle-on" + ), + } + } + } + } + current_config_revision = result.config_revision; current_policy_hash = result.policy_hash; current_settings = result.settings; diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs new file mode 100644 index 000000000..21556ec6a --- /dev/null +++ b/crates/openshell-sandbox/src/policy_local.rs @@ -0,0 +1,1157 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Sandbox-local policy advisor HTTP API. + +use miette::{IntoDiagnostic, Result}; +use openshell_core::proto::{ + L7Allow, L7DenyRule, L7Rule, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, PolicyChunk, + SandboxPolicy as ProtoSandboxPolicy, +}; +use serde::Deserialize; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::sync::RwLock; + +pub const POLICY_LOCAL_HOST: &str = "policy.local"; + +/// Filesystem path of the static agent guidance bundle inside the sandbox. +/// Single source of truth: the skill installer writes here, the L7 deny body +/// references this path in `next_steps`, and the skill's own documentation +/// renders the same path. Changing the location is a one-line update here. +pub const SKILL_PATH: &str = "/etc/openshell/skills/policy_advisor.md"; + +/// Routes served by the in-sandbox policy advisor API. Held in one place so +/// the L7 deny `next_steps` array, the route dispatcher, the skill content, +/// and tests all stay in sync — change the wire path here and every caller +/// follows. See `agent_next_steps()` for the consumer that surfaces these +/// to the agent on a 403. +pub const ROUTE_POLICY_CURRENT: &str = "/v1/policy/current"; +pub const ROUTE_DENIALS: &str = "/v1/denials"; +pub const ROUTE_PROPOSALS: &str = "/v1/proposals"; + +const MAX_POLICY_LOCAL_BODY_BYTES: usize = 64 * 1024; +/// Hard ceiling on how long a single request body read can stall. Bounds a +/// slowloris-style upload from an in-sandbox process; the proxy listener only +/// accepts loopback connections, so practical impact is limited, but this is +/// cheap defense-in-depth. +const POLICY_LOCAL_BODY_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15); +const DEFAULT_DENIALS_LIMIT: usize = 10; +const MAX_DENIALS_LIMIT: usize = 100; +/// The shorthand rolling appender keeps three files (daily rotation); read the +/// most recent two so a request just past midnight still has yesterday's +/// denials. +const DENIAL_LOG_FILES_TO_SCAN: usize = 2; +const LOG_DIR: &str = "/var/log"; +/// Shorthand log filenames are `openshell.YYYY-MM-DD.log`. The trailing dot in +/// the prefix is intentional: it disambiguates from the OCSF JSONL appender's +/// `openshell-ocsf.YYYY-MM-DD.log`, which we never want to surface here (the +/// JSONL is opt-in via `ocsf_json_enabled` and not the source of truth for +/// `/v1/denials`). +const SHORTHAND_LOG_PREFIX: &str = "openshell."; +/// Defensive cap on per-line length returned to the agent so a pathological +/// log entry (very long URL path, etc.) cannot blow up the response. +const MAX_DENIAL_LINE_BYTES: usize = 4096; + +#[derive(Debug)] +pub struct PolicyLocalContext { + current_policy: Arc>>, + gateway_endpoint: Option, + sandbox_name: Option, + shorthand_log_dir: PathBuf, +} + +impl PolicyLocalContext { + pub fn new( + current_policy: Option, + gateway_endpoint: Option, + sandbox_name: Option, + ) -> Self { + Self::with_log_dir( + current_policy, + gateway_endpoint, + sandbox_name, + PathBuf::from(LOG_DIR), + ) + } + + fn with_log_dir( + current_policy: Option, + gateway_endpoint: Option, + sandbox_name: Option, + shorthand_log_dir: PathBuf, + ) -> Self { + Self { + current_policy: Arc::new(RwLock::new(current_policy)), + gateway_endpoint, + sandbox_name, + shorthand_log_dir, + } + } + + pub async fn set_current_policy(&self, policy: ProtoSandboxPolicy) { + *self.current_policy.write().await = Some(policy); + } +} + +pub async fn handle_forward_request( + ctx: &PolicyLocalContext, + method: &str, + path: &str, + initial_request: &[u8], + client: &mut S, +) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let body = read_request_body(initial_request, client).await?; + let (status, payload) = route_request(ctx, method, path, &body).await; + write_json_response(client, status, payload).await +} + +async fn route_request( + ctx: &PolicyLocalContext, + method: &str, + path: &str, + body: &[u8], +) -> (u16, serde_json::Value) { + let (route, query) = path.split_once('?').map_or((path, ""), |(r, q)| (r, q)); + // Gate every route on the feature flag so the agent surface is fully off + // when the flag is off — including the diagnostic `current_policy` and + // `denials` routes. The skill is also not installed in that mode, so a + // disabled sandbox has no entry point into this API at all. + if !crate::agent_proposals_enabled() { + return ( + 404, + serde_json::json!({ + "error": "feature_disabled", + "detail": "agent-driven policy proposals are not enabled in this sandbox; set the `agent_policy_proposals_enabled` setting to true to enable" + }), + ); + } + match (method, route) { + ("GET", ROUTE_POLICY_CURRENT) => current_policy_response(ctx).await, + ("GET", ROUTE_DENIALS) => recent_denials_response(ctx, query).await, + ("POST", ROUTE_PROPOSALS) => submit_proposal(ctx, body).await, + _ => ( + 404, + serde_json::json!({ + "error": "not_found", + "detail": format!("policy.local route not found: {method} {route}") + }), + ), + } +} + +/// Build the `next_steps` array embedded in the L7 deny body so the agent has +/// machine-readable pointers to this API. Centralizes the shape here to keep +/// the deny body and the actual route table from drifting — adding or +/// renaming a route only requires touching the route constants above. +/// +/// Returns an empty array when `agent_proposals_enabled()` is false so a +/// disabled sandbox doesn't advertise a surface that 404s. The deny body +/// caller still emits the field (with `[]`) so the wire shape is stable. +#[must_use] +pub fn agent_next_steps() -> serde_json::Value { + if !crate::agent_proposals_enabled() { + return serde_json::json!([]); + } + let host = POLICY_LOCAL_HOST; + serde_json::json!([ + { + "action": "read_skill", + "path": SKILL_PATH, + }, + { + "action": "inspect_policy", + "method": "GET", + "url": format!("http://{host}{ROUTE_POLICY_CURRENT}"), + }, + { + "action": "inspect_recent_denials", + "method": "GET", + "url": format!("http://{host}{ROUTE_DENIALS}?last=5"), + }, + { + "action": "submit_proposal", + "method": "POST", + "url": format!("http://{host}{ROUTE_PROPOSALS}"), + "body_type": "PolicyMergeOperation", + }, + ]) +} + +async fn current_policy_response(ctx: &PolicyLocalContext) -> (u16, serde_json::Value) { + let Some(policy) = ctx.current_policy.read().await.clone() else { + return ( + 404, + serde_json::json!({ + "error": "policy_unavailable", + "detail": "no current sandbox policy is loaded" + }), + ); + }; + + match openshell_policy::serialize_sandbox_policy(&policy) { + Ok(policy_yaml) => ( + 200, + serde_json::json!({ + "format": "yaml", + "policy_yaml": policy_yaml + }), + ), + Err(error) => ( + 500, + serde_json::json!({ + "error": "policy_serialize_failed", + "detail": error.to_string() + }), + ), + } +} + +async fn recent_denials_response( + ctx: &PolicyLocalContext, + query: &str, +) -> (u16, serde_json::Value) { + let limit = parse_last_query(query).unwrap_or(DEFAULT_DENIALS_LIMIT); + let log_dir = ctx.shorthand_log_dir.clone(); + + // Distinguish "shorthand log exists and no denials happened" from "no log + // file yet, so we have nothing to read." Without this flag the agent sees + // `[]` in both cases and cannot tell the difference. The shorthand log is + // always-on (no setting gates it), so the only way `log_available=false` + // happens in practice is if the supervisor has not flushed any events to + // disk yet, or `/var/log` is not writable in this image. + let log_available = matches!( + collect_shorthand_log_files(&log_dir, 1), + Ok(files) if !files.is_empty() + ); + + let denials = tokio::task::spawn_blocking(move || read_recent_denial_lines(&log_dir, limit)) + .await + .unwrap_or_default(); + + let mut payload = serde_json::json!({ + "denials": denials, + "log_available": log_available, + }); + if !log_available { + payload["note"] = serde_json::json!( + "no shorthand log file is present yet at /var/log/openshell.YYYY-MM-DD.log; the supervisor may not have emitted any events to disk yet" + ); + } + + (200, payload) +} + +fn parse_last_query(query: &str) -> Option { + if query.is_empty() { + return None; + } + for pair in query.split('&') { + let Some((key, value)) = pair.split_once('=') else { + continue; + }; + if key == "last" { + return value + .parse::() + .ok() + .map(|n| n.clamp(1, MAX_DENIALS_LIMIT)); + } + } + None +} + +/// Walk the shorthand log files (most-recent first) and return up to `limit` +/// raw denial lines in newest-first order. The agent receives the same +/// human-readable text that `openshell logs` displays — no parsing back into +/// structured form. Updating the shorthand format adds fields automatically; +/// no schema rev required. +/// +/// Reads files synchronously and is intended to run inside `spawn_blocking`. +fn read_recent_denial_lines(log_dir: &Path, limit: usize) -> Vec { + let Ok(files) = collect_shorthand_log_files(log_dir, DENIAL_LOG_FILES_TO_SCAN) else { + return Vec::new(); + }; + + let mut lines: Vec = Vec::with_capacity(limit); + for path in files { + let Ok(contents) = std::fs::read_to_string(&path) else { + continue; + }; + // Walk lines newest-first. Within a single file, the last line written + // is the freshest event. + for line in contents.lines().rev() { + if !is_ocsf_denial_line(line) { + continue; + } + // Defense-in-depth: redact query strings before truncation. The + // FORWARD deny path in `proxy.rs` populates the OCSF `message` + // and URL with the raw request path including `?query=...`, which + // the shorthand layer then renders verbatim. Stripping queries + // here means the agent never sees the secret even if an upstream + // emit site forgets to redact (TODO: harden the emit sites in + // proxy.rs FORWARD path so the on-disk shorthand log itself is + // clean — tracked separately). Redact first so truncation cannot + // slice mid-secret. + let redacted = redact_query_strings(line); + let surfaced = truncate_at_char_boundary(&redacted, MAX_DENIAL_LINE_BYTES); + lines.push(surfaced); + if lines.len() >= limit { + return lines; + } + } + } + lines +} + +/// Replace any `?` substring with `?[redacted]` to keep query-string +/// secrets out of the agent's view. Walks per Unicode scalar value so multi-byte +/// content is safe. A query is everything from `?` until the next whitespace or +/// `]` (the shorthand format uses `[...]` for context tags). +fn redact_query_strings(line: &str) -> String { + let mut out = String::with_capacity(line.len()); + let mut chars = line.chars(); + while let Some(c) = chars.next() { + if c == '?' { + out.push('?'); + out.push_str("[redacted]"); + // Consume until whitespace or `]` (preserved as the next token's + // boundary by writing it back out). + for next in chars.by_ref() { + if next.is_whitespace() || next == ']' { + out.push(next); + break; + } + } + } else { + out.push(c); + } + } + out +} + +/// Truncate `s` at the largest UTF-8 char boundary <= `max_bytes`, appending a +/// `...[truncated]` suffix. Returning a `String` (not `&str`) avoids surprising +/// callers about lifetime relationships with `s`. +fn truncate_at_char_boundary(s: &str, max_bytes: usize) -> String { + if s.len() <= max_bytes { + return s.to_string(); + } + let mut end = max_bytes; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + let mut out = String::with_capacity(end + "...[truncated]".len()); + out.push_str(&s[..end]); + out.push_str("...[truncated]"); + out +} + +/// True for OCSF denial events as rendered by the shorthand layer. The format +/// is ` OCSF <[SEV]> ...`. The literal +/// ` OCSF ` substring identifies an OCSF event (vs. a non-OCSF tracing line); +/// ` DENIED ` is the OCSF action label uppercased and surrounded by spaces, so +/// matching it is safe against substring collisions in URLs or hostnames. +fn is_ocsf_denial_line(line: &str) -> bool { + line.contains(" OCSF ") && line.contains(" DENIED ") +} + +fn collect_shorthand_log_files(log_dir: &Path, max_files: usize) -> std::io::Result> { + let mut entries: Vec<(std::time::SystemTime, PathBuf)> = std::fs::read_dir(log_dir)? + .filter_map(std::result::Result::ok) + .filter_map(|entry| { + let path = entry.path(); + let name = entry.file_name(); + let name = name.to_string_lossy(); + // `openshell.YYYY-MM-DD.log` only — the trailing dot in the prefix + // disambiguates from `openshell-ocsf.YYYY-MM-DD.log`. + if !name.starts_with(SHORTHAND_LOG_PREFIX) || !name.ends_with(".log") { + return None; + } + let modified = entry.metadata().and_then(|m| m.modified()).ok()?; + Some((modified, path)) + }) + .collect(); + + entries.sort_by_key(|entry| std::cmp::Reverse(entry.0)); + Ok(entries + .into_iter() + .take(max_files) + .map(|(_, p)| p) + .collect()) +} + +async fn submit_proposal(ctx: &PolicyLocalContext, body: &[u8]) -> (u16, serde_json::Value) { + let Some(endpoint) = ctx.gateway_endpoint.as_deref() else { + return ( + 503, + serde_json::json!({ + "error": "gateway_unavailable", + "detail": "policy proposal submission requires a gateway-connected sandbox" + }), + ); + }; + let Some(sandbox_name) = ctx + .sandbox_name + .as_deref() + .map(str::trim) + .filter(|name| !name.is_empty()) + else { + return ( + 503, + serde_json::json!({ + "error": "sandbox_name_unavailable", + "detail": "policy proposal submission requires a sandbox name" + }), + ); + }; + + let chunks = match proposal_chunks_from_body(body) { + Ok(chunks) => chunks, + Err(error) => return (400, error_payload("invalid_proposal", error)), + }; + + let client = match crate::grpc_client::CachedOpenShellClient::connect(endpoint).await { + Ok(client) => client, + Err(error) => { + return ( + 502, + serde_json::json!({ + "error": "gateway_connect_failed", + "detail": error.to_string() + }), + ); + } + }; + + let response = match client + .submit_policy_analysis(sandbox_name, vec![], chunks, "agent_authored") + .await + { + Ok(response) => response, + Err(error) => { + return ( + 502, + serde_json::json!({ + "error": "proposal_submit_failed", + "detail": error.to_string() + }), + ); + } + }; + + ( + 202, + serde_json::json!({ + "status": "submitted", + "accepted_chunks": response.accepted_chunks, + "rejected_chunks": response.rejected_chunks, + "rejection_reasons": response.rejection_reasons, + }), + ) +} + +fn proposal_chunks_from_body(body: &[u8]) -> std::result::Result, String> { + let request: ProposalRequest = serde_json::from_slice(body).map_err(|e| e.to_string())?; + if request.operations.is_empty() { + return Err("proposal requires at least one operation".to_string()); + } + + let mut chunks = Vec::new(); + for operation in request.operations { + let Some(add_rule) = operation.get("addRule").cloned() else { + return Err( + "this MVP accepts `addRule` operations; submit a full narrow NetworkPolicyRule" + .to_string(), + ); + }; + let add_rule: AddNetworkRuleJson = + serde_json::from_value(add_rule).map_err(|e| e.to_string())?; + chunks.push(policy_chunk_from_add_rule( + add_rule, + request.intent_summary.as_deref().unwrap_or_default(), + )?); + } + + Ok(chunks) +} + +fn policy_chunk_from_add_rule( + add_rule: AddNetworkRuleJson, + intent_summary: &str, +) -> std::result::Result { + let mut rule = network_rule_from_json(add_rule.rule)?; + let rule_name = add_rule + .rule_name + .as_deref() + .map(str::trim) + .filter(|name| !name.is_empty()) + .map_or_else(|| rule.name.clone(), ToString::to_string); + if rule_name.trim().is_empty() { + return Err("addRule.ruleName or rule.name is required".to_string()); + } + if rule.name.trim().is_empty() { + rule.name.clone_from(&rule_name); + } + + let binary = rule + .binaries + .first() + .map(|binary| binary.path.clone()) + .unwrap_or_default(); + + Ok(PolicyChunk { + id: String::new(), + status: "pending".to_string(), + rule_name, + proposed_rule: Some(rule), + rationale: intent_summary.to_string(), + security_notes: String::new(), + confidence: 0.75, + denial_summary_ids: vec![], + created_at_ms: 0, + decided_at_ms: 0, + stage: "agent".to_string(), + supersedes_chunk_id: String::new(), + hit_count: 1, + first_seen_ms: 0, + last_seen_ms: 0, + binary, + }) +} + +fn network_rule_from_json( + rule: NetworkPolicyRuleJson, +) -> std::result::Result { + if rule.endpoints.is_empty() { + return Err("rule.endpoints must contain at least one endpoint".to_string()); + } + + let endpoints = rule + .endpoints + .into_iter() + .map(network_endpoint_from_json) + .collect::, _>>()?; + let binaries = rule + .binaries + .into_iter() + .map(|binary| NetworkBinary { + path: binary.path, + ..Default::default() + }) + .collect(); + + Ok(NetworkPolicyRule { + name: rule.name.unwrap_or_default(), + endpoints, + binaries, + }) +} + +fn network_endpoint_from_json( + endpoint: NetworkEndpointJson, +) -> std::result::Result { + if endpoint.host.trim().is_empty() { + return Err("endpoint.host is required".to_string()); + } + + let mut ports = endpoint.ports; + if ports.is_empty() && endpoint.port > 0 { + ports.push(endpoint.port); + } + if ports.is_empty() { + return Err("endpoint.port or endpoint.ports is required".to_string()); + } + if endpoint + .rules + .iter() + .any(|rule| rule.allow.path.contains('?')) + { + return Err("L7 allow paths must not include query strings".to_string()); + } + + let port = ports.first().copied().unwrap_or_default(); + let rules = endpoint + .rules + .into_iter() + .map(|rule| L7Rule { + allow: Some(L7Allow { + method: rule.allow.method, + path: rule.allow.path, + command: rule.allow.command, + query: HashMap::new(), + // GraphQL fields default empty — agent-authored proposals from + // policy.local target REST/SQL/L4 endpoints; GraphQL operation + // matching is set on the policy server side or via direct YAML. + operation_type: String::new(), + operation_name: String::new(), + fields: Vec::new(), + }), + }) + .collect(); + let deny_rules = endpoint + .deny_rules + .into_iter() + .map(|rule| L7DenyRule { + method: rule.method, + path: rule.path, + command: rule.command, + query: HashMap::new(), + operation_type: String::new(), + operation_name: String::new(), + fields: Vec::new(), + }) + .collect(); + + Ok(NetworkEndpoint { + host: endpoint.host, + port, + protocol: endpoint.protocol, + tls: endpoint.tls, + enforcement: endpoint.enforcement, + access: endpoint.access, + rules, + allowed_ips: endpoint.allowed_ips, + ports, + deny_rules, + allow_encoded_slash: endpoint.allow_encoded_slash, + // GraphQL persisted-query knobs and path scoping default empty — + // agent proposals don't author them today. + persisted_queries: String::new(), + graphql_persisted_queries: HashMap::new(), + graphql_max_body_bytes: 0, + path: String::new(), + }) +} + +async fn read_request_body(initial_request: &[u8], client: &mut S) -> Result> +where + S: AsyncRead + Unpin, +{ + let Some(header_end) = find_header_end(initial_request) else { + return Ok(Vec::new()); + }; + let content_length = parse_content_length(&initial_request[..header_end])?; + if content_length > MAX_POLICY_LOCAL_BODY_BYTES { + return Err(miette::miette!( + "policy.local request body exceeds {MAX_POLICY_LOCAL_BODY_BYTES} bytes" + )); + } + + let mut body = initial_request[header_end..].to_vec(); + if body.len() > content_length { + body.truncate(content_length); + } + let read_loop = async { + while body.len() < content_length { + let remaining = content_length - body.len(); + let mut chunk = vec![0u8; remaining.min(8192)]; + let n = client.read(&mut chunk).await.into_diagnostic()?; + if n == 0 { + return Err(miette::miette!("policy.local request body ended early")); + } + body.extend_from_slice(&chunk[..n]); + } + Ok::<(), miette::Report>(()) + }; + tokio::time::timeout(POLICY_LOCAL_BODY_READ_TIMEOUT, read_loop) + .await + .map_err(|_| miette::miette!("policy.local request body read timed out"))??; + + Ok(body) +} + +fn parse_content_length(headers: &[u8]) -> Result { + let headers = String::from_utf8_lossy(headers); + for line in headers.lines().skip(1) { + if let Some((name, value)) = line.split_once(':') + && name.eq_ignore_ascii_case("content-length") + { + return value + .trim() + .parse::() + .into_diagnostic() + .map_err(|_| miette::miette!("invalid policy.local Content-Length")); + } + } + Ok(0) +} + +fn find_header_end(buf: &[u8]) -> Option { + buf.windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|idx| idx + 4) +} + +async fn write_json_response( + client: &mut S, + status: u16, + payload: serde_json::Value, +) -> Result<()> +where + S: AsyncWrite + Unpin, +{ + let body = payload.to_string(); + let response = format!( + "HTTP/1.1 {status} {}\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + Connection: close\r\n\ + \r\n\ + {}", + status_text(status), + body.len(), + body + ); + client + .write_all(response.as_bytes()) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + Ok(()) +} + +fn status_text(status: u16) -> &'static str { + match status { + 202 => "Accepted", + 400 => "Bad Request", + 404 => "Not Found", + 500 => "Internal Server Error", + 502 => "Bad Gateway", + 503 => "Service Unavailable", + _ => "OK", + } +} + +fn error_payload(error: &str, detail: String) -> serde_json::Value { + serde_json::json!({ + "error": error, + "detail": detail + }) +} + +#[derive(Debug, Deserialize)] +struct ProposalRequest { + #[serde(default)] + intent_summary: Option, + #[serde(default)] + operations: Vec, +} + +#[derive(Debug, Deserialize)] +struct AddNetworkRuleJson { + #[serde(default, rename = "ruleName")] + rule_name: Option, + rule: NetworkPolicyRuleJson, +} + +#[derive(Debug, Deserialize)] +struct NetworkPolicyRuleJson { + #[serde(default)] + name: Option, + #[serde(default)] + endpoints: Vec, + #[serde(default)] + binaries: Vec, +} + +#[derive(Debug, Deserialize)] +struct NetworkEndpointJson { + host: String, + #[serde(default)] + port: u32, + #[serde(default)] + ports: Vec, + #[serde(default)] + protocol: String, + #[serde(default)] + tls: String, + #[serde(default)] + enforcement: String, + #[serde(default)] + access: String, + #[serde(default)] + rules: Vec, + #[serde(default)] + allowed_ips: Vec, + #[serde(default)] + deny_rules: Vec, + #[serde(default)] + allow_encoded_slash: bool, +} + +#[derive(Debug, Deserialize)] +struct NetworkBinaryJson { + path: String, +} + +#[derive(Debug, Deserialize)] +struct L7RuleJson { + allow: L7AllowJson, +} + +#[derive(Debug, Deserialize)] +struct L7AllowJson { + #[serde(default)] + method: String, + #[serde(default)] + path: String, + #[serde(default)] + command: String, +} + +#[derive(Debug, Deserialize)] +struct L7DenyRuleJson { + #[serde(default)] + method: String, + #[serde(default)] + path: String, + #[serde(default)] + command: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn proposal_chunks_from_body_accepts_add_rule_operation() { + let body = br#"{ + "intent_summary": "Allow gh to create one repo.", + "operations": [ + { + "addRule": { + "ruleName": "github_api_repo_create", + "rule": { + "endpoints": [ + { + "host": "api.github.com", + "port": 443, + "protocol": "rest", + "tls": "terminate", + "enforcement": "enforce", + "rules": [ + { + "allow": { + "method": "POST", + "path": "/user/repos" + } + } + ] + } + ], + "binaries": [ + { + "path": "/usr/bin/gh" + } + ] + } + } + } + ] + }"#; + + let chunks = proposal_chunks_from_body(body).unwrap(); + + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0].rule_name, "github_api_repo_create"); + assert_eq!(chunks[0].rationale, "Allow gh to create one repo."); + assert_eq!(chunks[0].binary, "/usr/bin/gh"); + let rule = chunks[0].proposed_rule.as_ref().unwrap(); + assert_eq!(rule.name, "github_api_repo_create"); + assert_eq!(rule.endpoints[0].host, "api.github.com"); + assert_eq!(rule.endpoints[0].port, 443); + assert_eq!(rule.endpoints[0].ports, vec![443]); + assert_eq!(rule.endpoints[0].protocol, "rest"); + assert_eq!( + rule.endpoints[0].rules[0].allow.as_ref().unwrap().path, + "/user/repos" + ); + } + + #[test] + fn proposal_chunks_from_body_rejects_query_in_l7_path() { + let body = br#"{ + "operations": [ + { + "addRule": { + "ruleName": "bad", + "rule": { + "endpoints": [ + { + "host": "api.github.com", + "port": 443, + "rules": [ + { + "allow": { + "method": "GET", + "path": "/repos?token=secret" + } + } + ] + } + ] + } + } + } + ] + }"#; + + let error = proposal_chunks_from_body(body).unwrap_err(); + assert!(error.contains("query strings")); + assert!(!error.contains("secret")); + } + + #[test] + fn parse_last_query_clamps_to_max() { + assert_eq!(parse_last_query("last=5"), Some(5)); + assert_eq!(parse_last_query("foo=bar&last=20"), Some(20)); + assert_eq!(parse_last_query("last=999"), Some(MAX_DENIALS_LIMIT)); + assert_eq!(parse_last_query("last=0"), Some(1)); + assert_eq!(parse_last_query(""), None); + assert_eq!(parse_last_query("other=1"), None); + } + + #[test] + fn is_ocsf_denial_line_filters_correctly() { + // OCSF denial — match. + assert!(is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com:443/x [policy:p engine:l7]" + )); + assert!(is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z OCSF NET:OPEN [MED] DENIED curl(42) -> blocked.com:443 [policy:- engine:opa]" + )); + + // OCSF allowed — must not match. + assert!(!is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z OCSF NET:OPEN [INFO] ALLOWED curl(42) -> api.example.com:443" + )); + + // Non-OCSF tracing line — must not match even if it contains the word DENIED. + assert!(!is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z INFO some::module: request DENIED in upstream" + )); + + // Empty line — must not match. + assert!(!is_ocsf_denial_line("")); + } + + #[tokio::test] + async fn recent_denials_returns_newest_first_from_shorthand_lines() { + let dir = tempfile::tempdir().unwrap(); + let log_path = dir.path().join("openshell.2026-05-06.log"); + // Mixed file: allowed events, non-OCSF info lines, two denials. + // Lines are written in chronological order; reader walks newest-first. + let body = "\ +2026-05-06T17:02:00.000Z OCSF NET:OPEN [INFO] ALLOWED curl(10) -> api.example.com:443 [policy:default engine:opa] +2026-05-06T17:02:01.000Z INFO some::module: routine status check +2026-05-06T17:02:02.000Z OCSF HTTP:GET [MED] DENIED GET http://blocked.example/v1/data [policy:default-deny engine:l7] +2026-05-06T17:02:03.000Z OCSF NET:OPEN [INFO] ALLOWED curl(11) -> api.example.com:443 +2026-05-06T17:02:04.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com:443/repos/x/y/contents/z [policy:gh_readonly engine:l7] +"; + std::fs::write(&log_path, body).unwrap(); + + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (status, payload) = recent_denials_response(&ctx, "last=10").await; + assert_eq!(status, 200); + assert_eq!(payload["log_available"], true); + let denials = payload["denials"].as_array().unwrap(); + assert_eq!(denials.len(), 2); + // Newest first. + assert!(denials[0].as_str().unwrap().contains("HTTP:PUT")); + assert!( + denials[0] + .as_str() + .unwrap() + .contains("/repos/x/y/contents/z") + ); + assert!(denials[1].as_str().unwrap().contains("HTTP:GET")); + assert!(denials[1].as_str().unwrap().contains("blocked.example")); + } + + #[tokio::test] + async fn recent_denials_skips_jsonl_log_files() { + // The shorthand reader must not surface `openshell-ocsf.*.log` content + // even if a deny-looking line is present, so the response stays + // independent of the JSONL appender's enabled state. + let dir = tempfile::tempdir().unwrap(); + let jsonl = dir.path().join("openshell-ocsf.2026-05-06.log"); + std::fs::write( + &jsonl, + r#"{"class_uid":4002,"action_id":2,"message":"DENIED","time":1}"#, + ) + .unwrap(); + + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (status, payload) = recent_denials_response(&ctx, "").await; + assert_eq!(status, 200); + assert_eq!(payload["log_available"], false); + assert_eq!(payload["denials"].as_array().unwrap().len(), 0); + } + + #[tokio::test] + async fn recent_denials_signals_when_log_is_missing() { + let dir = tempfile::tempdir().unwrap(); + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (status, payload) = recent_denials_response(&ctx, "").await; + assert_eq!(status, 200); + assert_eq!(payload["log_available"], false); + assert_eq!(payload["denials"].as_array().unwrap().len(), 0); + assert!( + payload["note"] + .as_str() + .unwrap() + .contains("/var/log/openshell.") + ); + } + + #[test] + fn redact_query_strings_removes_query_from_url_token() { + let line = "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com/x?access_token=secret-token-1234 [policy:p engine:l7]"; + let redacted = redact_query_strings(line); + assert!(!redacted.contains("secret-token-1234")); + assert!(!redacted.contains("access_token")); + assert!(redacted.contains("?[redacted]")); + // Bracketed tag after the URL preserved. + assert!(redacted.contains("[policy:p engine:l7]")); + } + + #[test] + fn redact_query_strings_removes_query_in_reason_tag() { + // The FORWARD deny path's `message` becomes `[reason:...]` and may + // include a path with query string lacking a `://` prefix. + let line = "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com/x [policy:p engine:opa] [reason:FORWARD denied PUT api.github.com:443/x?token=secret-456]"; + let redacted = redact_query_strings(line); + assert!(!redacted.contains("secret-456")); + assert!(!redacted.contains("token=secret")); + assert!(redacted.contains("?[redacted]]")); + } + + #[test] + fn redact_query_strings_handles_multibyte_chars() { + let line = "ÜLÅUTF8 ? secret-x [policy:p]"; + // No `?` here, so no redaction — but must not panic. + let _ = redact_query_strings(line); + } + + #[test] + fn truncate_at_char_boundary_does_not_panic_on_multibyte() { + // 4-byte emoji sequence so byte-naive slicing would panic. + let s = "🚀".repeat(2000); // 8000 bytes + let truncated = truncate_at_char_boundary(&s, 4096); + assert!(truncated.len() <= 4096 + "...[truncated]".len()); + assert!(truncated.ends_with("...[truncated]")); + // Result must be valid UTF-8 — implicit if we return without panic. + } + + #[tokio::test] + async fn recent_denials_truncates_pathological_lines() { + let dir = tempfile::tempdir().unwrap(); + let log_path = dir.path().join("openshell.2026-05-06.log"); + // A single OCSF denial line exceeding MAX_DENIAL_LINE_BYTES. + let huge_path = "/".to_string() + &"a".repeat(MAX_DENIAL_LINE_BYTES + 100); + let line = format!( + "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://x{huge_path} [policy:p engine:l7]\n" + ); + std::fs::write(&log_path, line).unwrap(); + + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (_, payload) = recent_denials_response(&ctx, "last=1").await; + let denials = payload["denials"].as_array().unwrap(); + assert_eq!(denials.len(), 1); + let surfaced = denials[0].as_str().unwrap(); + assert!(surfaced.len() <= MAX_DENIAL_LINE_BYTES + "...[truncated]".len()); + assert!(surfaced.ends_with("...[truncated]")); + } + + use crate::test_helpers::ProposalsFlagGuard; + + #[test] + fn agent_next_steps_returns_empty_when_flag_off() { + let _guard = ProposalsFlagGuard::set_blocking(false); + let steps = agent_next_steps(); + let arr = steps.as_array().expect("agent_next_steps is an array"); + assert!( + arr.is_empty(), + "expected empty next_steps when feature is off, got {steps}" + ); + } + + #[test] + fn agent_next_steps_returns_full_array_when_flag_on() { + let _guard = ProposalsFlagGuard::set_blocking(true); + let steps = agent_next_steps(); + let arr = steps.as_array().expect("agent_next_steps is an array"); + assert_eq!(arr.len(), 4, "expected 4 next_steps when feature is on"); + let actions: Vec<&str> = arr + .iter() + .filter_map(|v| v.get("action").and_then(serde_json::Value::as_str)) + .collect(); + assert!(actions.contains(&"read_skill")); + assert!(actions.contains(&"submit_proposal")); + } + + #[tokio::test] + async fn route_request_returns_feature_disabled_when_flag_off() { + let _guard = ProposalsFlagGuard::set(false).await; + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + + // Even the otherwise-public `current_policy` route returns 404 with + // a feature_disabled error: when the surface is off it's off + // entirely, not selectively. + let (status, payload) = route_request(&ctx, "GET", ROUTE_POLICY_CURRENT, &[]).await; + assert_eq!(status, 404); + assert_eq!(payload["error"], "feature_disabled"); + assert!( + payload["detail"] + .as_str() + .unwrap() + .contains("agent_policy_proposals_enabled"), + "feature_disabled detail must name the setting key for actionability" + ); + } + + #[tokio::test] + async fn current_policy_route_returns_yaml_envelope() { + let _guard = ProposalsFlagGuard::set(true).await; + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + + let (mut client, mut server) = tokio::io::duplex(4096); + let request = + b"GET http://policy.local/v1/policy/current HTTP/1.1\r\nHost: policy.local\r\n\r\n"; + let task = tokio::spawn(async move { + handle_forward_request(&ctx, "GET", "/v1/policy/current", request, &mut server) + .await + .unwrap(); + }); + + let mut received = Vec::new(); + client.read_to_end(&mut received).await.unwrap(); + task.await.unwrap(); + + let response = String::from_utf8(received).unwrap(); + assert!(response.starts_with("HTTP/1.1 200 OK")); + let (_, body) = response.split_once("\r\n\r\n").unwrap(); + let body: serde_json::Value = serde_json::from_str(body).unwrap(); + assert_eq!(body["format"], "yaml"); + assert!(body["policy_yaml"].as_str().unwrap().contains("version: 1")); + } +} diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 179576d82..f20e51655 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -8,6 +8,7 @@ use crate::identity::BinaryIdentityCache; use crate::l7::tls::ProxyTlsState; use crate::opa::{NetworkAction, OpaEngine, PolicyGenerationGuard}; use crate::policy::ProxyPolicy; +use crate::policy_local::{POLICY_LOCAL_HOST, PolicyLocalContext}; use crate::provider_credentials::ProviderCredentialState; use crate::secrets::{SecretResolver, rewrite_header_line}; use miette::{IntoDiagnostic, Result}; @@ -157,6 +158,7 @@ impl ProxyHandle { tls_state: Option>, inference_ctx: Option>, provider_credentials: Option, + policy_local_ctx: Option>, denial_tx: Option>, ) -> Result { // Use override bind_addr, fall back to policy http_addr, then default @@ -195,13 +197,22 @@ impl ProxyHandle { let spid = entrypoint_pid.clone(); let tls = tls_state.clone(); let inf = inference_ctx.clone(); + let policy_local = policy_local_ctx.clone(); let resolver = provider_credentials .as_ref() .and_then(ProviderCredentialState::resolver); let dtx = denial_tx.clone(); tokio::spawn(async move { if let Err(err) = handle_tcp_connection( - stream, opa, cache, spid, tls, inf, resolver, dtx, + stream, + opa, + cache, + spid, + tls, + inf, + policy_local, + resolver, + dtx, ) .await { @@ -316,6 +327,7 @@ async fn handle_tcp_connection( entrypoint_pid: Arc, tls_state: Option>, inference_ctx: Option>, + policy_local_ctx: Option>, secret_resolver: Option>, denial_tx: Option>, ) -> Result<()> { @@ -360,6 +372,7 @@ async fn handle_tcp_connection( opa_engine, identity_cache, entrypoint_pid, + policy_local_ctx, secret_resolver, denial_tx.as_ref(), ) @@ -2411,6 +2424,7 @@ async fn handle_forward_proxy( opa_engine: Arc, identity_cache: Arc, entrypoint_pid: Arc, + policy_local_ctx: Option>, secret_resolver: Option>, denial_tx: Option<&mpsc::UnboundedSender>, ) -> Result<()> { @@ -2434,6 +2448,38 @@ async fn handle_forward_proxy( }; let host_lc = host.to_ascii_lowercase(); + if host_lc == POLICY_LOCAL_HOST { + if scheme != "http" || port != 80 { + respond( + client, + &build_json_error_response( + 400, + "Bad Request", + "invalid_policy_local_scheme", + "Use http://policy.local only", + ), + ) + .await?; + return Ok(()); + } + if let Some(ctx) = policy_local_ctx { + return crate::policy_local::handle_forward_request( + &ctx, + method, + &path, + &buf[..used], + client, + ) + .await; + } + respond( + client, + b"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 31\r\n\r\npolicy.local is not configured", + ) + .await?; + return Ok(()); + } + // 2. Reject HTTPS — must use CONNECT for TLS if scheme == "https" { { diff --git a/crates/openshell-sandbox/src/skills.rs b/crates/openshell-sandbox/src/skills.rs new file mode 100644 index 000000000..91654699f --- /dev/null +++ b/crates/openshell-sandbox/src/skills.rs @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Static agent guidance files exposed inside the sandbox. + +use miette::{IntoDiagnostic, Result}; +use std::path::{Path, PathBuf}; + +const SKILLS_RELATIVE_DIR: &str = "etc/openshell/skills"; +const POLICY_ADVISOR_FILE: &str = "policy_advisor.md"; +const POLICY_ADVISOR_CONTENT: &str = include_str!("skills/policy_advisor.md"); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct InstalledSkills { + pub policy_advisor: PathBuf, +} + +pub fn install_static_skills() -> Result { + install_static_skills_at(Path::new("/")) +} + +fn install_static_skills_at(root: &Path) -> Result { + let skills_dir = root.join(SKILLS_RELATIVE_DIR); + std::fs::create_dir_all(&skills_dir).into_diagnostic()?; + + let policy_advisor = skills_dir.join(POLICY_ADVISOR_FILE); + std::fs::write(&policy_advisor, POLICY_ADVISOR_CONTENT).into_diagnostic()?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + + std::fs::set_permissions(&policy_advisor, std::fs::Permissions::from_mode(0o444)) + .into_diagnostic()?; + } + + Ok(InstalledSkills { policy_advisor }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn install_static_skills_at_writes_policy_advisor() { + let dir = tempfile::tempdir().unwrap(); + + let installed = install_static_skills_at(dir.path()).unwrap(); + + let expected = dir + .path() + .join("etc") + .join("openshell") + .join("skills") + .join("policy_advisor.md"); + assert_eq!(installed.policy_advisor, expected); + + let content = std::fs::read_to_string(expected).unwrap(); + assert!(content.contains("# OpenShell Policy Advisor")); + assert!(content.contains("policy.local")); + assert!(content.contains("addRule")); + } +} diff --git a/crates/openshell-sandbox/src/skills/policy_advisor.md b/crates/openshell-sandbox/src/skills/policy_advisor.md new file mode 100644 index 000000000..57546145c --- /dev/null +++ b/crates/openshell-sandbox/src/skills/policy_advisor.md @@ -0,0 +1,100 @@ +# OpenShell Policy Advisor + +Use this when OpenShell blocks a network request and the response or logs say +`policy_denied`. + +## Goal + +Draft the smallest policy proposal that allows the user's current task without +giving the sandbox broad new network access. The developer approves or rejects +the proposal; do not try to bypass policy. + +## Local API + +The sandbox-local policy API is reachable at `http://policy.local`: + +- `GET /v1/policy/current` — current effective policy as YAML. +- `GET /v1/denials?last=10` — most recent network/L7 denials seen by this + sandbox (newest first), returned as raw shorthand log lines. Each line + carries the timestamp, class, severity, action, host/port, binary, policy + name, and (for denied events) a short reason. Read the lines directly; you + do not need to parse them into structured fields. +- `POST /v1/proposals` — submit a proposal for developer approval. + +The proposal body takes an `intent_summary` and one or more `addRule` +operations. Each `addRule` carries a complete narrow `NetworkPolicyRule`. + +## Workflow + +1. Read the denial response body. Use `layer`, `method`, `path`, `host`, + `port`, `binary`, `rule_missing`, and `detail` as evidence. +2. Fetch the current policy from `/v1/policy/current`. +3. Fetch recent denials from `/v1/denials` if the response body is incomplete. +4. Prefer L7 REST rules for REST APIs. Use L4 only for non-REST protocols or + when the client tunnels opaque traffic that OpenShell cannot inspect. +5. Draft the narrowest rule: exact host, exact port, exact binary when known, + exact method, and the smallest safe path. +6. Submit the proposal, tell the developer what you proposed, and retry the + denied action only after approval. + +## Proposal shape + +A complete narrow REST-inspected rule looks like this: + +```json +{ + "intent_summary": "Allow gh to update repository contents in NVIDIA/OpenShell only.", + "operations": [ + { + "addRule": { + "ruleName": "github_api_repo_contents_write", + "rule": { + "name": "github_api_repo_contents_write", + "endpoints": [ + { + "host": "api.github.com", + "port": 443, + "protocol": "rest", + "enforcement": "enforce", + "rules": [ + { + "allow": { + "method": "PUT", + "path": "/repos/NVIDIA/OpenShell/contents/**" + } + } + ] + } + ], + "binaries": [ + { + "path": "/usr/bin/gh" + } + ] + } + } + } + ] +} +``` + +## Norms + +- Do not propose wildcard hosts such as `**` or `*.com`. +- Do not propose `access: full` to fix a single denied REST request. +- Do not include query strings, tokens, credentials, or secret values in + paths. +- Explain uncertainty in `intent_summary` instead of widening the rule. +- If pushing with `git` fails, that is a separate L4 or protocol-specific + path from GitHub REST API access. Propose it separately. + +## Local logs (read-only) + +Two local files complement the API and are useful when debugging policy +behavior: + +- `/var/log/openshell.YYYY-MM-DD.log` — shorthand log of sandbox activity. + This is what `/v1/denials` reads from. +- `/var/log/openshell-ocsf.YYYY-MM-DD.log` — full OCSF JSON events, only + written when the `ocsf_json_enabled` setting is on. Not used by + `/v1/denials`; useful for SIEM ingestion. diff --git a/e2e/policy-advisor/README.md b/e2e/policy-advisor/README.md new file mode 100644 index 000000000..79f496e3e --- /dev/null +++ b/e2e/policy-advisor/README.md @@ -0,0 +1,54 @@ + + + +# Policy Advisor end-to-end test + +Deterministic, no-LLM exercise of the agent-driven policy loop: + +1. Start a sandbox with a read-only GitHub L7 policy. +2. From inside the sandbox, attempt a GitHub contents PUT and assert OpenShell + returns a structured `policy_denied` 403. +3. Submit a narrow `addRule` proposal through `http://policy.local/v1/proposals`. +4. Approve the draft from the host and retry until the write succeeds. + +This proves the proxy, the structured deny body, the `policy.local` HTTP API, +the gateway proposal path, and the hot-reload of approved rules — without +involving an LLM. The user-facing demo (`examples/agent-driven-policy-management/`) +runs the same loop with Codex driving from inside the sandbox. + +## Run it + +Run against an ephemeral Docker gateway: + +```bash +DEMO_GITHUB_OWNER= \ +DEMO_GITHUB_REPO=openshell-policy-demo \ +e2e/with-docker-gateway.sh bash -lc ' + target/debug/openshell settings set --global \ + --key agent_policy_proposals_enabled \ + --value true \ + --yes + OPENSHELL_BIN="$PWD/target/debug/openshell" bash e2e/policy-advisor/test.sh +' +``` + +To keep the sandbox for debugging, start a local gateway first with +`mise run gateway:docker`, then run: + +```bash +target/debug/openshell settings set --global \ + --key agent_policy_proposals_enabled \ + --value true \ + --yes + +OPENSHELL_GATEWAY=docker-dev \ +OPENSHELL_BIN="$PWD/target/debug/openshell" \ +DEMO_KEEP_SANDBOX=1 \ +DEMO_GITHUB_OWNER= \ +DEMO_GITHUB_REPO=openshell-policy-demo \ +bash e2e/policy-advisor/test.sh +``` + +Requires Docker, `agent_policy_proposals_enabled=true`, and a GitHub token with +contents write on the repository. The test auto-resolves the token from +`DEMO_GITHUB_TOKEN`, `GITHUB_TOKEN`, `GH_TOKEN`, or `gh auth token`. diff --git a/e2e/policy-advisor/policy.template.yaml b/e2e/policy-advisor/policy.template.yaml new file mode 100644 index 000000000..6452cb01c --- /dev/null +++ b/e2e/policy-advisor/policy.template.yaml @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +version: 1 + +filesystem_policy: + include_workdir: true + read_only: [/usr, /lib, /proc, /dev/urandom, /app, /etc, /var/log] + read_write: [/sandbox, /tmp, /dev/null] + +landlock: + compatibility: best_effort + +process: + run_as_user: sandbox + run_as_group: sandbox + +network_policies: + github_api_readonly: + name: github-api-readonly + endpoints: + - host: api.github.com + port: 443 + protocol: rest + enforcement: enforce + access: read-only + binaries: + - { path: /usr/bin/curl } diff --git a/e2e/policy-advisor/sandbox-runner.sh b/e2e/policy-advisor/sandbox-runner.sh new file mode 100755 index 000000000..780e85573 --- /dev/null +++ b/e2e/policy-advisor/sandbox-runner.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash + +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +cmd="$1" +shift + +json_status_response() { + local status="$1" + local body="$2" + printf 'HTTP_STATUS=%s\n' "$status" + cat "$body" + printf '\n' +} + +case "$cmd" in + check-skill) + test -f /etc/openshell/skills/policy_advisor.md + sed -n '1,40p' /etc/openshell/skills/policy_advisor.md + ;; + + current-policy) + body="$(mktemp)" + status="$(curl -sS -o "$body" -w "%{http_code}" http://policy.local/v1/policy/current)" + json_status_response "$status" "$body" + ;; + + put-file) + owner="$1" + repo="$2" + branch="$3" + file_path="$4" + run_id="$5" + body="$(mktemp)" + payload="$(mktemp)" + + python3 - "$branch" "$run_id" > "$payload" <<'PY' +import base64 +import json +import sys + +branch, run_id = sys.argv[1:3] +content = f"""# OpenShell policy advisor demo + +Run id: {run_id} + +This file was written from inside an OpenShell sandbox after an agent-authored +policy proposal was approved. +""" + +payload = { + "message": f"docs: add OpenShell policy advisor demo note {run_id}", + "branch": branch, + "content": base64.b64encode(content.encode("utf-8")).decode("ascii"), +} +print(json.dumps(payload)) +PY + + status="$(curl -sS \ + -o "$body" \ + -w "%{http_code}" \ + -X PUT \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + -H "Content-Type: application/json" \ + --data-binary "@${payload}" \ + "https://api.github.com/repos/${owner}/${repo}/contents/${file_path}")" + json_status_response "$status" "$body" + ;; + + submit-proposal) + owner="$1" + repo="$2" + file_path="$3" + body="$(mktemp)" + payload="$(mktemp)" + + python3 - "$owner" "$repo" "$file_path" > "$payload" <<'PY' +import json +import sys + +owner, repo, file_path = sys.argv[1:4] +rule_path = f"/repos/{owner}/{repo}/contents/{file_path}" +payload = { + "intent_summary": ( + "Allow curl to write the demo note to " + f"{owner}/{repo} at {file_path} only." + ), + "operations": [ + { + "addRule": { + "ruleName": "github_api_demo_contents_write", + "rule": { + "name": "github_api_demo_contents_write", + "endpoints": [ + { + "host": "api.github.com", + "port": 443, + "protocol": "rest", + "enforcement": "enforce", + "rules": [ + { + "allow": { + "method": "PUT", + "path": rule_path, + } + } + ], + } + ], + "binaries": [ + { + "path": "/usr/bin/curl", + } + ], + }, + } + } + ], +} +print(json.dumps(payload)) +PY + + status="$(curl -sS \ + -o "$body" \ + -w "%{http_code}" \ + -X POST \ + -H "Content-Type: application/json" \ + --data-binary "@${payload}" \ + http://policy.local/v1/proposals)" + json_status_response "$status" "$body" + ;; + + *) + echo "unknown command: $cmd" >&2 + exit 64 + ;; +esac diff --git a/e2e/policy-advisor/test.sh b/e2e/policy-advisor/test.sh new file mode 100755 index 000000000..cef09d1ed --- /dev/null +++ b/e2e/policy-advisor/test.sh @@ -0,0 +1,410 @@ +#!/usr/bin/env bash + +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +POLICY_TEMPLATE="${SCRIPT_DIR}/policy.template.yaml" +RUNNER_SOURCE="${SCRIPT_DIR}/sandbox-runner.sh" + +if [[ -z "${OPENSHELL_BIN:-}" ]]; then + if [[ -x "${REPO_ROOT}/target/debug/openshell" ]]; then + OPENSHELL_BIN="${REPO_ROOT}/target/debug/openshell" + else + OPENSHELL_BIN="openshell" + fi +fi + +DEMO_BRANCH="${DEMO_BRANCH:-main}" +DEMO_RUN_ID="${DEMO_RUN_ID:-$(date +%Y%m%d-%H%M%S)}" +DEMO_FILE_DIR="${DEMO_FILE_DIR:-openshell-policy-advisor-validation}" +DEMO_FILE_PATH="${DEMO_FILE_PATH:-${DEMO_FILE_DIR}/${DEMO_RUN_ID}.md}" +DEMO_SANDBOX_NAME="${DEMO_SANDBOX_NAME:-policy-agent-validation-${DEMO_RUN_ID}}" +DEMO_GITHUB_PROVIDER_NAME="${DEMO_GITHUB_PROVIDER_NAME:-github-policy-validation-${DEMO_RUN_ID}}" +DEMO_KEEP_SANDBOX="${DEMO_KEEP_SANDBOX:-0}" +DEMO_RETRY_ATTEMPTS="${DEMO_RETRY_ATTEMPTS:-30}" +DEMO_RETRY_SLEEP="${DEMO_RETRY_SLEEP:-2}" + +TMP_DIR="" +POLICY_FILE="" +SSH_CONFIG="" +SSH_HOST="" + +BOLD='\033[1m' +DIM='\033[2m' +CYAN='\033[36m' +GREEN='\033[32m' +RED='\033[31m' +YELLOW='\033[33m' +RESET='\033[0m' + +step() { + printf "\n${BOLD}${CYAN}==> %s${RESET}\n\n" "$1" +} + +info() { + printf " %b\n" "$*" +} + +fail() { + printf "\n${RED}error:${RESET} %s\n" "$*" >&2 + exit 1 +} + +cleanup() { + local status=$? + + if [[ "$DEMO_KEEP_SANDBOX" != "1" ]]; then + "$OPENSHELL_BIN" sandbox delete "$DEMO_SANDBOX_NAME" >/dev/null 2>&1 || true + else + printf "\n${YELLOW}Keeping sandbox because DEMO_KEEP_SANDBOX=1: %s${RESET}\n" "$DEMO_SANDBOX_NAME" + fi + + "$OPENSHELL_BIN" provider delete "$DEMO_GITHUB_PROVIDER_NAME" >/dev/null 2>&1 || true + + if [[ -z "$TMP_DIR" ]]; then + return + fi + + if [[ $status -eq 0 ]]; then + rm -rf "$TMP_DIR" + else + printf "\n${YELLOW}Temporary files kept at: %s${RESET}\n" "$TMP_DIR" + fi +} +trap cleanup EXIT + +require_command() { + command -v "$1" >/dev/null 2>&1 || fail "missing required command: $1" +} + +validate_name() { + local label="$1" + local value="$2" + [[ "$value" =~ ^[A-Za-z0-9_.-]+$ ]] || fail "$label may contain only letters, numbers, '.', '_', and '-'" +} + +validate_path() { + local label="$1" + local value="$2" + [[ "$value" =~ ^[A-Za-z0-9._/-]+$ ]] || fail "$label may contain only letters, numbers, '.', '_', '-', and '/'" + [[ "$value" != /* ]] || fail "$label must be relative" + [[ "$value" != *..* ]] || fail "$label must not contain '..'" +} + +resolve_token() { + if [[ -z "${DEMO_GITHUB_TOKEN:-}" ]]; then + if [[ -n "${GITHUB_TOKEN:-}" ]]; then + DEMO_GITHUB_TOKEN="$GITHUB_TOKEN" + elif [[ -n "${GH_TOKEN:-}" ]]; then + DEMO_GITHUB_TOKEN="$GH_TOKEN" + elif command -v gh >/dev/null 2>&1; then + DEMO_GITHUB_TOKEN="$(gh auth token 2>/dev/null || true)" + fi + fi + + [[ -n "${DEMO_GITHUB_TOKEN:-}" ]] || fail "set DEMO_GITHUB_TOKEN, GITHUB_TOKEN, GH_TOKEN, or sign in with gh" + export GITHUB_TOKEN="$DEMO_GITHUB_TOKEN" +} + +validate_env() { + require_command curl + require_command jq + require_command ssh + require_command "$OPENSHELL_BIN" + + [[ -f "$RUNNER_SOURCE" ]] || fail "missing sandbox runner: $RUNNER_SOURCE" + [[ -n "${DEMO_GITHUB_OWNER:-}" ]] || fail "set DEMO_GITHUB_OWNER" + [[ -n "${DEMO_GITHUB_REPO:-}" ]] || fail "set DEMO_GITHUB_REPO" + [[ "$DEMO_RUN_ID" =~ ^[a-z0-9-]+$ ]] || fail "DEMO_RUN_ID may contain only lowercase letters, numbers, and '-'" + [[ "$DEMO_RETRY_ATTEMPTS" =~ ^[0-9]+$ ]] || fail "DEMO_RETRY_ATTEMPTS must be a number" + [[ "$DEMO_RETRY_SLEEP" =~ ^[0-9]+$ ]] || fail "DEMO_RETRY_SLEEP must be a number" + + validate_name "DEMO_GITHUB_OWNER" "$DEMO_GITHUB_OWNER" + validate_name "DEMO_GITHUB_REPO" "$DEMO_GITHUB_REPO" + validate_path "DEMO_BRANCH" "$DEMO_BRANCH" + validate_path "DEMO_FILE_PATH" "$DEMO_FILE_PATH" + + resolve_token +} + +github_api_status() { + local url="$1" + local body="$2" + curl -sS \ + -o "$body" \ + -w "%{http_code}" \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${DEMO_GITHUB_TOKEN}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + "$url" +} + +check_gateway() { + step "Checking active OpenShell gateway" + if ! "$OPENSHELL_BIN" status >/dev/null 2>&1; then + fail "active OpenShell gateway is not reachable; start one separately, for example: mise run cluster" + fi + "$OPENSHELL_BIN" status | sed 's/^/ /' +} + +check_github_access() { + step "Checking GitHub repository access" + local body status branch branches_body branches_status branches + body="${TMP_DIR}/github-repo.json" + status="$(github_api_status "https://api.github.com/repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}" "$body")" + + if [[ "$status" != "200" ]]; then + printf '%s\n' "$(jq -r '.message // empty' "$body" 2>/dev/null)" | sed 's/^/ /' + fail "GitHub returned HTTP $status for ${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}; check the repo name and token access" + fi + + if jq -e 'has("permissions") and (.permissions.push == false and .permissions.admin == false and .permissions.maintain == false)' "$body" >/dev/null; then + fail "GitHub token can read ${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO} but does not appear to have write access" + fi + + branch="$(jq -rn --arg v "$DEMO_BRANCH" '$v|@uri')" + body="${TMP_DIR}/github-branch.json" + status="$(github_api_status "https://api.github.com/repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}/branches/${branch}" "$body")" + if [[ "$status" != "200" ]]; then + branches_body="${TMP_DIR}/github-branches.json" + branches_status="$(github_api_status "https://api.github.com/repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}/branches?per_page=20" "$branches_body")" + if [[ "$branches_status" == "200" ]]; then + branches="$(jq -r 'map(.name) | join(", ")' "$branches_body")" + if [[ -z "$branches" ]]; then + fail "GitHub repo exists but has no branches yet; add an initial README or push ${DEMO_BRANCH} before running the demo" + fi + fail "GitHub returned HTTP $status for branch ${DEMO_BRANCH}; set DEMO_BRANCH to one of: ${branches}" + fi + fail "GitHub returned HTTP $status for branch ${DEMO_BRANCH}" + fi + + body="${TMP_DIR}/github-demo-file.json" + status="$(github_api_status "https://api.github.com/repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}/contents/${DEMO_FILE_PATH}?ref=${branch}" "$body")" + if [[ "$status" == "200" ]]; then + fail "validation output file already exists: ${DEMO_FILE_PATH}; choose a new DEMO_RUN_ID or DEMO_FILE_PATH" + fi + [[ "$status" == "404" ]] || fail "GitHub returned HTTP $status while checking demo output path ${DEMO_FILE_PATH}" + + info "${GREEN}GitHub repo, branch, and output path are safe for this run.${RESET}" +} + +create_provider() { + step "Creating temporary GitHub provider" + "$OPENSHELL_BIN" provider delete "$DEMO_GITHUB_PROVIDER_NAME" >/dev/null 2>&1 || true + "$OPENSHELL_BIN" provider create \ + --name "$DEMO_GITHUB_PROVIDER_NAME" \ + --type github \ + --credential GITHUB_TOKEN +} + +check_agent_proposals_enabled() { + step "Checking agent-driven policy proposal opt-in" + local value + value="$("$OPENSHELL_BIN" settings get --global --json 2>/dev/null \ + | jq -r '.settings.agent_policy_proposals_enabled // ""')" + if [[ "$value" != "true" ]]; then + fail "agent_policy_proposals_enabled must be true before running this test. +Enable it with: + $OPENSHELL_BIN settings set --global --key agent_policy_proposals_enabled --value true --yes" + fi + info "${GREEN}agent_policy_proposals_enabled=true${RESET}" +} + +create_temp_workspace() { + TMP_DIR="$(mktemp -d "${TMPDIR:-/tmp}/openshell-agent-policy.XXXXXX")" + POLICY_FILE="${TMP_DIR}/policy.yaml" + SSH_CONFIG="${TMP_DIR}/ssh_config" +} + +create_sandbox() { + step "Creating sandbox with read-only GitHub L7 policy" + cp "$POLICY_TEMPLATE" "$POLICY_FILE" + "$OPENSHELL_BIN" sandbox delete "$DEMO_SANDBOX_NAME" >/dev/null 2>&1 || true + "$OPENSHELL_BIN" sandbox create \ + --name "$DEMO_SANDBOX_NAME" \ + --provider "$DEMO_GITHUB_PROVIDER_NAME" \ + --policy "$POLICY_FILE" \ + --upload "${RUNNER_SOURCE}:/sandbox/policy-validation-runner.sh" \ + --no-git-ignore \ + --keep \ + --no-auto-providers \ + --no-tty \ + -- bash -lc "chmod +x /sandbox/policy-validation-runner.sh && echo sandbox ready" +} + +connect_ssh() { + step "Connecting to sandbox over SSH" + "$OPENSHELL_BIN" sandbox ssh-config "$DEMO_SANDBOX_NAME" > "$SSH_CONFIG" + SSH_HOST="$(awk '/^Host / { print $2; exit }' "$SSH_CONFIG")" + [[ -n "$SSH_HOST" ]] || fail "could not find Host entry in sandbox SSH config" + + local retries=30 + local i + for i in $(seq 1 "$retries"); do + if ssh -F "$SSH_CONFIG" "$SSH_HOST" true >/dev/null 2>&1; then + return + fi + sleep 2 + done + fail "SSH connection to sandbox timed out" +} + +sandbox_exec() { + ssh -F "$SSH_CONFIG" "$SSH_HOST" "$@" +} + +http_status() { + awk -F= '/^HTTP_STATUS=/ { print $2; exit }' +} + +http_body() { + sed '/^HTTP_STATUS=/d' +} + +run_policy_local_checks() { + step "Checking sandbox-local skill and policy.local" + sandbox_exec /sandbox/policy-validation-runner.sh check-skill >/dev/null + info "${GREEN}Skill installed:${RESET} /etc/openshell/skills/policy_advisor.md" + + local output + output="$(sandbox_exec /sandbox/policy-validation-runner.sh current-policy)" + local status + status="$(printf '%s\n' "$output" | http_status)" + [[ "$status" == "200" ]] || fail "policy.local current policy returned HTTP $status" + + info "${GREEN}policy.local returned the current sandbox policy.${RESET}" + info "Initial policy: read-only REST access to api.github.com for /usr/bin/curl" +} + +attempt_write() { + sandbox_exec /sandbox/policy-validation-runner.sh put-file \ + "$DEMO_GITHUB_OWNER" \ + "$DEMO_GITHUB_REPO" \ + "$DEMO_BRANCH" \ + "$DEMO_FILE_PATH" \ + "$DEMO_RUN_ID" +} + +submit_policy_proposal() { + sandbox_exec /sandbox/policy-validation-runner.sh submit-proposal \ + "$DEMO_GITHUB_OWNER" \ + "$DEMO_GITHUB_REPO" \ + "$DEMO_FILE_PATH" +} + +capture_initial_denial() { + step "Attempting GitHub contents write from inside sandbox" + local output + output="$(attempt_write)" + local status + local body + status="$(printf '%s\n' "$output" | http_status)" + body="$(printf '%s\n' "$output" | http_body)" + + [[ "$status" == "403" ]] || fail "expected OpenShell HTTP 403, got HTTP $status" + printf '%s\n' "$body" | jq -e '.error == "policy_denied"' >/dev/null \ + || fail "expected structured policy_denied body" + printf '%s\n' "$body" | jq -e '.layer == "l7" and .protocol == "rest" and .method == "PUT"' >/dev/null \ + || fail "expected structured L7 REST deny fields" + + printf '%s\n' "$body" | jq -r ' + "Denied: \(.method) \(.path)", + "Layer: \(.layer)/\(.protocol) host=\(.host):\(.port) binary=\(.binary)", + "Agent guidance: \(.next_steps | map(.action) | join(" -> "))" + ' | sed 's/^/ /' + info "${GREEN}Captured structured L7 policy denial.${RESET}" +} + +submit_and_approve() { + step "Submitting proposal through policy.local" + local output + output="$(submit_policy_proposal)" + local status + local body + status="$(printf '%s\n' "$output" | http_status)" + body="$(printf '%s\n' "$output" | http_body)" + + [[ "$status" == "202" ]] || fail "expected proposal submit HTTP 202, got HTTP $status" + [[ "$(printf '%s\n' "$body" | jq -r '.accepted_chunks // 0')" != "0" ]] \ + || fail "proposal was not accepted" + printf '%s\n' "$body" | jq -r '"Proposal submitted: \(.accepted_chunks) accepted, \(.rejected_chunks) rejected"' | sed 's/^/ /' + + step "Approving pending draft rule from outside the sandbox" + "$OPENSHELL_BIN" rule get "$DEMO_SANDBOX_NAME" --status pending | sed 's/^/ /' + "$OPENSHELL_BIN" rule approve-all "$DEMO_SANDBOX_NAME" | sed 's/^/ /' +} + +print_success_summary() { + jq '{ + path: .content.path, + html_url: .content.html_url, + commit: .commit.sha, + message: .commit.message + }' +} + +retry_until_allowed() { + step "Retrying GitHub contents write after approval" + local output status body attempt + + for attempt in $(seq 1 "$DEMO_RETRY_ATTEMPTS"); do + output="$(attempt_write)" + status="$(printf '%s\n' "$output" | http_status)" + body="$(printf '%s\n' "$output" | http_body)" + + if printf '%s\n' "$body" | jq -e '.error == "policy_denied"' >/dev/null 2>&1; then + info "${DIM}Attempt ${attempt}/${DEMO_RETRY_ATTEMPTS}: policy not loaded yet; retrying...${RESET}" + sleep "$DEMO_RETRY_SLEEP" + continue + fi + + if [[ "$status" == "200" || "$status" == "201" ]]; then + printf '%s\n' "$body" | print_success_summary | sed 's/^/ /' + info "${GREEN}GitHub write succeeded from inside the sandbox.${RESET}" + return + fi + + printf '%s\n' "$body" | jq . | sed 's/^/ /' + if [[ "$status" == "404" ]]; then + fail "policy allowed the request, but GitHub returned HTTP 404; check DEMO_GITHUB_OWNER, DEMO_GITHUB_REPO, and token access" + fi + fail "policy allowed the request, but GitHub returned HTTP $status" + done + + fail "timed out waiting for approved policy to load into the sandbox" +} + +show_logs() { + step "Policy decision trace" + "$OPENSHELL_BIN" logs "$DEMO_SANDBOX_NAME" --since 5m -n 50 2>&1 \ + | grep -E 'HTTP:PUT|CONFIG:LOADED|ReportPolicyStatus' \ + | tail -n 8 \ + | sed 's/^/ /' || true +} + +main() { + validate_env + check_gateway + check_agent_proposals_enabled + create_temp_workspace + check_github_access + create_provider + create_sandbox + connect_ssh + run_policy_local_checks + capture_initial_denial + submit_and_approve + retry_until_allowed + show_logs + + printf "\n${BOLD}${GREEN}✓ Validation complete.${RESET}\n\n" + printf " Sandbox: %s\n" "$DEMO_SANDBOX_NAME" + printf " Repository: https://github.com/%s/%s\n" "$DEMO_GITHUB_OWNER" "$DEMO_GITHUB_REPO" + printf " File: %s\n" "$DEMO_FILE_PATH" +} + +main "$@" diff --git a/examples/agent-driven-policy-management/README.md b/examples/agent-driven-policy-management/README.md new file mode 100644 index 000000000..7ff9a7780 --- /dev/null +++ b/examples/agent-driven-policy-management/README.md @@ -0,0 +1,87 @@ + + + +# Agent-Driven Policy Management Demo + +Run the full agent-driven policy loop end-to-end: + +1. A Codex agent inside an OpenShell sandbox tries to write a markdown file to + GitHub via the Contents API. +2. OpenShell denies the request with a structured `policy_denied` 403 because + the initial policy only allows read-only access to `api.github.com`. +3. The agent reads `/etc/openshell/skills/policy_advisor.md`, drafts the + narrowest rule needed, and submits it to `http://policy.local/v1/proposals`. +4. You approve the proposal from the host with one keystroke. +5. The sandbox hot-reloads the merged policy and the agent's retry succeeds. + +The whole loop usually finishes in under two minutes. + +## Prerequisites + +- An active OpenShell gateway (`openshell gateway start`). +- `gh auth login` (or a `GITHUB_TOKEN` env var with contents-write on a + scratch repo). +- `codex login` on the host. +- A scratch GitHub repository with at least one commit on the default branch. + If you don't have one yet: + + ```shell + gh repo create "$(gh api user --jq .login)/openshell-policy-demo" \ + --private --add-readme \ + --description "OpenShell policy advisor demo scratch repo" + ``` + +## Run it + +```shell +bash examples/agent-driven-policy-management/demo.sh +``` + +That's the whole thing. The demo resolves your GitHub handle from `gh`, picks +`openshell-policy-demo` as the repo, and writes one timestamped markdown file +under `openshell-policy-advisor-demo/` per run. + +## Overrides (all optional) + +| Env var | Default | +|---|---| +| `DEMO_GITHUB_OWNER` | `gh api user --jq .login` | +| `DEMO_GITHUB_REPO` | `openshell-policy-demo` | +| `DEMO_BRANCH` | `main` | +| `DEMO_RUN_ID` | timestamp | +| `DEMO_GITHUB_TOKEN` | falls back to `GITHUB_TOKEN`, `GH_TOKEN`, or `gh auth token` | +| `DEMO_KEEP_SANDBOX` | `0` (set `1` to inspect the sandbox after the demo) | +| `DEMO_APPROVAL_TIMEOUT_SECS` | `240` | +| `OPENSHELL_BIN` | `target/debug/openshell` if present, else `openshell` on `PATH` | + +## What the agent sees + +`policy.template.yaml` is the initial restrictive policy: a read-only L7 REST +rule for `api.github.com` plus the binary set Codex needs. The agent has to +ask for the additional `PUT /repos/.../contents/...` write itself — that's the +proposal you approve. + +## What gets approved (trust model) + +Every proposal lands in the gateway as a `PolicyChunk` — a structured object +with three parts, each with a different trust level: + +| Field | Source | Trust | +|---|---|---| +| `proposed_rule` (host, port, method, path, binary) | agent, schema-validated by the gateway | structured contract — *this* is what you're approving | +| `rationale` (free-form prose) | agent | hint only — a compromised agent can lie here | +| `validation_result` (prover output) | gateway-side prover | trust signal — but this surface is in progress (see [RFC 0001](../../rfc/0001-agent-driven-policy-management.md)) | + +The MVP today shows the structured rule plus the agent's rationale in +`openshell rule get` and the TUI inbox panel. The demo's `openshell rule +approve-all` auto-approves to keep the loop short — in a real session a +developer reviews the structured grant before pressing `a`. Prover-backed +validation badges, computed reachability deltas, and a richer "this is what +the rule actually permits" summary are the next phase. For now, **always +approve based on the structured rule, not the agent's rationale.** + +## Going further + +`e2e/policy-advisor/test.sh` runs the same loop deterministically without an +LLM (curl + the `policy.local` API directly). Use it to validate the proxy and +proposal pipeline when iterating on the sandbox or gateway code. diff --git a/examples/agent-driven-policy-management/agent-task.md b/examples/agent-driven-policy-management/agent-task.md new file mode 100644 index 000000000..9c7588181 --- /dev/null +++ b/examples/agent-driven-policy-management/agent-task.md @@ -0,0 +1,50 @@ + + + +# Agent Task + +You are running inside an OpenShell sandbox. Your job is to write one +markdown file to GitHub via the GitHub Contents API. + +## Target + +- Repository: `{{OWNER}}/{{REPO}}` +- Branch: `{{BRANCH}}` +- File path: `{{FILE_PATH}}` +- Run id: `{{RUN_ID}}` + +## What to do + +1. PUT the file with `curl` against + `https://api.github.com/repos/{{OWNER}}/{{REPO}}/contents/{{FILE_PATH}}`. + Authenticate with the `GITHUB_TOKEN` environment variable. Do not echo the + token. Body content (base64-encoded into the JSON `content` field): + + ``` + # OpenShell policy advisor demo + + Run id: {{RUN_ID}} + + Written from inside an OpenShell sandbox after a narrowly-scoped policy + proposal was approved by the developer. + ``` + +2. The first attempt will fail with HTTP 403 and a JSON body containing + `"error": "policy_denied"`. Read the body — it tells you which layer denied + you (`l7`/`rest`), which method/path was missing, and where to learn more. + +3. Read `/etc/openshell/skills/policy_advisor.md` and follow it. Submit the + narrowest possible proposal to `http://policy.local/v1/proposals` — exact + host, exact port, exact method, exact path, binary `/usr/bin/curl`. Do not + include query strings. Do not propose wildcard hosts. + +4. After submitting, retry the PUT every few seconds for up to 120 seconds. + The developer is approving from outside the sandbox; once approved, the + sandbox hot-reloads policy and the same PUT will succeed. + +5. Stop as soon as the PUT returns HTTP 200 or 201. Print a short summary + showing whether it succeeded, plus `content.path` and `content.html_url` + from the GitHub response. Do not print the full response body. + +If anything is unclear, prefer making a narrower proposal and asking for +approval again over widening the rule. diff --git a/examples/agent-driven-policy-management/demo.sh b/examples/agent-driven-policy-management/demo.sh new file mode 100755 index 000000000..6c3c60cfa --- /dev/null +++ b/examples/agent-driven-policy-management/demo.sh @@ -0,0 +1,442 @@ +#!/usr/bin/env bash + +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Agent-driven policy management demo. +# +# Runs the full loop: a Codex agent inside a sandbox hits an OpenShell policy +# block, reads the policy advisor skill, drafts a narrow rule via policy.local, +# the developer approves from the host, and the agent retries successfully. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +POLICY_TEMPLATE="${SCRIPT_DIR}/policy.template.yaml" +TASK_TEMPLATE="${SCRIPT_DIR}/agent-task.md" +SANDBOX_AGENT="${SCRIPT_DIR}/sandbox-agent.sh" + +OPENSHELL_BIN="${OPENSHELL_BIN:-}" +if [[ -z "$OPENSHELL_BIN" ]]; then + if [[ -x "${REPO_ROOT}/target/debug/openshell" ]]; then + OPENSHELL_BIN="${REPO_ROOT}/target/debug/openshell" + else + OPENSHELL_BIN="openshell" + fi +fi + +DEMO_GITHUB_OWNER="${DEMO_GITHUB_OWNER:-}" +DEMO_GITHUB_REPO="${DEMO_GITHUB_REPO:-openshell-policy-demo}" +DEMO_BRANCH="${DEMO_BRANCH:-main}" +DEMO_RUN_ID="${DEMO_RUN_ID:-$(date +%Y%m%d-%H%M%S)}" +DEMO_FILE_DIR="${DEMO_FILE_DIR:-openshell-policy-advisor-demo}" +DEMO_FILE_PATH="${DEMO_FILE_DIR}/${DEMO_RUN_ID}.md" +DEMO_SANDBOX_NAME="${DEMO_SANDBOX_NAME:-policy-demo-${DEMO_RUN_ID}}" +DEMO_CODEX_PROVIDER_NAME="${DEMO_CODEX_PROVIDER_NAME:-codex-policy-demo-${DEMO_RUN_ID}}" +DEMO_GITHUB_PROVIDER_NAME="${DEMO_GITHUB_PROVIDER_NAME:-github-policy-demo-${DEMO_RUN_ID}}" +DEMO_APPROVAL_TIMEOUT_SECS="${DEMO_APPROVAL_TIMEOUT_SECS:-240}" +DEMO_KEEP_SANDBOX="${DEMO_KEEP_SANDBOX:-0}" + +TMP_DIR="$(mktemp -d "${TMPDIR:-/tmp}/openshell-policy-demo.XXXXXX")" +PAYLOAD_DIR="${TMP_DIR}/payload" +POLICY_FILE="${TMP_DIR}/policy.yaml" +AGENT_LOG="${TMP_DIR}/agent.log" +mkdir -p "$PAYLOAD_DIR" + +# Use ANSI-C quoting so the variables hold the actual ESC byte rather than a +# literal backslash sequence. This lets `cat`, heredocs, and any non-printf +# emitter render colors correctly without per-call interpretation. +BOLD=$'\033[1m' +DIM=$'\033[2m' +CYAN=$'\033[36m' +GREEN=$'\033[32m' +RED=$'\033[31m' +YELLOW=$'\033[33m' +RESET=$'\033[0m' + +AGENT_PID="" + +step() { printf "\n${BOLD}${CYAN}==> %s${RESET}\n\n" "$1"; } +info() { printf " %b\n" "$*"; } + +# Redact host-side credentials from the agent log tail before printing on +# failure. Codex shouldn't echo the token, but a misbehaving tool call (e.g., +# `curl -v`) could leak it; sanitize before showing the log. +redact_log() { + local replacement='[redacted]' + sed \ + -e "s|${DEMO_GITHUB_TOKEN:-__no_github_token__}|${replacement}|g" \ + -e "s|${CODEX_AUTH_ACCESS_TOKEN:-__no_codex_access__}|${replacement}|g" \ + -e "s|${CODEX_AUTH_REFRESH_TOKEN:-__no_codex_refresh__}|${replacement}|g" \ + -e "s|${CODEX_AUTH_ACCOUNT_ID:-__no_codex_account__}|${replacement}|g" +} + +fail() { + printf "\n${RED}error:${RESET} %s\n" "$*" >&2 + if [[ -f "$AGENT_LOG" ]]; then + printf "\n${YELLOW}Agent log tail:${RESET}\n" >&2 + tail -n 80 "$AGENT_LOG" | redact_log | sed 's/^/ /' >&2 || true + fi + exit 1 +} + +cleanup() { + local status=$? + + if [[ -n "$AGENT_PID" ]] && kill -0 "$AGENT_PID" >/dev/null 2>&1; then + kill "$AGENT_PID" >/dev/null 2>&1 || true + wait "$AGENT_PID" 2>/dev/null || true + fi + + if [[ "$DEMO_KEEP_SANDBOX" != "1" ]]; then + "$OPENSHELL_BIN" sandbox delete "$DEMO_SANDBOX_NAME" >/dev/null 2>&1 || true + else + printf "\n${YELLOW}Keeping sandbox because DEMO_KEEP_SANDBOX=1: %s${RESET}\n" "$DEMO_SANDBOX_NAME" + fi + "$OPENSHELL_BIN" provider delete "$DEMO_CODEX_PROVIDER_NAME" >/dev/null 2>&1 || true + "$OPENSHELL_BIN" provider delete "$DEMO_GITHUB_PROVIDER_NAME" >/dev/null 2>&1 || true + + # Restore the agent_policy_proposals_enabled setting to what it was + # before this run. + if [[ -n "${PRIOR_PROPOSALS_FLAG:-}" ]]; then + if [[ "$PRIOR_PROPOSALS_FLAG" == "(unset)" ]]; then + "$OPENSHELL_BIN" settings delete --global --key agent_policy_proposals_enabled \ + >/dev/null 2>&1 || true + else + "$OPENSHELL_BIN" settings set --global --key agent_policy_proposals_enabled \ + --value "$PRIOR_PROPOSALS_FLAG" >/dev/null 2>&1 || true + fi + fi + + if [[ $status -eq 0 ]]; then + rm -rf "$TMP_DIR" + else + printf "\n${YELLOW}Temporary files kept at: %s${RESET}\n" "$TMP_DIR" + fi +} +trap cleanup EXIT + +require_command() { + command -v "$1" >/dev/null 2>&1 || fail "missing required command: $1" +} + +resolve_github_owner() { + if [[ -n "$DEMO_GITHUB_OWNER" ]]; then + return + fi + if command -v gh >/dev/null 2>&1; then + DEMO_GITHUB_OWNER="$(gh api user --jq .login 2>/dev/null || true)" + fi + [[ -n "$DEMO_GITHUB_OWNER" ]] || fail "set DEMO_GITHUB_OWNER, or sign in with: gh auth login" +} + +resolve_github_token() { + DEMO_GITHUB_TOKEN="${DEMO_GITHUB_TOKEN:-${GITHUB_TOKEN:-${GH_TOKEN:-}}}" + if [[ -z "$DEMO_GITHUB_TOKEN" ]] && command -v gh >/dev/null 2>&1; then + DEMO_GITHUB_TOKEN="$(gh auth token 2>/dev/null || true)" + fi + [[ -n "$DEMO_GITHUB_TOKEN" ]] || fail "set DEMO_GITHUB_TOKEN, GITHUB_TOKEN, GH_TOKEN, or sign in with: gh auth login" + export DEMO_GITHUB_TOKEN +} + +resolve_codex_auth() { + [[ -f "${HOME}/.codex/auth.json" ]] || fail "missing local Codex sign-in; run: codex login" + export CODEX_AUTH_ACCESS_TOKEN CODEX_AUTH_REFRESH_TOKEN CODEX_AUTH_ACCOUNT_ID + CODEX_AUTH_ACCESS_TOKEN="$(jq -r '.tokens.access_token // empty' "${HOME}/.codex/auth.json")" + CODEX_AUTH_REFRESH_TOKEN="$(jq -r '.tokens.refresh_token // empty' "${HOME}/.codex/auth.json")" + CODEX_AUTH_ACCOUNT_ID="$(jq -r '.tokens.account_id // empty' "${HOME}/.codex/auth.json")" + [[ -n "$CODEX_AUTH_ACCESS_TOKEN" ]] || fail "Codex sign-in is missing an access token; run: codex login" + [[ -n "$CODEX_AUTH_REFRESH_TOKEN" ]] || fail "Codex sign-in is missing a refresh token; run: codex login" + [[ -n "$CODEX_AUTH_ACCOUNT_ID" ]] || fail "Codex sign-in is missing an account id; run: codex login" +} + +validate_env() { + require_command curl + require_command jq + require_command "$OPENSHELL_BIN" + + [[ -f "$POLICY_TEMPLATE" ]] || fail "missing policy template: $POLICY_TEMPLATE" + [[ -f "$TASK_TEMPLATE" ]] || fail "missing agent task template: $TASK_TEMPLATE" + [[ -f "$SANDBOX_AGENT" ]] || fail "missing sandbox agent script: $SANDBOX_AGENT" + + [[ "$DEMO_GITHUB_REPO" =~ ^[A-Za-z0-9_.-]+$ ]] || fail "DEMO_GITHUB_REPO contains unsupported characters" + [[ "$DEMO_BRANCH" =~ ^[A-Za-z0-9._/-]+$ ]] || fail "DEMO_BRANCH contains unsupported characters" + [[ "$DEMO_RUN_ID" =~ ^[A-Za-z0-9_.-]+$ ]] || fail "DEMO_RUN_ID contains unsupported characters" + # DEMO_FILE_DIR is interpolated through `sed` with `|` as the delimiter + # when rendering the agent task; reject any character that would break + # the substitution or escape into a shell context. + [[ "$DEMO_FILE_DIR" =~ ^[A-Za-z0-9._/-]+$ ]] || fail "DEMO_FILE_DIR contains unsupported characters" + + resolve_github_owner + [[ "$DEMO_GITHUB_OWNER" =~ ^[A-Za-z0-9_.-]+$ ]] || fail "DEMO_GITHUB_OWNER contains unsupported characters" + + resolve_github_token + resolve_codex_auth +} + +github_api_status() { + local url="$1" body="$2" + curl -sS -o "$body" -w "%{http_code}" \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${DEMO_GITHUB_TOKEN}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + "$url" +} + +check_gateway() { + local raw version + # `openshell status` colorizes labels with ANSI even when piped, so strip + # escapes before parsing. Use NO_COLOR as a belt-and-suspenders hint for + # libraries that respect it. + raw="$(NO_COLOR=1 "$OPENSHELL_BIN" status 2>/dev/null \ + | sed 's/\x1b\[[0-9;]*m//g')" + version="$(awk -F': *' '/Version:/ { print $2; exit }' <<<"$raw")" + [[ -n "$version" ]] \ + || fail "active OpenShell gateway is not reachable; start one with: openshell gateway start" + info "gateway: connected · ${version}" +} + +show_run_summary() { + step "Run summary" + printf " %-9s %s/%s\n" "repo:" "$DEMO_GITHUB_OWNER" "$DEMO_GITHUB_REPO" + printf " %-9s %s\n" "branch:" "$DEMO_BRANCH" + printf " %-9s %s\n" "target:" "$DEMO_FILE_PATH" + printf " %-9s %s\n" "sandbox:" "$DEMO_SANDBOX_NAME" +} + +check_github_access() { + local body status branch sha + body="${TMP_DIR}/github-repo.json" + status="$(github_api_status "https://api.github.com/repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}" "$body")" + if [[ "$status" != "200" ]]; then + info "${RED}Repo not found:${RESET} ${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}" + info "Create a private scratch repo first, then re-run:" + info " ${DIM}gh repo create ${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO} --private --add-readme \\${RESET}" + info " ${DIM} --description 'OpenShell policy advisor demo scratch repo'${RESET}" + fail "GitHub returned HTTP $status for ${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}" + fi + if jq -e '.permissions.push == false and .permissions.admin == false and .permissions.maintain == false' "$body" >/dev/null; then + fail "GitHub token does not have write access to ${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}" + fi + + branch="$(jq -rn --arg v "$DEMO_BRANCH" '$v|@uri')" + body="${TMP_DIR}/github-branch.json" + status="$(github_api_status "https://api.github.com/repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}/branches/${branch}" "$body")" + [[ "$status" == "200" ]] || fail "GitHub returned HTTP $status for branch ${DEMO_BRANCH}" + sha="$(jq -r '.commit.sha[0:7]' "$body")" + info "github: ${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO} @ ${DEMO_BRANCH} (${sha})" + + body="${TMP_DIR}/github-target.json" + status="$(github_api_status "https://api.github.com/repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}/contents/${DEMO_FILE_PATH}?ref=${branch}" "$body")" + if [[ "$status" == "200" ]]; then + fail "demo output file already exists: ${DEMO_FILE_PATH}; choose a new DEMO_RUN_ID" + fi + [[ "$status" == "404" ]] || fail "GitHub returned HTTP $status while checking output path" +} + +render_payload() { + sed \ + -e "s|{{OWNER}}|${DEMO_GITHUB_OWNER}|g" \ + -e "s|{{REPO}}|${DEMO_GITHUB_REPO}|g" \ + -e "s|{{BRANCH}}|${DEMO_BRANCH}|g" \ + -e "s|{{FILE_PATH}}|${DEMO_FILE_PATH}|g" \ + -e "s|{{RUN_ID}}|${DEMO_RUN_ID}|g" \ + "$TASK_TEMPLATE" > "${PAYLOAD_DIR}/agent-task.md" + cp "$SANDBOX_AGENT" "${PAYLOAD_DIR}/sandbox-agent.sh" + cp "$POLICY_TEMPLATE" "$POLICY_FILE" +} + +create_providers() { + "$OPENSHELL_BIN" provider delete "$DEMO_CODEX_PROVIDER_NAME" >/dev/null 2>&1 || true + "$OPENSHELL_BIN" provider delete "$DEMO_GITHUB_PROVIDER_NAME" >/dev/null 2>&1 || true + + "$OPENSHELL_BIN" provider create \ + --name "$DEMO_CODEX_PROVIDER_NAME" \ + --type generic \ + --credential CODEX_AUTH_ACCESS_TOKEN \ + --credential CODEX_AUTH_REFRESH_TOKEN \ + --credential CODEX_AUTH_ACCOUNT_ID >/dev/null + + "$OPENSHELL_BIN" provider create \ + --name "$DEMO_GITHUB_PROVIDER_NAME" \ + --type generic \ + --credential DEMO_GITHUB_TOKEN >/dev/null + + info "providers created (codex, github) — credentials injected as env vars only" +} + +start_agent_sandbox() { + step "Launching sandbox; agent will hit a policy block and draft a proposal" + "$OPENSHELL_BIN" sandbox delete "$DEMO_SANDBOX_NAME" >/dev/null 2>&1 || true + + info "initial policy: read-only access to api.github.com (no PUT)" + info "agent task: PUT /repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}/contents/${DEMO_FILE_PATH}" + info "live log: ${AGENT_LOG}" + + # `--upload :/sandbox` preserves the source directory basename + # (matches `scp -r`/`cp -r`, see PRs #952 / #1028), so `${PAYLOAD_DIR}` + # (basename `payload`) lands at `/sandbox/payload/...`. `--upload` accepts + # a single value, so we ship both files in one directory. + ( + "$OPENSHELL_BIN" sandbox create \ + --name "$DEMO_SANDBOX_NAME" \ + --from base \ + --provider "$DEMO_CODEX_PROVIDER_NAME" \ + --provider "$DEMO_GITHUB_PROVIDER_NAME" \ + --policy "$POLICY_FILE" \ + --upload "${PAYLOAD_DIR}:/sandbox" \ + --no-git-ignore \ + --no-auto-providers \ + --no-tty \ + -- bash /sandbox/payload/sandbox-agent.sh + ) >"$AGENT_LOG" 2>&1 & + AGENT_PID="$!" +} + +# Strip the rule_get output down to the lines a developer needs to make an +# informed approve/reject decision: rationale, binary, endpoint. Filters the +# noisy fields (UUID, agent-generated rule_name, hardcoded confidence, +# duplicate Binaries) until `openshell rule get` learns to print L7 +# method/path itself (tracked separately). +# +# `openshell rule get` colorizes labels with ANSI escapes; strip them before +# parsing so the field-name match works in piped contexts. +summarize_pending() { + local pending="$1" + sed 's/\x1b\[[0-9;]*m//g' "$pending" \ + | awk ' + /Rationale:/ { sub(/^[[:space:]]*/, ""); print " " $0; next } + /Binary:/ { sub(/^[[:space:]]*/, ""); print " " $0; next } + /Endpoints:/ { sub(/^[[:space:]]*/, ""); print " " $0; next } + ' +} + +narrate_sandbox_workflow() { + info "Inside the sandbox right now:" + info "" + info " ${BOLD}[1]${RESET} agent: ${DIM}curl -X PUT https://api.github.com/repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}/contents/...${RESET}" + info " ${BOLD}[2]${RESET} L7 proxy denies the write and returns a structured 403 the" + info " agent can parse and act on:" + cat </dev/null 2>&1; then + wait "$AGENT_PID" || true + AGENT_PID="" + fail "agent exited before a pending proposal appeared" + fi + + if "$OPENSHELL_BIN" rule get "$DEMO_SANDBOX_NAME" --status pending >"$pending" 2>/dev/null \ + && grep -q "Chunk:" "$pending" && grep -q "pending" "$pending"; then + info "" + info "${GREEN}proposal received:${RESET}" + summarize_pending "$pending" + + step "Approving and waiting for the agent to retry" + "$OPENSHELL_BIN" rule approve-all "$DEMO_SANDBOX_NAME" \ + | awk '/approved/ { print " " $0 }' + return + fi + + now="$(date +%s)" + if (( now - start >= DEMO_APPROVAL_TIMEOUT_SECS )); then + fail "timed out waiting for the agent to submit a policy proposal" + fi + sleep 2 + done +} + +wait_for_agent() { + if ! wait "$AGENT_PID"; then + AGENT_PID="" + fail "agent run failed" + fi + AGENT_PID="" + info "agent retried after policy hot-reload — write succeeded" +} + +verify_github_write() { + step "Verifying GitHub write" + local body status branch + branch="$(jq -rn --arg v "$DEMO_BRANCH" '$v|@uri')" + body="${TMP_DIR}/github-result.json" + status="$(github_api_status "https://api.github.com/repos/${DEMO_GITHUB_OWNER}/${DEMO_GITHUB_REPO}/contents/${DEMO_FILE_PATH}?ref=${branch}" "$body")" + [[ "$status" == "200" ]] || fail "expected demo file to exist after agent run; GitHub returned HTTP $status" + jq -r '" file: \(.path)", " url: \(.html_url)"' "$body" +} + +# Print the OCSF JSONL trace, filtered to the three events that *are* the +# demo's story: the L7 PUT deny, the policy hot-reload, and the L7 PUT allow. +# The native OCSF shorthand is informative and consistent with the rest of +# OpenShell's logging — keep it as-is rather than re-formatting. +show_logs() { + step "Policy decision trace (OCSF)" + "$OPENSHELL_BIN" logs "$DEMO_SANDBOX_NAME" --since 10m -n 200 2>&1 \ + | grep -E 'HTTP:PUT.*(DENIED|ALLOWED)|CONFIG:LOADED.*Policy reloaded' \ + | sed 's/^/ /' || true +} + +enable_agent_proposals() { + # The agent-driven proposal surface (skill, policy.local routes, deny + # next_steps) is opt-in. Snapshot the prior global value so cleanup() + # can restore it; the sentinel "(unset)" round-trips through `settings + # delete` rather than a value write. + local prior + prior="$("$OPENSHELL_BIN" settings get --global --json 2>/dev/null \ + | grep -o '"agent_policy_proposals_enabled"[^,}]*' \ + | grep -o 'true\|false' | head -1)" + PRIOR_PROPOSALS_FLAG="${prior:-(unset)}" + "$OPENSHELL_BIN" settings set --global \ + --key agent_policy_proposals_enabled --value true >/dev/null \ + || fail "could not enable agent_policy_proposals_enabled globally" +} + +main() { + validate_env + + step "Preflight" + check_gateway + check_github_access + render_payload + create_providers + enable_agent_proposals + + show_run_summary + + start_agent_sandbox + approve_when_pending + wait_for_agent + verify_github_write + show_logs + + printf "\n${BOLD}${GREEN}✓ Demo complete.${RESET}\n" +} + +main "$@" diff --git a/examples/agent-driven-policy-management/policy.template.yaml b/examples/agent-driven-policy-management/policy.template.yaml new file mode 100644 index 000000000..e920277b5 --- /dev/null +++ b/examples/agent-driven-policy-management/policy.template.yaml @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Initial sandbox policy for the agent-driven policy demo. +# +# The agent inside the sandbox can: +# - reach Codex's model and auth endpoints (codex) +# - clone Codex plugin repos read-only (codex_plugins) +# - read api.github.com via curl (github_api_readonly) +# +# The agent CANNOT write to GitHub yet. That's the proposal it has to draft +# and ask the developer to approve. + +version: 1 + +filesystem_policy: + include_workdir: true + read_only: [/usr, /lib, /proc, /dev/urandom, /app, /etc, /var/log] + read_write: [/sandbox, /tmp, /dev/null] + +landlock: + compatibility: best_effort + +process: + run_as_user: sandbox + run_as_group: sandbox + +network_policies: + codex: + name: codex + endpoints: + - { host: api.openai.com, port: 443, protocol: rest, enforcement: enforce, access: full } + - { host: auth.openai.com, port: 443, protocol: rest, enforcement: enforce, access: full } + - { host: chatgpt.com, port: 443, protocol: rest, enforcement: enforce, access: full } + - { host: ab.chatgpt.com, port: 443, protocol: rest, enforcement: enforce, access: full } + binaries: + - { path: /usr/bin/codex } + - { path: /usr/bin/node } + - { path: "/usr/lib/node_modules/@openai/**" } + + codex_plugins: + name: codex-plugins + endpoints: + - host: github.com + port: 443 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/openai/plugins.git/info/refs*" + - allow: + method: POST + path: "/openai/plugins.git/git-upload-pack" + binaries: + - { path: /usr/bin/git } + - { path: /usr/lib/git-core/git-remote-http } + - { path: "/usr/lib/node_modules/@openai/**" } + + github_api_readonly: + name: github-api-readonly + endpoints: + - host: api.github.com + port: 443 + protocol: rest + enforcement: enforce + access: read-only + binaries: + - { path: /usr/bin/curl } diff --git a/examples/agent-driven-policy-management/sandbox-agent.sh b/examples/agent-driven-policy-management/sandbox-agent.sh new file mode 100755 index 000000000..052535c35 --- /dev/null +++ b/examples/agent-driven-policy-management/sandbox-agent.sh @@ -0,0 +1,82 @@ +#!/usr/bin/env bash + +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Runs inside the sandbox. Bootstraps Codex from the credentials injected by +# the openshell provider, then drives the agent-task prompt to completion. + +set -euo pipefail + +require_env() { + local name="$1" + [[ -n "${!name:-}" ]] || { echo "missing required env: $name" >&2; exit 1; } +} + +require_env CODEX_AUTH_ACCESS_TOKEN +require_env CODEX_AUTH_REFRESH_TOKEN +require_env CODEX_AUTH_ACCOUNT_ID +require_env DEMO_GITHUB_TOKEN + +# Make the GitHub token visible to Codex's tool loop under the conventional name. +export GITHUB_TOKEN="$DEMO_GITHUB_TOKEN" + +# Codex looks for ~/.codex/auth.json. The OpenShell provider only injects env +# vars, so we materialize the file Codex expects from those credentials. +mkdir -p "$HOME/.codex" +node - <<'NODE' +const fs = require("fs"); +const path = `${process.env.HOME}/.codex/auth.json`; +const b64u = (obj) => Buffer.from(JSON.stringify(obj)).toString("base64url"); +const now = Math.floor(Date.now() / 1000); +// Placeholder id_token is required by Codex but never validated against an +// upstream JWKS in this flow. +const idToken = [ + b64u({ alg: "none", typ: "JWT" }), + b64u({ + iss: "https://auth.openai.com", + aud: "codex", + sub: "openshell-policy-demo", + email: "demo@openshell.local", + iat: now, + exp: now + 3600, + }), + "placeholder", +].join("."); +fs.writeFileSync(path, JSON.stringify({ + auth_mode: "chatgpt", + OPENAI_API_KEY: null, + tokens: { + id_token: idToken, + access_token: process.env.CODEX_AUTH_ACCESS_TOKEN, + refresh_token: process.env.CODEX_AUTH_REFRESH_TOKEN, + account_id: process.env.CODEX_AUTH_ACCOUNT_ID, + }, + last_refresh: new Date().toISOString(), +}, null, 2)); +NODE +chmod 600 "$HOME/.codex/auth.json" + +# Codex needs a writable cwd; /sandbox is uploaded read-only-ish, so work in /tmp. +WORK="$(mktemp -d)" +cd "$WORK" + +# Disable Codex's internal bubblewrap sandbox — OpenShell is already the +# security boundary, and bwrap can't create nested user namespaces inside the +# OpenShell sandbox container without extra capabilities. The "danger" framing +# is from Codex's perspective on a developer host; here the OpenShell network +# policy and filesystem constraints are doing the actual containment. +# +# Cap Codex's reasoning effort at the lower end. The demo task is mechanical +# (one HTTP request, parse a structured 403, post a JSON proposal, retry); the +# default high-effort reasoning roughly doubles the demo's wall time without +# improving outcomes. Override with DEMO_CODEX_REASONING if you want to +# compare runs. +DEMO_CODEX_REASONING="${DEMO_CODEX_REASONING:-low}" + +exec codex exec \ + --skip-git-repo-check \ + --sandbox danger-full-access \ + --ephemeral \ + -c "model_reasoning_effort=\"${DEMO_CODEX_REASONING}\"" \ + "$(cat /sandbox/payload/agent-task.md)" diff --git a/rfc/0001-agent-driven-policy-management.md b/rfc/0001-agent-driven-policy-management.md new file mode 100644 index 000000000..6816d1331 --- /dev/null +++ b/rfc/0001-agent-driven-policy-management.md @@ -0,0 +1,723 @@ +--- +authors: + - "@alwatson" +state: draft +links: + - https://github.com/NVIDIA/OpenShell/issues/1062 + - https://github.com/NVIDIA/OpenShell/blob/main/architecture/policy-advisor.md +--- + +# RFC 0001 - Agent-Driven Policy Management + + + +## Summary + +Evolve OpenShell's existing Policy Advisor into an agent-driven policy management system that lets agents inspect current sandbox policy, draft narrow policy changes, submit them for review, and apply approved updates without restarting the sandbox. The safety model stays the same: sandbox-side analysis, gateway-side validation and persistence, and explicit approval boundaries. The main change is the authoring and review experience: every sandbox should expose local policy guidance and APIs, and every developer surface should expose a responsive inbox for reviewing proposals. + +## Motivation + +OpenShell already has the core of a dynamic policy editing experience: + +- The sandbox proxy emits deny events. +- The sandbox-side `DenialAggregator` and mechanistic mapper convert those into draft `PolicyChunk` proposals. +- The gateway persists proposals and merges approved rules into the active policy. +- The TUI and CLI already provide review and approval flows. +- Running sandboxes already hot-reload dynamic policy updates. + +That is a strong foundation, but the current experience is still fundamentally operator-driven and network-centric. It is excellent for "observe a deny, approve a generated endpoint rule" but incomplete for the broader product promise: an agent should be able to understand what is blocked, discover what policy language is available, generate the narrowest valid policy change, and submit it to the developer with enough rationale and verification signal that approval is fast and trustworthy. + +This matters because: + +- developers should not need to learn policy syntax before becoming productive +- agents have the most task context and can often draft narrower changes than humans +- approvals should feel like reviewing a validated outcome, not guessing about a YAML diff +- the inbox experience must be fast and clear across TUI, CLI, and SDK surfaces +- organizations need a path from human approval to trusted bounded automation without losing auditability or least privilege + +This RFC proposes the next layer: make policy adaptation an intentional, agent-native workflow instead of a reactive operator convenience. + +## MVP implementation note + +The first implementation is tracked in [#1062](https://github.com/NVIDIA/OpenShell/issues/1062). It intentionally starts with the smallest agent-driven loop that can validate the product experience: + +- structured L7 REST deny responses for agent-readable failures +- a sandbox-local `policy.local` HTTP API backed by existing files, logs, and per-sandbox mTLS gateway calls +- static sandbox-local agent guidance in `/etc/openshell/skills/policy_advisor.md` +- agent-authored proposal provenance, validation status, and rejection guidance in the existing draft policy flow +- TUI/CLI review for a single sandbox, with polling as the MVP refresh path + +The MVP deliberately defers the supervisor Unix-socket API, server-streaming multi-sandbox inbox, Slack/web adapters, org ceilings, trusted auto-apply, and in-process prover optimization. Those remain aligned with the RFC direction, but they are not required to prove the initial loop. + +The entire MVP surface is gated behind the `agent_policy_proposals_enabled` setting (see `crates/openshell-core/src/settings.rs`), default false. When disabled, the supervisor does not install the skill, the `policy.local` routes return `404 feature_disabled`, and L7 deny bodies omit the `next_steps` array. The flag is independent of the per-proposal developer approval gate; both apply when the feature is on. Treat this as a soft launch: enable per-sandbox or globally once the loop is validated, and leave it off in environments where agent-authored proposals should not be available at all. + +## Non-goals + +- Allowing an in-sandbox agent to self-approve or unilaterally apply its own policy changes. +- Moving proposal generation into the gateway. Sandbox-side analysis remains the architectural default. +- Solving every policy domain in the first release. Network policy is the initial scope because it is the only hot-reloadable policy domain in the current architecture; filesystem and process policy can follow later through a different lifecycle model. +- Replacing the existing mechanistic mapper. It remains the deterministic baseline and safety net. +- Making Rego authoring a direct end-user requirement. The system should expose policy semantics to agents and advanced users, not require hand-authored policy for common workflows. + +## Proposal + +### Enforcement model + +This RFC is not proposing a generic "policy update" system without specifying what gets enforced. The intended model is layered: + +- **L4 remains the universal baseline** + Every outbound connection is gated by host, port, and binary identity. +- **L7 is the preferred least-privilege model for supported application protocols** + Today this primarily means `protocol: rest` with per-method and per-path rules for HTTP APIs. +- **Protocol-aware or tool-aware policy layers may sit above L7 where useful** + MCP is a strong candidate for a future higher-level enforcement surface, but it should be modeled explicitly rather than implied. + +For the initial implementation of this RFC, dynamic policy management should be grounded in the enforcement model OpenShell already has in the codebase today: + +- L4 network policy for all outbound traffic +- L7 REST enforcement for HTTP APIs where `protocol: rest` is configured +- policy prover checks that can distinguish L4-only access from L7-enforced access + +This matters because "allow GitHub" is not a single thing: + +- `github.com:443` used by `git` may require L4-only allowance depending on the workflow and protocol behavior +- `api.github.com:443` used by `gh`, `curl`, or an SDK is often a great fit for L7 REST controls +- the best least-privilege design often splits those paths rather than treating GitHub as one broad capability + +This RFC is therefore network-first by design, not because other policy domains are unimportant, but because the current OpenShell architecture only supports live mutation of `network_policies`. Filesystem, Landlock, and process settings are applied at sandbox startup and are currently immutable for the lifetime of a sandbox. + +### Product direction + +Every OpenShell sandbox should be able to host an agent-capable policy workflow with four core affordances: + +1. A local capability description that teaches an agent how to inspect current policy state, understand the available policy language, and submit a proposal for review. +2. A sandbox-local or supervisor-adjacent API for reading effective policy, recent denials, sandbox-local activity logs, and proposal state. +3. A gateway-managed developer inbox for reviewing, editing, approving, rejecting, and auditing proposals in real time. +4. A validation pipeline that checks proposed policy changes before they are applied. + +The product bar is not just correctness. The interaction model itself must be good: + +- proposals should appear quickly after a deny or agent request +- review surfaces should be understandable without policy expertise +- the same proposal should look coherent in the TUI, CLI, and SDKs +- approval should take one action when the system has high confidence +- high-volume exploratory agent workflows should not drown the user in repetitive prompts + +### UX requirements and latency targets + +For the developer inbox experience: + +- OpenShell **must provide a push/subscription path** for proposal and decision updates to the TUI, CLI, and SDKs. +- Polling may exist as a fallback, but polling-only delivery is not sufficient for the intended UX. + +Target UX metrics: + +- **Proposal appearance latency** + From the time the gateway accepts a proposal or actionable deny-derived recommendation to the time it appears in a connected inbox client: + - target `p50 <= 2s` + - target `p95 <= 5s` +- **Decision propagation latency** + From approval, rejection, edit, or auto-apply at the gateway to the time all connected inbox clients reflect the new state: + - target `p50 <= 1s` + - target `p95 <= 3s` +- **Activation feedback latency** + From gateway receipt of sandbox policy status (`loaded` or `failed`) to visible client state update: + - target `p50 <= 1s` + - target `p95 <= 3s` + +If sandbox policy activation takes longer than these targets, the inbox should still update immediately with an intermediate state such as `pending_activation` rather than leaving the user uncertain. + +The desired user experience: + +1. An agent encounters a deny with a structured explanation from the sandbox supervisor. +2. The agent uses a local policy-management skill to inspect the current effective policy, denial context, and relevant policy primitives. +3. The agent produces a minimal proposed change and submits it through a stable proposal API. +4. The developer sees the proposal in the TUI, CLI, or SDK, reviews its rationale and validation results, and approves or rejects it. +5. OpenShell applies the approved change as a hot-reloaded policy update and preserves a durable audit trail. + +### What exists today + +The RFC explicitly builds on the current codebase rather than replacing it. + +Current implementation points: + +- `crates/openshell-sandbox/src/denial_aggregator.rs` + Sandbox-side aggregation of deny events. +- `crates/openshell-sandbox/src/mechanistic_mapper.rs` + Deterministic generation of draft `PolicyChunk` recommendations, including partial L7 support. +- `crates/openshell-server/src/grpc/policy.rs` + Persistence, approval, merge, rejection, edit, undo, and policy revision handling. +- `crates/openshell-tui/src/ui/sandbox_draft.rs` + TUI review and approval surface for network rules. +- `crates/openshell-cli/src/run.rs` + `openshell rule get|approve|reject|approve-all|history`. +- `architecture/policy-advisor.md` + Current sandbox-side recommendation design. + +Important capabilities that already exist and should be preserved: + +- sandbox-side proposal generation +- hot-reloadable policy updates +- proposal editing and undo RPCs in `proto/openshell.proto` +- durable draft chunk storage and approval history +- a deterministic, mechanistic proposal path that does not require an LLM +- a real distinction between L4-only and L7 REST enforcement + +### What is missing + +The current implementation lacks several parts required for the intended developer experience: + +- A standard in-sandbox skill or instruction bundle for local agents. +- A first-class proposal API that agents can use intentionally, not only through deny-triggered analysis. +- Rich proposal context beyond host/port/binary, especially for developer intent, repository/task context, and write operations. +- Validation outputs that explain what a proposal would permit before approval. +- A generalized "developer inbox" model that can power the TUI, CLI, SDK, and future Slack/web surfaces from the same backend abstraction. +- A clear separation between: + - observed deny events, + - agent-authored policy changes, + - validated approval-ready proposals, + - applied policy revisions. +- A trust model for non-human approvers, where a trusted external agent may apply policy changes automatically when those changes remain within an organization-defined maximum policy envelope. +- Explicit proposal semantics for whether a recommendation is: + - L4-only + - L7 REST + - a conversion from L4 to L7 + - a future protocol-aware policy type such as MCP-aware controls + +### Architecture + +```mermaid +flowchart LR + AGENT["Agent in sandbox"] --> SKILL["Local policy skill / instructions"] + SKILL --> API["Supervisor policy API"] + API --> STATE["Effective policy + deny history + schema help"] + AGENT --> PROPOSE["Submit proposal"] + PROPOSE --> GW["Gateway proposal service"] + DENY["Proxy deny + L7 deny"] --> AGG["Sandbox aggregator + mechanistic mapper"] + AGG --> GW + GW --> VALIDATE["Validation + simulation + prover"] + VALIDATE --> INBOX["Developer inbox"] + INBOX -->|approve| MERGE["Policy merge + revision"] + INBOX -->|reject/edit| GW + MERGE --> POLL["Sandbox policy poll / push"] + POLL --> API +``` + +The important architectural principle is that the current Policy Advisor pipeline becomes one producer of proposals, not the only producer. Agent-authored proposals and mechanistic proposals should land in the same gateway inbox and go through the same validation and approval machinery. + +The second architectural principle is that approval is policy-driven. Human approval is the default mode, but the same machinery must also support a trusted external control plane deciding that a proposal is safe to auto-apply because it fits under higher-level organizational constraints. + +The end-to-end interaction should look like this: + +```mermaid +sequenceDiagram + participant A as Agent in Sandbox + participant S as Local Skill + participant U as policy.local API + participant P as Local Prover Aid + participant G as Gateway Proposal Service + participant X as External Validator / Trusted Approver + participant I as Developer Inbox + + A->>S: read policy skill / instructions + A->>U: get-effective-policy, get-recent-denials + U-->>A: effective policy + deny context + A->>P: evaluate candidate policy change + P-->>A: advisory narrowing feedback + A->>U: submit proposal + U->>G: forward proposal + provenance + G->>X: validate against policy + org ceiling + X-->>G: human_review_required or eligible_for_automatic_apply + G-->>I: publish proposal / status update + alt human review + I->>G: approve / reject / reject_with_guidance + G-->>U: proposal result + U-->>A: optional structured feedback + else trusted external auto-apply + X->>G: apply approved change + G-->>I: auto-applied status + audit trail + end +``` + +### Core design concepts + +#### 1. Policy workspace inside every sandbox + +Every default sandbox should expose a stable local "policy workspace" that agents can rely on. This is not necessarily a mutable file checked into the user's repo; it is a capability surface available from inside the sandbox. + +The workspace should provide: + +- Effective sandbox policy, already resolved to the currently active version. +- Policy schema and examples relevant to the sandbox's supported policy features. +- A description of approval workflow semantics. +- Recent denials, sandbox-local activity logs, and related proposal history. +- Guidance for generating the narrowest change possible. + +The first implementation can be a static `/etc/openshell/skills/policy_advisor.md` plus a sandbox-local `policy.local` HTTP API. The long-term contract is the API; the skill is the ergonomic on-ramp. MCP can wrap this API later for agents that benefit from tool discovery, but it should not be the first load-bearing protocol or a separate implementation path. + +The sandbox-facing surface must also have an explicit information boundary: + +Visible to the in-sandbox agent: + +- the current effective sandbox policy +- deny reasons, sandbox-local logs, and local proposal history for that sandbox +- supported policy primitives and examples +- coarse proposal outcomes such as `needs_human_review`, `eligible_for_external_review`, or `rejected_for_safety` + +Not exposed directly to the in-sandbox agent: + +- the full org ceiling policy document +- trusted external approver policies or scoring heuristics +- hidden approval thresholds +- policy state for other sandboxes, users, or org workflows + +Ceiling evaluation, trust decisions, and automatic-apply eligibility should be computed outside the sandbox. + +OpenShell should also make prover-backed analysis available to agents as a proposal-authoring aid. This lets an agent test whether a candidate policy is broader than necessary and iterate toward a narrower proposal before submission. Inside the sandbox, prover output is advisory and used for drafting only. Approval and automatic-apply decisions must still rely on external validation in the gateway or trusted approver path. + +#### 2. Intentional proposal submission API + +Add a supervisor-facing proposal API so an agent can submit policy updates even before or after a deny aggregation cycle. + +Example agent actions: + +- "Show me the current policy affecting `git`." +- "Explain why this GitHub push was denied." +- "Draft the minimal rule to allow writes to `github.com` and `api.github.com` for `git` only." +- "Submit this proposal for human review." + +This proposal path should support two modes: + +- `draft_from_observation` + Builds on real deny history. +- `draft_from_agent_intent` + Allows an agent to proactively request a change based on planned work. + +Both should land in the same inbox with provenance captured. + +When multiple producers submit effectively the same proposal, the gateway should apply a deterministic merge policy: + +- mechanistic proposals establish the baseline proposal record +- richer agent-authored proposals for the same sandbox + endpoint + binary may upgrade the existing record's rationale, context, and proposed L7 refinement +- fallback observation updates may continue to bump hit counts and timestamps without discarding richer metadata + +The important product requirement is that a richer agent proposal must not be silently lost behind an earlier mechanistic proposal. + +#### 3. Proposal model evolution + +Extend the existing `PolicyChunk`/draft-chunk model into a more expressive proposal object while preserving backward compatibility for current rule review commands. + +Additional fields should include: + +- Proposal source: mechanistic, agent-authored, or hybrid. +- Requested capability summary in plain language. +- Validation status and findings summary. +- Diff against current effective policy. +- Enforcement layer for first-release proposal types: + - `l4` + - `l7_rest` +- Intended scope: + - endpoint-only + - L7 method/path + - binary restriction + - time-bounded or session-bounded, if supported later +- Optional task context: + - repo URL + - issue/RFC reference + - command or tool that triggered the need + +The inbox should make it obvious whether a proposal is an L4 tunnel, an L7 REST rule, or a conversion from broad access to narrower L7 controls. + +Future protocol-aware proposal kinds such as MCP-aware controls should extend the model later rather than forcing the first-release schema to generalize prematurely. + +#### 4. Validation before approval + +Approval should present validated consequences, not just a proposed rule. + +Validation stages: + +1. Schema and static safety validation. +2. Deterministic simulation: + - what new hosts, ports, methods, or binaries would become reachable + - whether the change overlaps or broadens an existing rule + - whether the proposal is L4-only or protected by L7 enforcement +3. Policy-specific safety checks: + - always-blocked destinations + - suspicious private IP overrides + - wildcard or full-access expansions + - binaries or protocols that bypass L7 inspection +4. Formal verification when supported: + - use the existing prover infrastructure to check that the proposal satisfies a declared intent and does not exceed it + +The validator should emit an approval summary such as: + +- "Allows `git` to `github.com:443` and `api.github.com:443`." +- "Does not grant access to other GitHub hosts." +- "Adds write-capable REST paths for repo push semantics." +- "Touches only dynamic network policy." +- "This change is L4-only and does not provide method/path restriction." +- "This change upgrades the endpoint to L7 REST enforcement." + +Validation should also support two decision modes: + +- `human_review_required` + The proposal is shown in the developer inbox for explicit approval. +- `eligible_for_automatic_apply` + The proposal remains within a trusted approval envelope and may be applied automatically by policy. + +For first release, the recommended automatic-apply scope is intentionally narrow: + +- trusted external approver only +- network policy only +- L7 REST preferred where supported +- ephemeral lease durability by default +- only when prover, validation, and org ceiling checks succeed without ambiguity + +#### 4a. Structured deny feedback + +Denied operations should not only appear in logs and inboxes. OpenShell should also provide a structured deny feedback path that helps the in-sandbox agent recover intelligently by returning: + +- a machine-readable explanation of what was denied +- the relevant enforcement layer (`l4` or `l7_rest`) +- the reason the current policy did not allow it +- a pointer to the local policy workspace/API for inspection and proposal drafting + +The delivery mechanism may vary, but the RFC requires this to be a first-class capability rather than only an operator-facing side effect. + +#### 5. Unified developer inbox + +The existing draft-chunk review surface should become a generalized developer inbox with: + +- Real-time updates from the gateway. +- Filterable by sandbox, status, source, severity, and validation state. +- Renderable in: + - TUI + - CLI + - SDK/API + - future Slack/web integrations +- Support for: + - approve + - reject + - reject with guidance + - edit + - bulk approve with safeguards + - undo + - audit/history inspection + +The current TUI "Network Rules" panel is the correct seed, but the mental model should shift from "network rules list" to "policy proposal inbox." + +To support the UX targets above, the inbox architecture should include a subscription mechanism from the gateway to clients, such as streaming gRPC, SSE, or an equivalent event feed. The exact transport can be implementation-specific, but the user-visible behavior should be push-first. + +Rejection should be part of a revise-and-resubmit loop rather than a dead end. Operators should be able to reject a proposal with explanation so the agent can draft a narrower or corrected follow-up without requiring the operator to hand-author the policy change themselves. + +#### 6. L7-first agent experience + +A major product requirement is enabling strong default sandboxes with granular approval flows, especially for APIs like GitHub: + +- The default sandbox permits read-only GitHub API access via L7 policy. +- An agent attempts a write operation. +- The sandbox returns a structured deny that tells the agent: + - what was blocked, + - what part of policy caused the denial, + - how to inspect current policy, + - how to submit a narrow proposal. +- The agent proposes the smallest change needed for the target repo/workflow. +- The developer reviews a proposal phrased in task terms, not raw YAML only. + +OpenShell should explicitly steer the system toward the narrowest viable enforcement level: + +- prefer L7 REST rules for HTTP APIs such as GitHub, LinkedIn, X, Slack, Jira, and similar services +- fall back to L4 only when the protocol or client behavior prevents meaningful L7 enforcement +- tell the developer when a proposal is broad because the workload itself is broad, not because the system failed to model it precisely + +### REST, L4, and MCP + +REST APIs are the clearest near-term least-privilege win because OpenShell already supports `protocol: rest`, access presets, explicit method/path rules, TLS termination, and prover logic that can distinguish L4-only access from L7 write exposure. L7 REST should therefore be the default recommendation path for HTTP APIs, while L4-only proposals remain available for non-HTTP or opaque clients and should be clearly marked as broader access. MCP remains strategically important, but it should not drive the first-release schema: remote MCP still rests on transport controls such as HTTP/SSE/WebSocket, while local stdio MCP does not map neatly to network enforcement. The near-term plan is simple: **Phase 1-4 focus on L4 + L7 REST policy management; MCP-aware controls land as a later dedicated track.** + +#### 7. Trusted external approvers and policy ceilings + +Human approval should remain the default, but the system should also support a second mode where a trusted agent outside the sandbox can approve and apply changes automatically on behalf of the user when: + +- the organization defines an immutable high-level policy ceiling +- the sandbox policy starts below that ceiling +- the agent proposes a narrower incremental change needed to complete a task +- the prover and policy validator can show that the change stays within the allowed envelope + +In this model: + +- the org-level ceiling acts as a non-bypassable maximum +- sandbox policy revisions can expand only within that ceiling +- a trusted external agent or control-plane service may auto-apply compliant changes +- every request, validation result, and applied revision is logged for audit + +This gives OpenShell a path to adaptive least privilege without forcing a human to approve every safe change in real time. + +### Trust and approval model + +OpenShell should support at least three approval modes: + +1. `human_in_the_loop` + Every proposal requires explicit user approval. +2. `trusted_agent_within_ceiling` + A trusted external agent may apply changes automatically when validation and prover checks confirm the proposal stays within an org or user-defined maximum. +3. `manual_only_locked_down` + No automatic apply; some proposals may be visible but categorically blocked from execution by policy. + +The RFC does not propose allowing an in-sandbox agent to self-approve its own policy requests. Trusted external auto-apply is **in scope**, but it is distinct from autonomous in-sandbox mutation. The minimum shippable baseline is still a strong human-in-the-loop workflow. + +### Organizational policy layering + +This RFC assumes policy layering rather than a single mutable document: + +- `org ceiling policy` + The maximum capability envelope defined by security or platform teams. +- `sandbox effective policy` + The currently active policy for a sandbox, always a subset of the org ceiling when one exists. +- `proposal diff` + The incremental change requested by an agent or generated from deny analysis. + +For a proposal to be auto-applied, it must satisfy all of: + +1. valid OpenShell policy schema and merge semantics +2. no violation of always-blocked destinations or other hard safety rules +3. no violation of org ceiling constraints +4. successful prover or simulation checks against declared assumptions +5. successful audit logging and attribution + +If any check fails, the proposal falls back to human review or outright rejection. + +### Durability model + +Policy changes should not all have the same lifecycle. This RFC proposes three durability classes: + +1. `ephemeral_lease` + A time-bounded grant that expires automatically unless renewed. This is the recommended default for automatically applied expansions. +2. `sandbox_durable` + A durable revision for a specific sandbox or long-lived workflow. Suitable for human-approved changes or explicit promotion from a lease. +3. `promoted_policy_artifact` + A reusable policy artifact intended for future sandboxes, templates, or org-managed defaults. + +Recommended defaults: + +- auto-applied trusted-agent changes should start as `ephemeral_lease` unless explicitly promoted +- human-approved changes may become `sandbox_durable` directly when the reviewer intends lasting behavior +- promotion into reusable artifacts should be a deliberate step + +### Reject with guidance + +Operators should be able to do more than approve or reject. The system should support a guided rejection path: + +- `approve` + Accept and apply the proposal. +- `reject` + Decline the proposal without expecting an immediate follow-up. +- `reject_with_guidance` + Decline the proposal while returning operator guidance that the agent can use to revise and resubmit. + +Guidance may include free-form explanation plus structured hints such as `too_broad`, `use_l7_not_l4`, `wrong_binary_scope`, `wrong_endpoint`, `needs_time_limit`, or `outside_org_ceiling`. + +### Example: trusted daily research workflow + +One motivating workflow is a recurring research task: search X and LinkedIn for posts about a topic, summarize the results, and email the summary to the user. In that flow, the sandbox may start with minimal permissions plus an email provider, then request new outbound access to X and LinkedIn. A trusted external policy agent can prefer L7 REST rules when possible and apply them automatically when they fit within the organization's permitted research ceiling. + +### API and component changes + +#### Sandbox supervisor + +Add a local policy interaction surface: + +- sandbox-local `policy.local` HTTP API +- optional future MCP wrapper backed by that API + +Representative operations: + +- read effective policy, recent denials, and sandbox-local activity logs +- inspect proposal guidance and current proposal state +- submit a policy proposal + +This surface must be readable by the agent but not self-approving. + +Phase 2 implementation decisions: + +- primary transport: sandbox-local HTTP JSON at `policy.local` +- ergonomic wrapper: defer MCP/CLI wrappers until the local API proves useful +- first trust model: the sandbox is treated as single-tenant, so local callers are part of the sandbox tenant; this does not grant approval rights +- first proposal format: reuse the `PolicyMergeOperation` shape behind `openshell policy update` inside a JSON request body; the supervisor/local service bundles those operations with intent, summary, and optional evidence references, sends them to the gateway over gRPC, and the gateway stores them as draft chunks for approval instead of applying them immediately + +#### Gateway / server + +Extend the gateway proposal service to support: + +- explicit agent-authored proposal submission +- richer proposal metadata +- validation result persistence +- inbox subscriptions for multiple frontends +- trusted approver identities and authorization policies +- automatic-apply decisions gated by org ceiling and validation outcomes +- enforcement-layer-aware summaries and diffing +- durability classes and lease expiration metadata +- rejection reasons and operator guidance that can feed follow-up proposals +- stronger audit records tying: + - deny event(s) + - proposal author/source + - approval decision + - resulting policy revision + +The existing gRPC policy service is the natural place to grow this. + +#### TUI, CLI, and SDK + +The TUI should evolve from the current rules panel into a richer inbox with proposal summaries, validation state, diff views, edit-before-approve flow, and a clear distinction between "awaiting you" and "already auto-applied within policy ceiling." The CLI should preserve `openshell rule` for compatibility while introducing clearer proposal-centric aliases, and CLI/SDK surfaces should expose the same approval metadata so integrators can build their own inboxes and automation. + +## Implementation plan + +### Phase 1: Productize the current Policy Advisor + +Goal: turn the existing network rule draft flow into a first-class, polished foundation. + +Deliverables: + +- Rename and frame the current draft-chunk system internally as a proposal inbox. +- Add proposal provenance fields and validation summary fields. +- Improve TUI and CLI language to emphasize reviewable proposals. +- Document the current approval loop as a stable workflow. +- Set explicit UX targets for proposal latency and review responsiveness. +- Add a push/subscription path for proposal and decision updates to inbox clients. +- Audit existing `PolicyChunk` and draft-chunk persistence fields, then either hydrate, deprecate, or remove hollow fields before extending the model further. + +This phase is mostly packaging and data-model hardening on top of existing code in: + +- `crates/openshell-sandbox` +- `crates/openshell-server` +- `crates/openshell-tui` +- `crates/openshell-cli` + +### Phase 2: Local agent skill and supervisor policy API + +Goal: let any agent in a sandbox intentionally inspect and draft policy changes. + +Deliverables: + +- Generated sandbox-local `policy_advisor.md` or equivalent instruction bundle. +- Supervisor read APIs for policy state, denials, local activity logs, and capabilities. +- Initial proposal submission API. +- Structured deny messages that point agents to the local policy workflow. +- Feedback path so agents can read operator rejection guidance and iterate on a proposal. + +This is the point where the feature becomes broadly useful to OpenClaw, Claude Code, Cursor, and other agents. + +### Phase 3: Validation and simulation + +Goal: make approval trustworthy and fast. + +Deliverables: + +- Policy diff generation. +- Consequence summaries for proposed changes. +- Integration with prover/simulation infrastructure where available. +- Clear validation statuses in TUI and CLI. +- Org ceiling checks and trusted-agent auto-apply eligibility. +- Clear reporting for L4-only versus L7-enforced proposals. +- Safety-aware redaction so sandbox-local introspection does not expose full ceiling internals. + +This phase is critical before broadening beyond simple endpoint approvals. + +### Phase 4: Rich L7 authoring and GitHub write flow + +Goal: demonstrate the full UX on a high-value developer workflow. + +Deliverables: + +- Structured GitHub write-policy proposals from agent intent. +- Support for method/path-level rule authoring via agent workflow. +- Validation tuned for common provider/API patterns. +- Demo and tutorial flows centered on repo write access. + +This phase should produce the canonical blocked-write upgrade experience. + +### Phase 5: Generalized inbox surfaces + +Goal: expose proposal review outside the TUI. + +Deliverables: + +- Stable SDK/API for proposal feeds and decisions. +- CLI parity for all proposal operations. +- Optional Slack/web notification adapters. + +### Phase 6: Trusted automation and recurring workflows + +Goal: support safe automatic policy evolution for approved automation patterns. + +Deliverables: + +- policy ceiling model for org or platform admins +- trusted external approver identity model +- automatic apply path when proposals stay within ceiling +- audit trail and reporting for auto-applied revisions +- lease-based durability for automatically applied changes +- reference workflow for recurring research-and-email automation + +### Future phase: protocol-aware policy adapters + +Goal: extend dynamic policy management beyond REST where higher-level semantics exist, including MCP-aware policy controls, richer SQL enforcement once enforce-mode support exists, and protocol-specific adapters for common tool ecosystems. + +## Migration and compatibility + +The intended rollout is additive-first. + +- Existing `openshell rule` commands should continue to work while proposal-centric APIs and UX are introduced. +- Existing mechanistic sandboxes should remain compatible with a newer gateway during the transition. +- Database and proto evolution should prefer additive fields and compatibility shims before any cleanup of legacy draft-chunk semantics. +- If proposal semantics outgrow the current draft-chunk schema, migration should preserve existing pending, approved, and rejected records rather than discarding inbox history. + +## Risks + +- Agent-authored proposals may overfit to task success and underweight least privilege. +- A local skill that teaches policy mutation could be abused if submission and approval boundaries are not crisp. +- Validation that is too weak will make approvals feel unsafe; validation that is too noisy will make the UX slow and frustrating. +- Expanding too quickly from network policy into filesystem/process policy could blur scope and delay a polished first release. +- Adding multiple proposal producers without a unified model could create duplicate or conflicting inbox entries. +- If the inbox UX is not excellent, developers may perceive OpenShell as secure but cumbersome and choose a less safe system with lower friction. +- Automatic apply under trusted-agent control could become a footgun if org ceiling semantics are vague or prover guarantees are misunderstood. + +## Alternatives + +### Keep the current Policy Advisor as-is + +This would preserve a useful feature, but it leaves the product short of the agent-native UX we want. Developers would still do too much translation work between denies, policy syntax, and human approval. + +### Rely only on a human-side coding agent outside the sandbox + +This is workable for expert users and is already partially demonstrated in tutorials, but it misses the core product insight: the in-sandbox agent has the best task context and should be the one drafting the narrowest possible change. + +### Let agents mutate policy directly without approval + +This would be faster, but it is not aligned with OpenShell's safety model and would erase the developer-control story that makes dynamic policy editing acceptable in the first place. + +### Require human approval for every policy change forever + +This is safer in a narrow sense, but it caps automation quality and makes some recurring workflows awkward or brittle. A trusted external approver model bounded by organizational ceilings provides a better long-term path. + +### Treat all network expansion as generic L4 access + +This would simplify the proposal model, but it would throw away one of OpenShell's strongest differentiators. For API-driven developer workflows, L7 REST enforcement is often the right least-privilege abstraction and should be surfaced directly in the RFC and UI. + +### Move proposal generation to the gateway + +This would centralize logic, but it weakens the current architecture. Sandbox-side analysis is the right default because it scales naturally and keeps task-local context near the source of truth. + +## Open questions + +- How should developer intent be declared for validation: + - free-form text + - a structured capability request + - both +- Do we want a single proposal inbox for all policy domains eventually, or separate inboxes that share infrastructure? +- How should org ceiling policy be authored and stored: OpenShell policy syntax, a separate constraint language, or both? +- Which identities are allowed to act as trusted external approvers, and how are those permissions delegated? +- How do we present auto-applied changes so users feel informed rather than surprised? +- When L7 policy is involved, how much raw request context can be safely shown to the developer without leaking sensitive request data? +- Should MCP-aware policy be modeled as network policy enrichment, a separate policy domain, or a capability layer above both? From af60d4e466997763a769862e955d7a210ec5b3e6 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sat, 9 May 2026 01:16:47 -0400 Subject: [PATCH 020/157] docs: document OPENSHELL_SSH_HANDSHAKE_SECRET in Getting Started (#1287) The Podman and Kubernetes compute drivers require OPENSHELL_SSH_HANDSHAKE_SECRET to be set. This was introduced in 2e0afeab ("feat(vm): derive guest rootfs from sandbox images (#957)"), which exempted only the Docker and VM drivers from the check. The Getting Started instructions in CONTRIBUTING.md didn't mention the variable, so developers using Podman (the default on systems where it is installed) hit an opaque configuration error on first run. Add the export as a separate setup step with a comment explaining which drivers require it. Signed-off-by: Russell Bryant --- CONTRIBUTING.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5c091c6c4..aa3d1b0a9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -152,6 +152,10 @@ cargo build -p openshell-prover --features bundled-z3 # One-time trust mise trust +# Podman and Kubernetes drivers require an SSH handshake secret. +# Set any value for local development: +export OPENSHELL_SSH_HANDSHAKE_SECRET=dev-secret + # Run a standalone gateway for local development mise run gateway ``` From 57a80ed2ad2fd2757b3d25e95787e5c709f62158 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sat, 9 May 2026 01:17:12 -0400 Subject: [PATCH 021/157] fix(gateway): update Podman supervisor build task name (#1288) The Podman path in the dev gateway script references the build:docker:supervisor-sideload mise task, which was removed in d8b84773 ("feat(rpm): add RPM packaging with Packit/COPR and GHA release publishing (#1126)"). That commit consolidated the sideload and standalone supervisor image builds into a single build:docker:supervisor task. Update the reference so `mise run gateway` works with the Podman compute driver. Signed-off-by: Russell Bryant --- tasks/scripts/gateway.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tasks/scripts/gateway.sh b/tasks/scripts/gateway.sh index 608eabbf2..823ccb5f0 100644 --- a/tasks/scripts/gateway.sh +++ b/tasks/scripts/gateway.sh @@ -177,7 +177,7 @@ ensure_podman_supervisor_image() { echo "Building Podman supervisor sideload image (${supervisor_image})..." require_mise - CONTAINER_ENGINE=podman IMAGE_TAG=dev mise run build:docker:supervisor-sideload + CONTAINER_ENGINE=podman IMAGE_TAG=dev mise run build:docker:supervisor if ! podman image exists "${supervisor_image}" >/dev/null 2>&1; then echo "ERROR: expected supervisor image '${supervisor_image}' after build" >&2 From 072f2272488518e5ece027d36aa0d80fa0a71f8f Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Sat, 9 May 2026 11:10:57 -0700 Subject: [PATCH 022/157] fix(installer): guard incompatible v0.0.37 upgrades (#1294) --- install.sh | 154 +++++++++++++++++++++++++++++++++++++++++++++++++++++ mise.lock | 14 +++-- 2 files changed, 160 insertions(+), 8 deletions(-) diff --git a/install.sh b/install.sh index 98d2d5d77..e6ed53bb4 100755 --- a/install.sh +++ b/install.sh @@ -18,6 +18,8 @@ CHECKSUMS_NAME="openshell-checksums-sha256.txt" LOCAL_GATEWAY_PORT="17670" HOMEBREW_TAP="nvidia/openshell" HOMEBREW_FORMULA_NAME="openshell" +BREAKING_RELEASE_VERSION="0.0.37" +UPGRADE_NOTICE_ACK="${OPENSHELL_ACK_BREAKING_UPGRADE:-}" info() { printf '%s: %s\n' "$APP_NAME" "$*" >&2 @@ -48,6 +50,9 @@ OPTIONS: ENVIRONMENT VARIABLES: OPENSHELL_VERSION Release tag to install (default: latest tagged release). Set OPENSHELL_VERSION=dev to install the rolling dev build. + OPENSHELL_ACK_BREAKING_UPGRADE + Set to 1 only after backing up and cleaning up a + pre-v0.0.37 installation. NOTES: When OPENSHELL_VERSION is unset, this resolves the latest tagged release @@ -76,6 +81,153 @@ download() { curl -fLsS --retry 3 --max-redirs 5 -o "$_output" "$_url" } +semver_core() { + _version="${1#v}" + _version="${_version%%[-+]*}" + printf '%s\n' "$_version" +} + +semver_at_least() { + _version="$(semver_core "$1")" + _minimum="$(semver_core "$2")" + + _major="${_version%%.*}" + _rest="${_version#*.}" + [ "$_rest" != "$_version" ] || return 1 + _minor="${_rest%%.*}" + _patch="${_rest#*.}" + _patch="${_patch%%.*}" + + _min_major="${_minimum%%.*}" + _min_rest="${_minimum#*.}" + [ "$_min_rest" != "$_minimum" ] || return 1 + _min_minor="${_min_rest%%.*}" + _min_patch="${_min_rest#*.}" + _min_patch="${_min_patch%%.*}" + + case "$_major:$_minor:$_patch:$_min_major:$_min_minor:$_min_patch" in + *[!0-9:]* | *::*) + return 1 + ;; + esac + + [ "$_major" -gt "$_min_major" ] && return 0 + [ "$_major" -lt "$_min_major" ] && return 1 + [ "$_minor" -gt "$_min_minor" ] && return 0 + [ "$_minor" -lt "$_min_minor" ] && return 1 + [ "$_patch" -ge "$_min_patch" ] +} + +target_uses_breaking_gateway_model() { + case "$RELEASE_TAG" in + dev) + return 0 + ;; + esac + + semver_at_least "$RELEASE_TAG" "$BREAKING_RELEASE_VERSION" +} + +installed_version_needs_breaking_upgrade_notice() { + _version="$1" + + if [ -z "$_version" ]; then + return 0 + fi + + ! semver_at_least "$_version" "$BREAKING_RELEASE_VERSION" +} + +find_existing_openshell_bin() { + _path="$(command -v openshell 2>/dev/null || true)" + if [ -n "$_path" ] && [ -x "$_path" ]; then + printf '%s\n' "$_path" + return 0 + fi + + for _candidate in \ + "${TARGET_HOME:-}/.local/bin/openshell" \ + /usr/local/bin/openshell \ + /usr/bin/openshell \ + /opt/homebrew/bin/openshell; do + if [ -n "$_candidate" ] && [ -x "$_candidate" ]; then + printf '%s\n' "$_candidate" + return 0 + fi + done + + return 1 +} + +existing_openshell_version() { + _bin="$1" + _output="$("$_bin" --version 2>/dev/null | sed -n '1p' || true)" + printf '%s\n' "$_output" | awk ' + { + for (i = 1; i <= NF; i++) { + if ($i ~ /^v?[0-9]+\.[0-9]+\.[0-9]+([-+][A-Za-z0-9.+~-]+)?$/) { + print $i + exit + } + } + } + ' +} + +print_breaking_upgrade_notice() { + _bin="$1" + _version="$2" + + if [ -n "$_version" ]; then + warn "detected existing OpenShell ${_version} at ${_bin}" + else + warn "detected an existing OpenShell installation at ${_bin}" + fi + + cat >&2 < Date: Sat, 9 May 2026 14:18:34 -0400 Subject: [PATCH 023/157] fix(docker): add SELinux labeling to bind mounts (#1291) --- crates/openshell-driver-docker/src/lib.rs | 38 +++++++++------------ crates/openshell-driver-docker/src/tests.rs | 8 ++--- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index db197685d..a74609fb5 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -9,8 +9,8 @@ use bollard::Docker; use bollard::errors::Error as BollardError; use bollard::models::{ ContainerCreateBody, ContainerSummary, ContainerSummaryStateEnum, DeviceRequest, - EndpointSettings, HostConfig, Mount, MountTypeEnum, NetworkCreateRequest, NetworkingConfig, - RestartPolicy, RestartPolicyNameEnum, SystemInfo, + EndpointSettings, HostConfig, NetworkCreateRequest, NetworkingConfig, RestartPolicy, + RestartPolicyNameEnum, SystemInfo, }; use bollard::query_parameters::{ CreateContainerOptionsBuilder, CreateImageOptions, DownloadFromContainerOptionsBuilder, @@ -865,28 +865,22 @@ impl ComputeDriver for DockerComputeDriver { } } -fn build_mounts(config: &DockerDriverRuntimeConfig) -> Vec { - let mut mounts = vec![bind_mount( - &config.supervisor_bin, - SUPERVISOR_MOUNT_PATH, - true, +fn build_binds(config: &DockerDriverRuntimeConfig) -> Vec { + let mut binds = vec![format!( + "{}:{}:ro,z", + config.supervisor_bin.display(), + SUPERVISOR_MOUNT_PATH )]; if let Some(tls) = &config.guest_tls { - mounts.push(bind_mount(&tls.ca, TLS_CA_MOUNT_PATH, true)); - mounts.push(bind_mount(&tls.cert, TLS_CERT_MOUNT_PATH, true)); - mounts.push(bind_mount(&tls.key, TLS_KEY_MOUNT_PATH, true)); - } - mounts -} - -fn bind_mount(source: &Path, target: &str, read_only: bool) -> Mount { - Mount { - target: Some(target.to_string()), - source: Some(source.display().to_string()), - typ: Some(MountTypeEnum::BIND), - read_only: Some(read_only), - ..Default::default() + binds.push(format!("{}:{}:ro,z", tls.ca.display(), TLS_CA_MOUNT_PATH)); + binds.push(format!( + "{}:{}:ro,z", + tls.cert.display(), + TLS_CERT_MOUNT_PATH + )); + binds.push(format!("{}:{}:ro,z", tls.key.display(), TLS_KEY_MOUNT_PATH)); } + binds } fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig) -> Vec { @@ -999,7 +993,7 @@ fn build_container_create_body( nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, device_requests: docker_gpu_device_requests(spec.gpu), - mounts: Some(build_mounts(config)), + binds: Some(build_binds(config)), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), maximum_retry_count: None, diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index 41c9a5901..c89019398 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -317,11 +317,11 @@ fn build_environment_keeps_path_driver_controlled() { } #[test] -fn build_mounts_uses_docker_tls_directory() { - let mounts = build_mounts(&runtime_config()); - let targets = mounts +fn build_binds_uses_docker_tls_directory() { + let binds = build_binds(&runtime_config()); + let targets = binds .iter() - .filter_map(|mount| mount.target.clone()) + .filter_map(|bind| bind.split(':').nth(1).map(String::from)) .collect::>(); assert!(targets.contains(&SUPERVISOR_MOUNT_PATH.to_string())); assert!(targets.contains(&TLS_CA_MOUNT_PATH.to_string())); From 435048216023b9090ac51c75fc4f5ae3f8daa55b Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Sat, 9 May 2026 11:28:09 -0700 Subject: [PATCH 024/157] docs(readme): add roadmap and RFC issue guidance (#1284) Signed-off-by: Drew Newberry --- README.md | 1 + rfc/0000-template/README.md | 2 +- rfc/README.md | 12 ++++++------ 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 02447e421..14913b995 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ All implementation work is human-gated — agents propose plans, humans approve, - [Quickstart](https://docs.nvidia.com/openshell/latest/get-started/quickstart) — detailed install and first sandbox walkthrough - [GitHub Sandbox Tutorial](https://docs.nvidia.com/openshell/latest/tutorials/github-sandbox) — end-to-end scoped GitHub repo access - [Architecture](https://github.com/NVIDIA/OpenShell/tree/main/architecture) — detailed architecture docs and design decisions +- [Roadmap](https://github.com/orgs/NVIDIA/projects/233) — planned work and project priorities - [Support Matrix](https://docs.nvidia.com/openshell/latest/reference/support-matrix) — platforms, versions, and kernel requirements - [Brev Launchable](https://brev.nvidia.com/launchable/deploy/now?launchableID=env-3Ap3tL55zq4a8kew1AuW0FpSLsg) — try OpenShell on cloud compute without local setup - [Agent Instructions](AGENTS.md) — system prompt and workflow documentation for agent contributors diff --git a/rfc/0000-template/README.md b/rfc/0000-template/README.md index ec66ca967..1cd1810a5 100644 --- a/rfc/0000-template/README.md +++ b/rfc/0000-template/README.md @@ -3,7 +3,7 @@ authors: - "@your-github-username" state: draft links: - - (related PRs, discussions, or issues) + - (related PRs or issues) --- # RFC NNNN - Your Title Here diff --git a/rfc/README.md b/rfc/README.md index 67a643848..96296e5fe 100644 --- a/rfc/README.md +++ b/rfc/README.md @@ -2,16 +2,16 @@ Substantial changes to OpenShell should be proposed in writing before implementation begins. An RFC provides a consistent way to propose an idea, collect feedback from the community, build consensus, and document the decision for future contributors. Not every change needs an RFC — bug fixes, small features, and routine maintenance go through normal pull requests. RFCs are for the changes that are cross-cutting, potentially controversial, or significant enough that stakeholders should weigh in before code is written. -## Start with a GitHub Discussion +## Start with a GitHub issue -Before writing an RFC, consider opening a [GitHub Discussion](https://github.com/NVIDIA/OpenShell/discussions) to gauge interest and get early feedback. This helps: +Before writing an RFC, consider opening a [GitHub issue](https://github.com/NVIDIA/OpenShell/issues/new/choose) to scope the problem, gauge interest, and get early feedback. This helps: - Validate that the problem is worth solving - Surface potential concerns early - Build consensus before investing in a detailed proposal - Identify the right reviewers and stakeholders -If the discussion shows sufficient interest and the idea has merit, then it's time to write an RFC to detail the plan and technical approach. +If the ticket shows sufficient interest and the idea has merit, then it's time to write an RFC to detail the plan and technical approach. ## RFCs vs other artifacts @@ -19,7 +19,7 @@ OpenShell has several places where design information lives. Use this guide to p | Artifact | Purpose | When to use | |----------|---------|-------------| -| **GitHub Discussion** | Gauge interest in a rough idea | You have a thought but aren't sure it's worth a proposal yet | +| **GitHub issue** | Track and scope a rough idea | You have a thought but aren't sure it's worth a proposal yet | | **Spike issue** (`create-spike`) | Investigate implementation feasibility for a scoped change | You need to explore the codebase and produce a buildable issue for a specific component or feature | | **RFC** | Propose a cross-cutting decision that needs broad consensus | Architectural changes, API contracts, process changes, or anything that spans multiple components or teams | | **Architecture doc** (`architecture/`) | Document how things work today | Living reference material — updated as the system evolves | @@ -61,7 +61,7 @@ authors: state: draft links: - https://github.com/NVIDIA/OpenShell/pull/123 - - https://github.com/NVIDIA/OpenShell/discussions/456 + - https://github.com/NVIDIA/OpenShell/issues/456 --- ``` @@ -69,7 +69,7 @@ We track the following metadata: - **authors**: The authors (and therefore owners) of an RFC. Listed as GitHub usernames. - **state**: Must be one of the states discussed below. -- **links**: Related PRs, discussions, or issues. Add entries as the RFC progresses. +- **links**: Related PRs or issues. Add entries as the RFC progresses. - **superseded_by**: *(optional)* For RFCs in the `superseded` state, the RFC number that replaces this one (e.g., `0005`). An RFC can be in one of the following states: From ca6384195388a8cdd248403b2a713e0208c27579 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Sat, 9 May 2026 11:30:30 -0700 Subject: [PATCH 025/157] docs(rfc): move policy management RFC to 0002 (#1283) --- .../README.md} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename rfc/{0001-agent-driven-policy-management.md => 0002-agent-driven-policy-management/README.md} (99%) diff --git a/rfc/0001-agent-driven-policy-management.md b/rfc/0002-agent-driven-policy-management/README.md similarity index 99% rename from rfc/0001-agent-driven-policy-management.md rename to rfc/0002-agent-driven-policy-management/README.md index 6816d1331..9a1e3c4aa 100644 --- a/rfc/0001-agent-driven-policy-management.md +++ b/rfc/0002-agent-driven-policy-management/README.md @@ -7,7 +7,7 @@ links: - https://github.com/NVIDIA/OpenShell/blob/main/architecture/policy-advisor.md --- -# RFC 0001 - Agent-Driven Policy Management +# RFC 0002 - Agent-Driven Policy Management >CLI: Client bytes GW->>SUP: Relay bytes SUP-->>GW: Relay bytes ``` -The same relay pattern backs interactive SSH, command execution, and file sync. -The gateway tracks live sessions in memory and persists session records so -tokens can expire or be revoked. +The same relay pattern backs interactive SSH, command execution, file sync, and +local service forwarding. The gateway tracks live sessions in memory and +persists session records so tokens can expire or be revoked. + +`ForwardTcp` is the client-facing byte stream for SSH and service forwarding. +The first frame is a `TcpForwardInit` that carries the sandbox ID, an +authorization token from `CreateSshSession`, and an explicit target: +`target.ssh` for the sandbox SSH socket or `target.tcp` for a loopback service +inside the sandbox. The gateway validates the token and sandbox readiness, +sends a targeted `RelayOpen` to the supervisor, then bridges +`TcpForwardFrame::Data` to `RelayFrame::Data` until either side closes. + +For `target.tcp`, the gateway only accepts loopback destinations such as +`localhost`, `127.0.0.0/8`, or `::1`. The gateway never needs to know or dial a +sandbox pod IP; supervisors connect outbound and bridge only the explicit target +requested for that relay. ## PKI Bootstrap @@ -143,13 +157,13 @@ created. Both deployment paths use it: | Filesystem | `--output-dir ` | `/{ca.crt, ca.key, server/tls.{crt,key}, client/tls.{crt,key}}`. Also copies client materials to `$XDG_CONFIG_HOME/openshell/gateways/openshell/mtls/` for CLI auto-discovery. | On Kubernetes, the Helm chart runs the command via a pre-install/pre-upgrade -hook Job using the gateway image itself — no separate cert-generation image, +hook Job using the gateway image itself -- no separate cert-generation image, no extra mirror burden in air-gapped environments. On the RPM gateway, the same command runs from the systemd unit's `ExecStartPre` to bootstrap PKI into the user's state directory on first start. -Both modes share the same idempotency contract: all targets present → skip; -partial state → fail with a recovery hint; nothing present → generate and +Both modes share the same idempotency contract: all targets present -> skip; +partial state -> fail with a recovery hint; nothing present -> generate and write. This guards mTLS continuity across restarts and upgrades while still recovering cleanly if an operator deletes everything and starts over. diff --git a/crates/openshell-cli/Cargo.toml b/crates/openshell-cli/Cargo.toml index 8b86544b7..21068ad99 100644 --- a/crates/openshell-cli/Cargo.toml +++ b/crates/openshell-cli/Cargo.toml @@ -68,6 +68,7 @@ tokio-tungstenite = { workspace = true } # Streams futures = { workspace = true } +tokio-stream = { workspace = true } nix = { workspace = true } # URL parsing diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index e25fb7576..e370d1f27 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -8,6 +8,7 @@ use clap_complete::engine::ArgValueCompleter; use clap_complete::env::CompleteEnv; use miette::Result; use owo_colors::OwoColorize; +use std::collections::HashMap; use std::io::Write; use std::path::PathBuf; @@ -266,6 +267,7 @@ const FORWARD_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m \x1b[1mEXAMPLES\x1b[0m $ openshell forward start 8080 $ openshell forward start 3000 my-sandbox + $ openshell forward service my-sandbox --target-port 8000 --local 8000 $ openshell forward stop 8080 $ openshell forward list "; @@ -1612,6 +1614,26 @@ enum ForwardCommands { /// List active port forwards. #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] List, + + /// Forward a local TCP port to a loopback service inside a sandbox over gRPC. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Service { + /// Sandbox name (defaults to last-used sandbox). + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: Option, + + /// Target service port inside the sandbox. + #[arg(long)] + target_port: u16, + + /// Target service host inside the sandbox. Phase 1 accepts loopback only. + #[arg(long, default_value = "127.0.0.1")] + target_host: String, + + /// Local bind address and port: `[bind_address:]port`. Defaults to the target port. Use port 0 for dynamic assignment. + #[arg(long)] + local: Option, + }, } #[tokio::main] @@ -1854,6 +1876,27 @@ async fn main() -> Result<()> { } } } + ForwardCommands::Service { + name, + target_port, + target_host, + local, + } => { + let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; + let mut tls = tls.with_gateway_name(&ctx.name); + apply_auth(&mut tls, &ctx.name); + let name = resolve_sandbox_name(name, &ctx.name)?; + let local = local.unwrap_or_else(|| target_port.to_string()); + run::service_forward_tcp( + &ctx.endpoint, + &name, + Some(&local), + &target_host, + target_port, + &tls, + ) + .await?; + } ForwardCommands::Start { port, name, @@ -2237,7 +2280,7 @@ async fn main() -> Result<()> { }; // Parse --label flags into a HashMap. - let mut labels_map = std::collections::HashMap::new(); + let mut labels_map = HashMap::new(); for label_str in &labels { let parts: Vec<&str> = label_str.splitn(2, '=').collect(); if parts.len() != 2 { diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 165713b6e..2797bd66c 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -26,7 +26,7 @@ use openshell_bootstrap::{ use openshell_core::proto::ProviderProfileCategory; use openshell_core::proto::{ ApproveAllDraftChunksRequest, ApproveDraftChunkRequest, AttachSandboxProviderRequest, - ClearDraftChunksRequest, CreateProviderRequest, CreateSandboxRequest, + ClearDraftChunksRequest, CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, DeleteProviderProfileRequest, DeleteProviderRequest, DeleteSandboxRequest, DetachSandboxProviderRequest, ExecSandboxRequest, GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, @@ -35,9 +35,10 @@ use openshell_core::proto::{ LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest, PolicySource, PolicyStatus, Provider, ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, - RejectDraftChunkRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, - SetClusterInferenceRequest, SettingScope, SettingValue, UpdateConfigRequest, - UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value, + RejectDraftChunkRequest, RevokeSshSessionRequest, Sandbox, SandboxPhase, SandboxPolicy, + SandboxSpec, SandboxTemplate, SetClusterInferenceRequest, SettingScope, SettingValue, + TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest, + WatchSandboxRequest, exec_sandbox_event, setting_value, tcp_forward_init, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; @@ -1554,7 +1555,7 @@ pub async fn sandbox_create( status.message() )); } - Err(status) => return Err(status).into_diagnostic(), + Err(status) => return Err(miette::miette!(status.to_string())), }; let sandbox = response .into_inner() @@ -2438,6 +2439,292 @@ pub async fn sandbox_exec_grpc( Ok(exit_code) } +pub async fn service_forward_tcp( + server: &str, + name: &str, + local: Option<&str>, + target_host: &str, + target_port: u16, + tls: &TlsOptions, +) -> Result<()> { + let (bind_addr, bind_port) = parse_tcp_forward_spec(local, target_port)?; + let mut client = grpc_client(server, tls).await?; + + let sandbox = fetch_ready_sandbox_for_forward(&mut client, name).await?; + + let listener = tokio::net::TcpListener::bind((bind_addr.as_str(), bind_port)) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to bind local forward on {bind_addr}:{bind_port}"))?; + let local_addr = listener + .local_addr() + .into_diagnostic() + .wrap_err("failed to read local forward address")?; + eprintln!( + "{} Forwarding {} -> {}:{} in sandbox {} via gRPC", + "✓".green().bold(), + local_addr, + target_host, + target_port, + name, + ); + + let sandbox_id = sandbox.object_id().to_string(); + let (fatal_tx, mut fatal_rx) = tokio::sync::mpsc::channel::(1); + let mut health_check = tokio::time::interval(Duration::from_secs(2)); + health_check.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + loop { + tokio::select! { + Some(reason) = fatal_rx.recv() => { + return Err(miette::miette!("service forward stopped: {reason}")); + } + + _ = health_check.tick() => { + fetch_ready_sandbox_for_forward(&mut client, name).await?; + } + + accepted = listener.accept() => { + let (socket, peer) = accepted + .into_diagnostic() + .wrap_err("failed to accept local forward connection")?; + let mut client = client.clone(); + let sandbox_id = sandbox_id.clone(); + let target_host = target_host.to_string(); + let service_id = format!("service-forward:{name}:{target_host}:{target_port}"); + let fatal_tx = fatal_tx.clone(); + tokio::spawn(async move { + let token = match create_forward_session_token(&mut client, &sandbox_id).await { + Ok(token) => token, + Err(err) => { + tracing::warn!(peer = %peer, error = %err, "service forward session creation failed"); + if err.fatal { + let _ = fatal_tx.send(err.message).await; + } + return; + } + }; + if let Err(err) = forward_one_tcp_connection( + &mut client, + socket, + sandbox_id, + target_host, + target_port, + service_id, + token.clone(), + ) + .await + { + tracing::warn!(peer = %peer, error = %err, "service forward connection failed"); + if err.fatal { + let _ = fatal_tx.send(err.message).await; + } + } + let _ = client + .revoke_ssh_session(RevokeSshSessionRequest { token }) + .await; + }); + } + } + } +} + +async fn create_forward_session_token( + client: &mut crate::tls::GrpcClient, + sandbox_id: &str, +) -> std::result::Result { + let response = client + .create_ssh_session(CreateSshSessionRequest { + sandbox_id: sandbox_id.to_string(), + }) + .await + .map_err(ForwardTcpConnectionError::from_status)?; + Ok(response.into_inner().token) +} + +async fn fetch_ready_sandbox_for_forward( + client: &mut crate::tls::GrpcClient, + name: &str, +) -> Result { + let response = match client + .get_sandbox(GetSandboxRequest { + name: name.to_string(), + }) + .await + { + Ok(response) => response, + Err(status) if status.code() == Code::NotFound => { + return Err(miette::miette!( + "sandbox '{name}' no longer exists; stopping service forward" + )); + } + Err(status) => return Err(status).into_diagnostic(), + }; + + let sandbox = response + .into_inner() + .sandbox + .ok_or_else(|| miette::miette!("sandbox '{name}' not found"))?; + + if SandboxPhase::try_from(sandbox.phase) != Ok(SandboxPhase::Ready) { + return Err(miette::miette!( + "sandbox '{}' is no longer ready (phase: {}); stopping service forward", + name, + phase_name(sandbox.phase) + )); + } + + Ok(sandbox) +} + +#[derive(Debug)] +struct ForwardTcpConnectionError { + message: String, + fatal: bool, +} + +impl ForwardTcpConnectionError { + fn transient(message: impl Into) -> Self { + Self { + message: message.into(), + fatal: false, + } + } + + fn from_status(status: Status) -> Self { + let fatal = matches!(status.code(), Code::NotFound | Code::FailedPrecondition); + Self { + message: status.to_string(), + fatal, + } + } +} + +impl std::fmt::Display for ForwardTcpConnectionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for ForwardTcpConnectionError {} + +fn parse_tcp_forward_spec(local: Option<&str>, default_port: u16) -> Result<(String, u16)> { + let Some(spec) = local else { + return Ok(("127.0.0.1".to_string(), default_port)); + }; + + if let Some(pos) = spec.rfind(':') { + let addr = &spec[..pos]; + let port_str = &spec[pos + 1..]; + if let Ok(port) = port_str.parse::() { + if addr.is_empty() { + return Err(miette::miette!("bind address is required before ':'")); + } + return Ok((addr.to_string(), port)); + } + } + + let port: u16 = spec.parse().map_err(|_| { + miette::miette!("invalid local forward spec '{spec}': expected [bind_address:]port") + })?; + Ok(("127.0.0.1".to_string(), port)) +} + +async fn forward_one_tcp_connection( + client: &mut crate::tls::GrpcClient, + socket: tokio::net::TcpStream, + sandbox_id: String, + target_host: String, + target_port: u16, + service_id: String, + authorization_token: String, +) -> std::result::Result<(), ForwardTcpConnectionError> { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::wrappers::ReceiverStream; + + let (tx, rx) = tokio::sync::mpsc::channel::(16); + tx.send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Init( + TcpForwardInit { + sandbox_id, + service_id, + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: target_host, + port: u32::from(target_port), + })), + authorization_token, + }, + )), + }) + .await + .map_err(|_| ForwardTcpConnectionError::transient("failed to initialize forward stream"))?; + + let mut response = match client.forward_tcp(ReceiverStream::new(rx)).await { + Ok(response) => response.into_inner(), + Err(status) => { + let err = ForwardTcpConnectionError::from_status(status); + drain_and_shutdown_local_socket(socket).await; + return Err(err); + } + }; + + let (mut local_read, mut local_write) = socket.into_split(); + + let to_gateway = tokio::spawn(async move { + let mut buf = vec![0u8; 64 * 1024]; + loop { + let n = local_read.read(&mut buf).await?; + if n == 0 { + break; + } + if tx + .send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Data( + buf[..n].to_vec(), + )), + }) + .await + .is_err() + { + break; + } + } + Ok::<(), std::io::Error>(()) + }); + + while let Some(frame) = response + .message() + .await + .map_err(ForwardTcpConnectionError::from_status)? + { + let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = frame.payload + else { + continue; + }; + if data.is_empty() { + continue; + } + local_write + .write_all(&data) + .await + .map_err(|err| ForwardTcpConnectionError::transient(err.to_string()))?; + } + + let _ = local_write.shutdown().await; + to_gateway.abort(); + Ok(()) +} + +async fn drain_and_shutdown_local_socket(mut socket: tokio::net::TcpStream) { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let mut buf = [0u8; 4096]; + while matches!( + tokio::time::timeout(Duration::from_millis(25), socket.read(&mut buf)).await, + Ok(Ok(n)) if n != 0 + ) {} + let _ = socket.shutdown().await; +} + /// Print a single YAML line with dimmed keys and regular values. fn print_yaml_line(line: &str) { // Find leading whitespace diff --git a/crates/openshell-cli/src/ssh.rs b/crates/openshell-cli/src/ssh.rs index 89e9071e1..e99e6ee15 100644 --- a/crates/openshell-cli/src/ssh.rs +++ b/crates/openshell-cli/src/ssh.rs @@ -3,30 +3,30 @@ //! SSH connection and proxy utilities. -use crate::tls::{TlsOptions, build_rustls_config, grpc_client, require_tls_materials}; +use crate::tls::{TlsOptions, grpc_client}; use miette::{IntoDiagnostic, Result, WrapErr}; #[cfg(unix)] use nix::sys::signal::{SaFlags, SigAction, SigHandler, SigSet, Signal, sigaction}; use openshell_core::ObjectId; use openshell_core::forward::{ - build_proxy_command, find_ssh_forward_pid, resolve_ssh_gateway, shell_escape, - validate_ssh_session_response, write_forward_pid, + build_proxy_command, find_ssh_forward_pid, format_gateway_url, resolve_ssh_gateway, + shell_escape, validate_ssh_session_response, write_forward_pid, +}; +use openshell_core::proto::{ + CreateSshSessionRequest, GetSandboxRequest, SshRelayTarget, TcpForwardFrame, TcpForwardInit, + tcp_forward_init, }; -use openshell_core::proto::{CreateSshSessionRequest, GetSandboxRequest}; use owo_colors::OwoColorize; -use rustls::pki_types::ServerName; use std::fs; use std::io::{IsTerminal, Write}; #[cfg(unix)] use std::os::unix::process::CommandExt; use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; -use std::sync::Arc; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; -use tokio::net::TcpStream; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::process::Command as TokioCommand; -use tokio_rustls::TlsConnector; +use tokio_stream::wrappers::ReceiverStream; const FOREGROUND_FORWARD_STARTUP_GRACE_PERIOD: Duration = Duration::from_secs(2); @@ -100,8 +100,7 @@ async fn ssh_session_config( // external tunnel endpoint (the cluster URL), not the server's internal // scheme/host/port which may be plaintext HTTP on 127.0.0.1. let gateway_url = if tls.is_bearer_auth() { - let base = server.trim_end_matches('/'); - format!("{base}{}", session.connect_path) + server.trim_end_matches('/').to_string() } else { // If the server returned a loopback gateway address, override it with the // cluster endpoint's host. This handles the case where the server defaults @@ -110,10 +109,7 @@ async fn ssh_session_config( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, server); - format!( - "{}://{}:{}{}", - session.gateway_scheme, gateway_host, gateway_port, session.connect_path - ) + format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port) }; let gateway_name = tls .gateway_name() @@ -821,18 +817,82 @@ pub async fn sandbox_ssh_proxy( token: &str, tls: &TlsOptions, ) -> Result<()> { - // The gateway returns 412 (Precondition Failed) when the sandbox pod - // exists but hasn't reached Ready phase yet. This is a transient state - // after sandbox allocation — retry with backoff instead of failing - // immediately. - const MAX_CONNECT_WAIT: Duration = Duration::from_secs(60); - const INITIAL_BACKOFF: Duration = Duration::from_secs(1); + let server = grpc_server_from_ssh_gateway_url(gateway_url)?; + let mut client = grpc_client(&server, tls).await?; + + let (tx, rx) = tokio::sync::mpsc::channel::(16); + tx.send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Init( + TcpForwardInit { + sandbox_id: sandbox_id.to_string(), + service_id: format!("ssh-proxy:{sandbox_id}"), + target: Some(tcp_forward_init::Target::Ssh(SshRelayTarget {})), + authorization_token: token.to_string(), + }, + )), + }) + .await + .map_err(|_| miette::miette!("failed to initialize SSH forward stream"))?; + + let mut response = client + .forward_tcp(ReceiverStream::new(rx)) + .await + .into_diagnostic()? + .into_inner(); + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + + let to_remote = tokio::spawn(async move { + let mut stdin = stdin; + let mut buf = vec![0u8; 64 * 1024]; + while let Ok(n) = stdin.read(&mut buf).await { + if n == 0 { + break; + } + if tx + .send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Data( + buf[..n].to_vec(), + )), + }) + .await + .is_err() + { + break; + } + } + }); + let from_remote = tokio::spawn(async move { + let mut stdout = stdout; + loop { + let Ok(Some(frame)) = response.message().await else { + break; + }; + let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = frame.payload + else { + continue; + }; + if data.is_empty() { + continue; + } + if stdout.write_all(&data).await.is_err() { + break; + } + let _ = stdout.flush().await; + } + }); + let _ = from_remote.await; + to_remote.abort(); + + Ok(()) +} + +fn grpc_server_from_ssh_gateway_url(gateway_url: &str) -> Result { let url: url::Url = gateway_url .parse() .into_diagnostic() .wrap_err("invalid gateway URL")?; - let scheme = url.scheme(); let gateway_host = url .host_str() @@ -840,69 +900,7 @@ pub async fn sandbox_ssh_proxy( let gateway_port = url .port_or_known_default() .ok_or_else(|| miette::miette!("gateway URL missing port"))?; - let connect_path = url.path(); - - let request = format!( - "CONNECT {connect_path} HTTP/1.1\r\nHost: {gateway_host}\r\nX-Sandbox-Id: {sandbox_id}\r\nX-Sandbox-Token: {token}\r\n\r\n" - ); - - let start = std::time::Instant::now(); - let mut backoff = INITIAL_BACKOFF; - let mut buf_stream; - - loop { - let mut stream: Box = - connect_gateway(scheme, gateway_host, gateway_port, tls).await?; - stream - .write_all(request.as_bytes()) - .await - .into_diagnostic()?; - - // Wrap in a BufReader **before** reading the HTTP response. The gateway - // may send the 200 OK response and the first SSH protocol bytes in the - // same TCP segment / WebSocket frame. A plain `read()` would consume - // those SSH bytes into our buffer and discard them, causing SSH to see a - // truncated protocol banner and exit with code 255. BufReader ensures - // any bytes read past the `\r\n\r\n` header boundary stay buffered and - // are returned by subsequent reads during the bidirectional copy phase. - buf_stream = BufReader::new(stream); - let status = read_connect_status(&mut buf_stream).await?; - if status == 200 { - break; - } - if status == 412 && start.elapsed() < MAX_CONNECT_WAIT { - tracing::debug!( - elapsed = ?start.elapsed(), - "sandbox not yet ready (HTTP 412), retrying in {backoff:?}" - ); - tokio::time::sleep(backoff).await; - backoff = (backoff * 2).min(Duration::from_secs(8)); - continue; - } - return Err(miette::miette!( - "gateway CONNECT failed with status {status}" - )); - } - - let (reader, writer) = tokio::io::split(buf_stream); - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); - - // Spawn both copy directions as independent tasks. Using separate spawned - // tasks (instead of try_join!/select!) ensures that when one direction - // completes or errors, the other continues independently until it also - // finishes. This is critical: when the remote side closes the connection, - // we must keep the stdin→gateway copy alive so SSH can finish sending its - // protocol-close packets, and vice-versa. - let to_remote = tokio::spawn(copy_ignoring_errors(stdin, writer)); - let from_remote = tokio::spawn(copy_ignoring_errors(reader, stdout)); - let _ = from_remote.await; - // Once the remote→stdout direction is done, SSH has received all the data - // it needs. Drop the stdin→gateway task – SSH will close its pipe when - // it's done regardless. - to_remote.abort(); - - Ok(()) + Ok(format_gateway_url(scheme, gateway_host, gateway_port)) } /// Run the SSH proxy in "name mode": create a session on the fly, then proxy. @@ -1122,93 +1120,6 @@ pub fn print_ssh_config(gateway: &str, name: &str) { print!("{}", render_ssh_config(gateway, name)); } -/// Copy all bytes from `reader` to `writer`, flushing on completion. -/// Errors are intentionally discarded – connection teardown errors are -/// expected during normal SSH session shutdown. -async fn copy_ignoring_errors(mut reader: R, mut writer: W) -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - let _ = tokio::io::copy(&mut reader, &mut writer).await; - let _ = AsyncWriteExt::flush(&mut writer).await; - let _ = AsyncWriteExt::shutdown(&mut writer).await; -} - -async fn connect_gateway( - scheme: &str, - host: &str, - port: u16, - tls: &TlsOptions, -) -> Result> { - // When using Cloudflare edge bearer auth, route through the WebSocket - // tunnel proxy regardless of the origin scheme. The proxy handles edge - // auth headers and TLS termination at the edge; the origin may be - // plaintext HTTP behind the tunnel. OIDC tokens bypass the tunnel. - if let Some(token) = tls.edge_token.as_deref() { - let gateway_url = format!("https://{host}:{port}"); - let proxy = crate::edge_tunnel::start_tunnel_proxy(&gateway_url, token).await?; - let tcp = TcpStream::connect(proxy.local_addr) - .await - .into_diagnostic()?; - tcp.set_nodelay(true).into_diagnostic()?; - return Ok(Box::new(tcp)); - } - - let tcp = TcpStream::connect((host, port)).await.into_diagnostic()?; - tcp.set_nodelay(true).into_diagnostic()?; - if scheme.eq_ignore_ascii_case("https") { - let materials = require_tls_materials(&format!("https://{host}:{port}"), tls)?; - let config = build_rustls_config(&materials)?; - let connector = TlsConnector::from(Arc::new(config)); - let server_name = ServerName::try_from(host.to_string()) - .map_err(|_| miette::miette!("invalid server name: {host}"))?; - let tls = connector - .connect(server_name, tcp) - .await - .into_diagnostic()?; - Ok(Box::new(tls)) - } else { - Ok(Box::new(tcp)) - } -} - -/// Read exactly the HTTP response status line and headers up to `\r\n\r\n`. -/// -/// Uses byte-at-a-time reads so that the caller's `BufReader` retains any -/// bytes that arrived after the header boundary (e.g. the SSH protocol -/// banner that the gateway may send in the same TCP segment). -async fn read_connect_status(stream: &mut R) -> Result { - let mut buf = Vec::new(); - let mut byte = [0u8; 1]; - loop { - let n = stream.read(&mut byte).await.into_diagnostic()?; - if n == 0 { - break; - } - buf.push(byte[0]); - if buf.len() >= 4 && &buf[buf.len() - 4..] == b"\r\n\r\n" { - break; - } - if buf.len() > 8192 { - break; - } - } - let text = String::from_utf8_lossy(&buf); - let line = text.lines().next().unwrap_or(""); - let status = line - .split_whitespace() - .nth(1) - .unwrap_or("0") - .parse::() - .unwrap_or(0); - Ok(status) -} - -trait ProxyStream: AsyncRead + AsyncWrite + Unpin + Send {} - -impl ProxyStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-cli/src/tls.rs b/crates/openshell-cli/src/tls.rs index c733d3db3..9c5de1773 100644 --- a/crates/openshell-cli/src/tls.rs +++ b/crates/openshell-cli/src/tls.rs @@ -342,6 +342,7 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { let endpoint = Endpoint::from_shared(server.to_string()) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) .keep_alive_while_idle(true); return endpoint.connect().await.into_diagnostic(); @@ -362,6 +363,7 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { let endpoint = Endpoint::from_shared(local_url) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) .keep_alive_while_idle(true); return endpoint.connect().await.into_diagnostic(); @@ -389,6 +391,7 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { let mut endpoint = Endpoint::from_shared(server.to_string()) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) .keep_alive_while_idle(true); diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index fec161c53..f1a11e661 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -516,6 +516,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // ── TLS helpers ────────────────────────────────────────────────────── diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index e833e7af9..c95e2cf98 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -407,6 +407,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } fn build_ca() -> (Certificate, KeyPair) { diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 3902bda34..fbe824cbf 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -625,6 +625,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } fn install_rustls_provider() { diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index eb28a18b3..a2fedab82 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -222,7 +222,6 @@ impl OpenShell for TestOpenShell { gateway_scheme: "https".to_string(), gateway_host: "localhost".to_string(), gateway_port: 443, - connect_path: "/connect/ssh".to_string(), ..CreateSshSessionResponse::default() })) } @@ -491,6 +490,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } fn install_rustls_provider() { @@ -782,10 +792,9 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { let _env = test_env(&fake_ssh_dir, &xdg_dir); let tls = test_tls(&server); install_fake_ssh(&fake_ssh_dir); - let forward_port = { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - listener.local_addr().unwrap().port() - }; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let forward_port = listener.local_addr().unwrap().port(); + drop(listener); run::sandbox_create( &server.endpoint, diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index 7e6ea68b8..94d5b3cfa 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -428,6 +428,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // ── helpers ─────────────────────────────────────────────────────────── diff --git a/crates/openshell-core/src/config.rs b/crates/openshell-core/src/config.rs index a2a973011..4922f5355 100644 --- a/crates/openshell-core/src/config.rs +++ b/crates/openshell-core/src/config.rs @@ -253,10 +253,6 @@ pub struct Config { #[serde(default = "default_ssh_gateway_port")] pub ssh_gateway_port: u16, - /// Path for SSH CONNECT/upgrade requests. - #[serde(default = "default_ssh_connect_path")] - pub ssh_connect_path: String, - /// SSH listen port inside sandbox containers that expose a TCP endpoint. #[serde(default = "default_sandbox_ssh_port")] pub sandbox_ssh_port: u16, @@ -410,7 +406,6 @@ impl Config { grpc_endpoint: String::new(), ssh_gateway_host: default_ssh_gateway_host(), ssh_gateway_port: default_ssh_gateway_port(), - ssh_connect_path: default_ssh_connect_path(), sandbox_ssh_port: default_sandbox_ssh_port(), sandbox_ssh_socket_path: default_sandbox_ssh_socket_path(), ssh_handshake_secret: String::new(), @@ -520,13 +515,6 @@ impl Config { self } - /// Create a new configuration with the SSH connect path. - #[must_use] - pub fn with_ssh_connect_path(mut self, path: impl Into) -> Self { - self.ssh_connect_path = path.into(); - self - } - /// Create a new configuration with the sandbox SSH port. #[must_use] pub const fn with_sandbox_ssh_port(mut self, port: u16) -> Self { @@ -601,10 +589,6 @@ const fn default_ssh_gateway_port() -> u16 { DEFAULT_SERVER_PORT } -fn default_ssh_connect_path() -> String { - "/connect/ssh".to_string() -} - fn default_sandbox_ssh_socket_path() -> String { "/run/openshell/ssh.sock".to_string() } diff --git a/crates/openshell-core/src/forward.rs b/crates/openshell-core/src/forward.rs index b48e5594a..82fe0114c 100644 --- a/crates/openshell-core/src/forward.rs +++ b/crates/openshell-core/src/forward.rs @@ -469,6 +469,20 @@ pub fn resolve_ssh_gateway( (gateway_host.to_string(), gateway_port) } +/// Format a gateway URL, bracketing IPv6 literals when needed. +pub fn format_gateway_url(scheme: &str, host: &str, port: u16) -> String { + let host = if host + .parse::() + .is_ok_and(|ip| ip.is_ipv6()) + && !host.starts_with('[') + { + format!("[{host}]") + } else { + host.to_string() + }; + format!("{scheme}://{host}:{port}") +} + /// Shell-escape a value for use inside a `ProxyCommand` string. pub fn shell_escape(value: &str) -> String { if value.is_empty() { @@ -525,14 +539,11 @@ pub enum SshSessionResponseError { InvalidScheme, #[error("gateway_port must be in range 1..=65535")] InvalidPort, - #[error("connect_path must start with '/'")] - ConnectPathNotAbsolute, } const MAX_SANDBOX_ID_LEN: usize = 128; const MAX_TOKEN_LEN: usize = 4096; const MAX_GATEWAY_HOST_LEN: usize = 253; -const MAX_CONNECT_PATH_LEN: usize = 2048; const MAX_FINGERPRINT_LEN: usize = 256; fn is_sandbox_id_byte(b: u8) -> bool { @@ -551,33 +562,6 @@ fn is_gateway_host_byte(b: u8) -> bool { b.is_ascii_alphanumeric() || matches!(b, b'.' | b'-' | b':' | b'[' | b']') } -fn is_connect_path_byte(b: u8) -> bool { - // RFC 3986 path charset (pchar) without `?`, `#`, space, backtick, or - // backslash. `%` is permitted so percent-encoded segments round-trip. - b.is_ascii_alphanumeric() - || matches!( - b, - b'-' | b'.' - | b'_' - | b'~' - | b'!' - | b'$' - | b'&' - | b'\'' - | b'(' - | b')' - | b'*' - | b'+' - | b',' - | b';' - | b'=' - | b':' - | b'@' - | b'/' - | b'%' - ) -} - fn is_fingerprint_byte(b: u8) -> bool { b.is_ascii_alphanumeric() || matches!(b, b':' | b'+' | b'/' | b'=' | b'-') } @@ -612,25 +596,6 @@ pub fn validate_ssh_session_response( if resp.gateway_port == 0 || resp.gateway_port > u32::from(u16::MAX) { return Err(SshSessionResponseError::InvalidPort); } - if resp.connect_path.is_empty() { - return Err(SshSessionResponseError::Empty { - field: "connect_path", - }); - } - if !resp.connect_path.starts_with('/') { - return Err(SshSessionResponseError::ConnectPathNotAbsolute); - } - if resp.connect_path.len() > MAX_CONNECT_PATH_LEN { - return Err(SshSessionResponseError::TooLong { - field: "connect_path", - max: MAX_CONNECT_PATH_LEN, - }); - } - if !resp.connect_path.bytes().all(is_connect_path_byte) { - return Err(SshSessionResponseError::InvalidChars { - field: "connect_path", - }); - } if !resp.host_key_fingerprint.is_empty() { if resp.host_key_fingerprint.len() > MAX_FINGERPRINT_LEN { return Err(SshSessionResponseError::TooLong { @@ -735,6 +700,26 @@ mod tests { assert_eq!(port, 8080); } + #[test] + fn format_gateway_url_brackets_ipv6_literals() { + assert_eq!( + format_gateway_url("https", "::1", 8080), + "https://[::1]:8080" + ); + } + + #[test] + fn format_gateway_url_leaves_dns_and_bracketed_ipv6_unchanged() { + assert_eq!( + format_gateway_url("https", "gateway.example.com", 443), + "https://gateway.example.com:443" + ); + assert_eq!( + format_gateway_url("https", "[::1]", 8080), + "https://[::1]:8080" + ); + } + #[test] fn shell_escape_empty() { assert_eq!(shell_escape(""), "''"); @@ -757,7 +742,6 @@ mod tests { gateway_scheme: "https".to_string(), gateway_host: "gateway.example.com".to_string(), gateway_port: 443, - connect_path: "/connect/ssh".to_string(), host_key_fingerprint: String::new(), expires_at_ms: 0, } @@ -857,33 +841,6 @@ mod tests { } } - #[test] - fn validate_ssh_session_response_rejects_connect_path_without_leading_slash() { - let mut r = valid_session_response(); - r.connect_path = "connect/ssh".to_string(); - assert!(matches!( - validate_ssh_session_response(&r), - Err(SshSessionResponseError::ConnectPathNotAbsolute) - )); - } - - #[test] - fn validate_ssh_session_response_rejects_injected_connect_path() { - // `$`, `(`, `)` are valid RFC 3986 sub-delims (pchar) so the validator - // permits them; shell_escape is the second defensive layer. The - // following characters are rejected at the validator boundary because - // they are either unambiguously hostile in a shell context or invalid - // per RFC 3986 in the path component. - for bad in ["/x`id`y", "/x y", "/x\nb", "/x\\b", "/x?q=1", "/x#frag"] { - let mut r = valid_session_response(); - r.connect_path = bad.to_string(); - assert!( - validate_ssh_session_response(&r).is_err(), - "expected reject for connect_path={bad:?}" - ); - } - } - #[test] fn build_proxy_command_escapes_shell_metacharacters() { // Attacker-controlled values in every escapable position. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index a6107f907..cde1f4b22 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -2308,10 +2308,18 @@ mod tests { #[test] fn log_level_propagates_as_env_var_to_sandbox_pod() { - let spec = SandboxSpec { log_level: "debug".to_string(), ..SandboxSpec::default() }; + let spec = SandboxSpec { + log_level: "debug".to_string(), + ..SandboxSpec::default() + }; let cr = sandbox_to_k8s_spec(Some(&spec), &SandboxPodParams::default()); - let env = cr["spec"]["podTemplate"]["spec"]["containers"][0]["env"].as_array().unwrap(); - assert!(env.iter().any(|e| e["name"] == "OPENSHELL_LOG_LEVEL" && e["value"] == "debug")); + let env = cr["spec"]["podTemplate"]["spec"]["containers"][0]["env"] + .as_array() + .unwrap(); + assert!( + env.iter() + .any(|e| e["name"] == "OPENSHELL_LOG_LEVEL" && e["value"] == "debug") + ); assert!(cr["spec"].get("logLevel").is_none()); } } diff --git a/crates/openshell-ocsf/src/format/shorthand.rs b/crates/openshell-ocsf/src/format/shorthand.rs index 42b30fbae..08b413429 100644 --- a/crates/openshell-ocsf/src/format/shorthand.rs +++ b/crates/openshell-ocsf/src/format/shorthand.rs @@ -61,6 +61,7 @@ pub fn severity_tag(severity_id: u8) -> &'static str { /// Max length for the reason text in `[reason:...]` before truncation. const MAX_REASON_LEN: usize = 80; +const MAX_MESSAGE_LEN: usize = 120; /// Format a `[reason:...]` tag from `status_detail` (or `message` fallback) /// for denied events. Returns an empty string if neither field is set. @@ -80,6 +81,19 @@ fn reason_tag(base: &BaseEventData) -> String { } } +fn message_tag(base: &BaseEventData) -> String { + let text = base.message.as_deref().unwrap_or(""); + if text.is_empty() { + return String::new(); + } + let text = text.replace(['\n', '\r'], " "); + if text.len() > MAX_MESSAGE_LEN { + format!(" [msg:{}...]", &text[..MAX_MESSAGE_LEN]) + } else { + format!(" [msg:{text}]") + } +} + impl OcsfEvent { /// Produce the single-line shorthand for `openshell.log` and gRPC log push. /// @@ -140,7 +154,13 @@ impl OcsfEvent { (false, true) => format!(" {action}"), (false, false) => format!(" {action}{arrow}"), }; - format!("NET:{activity} {sev}{detail}{rule_ctx}{reason_ctx}") + let message_ctx = + if detail.is_empty() && rule_ctx.is_empty() && reason_ctx.is_empty() { + message_tag(&e.base) + } else { + String::new() + }; + format!("NET:{activity} {sev}{detail}{rule_ctx}{reason_ctx}{message_ctx}") } Self::HttpActivity(e) => { @@ -541,6 +561,33 @@ mod tests { ); } + #[test] + fn test_network_activity_shorthand_shows_message_when_no_key_fields() { + let event = OcsfEvent::NetworkActivity(NetworkActivityEvent { + base: { + let mut b = base(4001, "Network Activity", 4, "Network Activity", 1, "Open"); + b.set_message("relay open (channel_id=ch-42)"); + b + }, + src_endpoint: None, + dst_endpoint: None, + proxy_endpoint: None, + actor: None, + firewall_rule: None, + connection_info: None, + action: None, + disposition: None, + observation_point_id: None, + is_src_dst_assignment_known: None, + }); + + let shorthand = event.format_shorthand(); + assert_eq!( + shorthand, + "NET:OPEN [INFO] [msg:relay open (channel_id=ch-42)]" + ); + } + #[test] fn test_http_activity_shorthand_denied_shows_reason() { let mut b = base(4002, "HTTP Activity", 4, "Network Activity", 99, "Other"); diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 25a28af54..81c575cfa 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -817,7 +817,7 @@ pub async fn run_sandbox( sandbox_id.as_ref(), ssh_socket_path.as_ref(), ) { - supervisor_session::spawn(endpoint.clone(), id.clone(), socket.clone()); + supervisor_session::spawn(endpoint.clone(), id.clone(), socket.clone(), ssh_netns_fd); info!("supervisor session task spawned"); } diff --git a/crates/openshell-sandbox/src/supervisor_session.rs b/crates/openshell-sandbox/src/supervisor_session.rs index 490a0cba7..6485dddf0 100644 --- a/crates/openshell-sandbox/src/supervisor_session.rs +++ b/crates/openshell-sandbox/src/supervisor_session.rs @@ -4,24 +4,28 @@ //! Persistent supervisor-to-gateway session. //! //! Maintains a long-lived `ConnectSupervisor` bidirectional gRPC stream to the -//! gateway. When the gateway sends `RelayOpen`, the supervisor initiates a -//! `RelayStream` gRPC call (a new HTTP/2 stream multiplexed over the same -//! TCP+TLS connection as the control stream) and bridges it to the local SSH -//! daemon. The supervisor is a dumb byte bridge — it has no protocol awareness -//! of the SSH or NSSH1 bytes flowing through. - +//! gateway. When the gateway sends `RelayOpen`, the supervisor dials the +//! requested local target, initiates a `RelayStream` gRPC call (a new HTTP/2 +//! stream multiplexed over the same TCP+TLS connection as the control stream), +//! and bridges bytes. The supervisor is a dumb byte bridge after target +//! selection — it has no protocol awareness of the bytes flowing through. + +use std::net::IpAddr; +#[cfg(target_os = "linux")] +use std::os::fd::RawFd; use std::time::Duration; use openshell_core::proto::open_shell_client::OpenShellClient; use openshell_core::proto::{ - GatewayMessage, RelayFrame, RelayInit, SupervisorHeartbeat, SupervisorHello, SupervisorMessage, - gateway_message, supervisor_message, + GatewayMessage, RelayFrame, RelayInit, RelayOpen, RelayOpenResult, SupervisorHeartbeat, + SupervisorHello, SupervisorMessage, TcpRelayTarget, gateway_message, relay_open, + supervisor_message, }; use openshell_ocsf::{ - ActivityId, Endpoint, NetworkActivityBuilder, OcsfEvent, SandboxContext, SeverityId, StatusId, - ocsf_emit, + ActivityId, ConnectionInfo, Endpoint, NetworkActivityBuilder, OcsfEvent, SandboxContext, + SeverityId, StatusId, ocsf_emit, }; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc; use tokio_stream::StreamExt; use tonic::transport::Channel; @@ -91,33 +95,102 @@ fn session_failed_event( .build() } -fn relay_open_event(ctx: &SandboxContext, channel_id: &str) -> OcsfEvent { - NetworkActivityBuilder::new(ctx) +fn relay_target_endpoint(open: &RelayOpen) -> Option { + let relay_open::Target::Tcp(target) = open.target.as_ref()? else { + return None; + }; + let host = target.host.trim(); + let port = u16::try_from(target.port).ok()?; + host.parse().map_or_else( + |_| Some(Endpoint::from_domain(host, port)), + |ip| Some(Endpoint::from_ip(ip, port)), + ) +} + +fn relay_target_kind(open: &RelayOpen) -> &'static str { + match open.target.as_ref() { + Some(relay_open::Target::Tcp(_)) => "tcp relay", + Some(relay_open::Target::Ssh(_)) | None => "ssh relay", + } +} + +fn relay_target_message( + open: &RelayOpen, + state: &str, + ssh_socket_path: &std::path::Path, +) -> String { + let target = match open.target.as_ref() { + Some(relay_open::Target::Tcp(target)) => { + format!("{}:{}", target.host.trim(), target.port) + } + Some(relay_open::Target::Ssh(_)) | None => { + format!("unix:{}", ssh_socket_path.display()) + } + }; + + format!( + "{} {state} (channel_id={}, target={target})", + relay_target_kind(open), + open.channel_id + ) +} + +fn relay_open_event( + ctx: &SandboxContext, + open: &RelayOpen, + ssh_socket_path: &std::path::Path, +) -> OcsfEvent { + let mut builder = NetworkActivityBuilder::new(ctx) .activity(ActivityId::Open) .severity(SeverityId::Informational) .status(StatusId::Success) - .message(format!("relay open (channel_id={channel_id})")) - .build() + .message(relay_target_message(open, "open", ssh_socket_path)); + if let Some(endpoint) = relay_target_endpoint(open) { + builder = builder + .dst_endpoint(endpoint) + .connection_info(ConnectionInfo::new("tcp")); + } + builder.build() } -fn relay_closed_event(ctx: &SandboxContext, channel_id: &str) -> OcsfEvent { - NetworkActivityBuilder::new(ctx) +fn relay_closed_event( + ctx: &SandboxContext, + open: &RelayOpen, + ssh_socket_path: &std::path::Path, +) -> OcsfEvent { + let mut builder = NetworkActivityBuilder::new(ctx) .activity(ActivityId::Close) .severity(SeverityId::Informational) .status(StatusId::Success) - .message(format!("relay closed (channel_id={channel_id})")) - .build() + .message(relay_target_message(open, "closed", ssh_socket_path)); + if let Some(endpoint) = relay_target_endpoint(open) { + builder = builder + .dst_endpoint(endpoint) + .connection_info(ConnectionInfo::new("tcp")); + } + builder.build() } -fn relay_failed_event(ctx: &SandboxContext, channel_id: &str, error: &str) -> OcsfEvent { - NetworkActivityBuilder::new(ctx) +fn relay_failed_event( + ctx: &SandboxContext, + open: &RelayOpen, + ssh_socket_path: &std::path::Path, + error: &str, +) -> OcsfEvent { + let mut builder = NetworkActivityBuilder::new(ctx) .activity(ActivityId::Fail) .severity(SeverityId::Low) .status(StatusId::Failure) .message(format!( - "relay bridge failed (channel_id={channel_id}): {error}" - )) - .build() + "{}: {error}", + relay_target_message(open, "bridge failed", ssh_socket_path) + )); + if let Some(endpoint) = relay_target_endpoint(open) { + builder = builder + .dst_endpoint(endpoint) + .connection_info(ConnectionInfo::new("tcp")); + } + builder.build() } fn relay_close_from_gateway_event( @@ -139,6 +212,10 @@ fn relay_close_from_gateway_event( /// HTTP/2 frame size so each `RelayFrame::data` fits in one frame. const RELAY_CHUNK_SIZE: usize = 16 * 1024; +trait TargetStream: AsyncRead + AsyncWrite + Send + Unpin {} + +impl TargetStream for T where T: AsyncRead + AsyncWrite + Send + Unpin {} + fn map_stream_message( message: Result, tonic::Status>, eof_error: &'static str, @@ -158,14 +235,21 @@ pub fn spawn( endpoint: String, sandbox_id: String, ssh_socket_path: std::path::PathBuf, + netns_fd: Option, ) -> tokio::task::JoinHandle<()> { - tokio::spawn(run_session_loop(endpoint, sandbox_id, ssh_socket_path)) + tokio::spawn(run_session_loop( + endpoint, + sandbox_id, + ssh_socket_path, + netns_fd, + )) } async fn run_session_loop( endpoint: String, sandbox_id: String, ssh_socket_path: std::path::PathBuf, + netns_fd: Option, ) { let mut backoff = INITIAL_BACKOFF; let mut attempt: u64 = 0; @@ -173,7 +257,7 @@ async fn run_session_loop( loop { attempt += 1; - match run_single_session(&endpoint, &sandbox_id, &ssh_socket_path).await { + match run_single_session(&endpoint, &sandbox_id, &ssh_socket_path, netns_fd).await { Ok(()) => { let event = session_closed_event(crate::ocsf_ctx(), &endpoint, &sandbox_id); ocsf_emit!(event); @@ -194,6 +278,7 @@ async fn run_single_session( endpoint: &str, sandbox_id: &str, ssh_socket_path: &std::path::Path, + netns_fd: Option, ) -> Result<(), Box> { // Connect to the gateway. The same `Channel` is used for both the // long-lived control stream and all data-plane `RelayStream` calls, so @@ -262,7 +347,9 @@ async fn run_single_session( &msg, sandbox_id, ssh_socket_path, + netns_fd, &channel, + &tx, ); } _ = heartbeat_interval.tick() => { @@ -283,7 +370,9 @@ fn handle_gateway_message( msg: &GatewayMessage, sandbox_id: &str, ssh_socket_path: &std::path::Path, + netns_fd: Option, channel: &Channel, + tx: &mpsc::Sender, ) { match &msg.payload { Some(gateway_message::Payload::Heartbeat(_)) => { @@ -291,22 +380,30 @@ fn handle_gateway_message( } Some(gateway_message::Payload::RelayOpen(open)) => { let channel_id = open.channel_id.clone(); + let relay_open = open.clone(); let sandbox_id = sandbox_id.to_string(); let channel = channel.clone(); let ssh_socket_path = ssh_socket_path.to_path_buf(); + let tx = tx.clone(); - let event = relay_open_event(crate::ocsf_ctx(), &channel_id); + let event = relay_open_event(crate::ocsf_ctx(), &relay_open, &ssh_socket_path); ocsf_emit!(event); tokio::spawn(async move { - match handle_relay_open(&channel_id, &ssh_socket_path, channel).await { + let event_open = relay_open.clone(); + match handle_relay_open(relay_open, &ssh_socket_path, netns_fd, channel, tx).await { Ok(()) => { - let event = relay_closed_event(crate::ocsf_ctx(), &channel_id); + let event = + relay_closed_event(crate::ocsf_ctx(), &event_open, &ssh_socket_path); ocsf_emit!(event); } Err(e) => { - let event = - relay_failed_event(crate::ocsf_ctx(), &channel_id, &e.to_string()); + let event = relay_failed_event( + crate::ocsf_ctx(), + &event_open, + &ssh_socket_path, + &e.to_string(), + ); ocsf_emit!(event); warn!( sandbox_id = %sandbox_id, @@ -336,10 +433,23 @@ fn handle_gateway_message( /// TLS handshake. The first `RelayFrame` we send is a `RelayInit`; subsequent /// frames carry raw SSH bytes in `data`. async fn handle_relay_open( - channel_id: &str, + relay_open: RelayOpen, ssh_socket_path: &std::path::Path, + netns_fd: Option, channel: Channel, + tx: mpsc::Sender, ) -> Result<(), Box> { + let channel_id = relay_open.channel_id.clone(); + let target = match open_target(&relay_open, ssh_socket_path, netns_fd).await { + Ok(target) => target, + Err(err) => { + send_relay_open_result(&tx, &channel_id, false, err.to_string()).await; + return Err(err); + } + }; + + send_relay_open_result(&tx, &channel_id, true, String::new()).await; + let mut client = OpenShellClient::new(channel); // Outbound chunks to the gateway. @@ -351,7 +461,7 @@ async fn handle_relay_open( .send(RelayFrame { payload: Some(openshell_core::proto::relay_frame::Payload::Init( RelayInit { - channel_id: channel_id.to_string(), + channel_id: channel_id.clone(), }, )), }) @@ -366,21 +476,19 @@ async fn handle_relay_open( let mut inbound = response.into_inner(); // Connect to the local SSH daemon on its Unix socket. - let ssh = tokio::net::UnixStream::connect(ssh_socket_path).await?; - let (mut ssh_r, mut ssh_w) = ssh.into_split(); + let (mut target_r, mut target_w) = tokio::io::split(target); debug!( channel_id = %channel_id, - socket = %ssh_socket_path.display(), - "relay bridge: connected to local SSH daemon" + "relay bridge: connected to local target" ); - // SSH → gRPC (out_tx): read local SSH, forward as `RelayFrame::data`. + // Target → gRPC (out_tx): read local target, forward as `RelayFrame::data`. let out_tx_writer = out_tx.clone(); - let ssh_to_grpc = tokio::spawn(async move { + let target_to_grpc = tokio::spawn(async move { let mut buf = vec![0u8; RELAY_CHUNK_SIZE]; loop { - match ssh_r.read(&mut buf).await { + match target_r.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => { let chunk = RelayFrame { @@ -396,7 +504,7 @@ async fn handle_relay_open( } }); - // gRPC (inbound) → SSH: drain inbound chunks into the local SSH socket. + // gRPC (inbound) → target: drain inbound chunks into the local target socket. let mut inbound_err: Option = None; while let Some(next) = inbound.next().await { match next { @@ -409,8 +517,8 @@ async fn handle_relay_open( if data.is_empty() { continue; } - if let Err(e) = ssh_w.write_all(&data).await { - inbound_err = Some(format!("write to ssh failed: {e}")); + if let Err(e) = target_w.write_all(&data).await { + inbound_err = Some(format!("write to target failed: {e}")); break; } } @@ -421,13 +529,13 @@ async fn handle_relay_open( } } - // Half-close the SSH socket's write side so the daemon sees EOF. - let _ = ssh_w.shutdown().await; + // Half-close the target socket's write side so the service sees EOF. + let _ = target_w.shutdown().await; // Dropping out_tx closes the outbound gRPC stream, letting the gateway // observe EOF on its side too. drop(out_tx); - let _ = ssh_to_grpc.await; + let _ = target_to_grpc.await; if let Some(e) = inbound_err { return Err(e.into()); @@ -435,6 +543,165 @@ async fn handle_relay_open( Ok(()) } +async fn send_relay_open_result( + tx: &mpsc::Sender, + channel_id: &str, + success: bool, + error: String, +) { + let _ = tx + .send(SupervisorMessage { + payload: Some(supervisor_message::Payload::RelayOpenResult( + RelayOpenResult { + channel_id: channel_id.to_string(), + success, + error, + }, + )), + }) + .await; +} + +async fn open_target( + relay_open: &RelayOpen, + ssh_socket_path: &std::path::Path, + netns_fd: Option, +) -> Result, Box> { + match relay_open.target.as_ref() { + Some(relay_open::Target::Tcp(target)) => open_tcp_target(target, netns_fd).await, + Some(relay_open::Target::Ssh(_)) | None => { + let stream = tokio::net::UnixStream::connect(ssh_socket_path).await?; + Ok(Box::new(stream)) + } + } +} + +async fn open_tcp_target( + target: &TcpRelayTarget, + netns_fd: Option, +) -> Result, Box> { + let host = normalize_tcp_target_host(target)?; + let port = u16::try_from(target.port).map_err(|_| "tcp target port must fit in u16")?; + let stream = connect_tcp_target(host, port, netns_fd).await?; + Ok(Box::new(stream)) +} + +#[cfg(target_os = "linux")] +async fn connect_tcp_target( + host: String, + port: u16, + netns_fd: Option, +) -> Result> { + if let Some(fd) = netns_fd { + let (tx, rx) = tokio::sync::oneshot::channel(); + std::thread::spawn(move || { + let result = (|| -> std::io::Result { + #[allow(unsafe_code)] + let rc = unsafe { libc::setns(fd, libc::CLONE_NEWNET) }; + if rc != 0 { + return Err(std::io::Error::last_os_error()); + } + std::net::TcpStream::connect((host.as_str(), port)) + })(); + let _ = tx.send(result); + }); + + let stream = rx + .await + .map_err(|_| "netns tcp connect thread panicked")??; + stream.set_nonblocking(true)?; + return Ok(tokio::net::TcpStream::from_std(stream)?); + } + + Ok(tokio::net::TcpStream::connect((host.as_str(), port)).await?) +} + +#[cfg(not(target_os = "linux"))] +async fn connect_tcp_target( + host: String, + port: u16, + _netns_fd: Option, +) -> Result> { + Ok(tokio::net::TcpStream::connect((host.as_str(), port)).await?) +} + +#[cfg(test)] +fn validate_tcp_target(target: &TcpRelayTarget) -> Result<(), String> { + normalize_tcp_target_host(target).map(|_| ()) +} + +fn normalize_tcp_target_host(target: &TcpRelayTarget) -> Result { + if target.port == 0 || target.port > u32::from(u16::MAX) { + return Err("tcp target port must be between 1 and 65535".to_string()); + } + + let host = target.host.trim(); + if host.is_empty() { + return Err("tcp target host is required".to_string()); + } + if host.eq_ignore_ascii_case("localhost") { + return Ok("127.0.0.1".to_string()); + } + + let ip: IpAddr = host + .parse() + .map_err(|_| "tcp target host must be loopback".to_string())?; + if ip.is_loopback() { + Ok(ip.to_string()) + } else { + Err("tcp target host must be loopback".to_string()) + } +} + +#[cfg(test)] +mod target_tests { + use super::*; + + fn tcp(host: &str, port: u32) -> TcpRelayTarget { + TcpRelayTarget { + host: host.to_string(), + port, + } + } + + #[test] + fn tcp_target_allows_loopback_hosts() { + validate_tcp_target(&tcp("127.0.0.1", 8080)).expect("ipv4 loopback"); + validate_tcp_target(&tcp("::1", 8080)).expect("ipv6 loopback"); + validate_tcp_target(&tcp("localhost", 8080)).expect("localhost"); + } + + #[test] + fn tcp_target_normalizes_localhost_before_dialing() { + assert_eq!( + normalize_tcp_target_host(&tcp("localhost", 8080)).expect("localhost"), + "127.0.0.1" + ); + assert_eq!( + normalize_tcp_target_host(&tcp("LOCALHOST", 8080)).expect("localhost"), + "127.0.0.1" + ); + } + + #[test] + fn tcp_target_rejects_non_loopback_hosts() { + let err = validate_tcp_target(&tcp("10.0.0.1", 8080)).expect_err("private ip rejected"); + assert_eq!(err, "tcp target host must be loopback"); + + let err = validate_tcp_target(&tcp("example.com", 8080)).expect_err("hostname rejected"); + assert_eq!(err, "tcp target host must be loopback"); + } + + #[test] + fn tcp_target_rejects_invalid_ports() { + let err = validate_tcp_target(&tcp("127.0.0.1", 0)).expect_err("zero rejected"); + assert_eq!(err, "tcp target port must be between 1 and 65535"); + + let err = validate_tcp_target(&tcp("127.0.0.1", 70000)).expect_err("too large rejected"); + assert_eq!(err, "tcp target port must be between 1 and 65535"); + } +} + #[cfg(test)] mod ocsf_event_tests { use super::*; @@ -479,6 +746,31 @@ mod ocsf_event_tests { } } + fn ssh_relay_open(channel_id: &str) -> RelayOpen { + RelayOpen { + channel_id: channel_id.to_string(), + target: Some(relay_open::Target::Ssh( + openshell_core::proto::SshRelayTarget::default(), + )), + service_id: String::new(), + } + } + + fn tcp_relay_open(channel_id: &str, host: &str, port: u32) -> RelayOpen { + RelayOpen { + channel_id: channel_id.to_string(), + target: Some(relay_open::Target::Tcp(TcpRelayTarget { + host: host.to_string(), + port, + })), + service_id: String::new(), + } + } + + fn ssh_socket_path() -> &'static std::path::Path { + std::path::Path::new("/run/openshell/ssh.sock") + } + #[test] fn session_established_emits_network_open_success() { let event = session_established_event(&ctx(), "https://gw:443", "sess-1", 30); @@ -518,22 +810,43 @@ mod ocsf_event_tests { #[test] fn relay_open_emits_network_open_success() { - let event = relay_open_event(&ctx(), "ch-42"); + let event = relay_open_event(&ctx(), &ssh_relay_open("ch-42"), ssh_socket_path()); let na = network_activity(&event); assert_eq!(na.base.activity_id, ActivityId::Open.as_u8()); assert_eq!(na.base.severity, SeverityId::Informational); + let msg = na.base.message.as_deref().unwrap_or_default(); + assert!(msg.contains("ch-42"), "message: {msg}"); assert!( - na.base - .message - .as_deref() - .unwrap_or_default() - .contains("ch-42") + msg.contains("target=unix:/run/openshell/ssh.sock"), + "message: {msg}" + ); + } + + #[test] + fn tcp_relay_open_emits_target_endpoint() { + let event = relay_open_event( + &ctx(), + &tcp_relay_open("ch-42", "127.0.0.1", 8765), + ssh_socket_path(), + ); + let na = network_activity(&event); + assert_eq!(na.base.activity_id, ActivityId::Open.as_u8()); + assert_eq!( + na.dst_endpoint.as_ref().and_then(|e| e.ip.as_deref()), + Some("127.0.0.1") + ); + assert_eq!(na.dst_endpoint.as_ref().and_then(|e| e.port), Some(8765)); + assert_eq!( + na.connection_info + .as_ref() + .map(|c| c.protocol_name.as_str()), + Some("tcp") ); } #[test] fn relay_closed_emits_network_close_success() { - let event = relay_closed_event(&ctx(), "ch-42"); + let event = relay_closed_event(&ctx(), &ssh_relay_open("ch-42"), ssh_socket_path()); let na = network_activity(&event); assert_eq!(na.base.activity_id, ActivityId::Close.as_u8()); assert_eq!(na.base.status, Some(StatusId::Success)); @@ -541,7 +854,12 @@ mod ocsf_event_tests { #[test] fn relay_failed_emits_network_fail_low() { - let event = relay_failed_event(&ctx(), "ch-42", "write to ssh failed"); + let event = relay_failed_event( + &ctx(), + &ssh_relay_open("ch-42"), + ssh_socket_path(), + "write to ssh failed", + ); let na = network_activity(&event); assert_eq!(na.base.activity_id, ActivityId::Fail.as_u8()); assert_eq!(na.base.severity, SeverityId::Low); diff --git a/crates/openshell-server/src/auth/authz.rs b/crates/openshell-server/src/auth/authz.rs index 7e69b1cd8..70d9d738c 100644 --- a/crates/openshell-server/src/auth/authz.rs +++ b/crates/openshell-server/src/auth/authz.rs @@ -59,6 +59,7 @@ const SCOPED_METHODS: &[(&str, &str)] = &[ ("/openshell.v1.OpenShell/CreateSandbox", "sandbox:write"), ("/openshell.v1.OpenShell/DeleteSandbox", "sandbox:write"), ("/openshell.v1.OpenShell/ExecSandbox", "sandbox:write"), + ("/openshell.v1.OpenShell/ForwardTcp", "sandbox:write"), ("/openshell.v1.OpenShell/CreateSshSession", "sandbox:write"), ("/openshell.v1.OpenShell/RevokeSshSession", "sandbox:write"), ( @@ -420,6 +421,11 @@ mod tests { .check(&id, "/openshell.v1.OpenShell/CreateSandbox") .is_ok() ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/ForwardTcp") + .is_ok() + ); assert!( policy .check(&id, "/openshell.v1.OpenShell/AttachSandboxProvider") diff --git a/crates/openshell-server/src/cli.rs b/crates/openshell-server/src/cli.rs index 534e3da37..a3098c1cf 100644 --- a/crates/openshell-server/src/cli.rs +++ b/crates/openshell-server/src/cli.rs @@ -127,14 +127,6 @@ struct RunArgs { #[arg(long, env = "OPENSHELL_SSH_GATEWAY_PORT", default_value_t = DEFAULT_SERVER_PORT)] ssh_gateway_port: u16, - /// HTTP path for SSH CONNECT/upgrade. - #[arg( - long, - env = "OPENSHELL_SSH_CONNECT_PATH", - default_value = "/connect/ssh" - )] - ssh_connect_path: String, - /// SSH port inside sandbox pods. #[arg(long, env = "OPENSHELL_SANDBOX_SSH_PORT", default_value_t = DEFAULT_SSH_PORT)] sandbox_ssh_port: u16, @@ -400,7 +392,6 @@ async fn run_from_args(args: RunArgs) -> Result<()> { .with_sandbox_namespace(args.sandbox_namespace) .with_ssh_gateway_host(args.ssh_gateway_host) .with_ssh_gateway_port(args.ssh_gateway_port) - .with_ssh_connect_path(args.ssh_connect_path) .with_sandbox_ssh_port(args.sandbox_ssh_port) .with_ssh_handshake_skew_secs(args.ssh_handshake_skew_secs); diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index ebb8b1021..16f016081 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -31,8 +31,9 @@ use openshell_core::proto::{ RejectDraftChunkRequest, RejectDraftChunkResponse, RelayFrame, ReportPolicyStatusRequest, ReportPolicyStatusResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, - SupervisorMessage, UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, - UpdateConfigResponse, UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, + SupervisorMessage, TcpForwardFrame, UndoDraftChunkRequest, UndoDraftChunkResponse, + UpdateConfigRequest, UpdateConfigResponse, UpdateProviderRequest, WatchSandboxRequest, + open_shell_server::OpenShell, }; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; @@ -240,6 +241,16 @@ impl OpenShell for OpenShellService { sandbox::handle_exec_sandbox(&self.state, request).await } + type ForwardTcpStream = + Pin> + Send + 'static>>; + + async fn forward_tcp( + &self, + request: Request>, + ) -> Result, Status> { + sandbox::handle_forward_tcp(&self.state, request).await + } + // --- SSH sessions --- async fn create_ssh_session( diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 65ac69acb..ad37a5482 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -12,6 +12,7 @@ use crate::ServerState; use crate::persistence::{ObjectType, generate_name}; use futures::future; +use openshell_core::ObjectId; use openshell_core::proto::{ AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteSandboxRequest, DeleteSandboxResponse, @@ -19,10 +20,13 @@ use openshell_core::proto::{ ExecSandboxRequest, ExecSandboxStderr, ExecSandboxStdout, GetSandboxRequest, ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, Provider, RevokeSshSessionRequest, RevokeSshSessionResponse, - SandboxResponse, SandboxStreamEvent, WatchSandboxRequest, + SandboxResponse, SandboxStreamEvent, SshRelayTarget, TcpForwardFrame, TcpForwardInit, + TcpRelayTarget, WatchSandboxRequest, relay_open, tcp_forward_init, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; use prost::Message; +use std::net::IpAddr; +use std::pin::Pin; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; @@ -40,6 +44,8 @@ use super::validation::{ }; use super::{MAX_PAGE_SIZE, MAX_PROVIDERS, clamp_limit, current_time_ms}; +const TCP_FORWARD_CHUNK_SIZE: usize = 64 * 1024; + // --------------------------------------------------------------------------- // Sandbox lifecycle handlers // --------------------------------------------------------------------------- @@ -646,9 +652,8 @@ pub(super) async fn handle_exec_sandbox( } // Open a relay channel through the supervisor session. Use a 15s - // session-wait timeout — enough to cover a transient supervisor - // reconnect, but shorter than `/connect/ssh` since `ExecSandbox` is - // typically called during normal operation (not right after create). + // session-wait timeout, enough to cover a transient supervisor reconnect + // while still failing quickly during normal operation. let (channel_id, relay_rx) = state .supervisor_sessions .open_relay(sandbox.object_id(), std::time::Duration::from_secs(15)) @@ -669,7 +674,12 @@ pub(super) async fn handle_exec_sandbox( let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) .await { - Ok(Ok(stream)) => stream, + Ok(Ok(Ok(stream))) => stream, + Ok(Ok(Err(status))) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ExecSandbox: relay target open failed"); + let _ = tx.send(Err(status)).await; + return; + } Ok(Err(_)) => { warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay channel dropped"); let _ = tx @@ -706,6 +716,328 @@ pub(super) async fn handle_exec_sandbox( Ok(Response::new(ReceiverStream::new(rx))) } +pub(super) async fn handle_forward_tcp( + state: &Arc, + request: Request>, +) -> Result< + Response< + Pin> + Send + 'static>>, + >, + Status, +> { + let mut inbound = request.into_inner(); + let first = inbound + .message() + .await? + .ok_or_else(|| Status::invalid_argument("empty ForwardTcp stream"))?; + let Some(openshell_core::proto::tcp_forward_frame::Payload::Init(init)) = first.payload else { + return Err(Status::invalid_argument( + "first TcpForwardFrame must be init", + )); + }; + + let target = validate_tcp_forward_init(&init)?; + + let sandbox = state + .store + .get_message::(&init.sandbox_id) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + + if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { + return Err(Status::failed_precondition("sandbox is not ready")); + } + + let connection_guard = acquire_forward_connection_guard(state, &init, &sandbox).await?; + let (channel_id, relay_rx) = state + .supervisor_sessions + .open_relay_with_target( + sandbox.object_id(), + target, + init.service_id.clone(), + std::time::Duration::from_secs(15), + ) + .await + .map_err(|e| Status::unavailable(format!("supervisor relay failed: {e}")))?; + + let sandbox_id = sandbox.object_id().to_string(); + let (tx, rx) = mpsc::channel::>(256); + tokio::spawn(async move { + let _connection_guard = connection_guard; + let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) + .await + { + Ok(Ok(Ok(stream))) => stream, + Ok(Ok(Err(status))) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ForwardTcp: relay target open failed"); + let _ = tx.send(Err(status)).await; + return; + } + Ok(Err(_)) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay channel dropped"); + let _ = tx + .send(Err(Status::unavailable("relay channel dropped"))) + .await; + return; + } + Err(_) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay open timed out"); + let _ = tx + .send(Err(Status::deadline_exceeded("relay open timed out"))) + .await; + return; + } + }; + + bridge_forward_tcp_stream(inbound, relay_stream, tx, &sandbox_id, &channel_id).await; + }); + + let stream: Pin< + Box> + Send + 'static>, + > = Box::pin(ReceiverStream::new(rx)); + Ok(Response::new(stream)) +} + +struct ForwardConnectionGuard { + state: Arc, + token: Option, + sandbox_id: String, +} + +impl Drop for ForwardConnectionGuard { + fn drop(&mut self) { + if let Some(token) = self.token.as_deref() { + decrement_ssh_connection_count(&self.state.ssh_connections_by_token, token); + decrement_ssh_connection_count( + &self.state.ssh_connections_by_sandbox, + &self.sandbox_id, + ); + } + } +} + +async fn acquire_forward_connection_guard( + state: &Arc, + init: &TcpForwardInit, + sandbox: &Sandbox, +) -> Result { + let sandbox_id = sandbox.object_id().to_string(); + let token = init.authorization_token.trim(); + if token.is_empty() { + return Err(Status::unauthenticated( + "authorization_token is required for ForwardTcp", + )); + } + + validate_ssh_forward_token(state, token, &sandbox_id).await?; + acquire_ssh_connection_slots( + &state.ssh_connections_by_token, + &state.ssh_connections_by_sandbox, + token, + &sandbox_id, + )?; + + Ok(ForwardConnectionGuard { + state: state.clone(), + token: Some(token.to_string()), + sandbox_id, + }) +} + +async fn validate_ssh_forward_token( + state: &Arc, + token: &str, + sandbox_id: &str, +) -> Result<(), Status> { + let session = state + .store + .get_message::(token) + .await + .map_err(|e| Status::internal(format!("fetch SSH session failed: {e}")))? + .ok_or_else(|| Status::unauthenticated("SSH session token not found"))?; + + if session.revoked || session.sandbox_id != sandbox_id { + return Err(Status::unauthenticated("SSH session token is not valid")); + } + + if session.expires_at_ms > 0 { + let now_ms = current_time_ms() + .map_err(|e| Status::internal(format!("timestamp generation failed: {e}")))?; + if now_ms > session.expires_at_ms { + return Err(Status::unauthenticated("SSH session token expired")); + } + } + + Ok(()) +} + +fn acquire_ssh_connection_slots( + token_counts: &std::sync::Mutex>, + sandbox_counts: &std::sync::Mutex>, + token: &str, + sandbox_id: &str, +) -> Result<(), Status> { + const MAX_CONNECTIONS_PER_TOKEN: u32 = 3; + const MAX_CONNECTIONS_PER_SANDBOX: u32 = 20; + + { + let mut counts = token_counts.lock().unwrap(); + let count = counts.entry(token.to_string()).or_insert(0); + if *count >= MAX_CONNECTIONS_PER_TOKEN { + return Err(Status::resource_exhausted( + "SSH session connection limit reached", + )); + } + *count += 1; + } + + { + let mut counts = sandbox_counts.lock().unwrap(); + let count = counts.entry(sandbox_id.to_string()).or_insert(0); + if *count >= MAX_CONNECTIONS_PER_SANDBOX { + decrement_ssh_connection_count(token_counts, token); + return Err(Status::resource_exhausted( + "sandbox SSH connection limit reached", + )); + } + *count += 1; + } + + Ok(()) +} + +fn decrement_ssh_connection_count( + counts: &std::sync::Mutex>, + key: &str, +) { + let mut counts = counts.lock().unwrap(); + if let Some(count) = counts.get_mut(key) { + *count = count.saturating_sub(1); + if *count == 0 { + counts.remove(key); + } + } +} + +fn validate_tcp_forward_init(init: &TcpForwardInit) -> Result { + if init.sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox_id is required")); + } + + if let Some(target) = init.target.as_ref() { + return match target { + tcp_forward_init::Target::Ssh(_) => { + Ok(relay_open::Target::Ssh(SshRelayTarget::default())) + } + tcp_forward_init::Target::Tcp(target) => Ok(relay_open::Target::Tcp( + validate_tcp_forward_target(target)?, + )), + }; + } + + Err(Status::invalid_argument("tcp forward target is required")) +} + +fn validate_tcp_forward_target(target: &TcpRelayTarget) -> Result { + if target.port == 0 || target.port > u32::from(u16::MAX) { + return Err(Status::invalid_argument( + "tcp target port must be between 1 and 65535", + )); + } + + validate_tcp_target_parts(target.host.trim(), target.port).map(|host| TcpRelayTarget { + host, + port: target.port, + }) +} + +fn validate_tcp_target_parts(host: &str, _port: u32) -> Result { + if host.is_empty() { + return Err(Status::invalid_argument("tcp target host is required")); + } + if host.eq_ignore_ascii_case("localhost") { + return Ok("127.0.0.1".to_string()); + } + + let ip: IpAddr = host + .parse() + .map_err(|_| Status::invalid_argument("tcp target host must be loopback"))?; + if ip.is_loopback() { + Ok(ip.to_string()) + } else { + Err(Status::invalid_argument("tcp target host must be loopback")) + } +} + +async fn bridge_forward_tcp_stream( + mut inbound: tonic::Streaming, + relay_stream: tokio::io::DuplexStream, + tx: mpsc::Sender>, + sandbox_id: &str, + channel_id: &str, +) { + let (mut relay_read, mut relay_write) = tokio::io::split(relay_stream); + + let sandbox_id_in = sandbox_id.to_string(); + let channel_id_in = channel_id.to_string(); + tokio::spawn(async move { + loop { + match inbound.message().await { + Ok(Some(frame)) => { + let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = + frame.payload + else { + warn!(sandbox_id = %sandbox_id_in, channel_id = %channel_id_in, "ForwardTcp: received non-data frame after init"); + break; + }; + if data.is_empty() { + continue; + } + if let Err(err) = + tokio::io::AsyncWriteExt::write_all(&mut relay_write, &data).await + { + warn!(sandbox_id = %sandbox_id_in, channel_id = %channel_id_in, error = %err, "ForwardTcp: write to relay failed"); + break; + } + } + Ok(None) => break, + Err(err) => { + warn!(sandbox_id = %sandbox_id_in, channel_id = %channel_id_in, error = %err, "ForwardTcp: inbound stream failed"); + break; + } + } + } + let _ = tokio::io::AsyncWriteExt::shutdown(&mut relay_write).await; + }); + + let mut buf = vec![0u8; TCP_FORWARD_CHUNK_SIZE]; + loop { + match tokio::io::AsyncReadExt::read(&mut relay_read, &mut buf).await { + Ok(0) => break, + Ok(n) => { + let frame = TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Data( + buf[..n].to_vec(), + )), + }; + if tx.send(Ok(frame)).await.is_err() { + break; + } + } + Err(err) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %err, "ForwardTcp: read from relay failed"); + let _ = tx + .send(Err(Status::unavailable(format!( + "relay read failed: {err}" + )))) + .await; + break; + } + } + } +} + // --------------------------------------------------------------------------- // SSH session handlers // --------------------------------------------------------------------------- @@ -773,7 +1105,6 @@ pub(super) async fn handle_create_ssh_session( gateway_host, gateway_port: gateway_port.into(), gateway_scheme: scheme.to_string(), - connect_path: state.config.ssh_connect_path.clone(), host_key_fingerprint: String::new(), expires_at_ms, })) @@ -882,8 +1213,7 @@ fn build_remote_exec_command(req: &ExecSandboxRequest) -> Result /// /// This is the relay equivalent of `stream_exec_over_ssh`. Instead of dialing a /// sandbox endpoint directly, the SSH transport runs over a `DuplexStream` that -/// is bridged to the supervisor's local SSH daemon via a reverse HTTP CONNECT -/// tunnel. +/// is bridged to the supervisor's local SSH daemon via `RelayStream`. #[allow(clippy::too_many_arguments)] async fn stream_exec_over_relay( tx: mpsc::Sender>, @@ -1219,6 +1549,87 @@ mod tests { assert!(build_remote_exec_command(&req).is_err()); } + #[test] + fn tcp_forward_init_allows_loopback_targets() { + for host in ["127.0.0.1", "::1", "localhost"] { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + service_id: String::new(), + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: host.to_string(), + port: 8080, + })), + authorization_token: String::new(), + }; + validate_tcp_forward_init(&init).expect("loopback target should pass"); + } + } + + #[test] + fn tcp_forward_init_allows_ssh_target() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + target: Some(tcp_forward_init::Target::Ssh(SshRelayTarget::default())), + ..Default::default() + }; + match validate_tcp_forward_init(&init).expect("ssh target should pass") { + relay_open::Target::Ssh(_) => {} + other @ relay_open::Target::Tcp(_) => panic!("expected SSH target, got {other:?}"), + } + } + + #[test] + fn tcp_forward_init_rejects_non_loopback_targets() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + service_id: String::new(), + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: "example.com".to_string(), + port: 8080, + })), + authorization_token: String::new(), + }; + assert_eq!( + validate_tcp_forward_init(&init) + .expect_err("hostname rejected") + .message(), + "tcp target host must be loopback" + ); + } + + #[test] + fn tcp_forward_init_rejects_invalid_port() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + service_id: String::new(), + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: "127.0.0.1".to_string(), + port: 0, + })), + authorization_token: String::new(), + }; + assert_eq!( + validate_tcp_forward_init(&init) + .expect_err("zero port rejected") + .message(), + "tcp target port must be between 1 and 65535" + ); + } + + #[test] + fn tcp_forward_init_requires_target() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + ..Default::default() + }; + assert_eq!( + validate_tcp_forward_init(&init) + .expect_err("missing target rejected") + .message(), + "tcp forward target is required" + ); + } + // ---- petname / generate_name ---- #[test] diff --git a/crates/openshell-server/src/http.rs b/crates/openshell-server/src/http.rs index 7650c2339..7ca9cb8bf 100644 --- a/crates/openshell-server/src/http.rs +++ b/crates/openshell-server/src/http.rs @@ -59,7 +59,5 @@ async fn render_metrics(State(handle): State) -> impl IntoResp /// Create the HTTP router. pub fn http_router(state: Arc) -> Router { - crate::ssh_tunnel::router(state.clone()) - .merge(crate::ws_tunnel::router(state.clone())) - .merge(crate::auth::router(state)) + crate::ws_tunnel::router(state.clone()).merge(crate::auth::router(state)) } diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index 07c3cef5c..bca6e44aa 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -31,7 +31,7 @@ mod persistence; pub(crate) mod policy_store; mod sandbox_index; mod sandbox_watch; -mod ssh_tunnel; +mod ssh_sessions; pub mod supervisor_session; mod tls; pub mod tracing_bus; @@ -220,7 +220,7 @@ pub async fn run_server( } state.compute.spawn_watchers(); - ssh_tunnel::spawn_session_reaper(store.clone(), Duration::from_secs(3600)); + ssh_sessions::spawn_session_reaper(store.clone(), Duration::from_secs(3600)); supervisor_session::spawn_relay_reaper(state.clone(), Duration::from_secs(30)); // Create the multiplexed service diff --git a/crates/openshell-server/src/multiplex.rs b/crates/openshell-server/src/multiplex.rs index 93e58d202..bca9a2171 100644 --- a/crates/openshell-server/src/multiplex.rs +++ b/crates/openshell-server/src/multiplex.rs @@ -470,7 +470,6 @@ fn grpc_status_from_response(res: &Response) -> String { fn normalize_http_path(path: &str) -> &'static str { match path { - p if p.starts_with("/connect/ssh") => "/connect/ssh", p if p.starts_with("/_ws_tunnel") => "/_ws_tunnel", p if p.starts_with("/auth/") => "/auth", _ => "unknown", @@ -724,19 +723,6 @@ mod tests { assert_eq!(grpc_method_from_path(""), ""); } - #[test] - fn normalize_ssh_path() { - assert_eq!(normalize_http_path("/connect/ssh"), "/connect/ssh"); - } - - #[test] - fn normalize_ssh_path_with_trailing_segments() { - assert_eq!( - normalize_http_path("/connect/ssh?token=abc"), - "/connect/ssh" - ); - } - #[test] fn normalize_ws_tunnel() { assert_eq!(normalize_http_path("/_ws_tunnel"), "/_ws_tunnel"); diff --git a/crates/openshell-server/src/ssh_sessions.rs b/crates/openshell-server/src/ssh_sessions.rs new file mode 100644 index 000000000..c3294b361 --- /dev/null +++ b/crates/openshell-server/src/ssh_sessions.rs @@ -0,0 +1,189 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! SSH session token storage and cleanup. + +use openshell_core::ObjectId; +use openshell_core::proto::SshSession; +use prost::Message; +use std::sync::Arc; +use std::time::Duration; +use tracing::{info, warn}; + +use crate::persistence::{ObjectType, Store}; + +impl ObjectType for SshSession { + fn object_type() -> &'static str { + "ssh_session" + } +} + +/// Spawn a background task that periodically reaps expired and revoked SSH sessions. +pub fn spawn_session_reaper(store: Arc, interval: Duration) { + tokio::spawn(async move { + tokio::time::sleep(interval).await; + + loop { + if let Err(e) = reap_expired_sessions(&store).await { + warn!(error = %e, "SSH session reaper sweep failed"); + } + tokio::time::sleep(interval).await; + } + }); +} + +async fn reap_expired_sessions(store: &Store) -> Result<(), String> { + let now_ms = unix_epoch_millis(); + + let records = store + .list(SshSession::object_type(), 1000, 0) + .await + .map_err(|e| e.to_string())?; + + let mut reaped = 0u32; + for record in records { + let session: SshSession = match Message::decode(record.payload.as_slice()) { + Ok(s) => s, + Err(_) => continue, + }; + + let should_delete = + (session.expires_at_ms > 0 && now_ms > session.expires_at_ms) || session.revoked; + + if should_delete { + if let Err(e) = store + .delete(SshSession::object_type(), session.object_id()) + .await + { + warn!(session_id = %session.object_id(), error = %e, "Failed to reap SSH session"); + } else { + reaped += 1; + } + } + } + + if reaped > 0 { + info!(count = reaped, "SSH session reaper: cleaned up sessions"); + } + Ok(()) +} + +fn unix_epoch_millis() -> i64 { + i64::try_from( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(), + ) + .unwrap_or(i64::MAX) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + fn make_session(id: &str, sandbox_id: &str, expires_at_ms: i64, revoked: bool) -> SshSession { + SshSession { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: id.to_string(), + name: format!("session-{id}"), + created_at_ms: 1000, + labels: HashMap::new(), + }), + sandbox_id: sandbox_id.to_string(), + token: id.to_string(), + expires_at_ms, + revoked, + } + } + + fn now_ms() -> i64 { + unix_epoch_millis() + } + + #[tokio::test] + async fn reaper_deletes_expired_sessions() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let expired = make_session("expired1", "sbx1", now_ms() - 60_000, false); + store.put_message(&expired).await.unwrap(); + + let valid = make_session("valid1", "sbx1", now_ms() + 3_600_000, false); + store.put_message(&valid).await.unwrap(); + + reap_expired_sessions(&store).await.unwrap(); + + assert!( + store + .get_message::("expired1") + .await + .unwrap() + .is_none(), + "expired session should be reaped" + ); + assert!( + store + .get_message::("valid1") + .await + .unwrap() + .is_some(), + "valid session should be kept" + ); + } + + #[tokio::test] + async fn reaper_deletes_revoked_sessions() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let revoked = make_session("revoked1", "sbx1", 0, true); + store.put_message(&revoked).await.unwrap(); + + let active = make_session("active1", "sbx1", 0, false); + store.put_message(&active).await.unwrap(); + + reap_expired_sessions(&store).await.unwrap(); + + assert!( + store + .get_message::("revoked1") + .await + .unwrap() + .is_none(), + "revoked session should be reaped" + ); + assert!( + store + .get_message::("active1") + .await + .unwrap() + .is_some(), + "active session should be kept" + ); + } + + #[tokio::test] + async fn reaper_preserves_zero_expiry_sessions() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let no_expiry = make_session("noexpiry1", "sbx1", 0, false); + store.put_message(&no_expiry).await.unwrap(); + + reap_expired_sessions(&store).await.unwrap(); + + assert!( + store + .get_message::("noexpiry1") + .await + .unwrap() + .is_some(), + "session with no expiry should be preserved" + ); + } +} diff --git a/crates/openshell-server/src/ssh_tunnel.rs b/crates/openshell-server/src/ssh_tunnel.rs deleted file mode 100644 index bd317d53f..000000000 --- a/crates/openshell-server/src/ssh_tunnel.rs +++ /dev/null @@ -1,541 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! SSH tunnel handler for the multiplexed gateway. - -use axum::{Router, extract::State, http::Method, response::IntoResponse, routing::any}; -use http::StatusCode; -use hyper::Request; -use hyper_util::rt::TokioIo; -use openshell_core::proto::{Sandbox, SandboxPhase, SshSession}; -use prost::Message; -use std::sync::Arc; -use std::time::Duration; -use tokio::io::AsyncWriteExt; -use tracing::{info, warn}; - -use crate::ServerState; -use crate::persistence::{ObjectType, Store}; - -const HEADER_SANDBOX_ID: &str = "x-sandbox-id"; -const HEADER_TOKEN: &str = "x-sandbox-token"; - -/// Maximum concurrent SSH tunnel connections per session token. -const MAX_CONNECTIONS_PER_TOKEN: u32 = 3; - -/// Redact a bearer token for safe logging — show only the last 4 characters. -fn redact_token(token: &str) -> String { - if token.len() <= 4 { - "****".to_string() - } else { - format!("****{}", &token[token.len() - 4..]) - } -} - -/// Maximum concurrent SSH tunnel connections per sandbox. -const MAX_CONNECTIONS_PER_SANDBOX: u32 = 20; - -fn acquire_connection_slots( - token_counts: &std::sync::Mutex>, - sandbox_counts: &std::sync::Mutex>, - token: &str, - sandbox_id: &str, -) -> Result<(), ConnectionLimit> { - { - let mut counts = token_counts.lock().unwrap(); - let count = counts.entry(token.to_string()).or_insert(0); - if *count >= MAX_CONNECTIONS_PER_TOKEN { - return Err(ConnectionLimit::PerToken); - } - *count += 1; - } - - { - let mut counts = sandbox_counts.lock().unwrap(); - let count = counts.entry(sandbox_id.to_string()).or_insert(0); - if *count >= MAX_CONNECTIONS_PER_SANDBOX { - decrement_connection_count(token_counts, token); - return Err(ConnectionLimit::PerSandbox); - } - *count += 1; - } - - Ok(()) -} - -enum ConnectionLimit { - PerToken, - PerSandbox, -} - -pub fn router(state: Arc) -> Router { - Router::new() - .route("/connect/ssh", any(ssh_connect)) - .with_state(state) -} - -async fn ssh_connect( - State(state): State>, - req: Request, -) -> impl IntoResponse { - if req.method() != Method::CONNECT { - return StatusCode::METHOD_NOT_ALLOWED.into_response(); - } - - let sandbox_id = match header_value(req.headers(), HEADER_SANDBOX_ID) { - Ok(value) => value, - Err(status) => return status.into_response(), - }; - let token = match header_value(req.headers(), HEADER_TOKEN) { - Ok(value) => value, - Err(status) => return status.into_response(), - }; - - let session = match state.store.get_message::(&token).await { - Ok(Some(session)) => session, - Ok(None) => return StatusCode::UNAUTHORIZED.into_response(), - Err(err) => { - warn!(error = %err, "Failed to fetch SSH session"); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - if session.revoked || session.sandbox_id != sandbox_id { - return StatusCode::UNAUTHORIZED.into_response(); - } - - // Check token expiry (0 means no expiry for backward compatibility). - if session.expires_at_ms > 0 { - let now_ms = i64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(), - ) - .unwrap_or(i64::MAX); - if now_ms > session.expires_at_ms { - return StatusCode::UNAUTHORIZED.into_response(); - } - } - - let sandbox = match state.store.get_message::(&sandbox_id).await { - Ok(Some(sandbox)) => sandbox, - Ok(None) => return StatusCode::NOT_FOUND.into_response(), - Err(err) => { - warn!(error = %err, "Failed to fetch sandbox"); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { - return StatusCode::PRECONDITION_FAILED.into_response(); - } - - // Enforce connection caps *before* opening a relay — otherwise denied - // calls churn pending relay slots and wake the supervisor until the relay - // timeout elapses. - if let Err(limit) = acquire_connection_slots( - &state.ssh_connections_by_token, - &state.ssh_connections_by_sandbox, - &token, - &sandbox_id, - ) { - match limit { - ConnectionLimit::PerToken => { - warn!(token = %redact_token(&token), "SSH tunnel: per-token connection limit reached"); - } - ConnectionLimit::PerSandbox => { - warn!(sandbox_id = %sandbox_id, "SSH tunnel: per-sandbox connection limit reached"); - } - } - return StatusCode::TOO_MANY_REQUESTS.into_response(); - } - - // Open a relay channel through the supervisor session. Use a generous - // 30s session-wait timeout because `/connect/ssh` is typically called - // immediately after `sandbox create`, so we need to cover the supervisor's - // initial TLS + gRPC handshake on a cold-started pod. The old - // direct-connect path tolerated ~34s here for similar reasons. - let (channel_id, relay_rx) = match state - .supervisor_sessions - .open_relay(&sandbox_id, Duration::from_secs(30)) - .await - { - Ok(pair) => pair, - Err(status) => { - warn!(sandbox_id = %sandbox_id, error = %status.message(), "SSH tunnel: supervisor session not available"); - decrement_connection_count(&state.ssh_connections_by_token, &token); - decrement_connection_count(&state.ssh_connections_by_sandbox, &sandbox_id); - return StatusCode::BAD_GATEWAY.into_response(); - } - }; - - let sandbox_id_clone = sandbox_id.clone(); - let token_clone = token.clone(); - let state_clone = state.clone(); - - let upgrade = hyper::upgrade::on(req); - tokio::spawn(async move { - // Wait for the supervisor to open its `RelayStream` and deliver the - // bridge half of the relay. - let mut relay = match tokio::time::timeout(Duration::from_secs(10), relay_rx).await { - Ok(Ok(stream)) => stream, - Ok(Err(_)) => { - warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay channel dropped"); - decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); - decrement_connection_count( - &state_clone.ssh_connections_by_sandbox, - &sandbox_id_clone, - ); - return; - } - Err(_) => { - warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay open timed out"); - decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); - decrement_connection_count( - &state_clone.ssh_connections_by_sandbox, - &sandbox_id_clone, - ); - return; - } - }; - - info!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay established, bridging client"); - - match upgrade.await { - Ok(upgraded) => { - let mut upgraded = TokioIo::new(upgraded); - let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut relay).await; - let _ = AsyncWriteExt::shutdown(&mut upgraded).await; - } - Err(err) => { - warn!(error = %err, "SSH upgrade failed"); - } - } - - // Decrement connection counts on tunnel completion. - decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); - decrement_connection_count(&state_clone.ssh_connections_by_sandbox, &sandbox_id_clone); - }); - - StatusCode::OK.into_response() -} - -fn header_value(headers: &http::HeaderMap, name: &str) -> Result { - let value = headers - .get(name) - .ok_or(StatusCode::UNAUTHORIZED)? - .to_str() - .map_err(|_| StatusCode::BAD_REQUEST)? - .trim() - .to_string(); - if value.is_empty() { - return Err(StatusCode::BAD_REQUEST); - } - Ok(value) -} - -impl ObjectType for SshSession { - fn object_type() -> &'static str { - "ssh_session" - } -} - -/// Decrement a connection count entry, removing it if it reaches zero. -fn decrement_connection_count( - counts: &std::sync::Mutex>, - key: &str, -) { - let mut map = counts.lock().unwrap(); - if let Some(count) = map.get_mut(key) { - *count = count.saturating_sub(1); - if *count == 0 { - map.remove(key); - } - } -} - -/// Spawn a background task that periodically reaps expired and revoked SSH sessions. -pub fn spawn_session_reaper(store: Arc, interval: Duration) { - tokio::spawn(async move { - // Initial delay to let startup settle. - tokio::time::sleep(interval).await; - - loop { - if let Err(e) = reap_expired_sessions(&store).await { - warn!(error = %e, "SSH session reaper sweep failed"); - } - tokio::time::sleep(interval).await; - } - }); -} - -async fn reap_expired_sessions(store: &Store) -> Result<(), String> { - let now_ms = i64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(), - ) - .unwrap_or(i64::MAX); - - let records = store - .list(SshSession::object_type(), 1000, 0) - .await - .map_err(|e| e.to_string())?; - - let mut reaped = 0u32; - for record in records { - let session: SshSession = match Message::decode(record.payload.as_slice()) { - Ok(s) => s, - Err(_) => continue, - }; - - let should_delete = - // Expired sessions (expires_at_ms > 0 means expiry is set). - (session.expires_at_ms > 0 && now_ms > session.expires_at_ms) - // Revoked sessions — already invalidated, just cleaning up storage. - || session.revoked; - - if should_delete { - use openshell_core::ObjectId; - if let Err(e) = store - .delete(SshSession::object_type(), session.object_id()) - .await - { - warn!(session_id = %session.object_id(), error = %e, "Failed to reap SSH session"); - } else { - reaped += 1; - } - } - } - - if reaped > 0 { - info!(count = reaped, "SSH session reaper: cleaned up sessions"); - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::persistence::Store; - use std::collections::HashMap; - use std::sync::Mutex; - - fn make_session(id: &str, sandbox_id: &str, expires_at_ms: i64, revoked: bool) -> SshSession { - SshSession { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: id.to_string(), - name: format!("session-{id}"), - created_at_ms: 1000, - labels: HashMap::new(), - }), - sandbox_id: sandbox_id.to_string(), - token: id.to_string(), - expires_at_ms, - revoked, - } - } - - fn now_ms() -> i64 { - i64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(), - ) - .unwrap_or(i64::MAX) - } - - // ---- Connection limit tests ---- - - #[test] - fn decrement_removes_entry_at_zero() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts.lock().unwrap().insert("tok1".to_string(), 1); - decrement_connection_count(&counts, "tok1"); - assert!(counts.lock().unwrap().is_empty()); - } - - #[test] - fn decrement_reduces_count() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts.lock().unwrap().insert("tok1".to_string(), 5); - decrement_connection_count(&counts, "tok1"); - assert_eq!(*counts.lock().unwrap().get("tok1").unwrap(), 4); - } - - #[test] - fn decrement_missing_key_is_noop() { - let counts: Mutex> = Mutex::new(HashMap::new()); - decrement_connection_count(&counts, "nonexistent"); - assert!(counts.lock().unwrap().is_empty()); - } - - #[test] - fn per_token_connection_limit_enforced() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts - .lock() - .unwrap() - .insert("tok1".to_string(), MAX_CONNECTIONS_PER_TOKEN); - let current = *counts.lock().unwrap().get("tok1").unwrap(); - assert!(current >= MAX_CONNECTIONS_PER_TOKEN); - } - - #[test] - fn per_sandbox_connection_limit_enforced() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts - .lock() - .unwrap() - .insert("sbx1".to_string(), MAX_CONNECTIONS_PER_SANDBOX); - let current = *counts.lock().unwrap().get("sbx1").unwrap(); - assert!(current >= MAX_CONNECTIONS_PER_SANDBOX); - } - - #[test] - fn acquire_connection_slots_rejects_per_token_limit_without_touching_sandbox() { - let token_counts: Mutex> = Mutex::new(HashMap::new()); - let sandbox_counts: Mutex> = Mutex::new(HashMap::new()); - token_counts - .lock() - .unwrap() - .insert("tok1".to_string(), MAX_CONNECTIONS_PER_TOKEN); - - let result = acquire_connection_slots(&token_counts, &sandbox_counts, "tok1", "sbx1"); - - assert!(matches!(result, Err(ConnectionLimit::PerToken))); - assert!(sandbox_counts.lock().unwrap().is_empty()); - } - - #[test] - fn acquire_connection_slots_rolls_back_token_increment_on_sandbox_limit() { - let token_counts: Mutex> = Mutex::new(HashMap::new()); - let sandbox_counts: Mutex> = Mutex::new(HashMap::new()); - sandbox_counts - .lock() - .unwrap() - .insert("sbx1".to_string(), MAX_CONNECTIONS_PER_SANDBOX); - - let result = acquire_connection_slots(&token_counts, &sandbox_counts, "tok1", "sbx1"); - - assert!(matches!(result, Err(ConnectionLimit::PerSandbox))); - assert!(token_counts.lock().unwrap().is_empty()); - } - - // ---- Session reaper tests ---- - - #[tokio::test] - async fn reaper_deletes_expired_sessions() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); - - let expired = make_session("expired1", "sbx1", now_ms() - 60_000, false); - store.put_message(&expired).await.unwrap(); - - let valid = make_session("valid1", "sbx1", now_ms() + 3_600_000, false); - store.put_message(&valid).await.unwrap(); - - reap_expired_sessions(&store).await.unwrap(); - - assert!( - store - .get_message::("expired1") - .await - .unwrap() - .is_none(), - "expired session should be reaped" - ); - assert!( - store - .get_message::("valid1") - .await - .unwrap() - .is_some(), - "valid session should be kept" - ); - } - - #[tokio::test] - async fn reaper_deletes_revoked_sessions() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); - - let revoked = make_session("revoked1", "sbx1", 0, true); - store.put_message(&revoked).await.unwrap(); - - let active = make_session("active1", "sbx1", 0, false); - store.put_message(&active).await.unwrap(); - - reap_expired_sessions(&store).await.unwrap(); - - assert!( - store - .get_message::("revoked1") - .await - .unwrap() - .is_none(), - "revoked session should be reaped" - ); - assert!( - store - .get_message::("active1") - .await - .unwrap() - .is_some(), - "active session should be kept" - ); - } - - #[tokio::test] - async fn reaper_preserves_zero_expiry_sessions() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); - - // expires_at_ms = 0 means no expiry (backward compatible). - let no_expiry = make_session("noexpiry1", "sbx1", 0, false); - store.put_message(&no_expiry).await.unwrap(); - - reap_expired_sessions(&store).await.unwrap(); - - assert!( - store - .get_message::("noexpiry1") - .await - .unwrap() - .is_some(), - "session with no expiry should be preserved" - ); - } - - // ---- Expiry validation logic tests ---- - - #[test] - fn expired_session_is_detected() { - let session = make_session("tok1", "sbx1", now_ms() - 1000, false); - let is_expired = session.expires_at_ms > 0 && now_ms() > session.expires_at_ms; - assert!(is_expired, "session in the past should be expired"); - } - - #[test] - fn future_session_is_not_expired() { - let session = make_session("tok1", "sbx1", now_ms() + 3_600_000, false); - let is_expired = session.expires_at_ms > 0 && now_ms() > session.expires_at_ms; - assert!(!is_expired, "session in the future should not be expired"); - } - - #[test] - fn zero_expiry_is_not_expired() { - let session = make_session("tok1", "sbx1", 0, false); - let is_expired = session.expires_at_ms > 0 && now_ms() > session.expires_at_ms; - assert!( - !is_expired, - "session with zero expiry should never be expired" - ); - } -} diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs index 94c352ba5..19d358826 100644 --- a/crates/openshell-server/src/supervisor_session.rs +++ b/crates/openshell-server/src/supervisor_session.rs @@ -13,8 +13,8 @@ use tracing::{info, warn}; use uuid::Uuid; use openshell_core::proto::{ - GatewayMessage, RelayFrame, RelayInit, RelayOpen, Sandbox, SessionAccepted, SupervisorMessage, - gateway_message, supervisor_message, + GatewayMessage, RelayFrame, RelayInit, RelayOpen, Sandbox, SessionAccepted, SshRelayTarget, + SupervisorMessage, gateway_message, relay_open, supervisor_message, }; use crate::ServerState; @@ -58,8 +58,9 @@ struct LiveSession { connected_at: Instant, } -/// Holds a oneshot sender that will deliver the upgraded relay stream. -type RelayStreamSender = oneshot::Sender; +/// Holds a oneshot sender that will deliver the upgraded relay stream or a +/// target-open failure reported by the supervisor. +type RelayStreamSender = oneshot::Sender>; impl openshell_driver_docker::SupervisorReadiness for SupervisorSessionRegistry { fn is_supervisor_connected(&self, sandbox_id: &str) -> bool { @@ -79,6 +80,7 @@ pub struct SupervisorSessionRegistry { struct PendingRelay { sender: RelayStreamSender, sandbox_id: String, + relay_open: RelayOpen, created_at: Instant, } @@ -234,12 +236,45 @@ impl SupervisorSessionRegistry { &self, sandbox_id: &str, session_wait_timeout: Duration, - ) -> Result<(String, oneshot::Receiver), Status> { + ) -> Result< + ( + String, + oneshot::Receiver>, + ), + Status, + > { + self.open_relay_with_target( + sandbox_id, + relay_open::Target::Ssh(SshRelayTarget {}), + String::new(), + session_wait_timeout, + ) + .await + } + + pub async fn open_relay_with_target( + &self, + sandbox_id: &str, + target: relay_open::Target, + service_id: String, + session_wait_timeout: Duration, + ) -> Result< + ( + String, + oneshot::Receiver>, + ), + Status, + > { let tx = self .wait_for_session(sandbox_id, session_wait_timeout) .await?; let channel_id = Uuid::new_v4().to_string(); + let relay_open = RelayOpen { + channel_id: channel_id.clone(), + target: Some(target), + service_id, + }; // Register the pending relay before sending RelayOpen to avoid a race. // Both caps are checked and the insert happens under a single lock hold @@ -267,15 +302,14 @@ impl SupervisorSessionRegistry { PendingRelay { sender: relay_tx, sandbox_id: sandbox_id.to_string(), + relay_open: relay_open.clone(), created_at: Instant::now(), }, ); } let msg = GatewayMessage { - payload: Some(gateway_message::Payload::RelayOpen(RelayOpen { - channel_id: channel_id.clone(), - })), + payload: Some(gateway_message::Payload::RelayOpen(relay_open)), }; if tx.send(msg).await.is_err() { @@ -287,6 +321,16 @@ impl SupervisorSessionRegistry { Ok((channel_id, relay_rx)) } + pub fn fail_pending_relay(&self, channel_id: &str, error: String) -> bool { + let pending = self.pending_relays.lock().unwrap().remove(channel_id); + if let Some(pending) = pending { + let _ = pending.sender.send(Err(Status::unavailable(error))); + true + } else { + false + } + } + /// Claim a pending relay channel. Called by the `/relay/{channel_id}` HTTP handler /// when the supervisor's reverse CONNECT arrives. /// @@ -308,8 +352,8 @@ impl SupervisorSessionRegistry { // the supervisor HTTP CONNECT handler. let (gateway_stream, supervisor_stream) = tokio::io::duplex(64 * 1024); - // Send the gateway-side stream to the waiter (ssh_tunnel or exec handler). - if pending.sender.send(gateway_stream).is_err() { + // Send the gateway-side stream to the waiter (exec handler or forward handler). + if pending.sender.send(Ok(gateway_stream)).is_err() { return Err(Status::internal("relay requester dropped")); } @@ -329,10 +373,17 @@ impl SupervisorSessionRegistry { pub async fn replay_pending_relays(&self, sandbox_id: &str, tx: &mpsc::Sender) { for channel_id in self.pending_channel_ids(sandbox_id) { + let relay_open = { + let pending = self.pending_relays.lock().unwrap(); + pending + .get(&channel_id) + .map(|pending| pending.relay_open.clone()) + }; + let Some(relay_open) = relay_open else { + continue; + }; let msg = GatewayMessage { - payload: Some(gateway_message::Payload::RelayOpen(RelayOpen { - channel_id: channel_id.clone(), - })), + payload: Some(gateway_message::Payload::RelayOpen(relay_open)), }; if tx.send(msg).await.is_err() { warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "supervisor session: failed to replay pending relay to superseding session"); @@ -626,7 +677,7 @@ pub async fn handle_connect_supervisor( } async fn run_session_loop( - _state: &Arc, + state: &Arc, sandbox_id: &str, session_id: &str, tx: &mpsc::Sender, @@ -647,7 +698,7 @@ async fn run_session_loop( msg = inbound.message() => { match msg { Ok(Some(msg)) => { - handle_supervisor_message(sandbox_id, session_id, msg); + handle_supervisor_message(state, sandbox_id, session_id, msg); } Ok(None) => { info!(sandbox_id = %sandbox_id, session_id = %session_id, "supervisor session: stream closed by supervisor"); @@ -674,7 +725,12 @@ async fn run_session_loop( } } -fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: SupervisorMessage) { +fn handle_supervisor_message( + state: &Arc, + sandbox_id: &str, + session_id: &str, + msg: SupervisorMessage, +) { match msg.payload { Some(supervisor_message::Payload::Heartbeat(_)) => { // Heartbeat received — nothing to do for now. @@ -688,11 +744,15 @@ fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: Supervisor "supervisor session: relay opened successfully" ); } else { + let failed = state + .supervisor_sessions + .fail_pending_relay(&result.channel_id, result.error.clone()); warn!( sandbox_id = %sandbox_id, session_id = %session_id, channel_id = %result.channel_id, error = %result.error, + pending_relay_failed = failed, "supervisor session: relay open failed" ); } @@ -745,6 +805,23 @@ mod tests { } } + fn pending_relay( + sandbox_id: &str, + relay_tx: RelayStreamSender, + created_at: Instant, + ) -> PendingRelay { + PendingRelay { + sender: relay_tx, + sandbox_id: sandbox_id.to_string(), + relay_open: RelayOpen { + channel_id: "ch-test".to_string(), + target: Some(relay_open::Target::Ssh(SshRelayTarget {})), + service_id: String::new(), + }, + created_at, + } + } + // ---- registry: register / remove ---- #[test] @@ -863,6 +940,7 @@ mod tests { match msg.payload { Some(gateway_message::Payload::RelayOpen(open)) => { assert_eq!(open.channel_id, channel_id); + assert!(matches!(open.target, Some(relay_open::Target::Ssh(_)))); } other => panic!("expected RelayOpen, got {other:?}"), } @@ -944,11 +1022,7 @@ mod tests { let sandbox_id = if i % 2 == 0 { "sbx-a" } else { "sbx-b" }; pending.insert( format!("channel-{i}"), - PendingRelay { - sender: oneshot_tx, - sandbox_id: sandbox_id.to_string(), - created_at: Instant::now(), - }, + pending_relay(sandbox_id, oneshot_tx, Instant::now()), ); } } @@ -973,11 +1047,7 @@ mod tests { let (oneshot_tx, _) = oneshot::channel(); pending.insert( format!("channel-{i}"), - PendingRelay { - sender: oneshot_tx, - sandbox_id: "sbx".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx", oneshot_tx, Instant::now()), ); } } @@ -1174,11 +1244,7 @@ mod tests { let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-1".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); let result = registry.claim_relay("ch-1"); @@ -1186,19 +1252,43 @@ mod tests { assert!(!registry.pending_relays.lock().unwrap().contains_key("ch-1")); } + #[tokio::test] + async fn relay_open_failure_completes_pending_waiter() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-fail".to_string(), + pending_relay("sbx-test", relay_tx, Instant::now()), + ); + + assert!(registry.fail_pending_relay("ch-fail", "target refused".to_string())); + assert!( + !registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-fail") + ); + + let result = relay_rx.await.expect("failure should wake waiter"); + let status = result.expect_err("waiter should receive status failure"); + assert_eq!(status.code(), tonic::Code::Unavailable); + assert_eq!(status.message(), "target refused"); + } + #[test] fn claim_relay_expired_returns_deadline_exceeded() { let registry = SupervisorSessionRegistry::new(); let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-old".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now() + pending_relay( + "sbx-test", + relay_tx, + Instant::now() .checked_sub(Duration::from_secs(60)) - .expect("test instant subtraction underflow"), - }, + .expect("test duration should be before now"), + ), ); let err = registry @@ -1218,15 +1308,11 @@ mod tests { #[test] fn claim_relay_receiver_dropped_returns_internal() { let registry = SupervisorSessionRegistry::new(); - let (relay_tx, relay_rx) = oneshot::channel::(); + let (relay_tx, relay_rx) = oneshot::channel::>(); drop(relay_rx); // Gateway-side waiter has given up already. registry.pending_relays.lock().unwrap().insert( "ch-1".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); let err = registry @@ -1238,18 +1324,17 @@ mod tests { #[tokio::test] async fn claim_relay_connects_both_ends() { let registry = SupervisorSessionRegistry::new(); - let (relay_tx, relay_rx) = oneshot::channel::(); + let (relay_tx, relay_rx) = oneshot::channel::>(); registry.pending_relays.lock().unwrap().insert( "ch-io".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); let mut supervisor_side = registry.claim_relay("ch-io").expect("claim should succeed"); - let mut gateway_side = relay_rx.await.expect("gateway side should receive stream"); + let mut gateway_side = relay_rx + .await + .expect("gateway side should receive result") + .expect("gateway side should receive stream"); // Supervisor side writes → gateway side reads. supervisor_side.write_all(b"hello").await.unwrap(); @@ -1272,13 +1357,13 @@ mod tests { let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-old".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now() + pending_relay( + "sbx-test", + relay_tx, + Instant::now() .checked_sub(Duration::from_secs(60)) - .expect("test instant subtraction underflow"), - }, + .expect("test duration should be before now"), + ), ); registry.reap_expired_relays(); @@ -1297,11 +1382,7 @@ mod tests { let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-fresh".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); registry.reap_expired_relays(); diff --git a/crates/openshell-server/tests/auth_endpoint_integration.rs b/crates/openshell-server/tests/auth_endpoint_integration.rs index c66f2ad6b..f160f98b8 100644 --- a/crates/openshell-server/tests/auth_endpoint_integration.rs +++ b/crates/openshell-server/tests/auth_endpoint_integration.rs @@ -754,6 +754,21 @@ impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { ) -> Result, tonic::Status> { Err(tonic::Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = std::pin::Pin< + Box< + dyn tokio_stream::Stream< + Item = Result, + > + Send, + >, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented in test")) + } } /// Test 7: Plaintext server (no TLS) accepts both gRPC and HTTP. diff --git a/crates/openshell-server/tests/edge_tunnel_auth.rs b/crates/openshell-server/tests/edge_tunnel_auth.rs index 706967d1f..689cfcf59 100644 --- a/crates/openshell-server/tests/edge_tunnel_auth.rs +++ b/crates/openshell-server/tests/edge_tunnel_auth.rs @@ -42,9 +42,9 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + ProviderResponse, RelayFrame, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, TcpForwardFrame, + UpdateProviderRequest, WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -379,14 +379,24 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } - type RelayStreamStream = ReceiverStream>; + type RelayStreamStream = ReceiverStream>; async fn relay_stream( &self, - _request: tonic::Request>, + _request: tonic::Request>, ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = + std::pin::Pin> + Send>>; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // --------------------------------------------------------------------------- diff --git a/crates/openshell-server/tests/multiplex_integration.rs b/crates/openshell-server/tests/multiplex_integration.rs index d5631319d..9cab950db 100644 --- a/crates/openshell-server/tests/multiplex_integration.rs +++ b/crates/openshell-server/tests/multiplex_integration.rs @@ -16,9 +16,9 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + ProviderResponse, RelayFrame, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, TcpForwardFrame, + UpdateProviderRequest, WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -347,14 +347,24 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } - type RelayStreamStream = ReceiverStream>; + type RelayStreamStream = ReceiverStream>; async fn relay_stream( &self, - _request: tonic::Request>, + _request: tonic::Request>, ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = + std::pin::Pin> + Send>>; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } #[tokio::test] diff --git a/crates/openshell-server/tests/multiplex_tls_integration.rs b/crates/openshell-server/tests/multiplex_tls_integration.rs index c4f68eaf4..21b75c12c 100644 --- a/crates/openshell-server/tests/multiplex_tls_integration.rs +++ b/crates/openshell-server/tests/multiplex_tls_integration.rs @@ -18,9 +18,9 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + ProviderResponse, RelayFrame, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, TcpForwardFrame, + UpdateProviderRequest, WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -360,14 +360,24 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } - type RelayStreamStream = ReceiverStream>; + type RelayStreamStream = ReceiverStream>; async fn relay_stream( &self, - _request: tonic::Request>, + _request: tonic::Request>, ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = + std::pin::Pin> + Send>>; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } /// PKI bundle: CA cert, server cert+key, client cert+key. diff --git a/crates/openshell-server/tests/supervisor_relay_integration.rs b/crates/openshell-server/tests/supervisor_relay_integration.rs index 8f5cac03a..2d722b051 100644 --- a/crates/openshell-server/tests/supervisor_relay_integration.rs +++ b/crates/openshell-server/tests/supervisor_relay_integration.rs @@ -23,7 +23,7 @@ use hyper_util::{ server::conn::auto::Builder, }; use openshell_core::proto::{ - GatewayMessage, RelayFrame, RelayInit, SupervisorMessage, + GatewayMessage, RelayFrame, RelayInit, SupervisorMessage, TcpForwardFrame, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -87,6 +87,15 @@ impl OpenShell for RelayGateway { Err(Status::unimplemented("unused")) } + type ForwardTcpStream = + std::pin::Pin> + Send>>; + async fn forward_tcp( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn health( &self, _: tonic::Request, @@ -439,7 +448,7 @@ async fn relay_round_trips_bytes() { tokio::spawn(run_echo_supervisor(channel, channel_id)); - let relay = relay_rx.await.expect("relay duplex"); + let relay = relay_rx.await.expect("relay result").expect("relay duplex"); let (mut read_half, mut write_half) = tokio::io::split(relay); write_half.write_all(b"hello relay").await.expect("write"); @@ -464,7 +473,7 @@ async fn relay_closes_cleanly_when_gateway_drops() { let supervisor = tokio::spawn(run_echo_supervisor(channel, channel_id)); - let relay = relay_rx.await.expect("relay duplex"); + let relay = relay_rx.await.expect("relay result").expect("relay duplex"); drop(relay); // The supervisor's inbound stream should terminate shortly after the @@ -509,7 +518,7 @@ async fn relay_sees_eof_when_supervisor_closes() { }) }; - let relay = relay_rx.await.expect("relay duplex"); + let relay = relay_rx.await.expect("relay result").expect("relay duplex"); let (mut read_half, _write_half) = tokio::io::split(relay); let mut buf = [0u8; 16]; let n = tokio::time::timeout(Duration::from_secs(5), read_half.read(&mut buf)) @@ -555,8 +564,8 @@ async fn concurrent_relays_multiplex_independently() { tokio::spawn(run_echo_supervisor(channel.clone(), id_a)); tokio::spawn(run_echo_supervisor(channel, id_b)); - let relay_a = rx_a.await.expect("relay a"); - let relay_b = rx_b.await.expect("relay b"); + let relay_a = rx_a.await.expect("relay a result").expect("relay a"); + let relay_b = rx_b.await.expect("relay b result").expect("relay b"); let (mut ra, mut wa) = tokio::io::split(relay_a); let (mut rb, mut wb) = tokio::io::split(relay_b); diff --git a/crates/openshell-server/tests/ws_tunnel_integration.rs b/crates/openshell-server/tests/ws_tunnel_integration.rs index 8212b1085..14d5e9bb7 100644 --- a/crates/openshell-server/tests/ws_tunnel_integration.rs +++ b/crates/openshell-server/tests/ws_tunnel_integration.rs @@ -45,9 +45,9 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + ProviderResponse, RelayFrame, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, TcpForwardFrame, + UpdateProviderRequest, WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -373,14 +373,24 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } - type RelayStreamStream = ReceiverStream>; + type RelayStreamStream = ReceiverStream>; async fn relay_stream( &self, - _request: tonic::Request>, + _request: tonic::Request>, ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = + std::pin::Pin> + Send>>; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // --------------------------------------------------------------------------- diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index 8571ebbe1..b96c0abbf 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -839,10 +839,7 @@ async fn handle_shell_connect( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, &app.endpoint); - let gateway_url = format!( - "{}://{}:{gateway_port}{}", - session.gateway_scheme, gateway_host, session.connect_path - ); + let gateway_url = format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port); // Step 4: Build the ProxyCommand using our own binary. let exe = match std::env::current_exe() { @@ -988,10 +985,7 @@ async fn handle_exec_command( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, &app.endpoint); - let gateway_url = format!( - "{}://{}:{gateway_port}{}", - session.gateway_scheme, gateway_host, session.connect_path - ); + let gateway_url = format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port); let exe = match std::env::current_exe() { Ok(p) => p, @@ -1080,7 +1074,8 @@ async fn handle_exec_command( // SSH utility functions are shared via openshell_core::forward. use openshell_core::forward::{ - build_proxy_command, resolve_ssh_gateway, shell_escape, validate_ssh_session_response, + build_proxy_command, format_gateway_url, resolve_ssh_gateway, shell_escape, + validate_ssh_session_response, }; /// Convert a `SandboxPolicy` proto into styled ratatui lines for the policy viewer. @@ -1424,10 +1419,7 @@ async fn start_port_forwards( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, endpoint); - let gateway_url = format!( - "{}://{}:{gateway_port}{}", - session.gateway_scheme, gateway_host, session.connect_path - ); + let gateway_url = format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port); // Build ProxyCommand. let exe = match std::env::current_exe() { diff --git a/proto/openshell.proto b/proto/openshell.proto index b0291254a..883c1576c 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -54,6 +54,9 @@ service OpenShell { // Execute a command in a ready sandbox and stream output. rpc ExecSandbox(ExecSandboxRequest) returns (stream ExecSandboxEvent); + // Forward one CLI-side TCP connection to a loopback TCP target in a sandbox. + rpc ForwardTcp(stream TcpForwardFrame) returns (stream TcpForwardFrame); + // Create a provider. rpc CreateProvider(CreateProviderRequest) returns (ProviderResponse); @@ -127,8 +130,9 @@ service OpenShell { // // The supervisor opens this stream at startup and keeps it alive for the // sandbox lifetime. The gateway uses it to coordinate relay channels for - // SSH connect and ExecSandbox. Raw SSH bytes flow over RelayStream calls - // (separate HTTP/2 streams on the same connection), not over this stream. + // SSH connect, ExecSandbox, and targetable sandbox services. Raw service + // bytes flow over RelayStream calls (separate HTTP/2 streams on the same + // connection), not over this stream. rpc ConnectSupervisor(stream SupervisorMessage) returns (stream GatewayMessage); // Raw byte relay between supervisor and gateway. @@ -137,8 +141,8 @@ service OpenShell { // on its ConnectSupervisor stream. The first RelayFrame carries a // RelayInit with the channel_id to associate the new HTTP/2 stream with // the pending relay slot on the gateway. Subsequent frames carry raw bytes in either - // direction between the gateway-side waiter (ssh_tunnel / exec handler) - // and the supervisor-side local SSH daemon bridge. + // direction between the gateway-side waiter (ForwardTcp / exec handler) + // and the supervisor-side target bridge. // // This rides the same TCP+TLS+HTTP/2 connection as ConnectSupervisor — // no new TLS handshake, no reverse HTTP CONNECT. @@ -446,11 +450,6 @@ message CreateSshSessionResponse { // Gateway scheme. Must be exactly "http" or "https". string gateway_scheme = 5; - // HTTP path for the CONNECT/upgrade endpoint. Must begin with `/`. RFC - // 3986 path charset only ([A-Za-z0-9._~!$&'()*+,;=:@/-] plus %HH). - // Must not contain `?`, `#`, whitespace, backtick, or backslash. - string connect_path = 6; - // Optional host key fingerprint. If non-empty, [A-Za-z0-9:+/=-] only. string host_key_fingerprint = 7; @@ -518,6 +517,30 @@ message ExecSandboxEvent { } } +// Initial frame for one TCP forward stream. +message TcpForwardInit { + // Sandbox id. + string sandbox_id = 1; + // Optional service identifier for audit/correlation. + string service_id = 4; + // Target the gateway should request from the supervisor. + oneof target { + SshRelayTarget ssh = 5; + TcpRelayTarget tcp = 6; + } + // Optional target-specific authorization token. SSH targets use this as the + // short-lived SSH session token issued by CreateSshSession. + string authorization_token = 7; +} + +// A single frame on the CLI-to-gateway TCP forward stream. +message TcpForwardFrame { + oneof payload { + TcpForwardInit init = 1; + bytes data = 2; + } +} + // SSH session record stored in persistence. message SshSession { // Kubernetes-style metadata (id, name, labels, timestamps, resource version). @@ -1030,10 +1053,29 @@ message GatewayHeartbeat {} // On receiving this, the supervisor should initiate a RelayStream RPC to // the gateway, sending a RelayInit in the first RelayFrame to associate // the new HTTP/2 stream with the pending relay slot. The supervisor -// bridges that stream to the local SSH daemon. +// bridges that stream to the requested local target. message RelayOpen { // Gateway-allocated channel identifier (UUID). string channel_id = 1; + // Target the supervisor should dial inside the sandbox. + // If absent, supervisors treat the relay as SSH for compatibility. + oneof target { + SshRelayTarget ssh = 2; + TcpRelayTarget tcp = 3; + } + // Optional service identifier for audit/correlation. + string service_id = 5; +} + +// Built-in SSH relay target. +message SshRelayTarget {} + +// TCP target dialed by the supervisor from inside the sandbox. +message TcpRelayTarget { + // Phase 1 accepts loopback only: 127.0.0.1, ::1, or localhost. + string host = 1; + // Target port. Must fit in u16 and be non-zero. + uint32 port = 2; } // Initial RelayStream frame sent by the supervisor to claim a pending relay. From 9ea94b645ddad445650ad0bcbee093beeaeb1451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aaron=20Erickson=20=F0=9F=A6=9E?= Date: Mon, 11 May 2026 22:40:44 -0700 Subject: [PATCH 038/157] fix(sandbox): rewrite messaging credential placeholders (#1286) * fix(sandbox): rewrite credential placeholders in websocket text frames Signed-off-by: Aaron Erickson * fix(sandbox): harden websocket credential rewrite Signed-off-by: Aaron Erickson * feat(sandbox): add websocket l7 inspection and compression Signed-off-by: Aaron Erickson * fix(sandbox): harden websocket upgrade validation Signed-off-by: Aaron Erickson * test(sandbox): cover route-selected websocket upgrades Signed-off-by: Aaron Erickson * fix(sandbox): harden websocket negotiation parsing Signed-off-by: Aaron Erickson * test(sandbox): add websocket conformance relay matrix Signed-off-by: Aaron Erickson * test(e2e): add websocket conformance lane Signed-off-by: Aaron Erickson * fix(policy): support websocket incremental rules Signed-off-by: Aaron Erickson * feat(policy): enable websocket credential rewrite updates Signed-off-by: Aaron Erickson * fix(cli): make websocket rewrite endpoint-local Signed-off-by: Aaron Erickson * feat(sandbox): support graphql websocket policy Signed-off-by: Aaron Erickson * fix(policy): allow private IPs for websocket endpoints * feat(sandbox): rewrite REST credential placeholders Signed-off-by: Aaron Erickson * refactor(sandbox): generalize credential aliases Signed-off-by: Aaron Erickson * fix(sandbox): rewrite encoded form credentials Signed-off-by: Aaron Erickson * fix(sandbox): close websocket policy and provider alias gaps Signed-off-by: Aaron Erickson * fix(e2e): route websocket probe through host gateway * fix(e2e): stabilize websocket probe handshake * fix(e2e): exercise websocket probe through proxy * ci: remove websocket conformance workflow --------- Signed-off-by: Aaron Erickson Co-authored-by: John Myers <9696606+johntmyers@users.noreply.github.com> --- Cargo.lock | 2 + architecture/security-policy.md | 10 +- crates/openshell-cli/src/main.rs | 7 +- crates/openshell-cli/src/policy_update.rs | 360 ++- crates/openshell-driver-docker/src/lib.rs | 21 +- crates/openshell-driver-docker/src/tests.rs | 16 +- crates/openshell-policy/src/lib.rs | 87 + crates/openshell-policy/src/merge.rs | 209 +- crates/openshell-providers/src/profiles.rs | 8 + crates/openshell-sandbox/Cargo.toml | 2 + .../data/sandbox-policy.rego | 10 + crates/openshell-sandbox/src/l7/graphql.rs | 13 + crates/openshell-sandbox/src/l7/mod.rs | 479 ++- crates/openshell-sandbox/src/l7/provider.rs | 5 +- crates/openshell-sandbox/src/l7/relay.rs | 731 ++++- crates/openshell-sandbox/src/l7/rest.rs | 2827 +++++++++++++++-- crates/openshell-sandbox/src/l7/websocket.rs | 1937 +++++++++++ crates/openshell-sandbox/src/opa.rs | 252 ++ crates/openshell-sandbox/src/policy_local.rs | 2 + .../src/provider_credentials.rs | 50 +- crates/openshell-sandbox/src/proxy.rs | 831 ++++- crates/openshell-sandbox/src/secrets.rs | 597 +++- .../tests/websocket_upgrade.rs | 2 +- crates/openshell-server/src/grpc/policy.rs | 62 + docs/reference/policy-schema.mdx | 91 +- docs/sandboxes/policies.mdx | 125 +- docs/security/best-practices.mdx | 8 +- e2e/rust/Cargo.lock | 19 + e2e/rust/Cargo.toml | 7 + e2e/rust/tests/websocket_conformance.rs | 482 +++ proto/sandbox.proto | 10 +- tasks/test.toml | 7 + 32 files changed, 8647 insertions(+), 622 deletions(-) create mode 100644 crates/openshell-sandbox/src/l7/websocket.rs create mode 100644 e2e/rust/tests/websocket_conformance.rs diff --git a/Cargo.lock b/Cargo.lock index 808956cd9..05a1bdff2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3571,6 +3571,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "flate2", "futures", "glob", "hex", @@ -3594,6 +3595,7 @@ dependencies = [ "serde", "serde_json", "serde_yml", + "sha1 0.10.6", "sha2 0.10.9", "temp-env", "tempfile", diff --git a/architecture/security-policy.md b/architecture/security-policy.md index e5f179dc1..5c04bebf5 100644 --- a/architecture/security-policy.md +++ b/architecture/security-policy.md @@ -43,9 +43,13 @@ with the sandbox's ephemeral CA and inspect method/path or protocol-specific metadata before forwarding. The proxy also supports credential injection on terminated HTTP streams when policy allows the endpoint. -Raw streams, HTTP upgrades, and long-lived response bodies are connection -scoped. Policy reloads affect the next connection or the next parsed HTTP -request; they do not rewrite bytes already being relayed. +Raw streams and long-lived response bodies are connection scoped. Policy +reloads affect the next connection or the next parsed HTTP request; they do not +rewrite bytes already being relayed. HTTP upgrades switch to raw relay by +default. A `protocol: rest` endpoint can opt in to +`websocket_credential_rewrite` for client-to-server WebSocket text messages +after an allowed `101` upgrade; server-to-client traffic and all other upgraded +protocols remain raw passthrough. ## Live Updates diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index e370d1f27..c06dda62d 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -289,6 +289,7 @@ const POLICY_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m $ openshell policy get my-sandbox $ openshell policy set my-sandbox --policy policy.yaml $ openshell policy update my-sandbox --add-endpoint api.github.com:443:read-only:rest:enforce + $ openshell policy update my-sandbox --add-endpoint realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite,allowed-ip=10.0.0.0/8 $ openshell policy update my-sandbox --add-allow 'api.github.com:443:GET:/repos/**' $ openshell policy set --global --policy policy.yaml $ openshell policy delete --global @@ -1406,7 +1407,7 @@ enum PolicyCommands { #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] name: Option, - /// Add or merge an endpoint: host:port[:access[:protocol[:enforcement]]]. + /// Add or merge an endpoint: host:port[:access[:protocol[:enforcement[:options]]]]. #[arg(long = "add-endpoint")] add_endpoints: Vec, @@ -1414,11 +1415,11 @@ enum PolicyCommands { #[arg(long = "remove-endpoint")] remove_endpoints: Vec, - /// Add a REST allow rule: `host:port:METHOD:path_glob`. + /// Add a REST or WebSocket method/path allow rule: `host:port:METHOD:path_glob`. #[arg(long = "add-allow")] add_allow: Vec, - /// Add a REST deny rule: `host:port:METHOD:path_glob`. + /// Add a REST or WebSocket method/path deny rule: `host:port:METHOD:path_glob`. #[arg(long = "add-deny")] add_deny: Vec, diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 322a28df6..57656b878 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -18,6 +18,7 @@ pub struct PolicyUpdatePlan { pub preview_operations: Vec, } +#[allow(clippy::too_many_arguments)] pub fn build_policy_update_plan( add_endpoints: &[String], remove_endpoints: &[String], @@ -41,7 +42,6 @@ pub fn build_policy_update_plan( "--rule-name is only supported when exactly one --add-endpoint is provided" )); } - let mut merge_operations = Vec::new(); let mut preview_operations = Vec::new(); @@ -155,6 +155,40 @@ pub fn build_policy_update_plan( }) } +fn ensure_websocket_credential_rewrite_protocol( + spec: &str, + endpoint: &NetworkEndpoint, +) -> Result<()> { + if matches!(endpoint.protocol.as_str(), "rest" | "websocket") { + return Ok(()); + } + let protocol = if endpoint.protocol.is_empty() { + "" + } else { + endpoint.protocol.as_str() + }; + Err(miette!( + "websocket-credential-rewrite endpoint option requires --add-endpoint protocol segment to be 'rest' or 'websocket'; got '{protocol}' in '{spec}'" + )) +} + +fn ensure_request_body_credential_rewrite_protocol( + spec: &str, + endpoint: &NetworkEndpoint, +) -> Result<()> { + if endpoint.protocol == "rest" { + return Ok(()); + } + let protocol = if endpoint.protocol.is_empty() { + "" + } else { + endpoint.protocol.as_str() + }; + Err(miette!( + "request-body-credential-rewrite endpoint option requires --add-endpoint protocol segment to be 'rest'; got '{protocol}' in '{spec}'" + )) +} + fn group_allow_rules(specs: &[String]) -> Result>> { let mut grouped = BTreeMap::new(); for spec in specs { @@ -257,9 +291,9 @@ fn parse_remove_endpoint_spec(spec: &str) -> Result<(String, u32)> { fn parse_add_endpoint_spec(spec: &str) -> Result { let parts = spec.split(':').collect::>(); - if !(2..=5).contains(&parts.len()) { + if !(2..=6).contains(&parts.len()) { return Err(miette!( - "--add-endpoint expects host:port[:access[:protocol[:enforcement]]], got '{spec}'" + "--add-endpoint expects host:port[:access[:protocol[:enforcement[:options]]]], got '{spec}'" )); } @@ -269,12 +303,18 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { let access = parts.get(2).copied().unwrap_or("").trim(); let protocol = parts.get(3).copied().unwrap_or("").trim(); let enforcement = parts.get(4).copied().unwrap_or("").trim(); + let options = parts.get(5).copied().unwrap_or("").trim(); if parts.len() == 3 && access.is_empty() { return Err(miette!( "--add-endpoint has an empty access segment in '{spec}'; omit it entirely if you do not need access or protocol fields" )); } + if parts.len() == 6 && options.is_empty() { + return Err(miette!( + "--add-endpoint has an empty options segment in '{spec}'; omit it entirely if you do not need endpoint options" + )); + } if !enforcement.is_empty() && protocol.is_empty() { return Err(miette!( "--add-endpoint cannot set enforcement without protocol in '{spec}'" @@ -285,9 +325,9 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { "--add-endpoint access segment must be one of read-only, read-write, or full; got '{access}' in '{spec}'" )); } - if !protocol.is_empty() && !matches!(protocol, "rest" | "sql") { + if !protocol.is_empty() && !matches!(protocol, "rest" | "websocket" | "sql") { return Err(miette!( - "--add-endpoint protocol segment must be 'rest' or 'sql'; got '{protocol}' in '{spec}'" + "--add-endpoint protocol segment must be 'rest', 'websocket', or 'sql'; got '{protocol}' in '{spec}'" )); } if !enforcement.is_empty() && !matches!(enforcement, "enforce" | "audit") { @@ -296,7 +336,7 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { )); } - Ok(NetworkEndpoint { + let mut endpoint = NetworkEndpoint { host, port, ports: vec![port], @@ -304,7 +344,65 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { enforcement: enforcement.to_string(), access: access.to_string(), ..Default::default() - }) + }; + apply_add_endpoint_options(spec, &mut endpoint, options)?; + Ok(endpoint) +} + +fn apply_add_endpoint_options( + spec: &str, + endpoint: &mut NetworkEndpoint, + options: &str, +) -> Result<()> { + if options.is_empty() { + return Ok(()); + } + + for option in options.split(',') { + let option = option.trim(); + if option.is_empty() { + return Err(miette!( + "--add-endpoint options segment must not contain empty options in '{spec}'" + )); + } + match option { + "websocket-credential-rewrite" => { + ensure_websocket_credential_rewrite_protocol(spec, endpoint)?; + endpoint.websocket_credential_rewrite = true; + } + "request-body-credential-rewrite" => { + ensure_request_body_credential_rewrite_protocol(spec, endpoint)?; + endpoint.request_body_credential_rewrite = true; + } + _ => { + let Some(allowed_ip) = option.strip_prefix("allowed-ip=") else { + return Err(miette!( + "--add-endpoint options segment supports only 'websocket-credential-rewrite', 'request-body-credential-rewrite', and 'allowed-ip='; got '{option}' in '{spec}'" + )); + }; + let allowed_ip = allowed_ip.trim(); + if allowed_ip.is_empty() { + return Err(miette!( + "--add-endpoint allowed-ip option must include a CIDR or IP value in '{spec}'" + )); + } + if allowed_ip.contains(char::is_whitespace) { + return Err(miette!( + "--add-endpoint allowed-ip option must not contain whitespace in '{spec}'" + )); + } + if !endpoint + .allowed_ips + .iter() + .any(|existing| existing == allowed_ip) + { + endpoint.allowed_ips.push(allowed_ip.to_string()); + } + } + } + } + + Ok(()) } fn parse_host(flag: &str, spec: &str, host: &str) -> Result { @@ -352,7 +450,30 @@ fn dedup_strings(values: &[String]) -> Vec { #[cfg(test)] mod tests { - use super::build_policy_update_plan; + use super::{ + PolicyUpdatePlan, build_policy_update_plan as build_policy_update_plan_with_options, + }; + use openshell_policy::PolicyMergeOp; + + fn build_policy_update_plan( + add_endpoints: &[String], + remove_endpoints: &[String], + add_deny: &[String], + add_allow: &[String], + remove_rules: &[String], + binaries: &[String], + rule_name: Option<&str>, + ) -> miette::Result { + build_policy_update_plan_with_options( + add_endpoints, + remove_endpoints, + add_deny, + add_allow, + remove_rules, + binaries, + rule_name, + ) + } #[test] fn parse_add_endpoint_basic_l4() { @@ -392,6 +513,229 @@ mod tests { .expect("plan should build"); } + #[test] + fn parse_add_endpoint_accepts_websocket_protocol() { + let plan = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert_eq!(endpoint.host, "realtime.example.com"); + assert_eq!(endpoint.protocol, "websocket"); + assert_eq!(endpoint.access, "read-write"); + assert_eq!(endpoint.enforcement, "enforce"); + } + + #[test] + fn parse_add_endpoint_enables_websocket_credential_rewrite() { + let plan = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite" + .to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert!(rule.endpoints[0].websocket_credential_rewrite); + } + + #[test] + fn parse_add_endpoint_enables_websocket_credential_rewrite_on_rest_compat_endpoint() { + let plan = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:rest:enforce:websocket-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert!(rule.endpoints[0].websocket_credential_rewrite); + } + + #[test] + fn parse_add_endpoint_enables_request_body_credential_rewrite_on_rest_endpoint() { + let plan = build_policy_update_plan( + &[ + "api.example.com:443:read-write:rest:enforce:request-body-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert_eq!(endpoint.protocol, "rest"); + assert!(endpoint.request_body_credential_rewrite); + } + + #[test] + fn parse_add_endpoint_merges_allowed_ips_with_websocket_options() { + let plan = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite,allowed-ip=10.0.0.0/8,allowed-ip=172.16.0.0/12,allowed-ip=10.0.0.0/8" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert!(endpoint.websocket_credential_rewrite); + assert_eq!( + endpoint.allowed_ips, + vec!["10.0.0.0/8".to_string(), "172.16.0.0/12".to_string()] + ); + } + + #[test] + fn parse_add_endpoint_accepts_allowed_ip_on_rest_endpoint() { + let plan = build_policy_update_plan( + &["api.example.com:443:read-write:rest:enforce:allowed-ip=192.168.0.0/16".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert_eq!(rule.endpoints[0].allowed_ips, vec!["192.168.0.0/16"]); + } + + #[test] + fn parse_add_endpoint_rejects_empty_allowed_ip() { + let error = build_policy_update_plan( + &["api.example.com:443:read-write:rest:enforce:allowed-ip=".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("allowed-ip option")); + } + + #[test] + fn websocket_credential_rewrite_rejects_l4_endpoint() { + let error = build_policy_update_plan( + &["realtime.example.com:443::::websocket-credential-rewrite".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("protocol segment")); + } + + #[test] + fn request_body_credential_rewrite_rejects_non_rest_endpoint() { + let error = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:websocket:enforce:request-body-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + + assert!(error.to_string().contains("protocol segment")); + assert!(error.to_string().contains("'rest'")); + } + + #[test] + fn parse_add_endpoint_rejects_unknown_options() { + let error = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce:future-option".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("options segment")); + } + + #[test] + fn parse_add_allow_accepts_websocket_text_method() { + let plan = build_policy_update_plan( + &[], + &[], + &[], + &["realtime.example.com:443:websocket_text:/v1/messages/**".to_string()], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddAllowRules { host, port, rules } = &plan.preview_operations[0] else { + panic!("expected add-allow preview"); + }; + assert_eq!(host, "realtime.example.com"); + assert_eq!(*port, 443); + let allow = rules[0].allow.as_ref().expect("allow rule"); + assert_eq!(allow.method, "WEBSOCKET_TEXT"); + assert_eq!(allow.path, "/v1/messages/**"); + } + #[test] fn parse_add_deny_rejects_empty_method() { let error = build_policy_update_plan( diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index cff4c57b9..6059596ab 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -1103,7 +1103,7 @@ fn docker_gateway_route( }; } - if is_compat_docker_runtime(info) { + if uses_host_gateway_alias(info) { DockerGatewayRoute::HostGateway } else { DockerGatewayRoute::Bridge { @@ -1113,15 +1113,15 @@ fn docker_gateway_route( } } -/// Detect Docker Desktop and behaviourally compatible runtimes — Colima, -/// Lima, Rancher Desktop, and `OrbStack` — that share Docker Desktop's -/// routing constraint: the bridge gateway IP is reachable from inside -/// containers but not from the `OpenShell` server process running on the -/// host, so callbacks must traverse `host-gateway`. +/// Detect Docker Desktop and behaviourally compatible runtimes - Colima, +/// Lima, Rancher Desktop, and `OrbStack` - that share Docker Desktop's routing +/// constraint: the bridge gateway IP is reachable from inside containers but +/// not from the `OpenShell` server process running on the host, so callbacks +/// must traverse `host-gateway`. /// /// Each runtime is detected via the daemon's reported OS string or hostname, /// supplemented by labels where the runtime publishes them. -fn is_compat_docker_runtime(info: &SystemInfo) -> bool { +fn uses_host_gateway_alias(info: &SystemInfo) -> bool { let operating_system = info .operating_system .as_deref() @@ -1159,9 +1159,10 @@ fn docker_extra_hosts(route: &DockerGatewayRoute) -> Vec { format!("{HOST_DOCKER_INTERNAL}:{host_alias_ip}"), format!("{HOST_OPENSHELL_INTERNAL}:{host_alias_ip}"), ], - DockerGatewayRoute::HostGateway => { - vec![format!("{HOST_OPENSHELL_INTERNAL}:host-gateway")] - } + DockerGatewayRoute::HostGateway => vec![ + format!("{HOST_DOCKER_INTERNAL}:host-gateway"), + format!("{HOST_OPENSHELL_INTERNAL}:host-gateway"), + ], } } diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index 2e43e96e9..df68d39d6 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -160,27 +160,37 @@ fn docker_gateway_route_uses_host_gateway_for_docker_desktop() { ); assert_eq!( docker_extra_hosts(&DockerGatewayRoute::HostGateway), - vec!["host.openshell.internal:host-gateway".to_string()] + vec![ + "host.docker.internal:host-gateway".to_string(), + "host.openshell.internal:host-gateway".to_string() + ] ); } #[test] fn docker_gateway_route_uses_host_gateway_for_colima() { let info = SystemInfo { - operating_system: Some("Ubuntu 24.04 LTS".to_string()), name: Some("colima".to_string()), + operating_system: Some("Ubuntu 24.04.4 LTS".to_string()), ..Default::default() }; assert_eq!( docker_gateway_route( &info, - IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), + IpAddr::V4(Ipv4Addr::new(172, 20, 0, 1)), DEFAULT_SERVER_PORT, None, ), DockerGatewayRoute::HostGateway ); + assert_eq!( + docker_extra_hosts(&DockerGatewayRoute::HostGateway), + vec![ + "host.docker.internal:host-gateway".to_string(), + "host.openshell.internal:host-gateway".to_string() + ] + ); } #[test] diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 61df0aadb..908450111 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -120,6 +120,15 @@ struct NetworkEndpointDef { /// Defaults to false (strict). #[serde(default, skip_serializing_if = "std::ops::Not::not")] allow_encoded_slash: bool, + /// When true, client-to-server WebSocket text messages on this REST + /// endpoint rewrite credential placeholders after an allowed 101 upgrade. + /// Defaults to false. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + websocket_credential_rewrite: bool, + /// When true, supported textual REST request bodies rewrite credential + /// placeholders before forwarding upstream. Defaults to false. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + request_body_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] persisted_queries: String, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] @@ -317,6 +326,8 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { }) .collect(), allow_encoded_slash: e.allow_encoded_slash, + websocket_credential_rewrite: e.websocket_credential_rewrite, + request_body_credential_rewrite: e.request_body_credential_rewrite, persisted_queries: e.persisted_queries, graphql_persisted_queries: e .graphql_persisted_queries @@ -480,6 +491,8 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(), allow_encoded_slash: e.allow_encoded_slash, + websocket_credential_rewrite: e.websocket_credential_rewrite, + request_body_credential_rewrite: e.request_body_credential_rewrite, persisted_queries: e.persisted_queries.clone(), graphql_persisted_queries: e .graphql_persisted_queries @@ -1656,6 +1669,80 @@ network_policies: assert_eq!(ep.deny_rules[0].fields, vec!["deleteRepository"]); } + #[test] + fn round_trip_preserves_websocket_credential_rewrite() { + let yaml = r" +version: 1 +network_policies: + discord_gateway: + name: discord_gateway + endpoints: + - host: gateway.example.com + port: 443 + protocol: rest + enforcement: enforce + access: full + websocket_credential_rewrite: true + binaries: + - path: /usr/bin/node +"; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let ep = &proto2.network_policies["discord_gateway"].endpoints[0]; + assert_eq!(ep.protocol, "rest"); + assert!(ep.websocket_credential_rewrite); + assert!(yaml_out.contains("websocket_credential_rewrite: true")); + } + + #[test] + fn round_trip_preserves_request_body_credential_rewrite() { + let yaml = r" +version: 1 +network_policies: + slack_api: + name: slack_api + endpoints: + - host: slack.com + port: 443 + protocol: rest + enforcement: enforce + access: read-write + request_body_credential_rewrite: true + binaries: + - path: /usr/bin/node +"; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let ep = &proto2.network_policies["slack_api"].endpoints[0]; + assert_eq!(ep.protocol, "rest"); + assert!(ep.request_body_credential_rewrite); + assert!(yaml_out.contains("request_body_credential_rewrite: true")); + } + + #[test] + fn websocket_credential_rewrite_defaults_false() { + let yaml = r" +version: 1 +network_policies: + gateway: + endpoints: + - host: gateway.example.com + port: 443 + protocol: rest + access: full + binaries: + - path: /usr/bin/node +"; + let proto = parse_sandbox_policy(yaml).expect("parse failed"); + let ep = &proto.network_policies["gateway"].endpoints[0]; + assert!(!ep.websocket_credential_rewrite); + assert!(!ep.request_body_credential_rewrite); + } + #[test] fn parse_rejects_unknown_fields_in_deny_rule() { let yaml = r" diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 7a5dec916..d99d9c216 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -184,7 +184,7 @@ impl std::fmt::Display for PolicyMergeError { protocol, } => write!( f, - "endpoint {host}:{port} uses unsupported protocol '{protocol}'; this operation currently supports only protocol 'rest'" + "endpoint {host}:{port} uses unsupported protocol '{protocol}'; this operation currently supports only protocol 'rest' or 'websocket'" ), Self::EndpointHasNoAllowBase { host, port } => write!( f, @@ -265,7 +265,7 @@ fn apply_operation( port: *port, } })?; - ensure_rest_endpoint(endpoint, host, *port)?; + ensure_method_path_endpoint(endpoint, host, *port)?; if endpoint.access.is_empty() && endpoint.rules.is_empty() { return Err(PolicyMergeError::EndpointHasNoAllowBase { host: host.clone(), @@ -281,7 +281,7 @@ fn apply_operation( port: *port, } })?; - ensure_rest_endpoint(endpoint, host, *port)?; + ensure_method_path_endpoint(endpoint, host, *port)?; expand_existing_access(endpoint, host, *port, warnings)?; append_unique_l7_rules(&mut endpoint.rules, rules); } @@ -462,6 +462,9 @@ fn merge_endpoint( append_unique_deny_rules(&mut existing.deny_rules, &incoming.deny_rules); append_unique_strings(&mut existing.allowed_ips, &incoming.allowed_ips); + existing.allow_encoded_slash |= incoming.allow_encoded_slash; + existing.websocket_credential_rewrite |= incoming.websocket_credential_rewrite; + existing.request_body_credential_rewrite |= incoming.request_body_credential_rewrite; normalize_endpoint(existing); Ok(()) } @@ -568,7 +571,7 @@ fn endpoint_matches_host_port(endpoint: &NetworkEndpoint, host: &str, port: u32) endpoint.host.eq_ignore_ascii_case(host) && canonical_ports(endpoint).contains(&port) } -fn ensure_rest_endpoint( +fn ensure_method_path_endpoint( endpoint: &NetworkEndpoint, host: &str, port: u32, @@ -579,7 +582,7 @@ fn ensure_rest_endpoint( port, }); } - if endpoint.protocol != "rest" { + if !matches!(endpoint.protocol.as_str(), "rest" | "websocket") { return Err(PolicyMergeError::UnsupportedEndpointProtocol { host: host.to_string(), port, @@ -600,12 +603,13 @@ fn expand_existing_access( } let access = endpoint.access.clone(); - let expanded = - expand_access_preset(&access).ok_or_else(|| PolicyMergeError::UnsupportedAccessPreset { + let expanded = expand_access_preset(&endpoint.protocol, &access).ok_or_else(|| { + PolicyMergeError::UnsupportedAccessPreset { host: host.to_string(), port, access: access.clone(), - })?; + } + })?; endpoint.access.clear(); append_unique_l7_rules(&mut endpoint.rules, &expanded); warnings.push(PolicyMergeWarning::ExpandedAccessPreset { @@ -616,11 +620,13 @@ fn expand_existing_access( Ok(()) } -fn expand_access_preset(access: &str) -> Option> { - let methods = match access { - "read-only" => vec!["GET", "HEAD", "OPTIONS"], - "read-write" => vec!["GET", "HEAD", "OPTIONS", "POST", "PUT", "PATCH"], - "full" => vec!["*"], +fn expand_access_preset(protocol: &str, access: &str) -> Option> { + let methods = match (protocol, access) { + (_, "full") => vec!["*"], + ("websocket", "read-only") => vec!["GET"], + ("websocket", "read-write") => vec!["GET", "WEBSOCKET_TEXT"], + (_, "read-only") => vec!["GET", "HEAD", "OPTIONS"], + (_, "read-write") => vec!["GET", "HEAD", "OPTIONS", "POST", "PUT", "PATCH"], _ => return None, }; @@ -870,6 +876,96 @@ mod tests { assert_eq!(rule.binaries.len(), 2); } + #[test] + fn add_rule_merges_websocket_credential_rewrite_flag() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let incoming = NetworkPolicyRule { + name: "incoming".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + ..Default::default() + }; + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_realtime_example_com_443".to_string(), + rule: incoming, + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["existing"].endpoints[0]; + assert!(endpoint.websocket_credential_rewrite); + } + + #[test] + fn add_rule_merges_request_body_credential_rewrite_flag() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let incoming = NetworkPolicyRule { + name: "incoming".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + ..Default::default() + }; + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_slack_com_443".to_string(), + rule: incoming, + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["existing"].endpoints[0]; + assert!(endpoint.request_body_credential_rewrite); + } + #[test] fn add_allow_expands_access_preset() { let mut policy = restrictive_default_policy(); @@ -909,7 +1005,92 @@ mod tests { } #[test] - fn add_deny_requires_rest_protocol() { + fn add_allow_expands_websocket_access_preset() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "realtime".to_string(), + NetworkPolicyRule { + name: "realtime".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddAllowRules { + host: "realtime.example.com".to_string(), + port: 443, + rules: vec![rest_rule("WEBSOCKET_TEXT", "/rooms/private/**")], + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["realtime"].endpoints[0]; + assert!(endpoint.access.is_empty()); + assert_eq!(endpoint.rules.len(), 3); + assert!(endpoint.rules.contains(&rest_rule("GET", "**"))); + assert!(endpoint.rules.contains(&rest_rule("WEBSOCKET_TEXT", "**"))); + assert!( + endpoint + .rules + .contains(&rest_rule("WEBSOCKET_TEXT", "/rooms/private/**")) + ); + assert!(!endpoint.rules.contains(&rest_rule("POST", "**"))); + assert!(result.warnings.iter().any(|warning| matches!( + warning, + PolicyMergeWarning::ExpandedAccessPreset { access, .. } if access == "read-write" + ))); + } + + #[test] + fn add_deny_accepts_websocket_protocol() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "realtime".to_string(), + NetworkPolicyRule { + name: "realtime".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddDenyRules { + host: "realtime.example.com".to_string(), + port: 443, + deny_rules: vec![L7DenyRule { + method: "WEBSOCKET_TEXT".to_string(), + path: "/admin/**".to_string(), + ..Default::default() + }], + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["realtime"].endpoints[0]; + assert_eq!(endpoint.deny_rules.len(), 1); + assert_eq!(endpoint.deny_rules[0].method, "WEBSOCKET_TEXT"); + assert_eq!(endpoint.deny_rules[0].path, "/admin/**"); + } + + #[test] + fn add_deny_rejects_unsupported_protocol() { let mut policy = restrictive_default_policy(); policy.network_policies.insert( "db".to_string(), diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 8c3f247cf..588e77702 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -114,6 +114,10 @@ pub struct EndpointProfile { pub deny_rules: Vec, #[serde(default, skip_serializing_if = "is_false")] pub allow_encoded_slash: bool, + #[serde(default, skip_serializing_if = "is_false")] + pub websocket_credential_rewrite: bool, + #[serde(default, skip_serializing_if = "is_false")] + pub request_body_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] pub persisted_queries: String, #[serde(default, skip_serializing_if = "HashMap::is_empty")] @@ -414,6 +418,8 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { ports: endpoint.ports.clone(), deny_rules: endpoint.deny_rules.iter().map(deny_rule_to_proto).collect(), allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: endpoint.websocket_credential_rewrite, + request_body_credential_rewrite: endpoint.request_body_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries @@ -442,6 +448,8 @@ fn endpoint_from_proto(endpoint: &NetworkEndpoint) -> EndpointProfile { .map(deny_rule_from_proto) .collect(), allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: endpoint.websocket_credential_rewrite, + request_body_credential_rewrite: endpoint.request_body_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 4e07521ce..29919ede4 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -58,6 +58,8 @@ uuid = { workspace = true } # Encoding base64 = { workspace = true } +flate2 = "1" +sha1 = "0.10" # IP network / CIDR parsing ipnet = "2" diff --git a/crates/openshell-sandbox/data/sandbox-policy.rego b/crates/openshell-sandbox/data/sandbox-policy.rego index 9fa820627..a8e4affce 100644 --- a/crates/openshell-sandbox/data/sandbox-policy.rego +++ b/crates/openshell-sandbox/data/sandbox-policy.rego @@ -260,6 +260,16 @@ request_denied_for_endpoint(request, endpoint) if { not graphql_request_allowed(request, endpoint) } +# The same authority applies when a WebSocket endpoint opts into GraphQL +# operation policy. Once the relay classifies a client text message as a +# GraphQL-over-WebSocket operation, generic WEBSOCKET_TEXT rules must not bypass +# operation_type / operation_name / fields policy. +request_denied_for_endpoint(request, endpoint) if { + endpoint.protocol == "websocket" + is_object(request.graphql) + not graphql_request_allowed(request, endpoint) +} + # Deny query matching: fail-closed semantics. # If no query rules on the deny rule, match unconditionally (any query params). # If query rules present, trigger the deny if ANY value for a configured key diff --git a/crates/openshell-sandbox/src/l7/graphql.rs b/crates/openshell-sandbox/src/l7/graphql.rs index db91ecb45..5d0746d01 100644 --- a/crates/openshell-sandbox/src/l7/graphql.rs +++ b/crates/openshell-sandbox/src/l7/graphql.rs @@ -78,6 +78,19 @@ pub fn classify_request(request: &L7Request, body: &[u8]) -> GraphqlRequestInfo } } +pub fn classify_json_envelope_value(value: &Value) -> GraphqlRequestInfo { + match classify_json_envelope(value) { + Ok(operations) => GraphqlRequestInfo { + operations, + error: None, + }, + Err(err) => GraphqlRequestInfo { + operations: Vec::new(), + error: Some(err), + }, + } +} + fn classify_request_inner( request: &L7Request, body: &[u8], diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index 5301ac4d5..09278b4f8 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -15,11 +15,13 @@ pub mod provider; pub mod relay; pub mod rest; pub mod tls; +pub(crate) mod websocket; /// Application-layer protocol for L7 inspection. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum L7Protocol { Rest, + Websocket, Graphql, Sql, } @@ -28,6 +30,7 @@ impl L7Protocol { pub fn parse(s: &str) -> Option { match s.to_ascii_lowercase().as_str() { "rest" => Some(Self::Rest), + "websocket" => Some(Self::Websocket), "graphql" => Some(Self::Graphql), "sql" => Some(Self::Sql), _ => None, @@ -58,6 +61,10 @@ pub enum EnforcementMode { } /// L7 configuration for an endpoint, extracted from policy data. +#[allow( + clippy::struct_excessive_bools, + reason = "Endpoint config mirrors independent policy schema toggles." +)] #[derive(Debug, Clone)] pub struct L7EndpointConfig { pub protocol: L7Protocol, @@ -72,6 +79,15 @@ pub struct L7EndpointConfig { /// rather than rejected at the parser. Needed by upstreams like GitLab /// that embed `%2F` in namespaced project paths. Defaults to false. pub allow_encoded_slash: bool, + /// Opt-in rewrite of credential placeholders in client-to-server + /// WebSocket text messages after an allowed HTTP 101 upgrade. + pub websocket_credential_rewrite: bool, + /// Opt-in rewrite of credential placeholders in supported textual REST + /// request bodies before forwarding upstream. + pub request_body_credential_rewrite: bool, + /// When true, client-to-server GraphQL-over-WebSocket operation messages + /// are classified with the same operation policy used by GraphQL-over-HTTP. + pub websocket_graphql_policy: bool, } /// Result of an L7 policy decision for a single request. @@ -138,6 +154,12 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { }; let allow_encoded_slash = get_object_bool(val, "allow_encoded_slash").unwrap_or(false); + let websocket_credential_rewrite = + get_object_bool(val, "websocket_credential_rewrite").unwrap_or(false); + let request_body_credential_rewrite = + get_object_bool(val, "request_body_credential_rewrite").unwrap_or(false); + let websocket_graphql_policy = + protocol == L7Protocol::Websocket && endpoint_has_graphql_policy(val); let graphql_max_body_bytes = get_object_u64(val, "graphql_max_body_bytes") .and_then(|v| usize::try_from(v).ok()) .filter(|v| *v > 0) @@ -150,6 +172,9 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { enforcement, graphql_max_body_bytes, allow_encoded_slash, + websocket_credential_rewrite, + request_body_credential_rewrite, + websocket_graphql_policy, }) } @@ -231,6 +256,60 @@ fn get_object_str(val: ®orus::Value, key: &str) -> Option { } } +fn endpoint_has_graphql_policy(val: ®orus::Value) -> bool { + has_non_empty_object_field(val, "graphql_persisted_queries") + || has_graphql_persisted_query_mode(val) + || rules_have_graphql_policy(val, "rules", true) + || rules_have_graphql_policy(val, "deny_rules", false) +} + +fn rules_have_graphql_policy(val: ®orus::Value, key: &str, allow_wrapped: bool) -> bool { + let Some(regorus::Value::Array(rules)) = get_object_value(val, key) else { + return false; + }; + rules.iter().any(|rule| { + let rule = if allow_wrapped { + get_object_value(rule, "allow").unwrap_or(rule) + } else { + rule + }; + has_graphql_rule_fields(rule) + }) +} + +fn has_graphql_rule_fields(val: ®orus::Value) -> bool { + has_non_empty_string_field(val, "operation_type") + || has_non_empty_string_field(val, "operation_name") + || has_non_empty_array_field(val, "fields") +} + +fn has_non_empty_string_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::String(s)) if !s.is_empty()) +} + +fn has_non_empty_array_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Array(values)) if !values.is_empty()) +} + +fn has_non_empty_object_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Object(values)) if !values.is_empty()) +} + +fn has_graphql_persisted_query_mode(val: ®orus::Value) -> bool { + matches!( + get_object_value(val, "persisted_queries"), + Some(regorus::Value::String(mode)) if !mode.is_empty() && mode.as_ref() != "deny" + ) +} + +fn get_object_value<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => map.get(&key_val), + _ => None, + } +} + /// Check a glob pattern for obvious syntax issues. /// /// Returns `Some(warning_message)` if the pattern looks malformed. @@ -353,6 +432,45 @@ fn validate_graphql_rule( validate_graphql_fields(errors, warnings, loc, rule.get("fields")); } +fn json_rule_has_graphql_fields(rule: &serde_json::Value) -> bool { + rule.get("operation_type") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule + .get("operation_name") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule.get("fields").is_some() +} + +fn json_rule_has_transport_fields(rule: &serde_json::Value) -> bool { + rule.get("method").is_some() || rule.get("path").is_some() || rule.get("query").is_some() +} + +fn json_endpoint_has_graphql_policy(ep: &serde_json::Value) -> bool { + ep.get("graphql_persisted_queries") + .and_then(|v| v.as_object()) + .is_some_and(|v| !v.is_empty()) + || ep + .get("persisted_queries") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty() && v != "deny") + || ep + .get("rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| { + rules.iter().any(|rule| { + rule.get("allow") + .or(Some(rule)) + .is_some_and(json_rule_has_graphql_fields) + }) + }) + || ep + .get("deny_rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| rules.iter().any(json_rule_has_graphql_fields)) +} + /// Validate L7 policy configuration in the loaded OPA data. /// /// Returns a list of errors and warnings. Errors should prevent sandbox startup; @@ -382,6 +500,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< .get("rules") .and_then(|v| v.as_array()) .is_some_and(|a| !a.is_empty()); + let websocket_has_graphql_policy = + protocol == "websocket" && json_endpoint_has_graphql_policy(ep); let host = ep.get("host").and_then(|v| v.as_str()).unwrap_or(""); let endpoint_path = ep.get("path").and_then(|v| v.as_str()).unwrap_or(""); @@ -462,7 +582,7 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< if !protocol.is_empty() && L7Protocol::parse(protocol).is_none() { errors.push(format!( - "{loc}: unknown protocol '{protocol}' (expected rest, graphql, or sql)" + "{loc}: unknown protocol '{protocol}' (expected rest, websocket, graphql, or sql)" )); } @@ -489,12 +609,36 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } if protocol != "graphql" + && protocol != "websocket" && (ep.get("persisted_queries").is_some() || ep.get("graphql_persisted_queries").is_some() || ep.get("graphql_max_body_bytes").is_some()) { warnings.push(format!( - "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql" + "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql or websocket" + )); + } + + if ep + .get("websocket_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + && protocol != "websocket" + { + warnings.push(format!( + "{loc}: websocket_credential_rewrite is ignored unless protocol is rest or websocket" + )); + } + + if ep + .get("request_body_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + { + warnings.push(format!( + "{loc}: request_body_credential_rewrite is ignored unless protocol is rest" )); } @@ -574,14 +718,13 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< // Validate method if let Some(method) = deny_rule.get("method").and_then(|m| m.as_str()) && !method.is_empty() - && protocol == "rest" + && (protocol == "rest" || protocol == "websocket") { - let valid_methods = [ - "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", - ]; + let valid_methods = valid_methods_for_protocol(protocol); if !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { warnings.push(format!( - "{deny_loc}: Unknown HTTP method '{method}'. Standard methods: GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS." + "{deny_loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") )); } } @@ -701,7 +844,17 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< .push(format!("{deny_loc}: command is for SQL protocol, not REST")); } - if protocol == "graphql" { + let deny_has_graphql = json_rule_has_graphql_fields(deny_rule); + if protocol == "websocket" + && deny_has_graphql + && json_rule_has_transport_fields(deny_rule) + { + errors.push(format!( + "{deny_loc}: WebSocket GraphQL deny rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + + if protocol == "graphql" || (protocol == "websocket" && deny_has_graphql) { validate_graphql_rule( &mut errors, &mut warnings, @@ -709,12 +862,9 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< deny_rule, true, ); - } else if deny_rule.get("operation_type").is_some() - || deny_rule.get("operation_name").is_some() - || deny_rule.get("fields").is_some() - { + } else if deny_has_graphql { warnings.push(format!( - "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql" + "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" )); } } @@ -733,10 +883,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } // Validate HTTP methods in rules - if has_rules && protocol == "rest" { - let valid_methods = [ - "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", - ]; + if has_rules && (protocol == "rest" || protocol == "websocket") { + let valid_methods = valid_methods_for_protocol(protocol); if let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { for (rule_idx, rule) in rules.iter().enumerate() { if let Some(method) = rule @@ -747,7 +895,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< && !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { warnings.push(format!( - "{loc}: Unknown HTTP method '{method}'. Standard methods: GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS." + "{loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") )); } @@ -858,14 +1007,36 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } } - if has_rules - && protocol == "graphql" - && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) - { + if has_rules && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { for (rule_idx, rule) in rules.iter().enumerate() { let allow = rule.get("allow").unwrap_or(rule); let rule_loc = format!("{loc}.rules[{rule_idx}].allow"); - validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); + let allow_has_graphql = json_rule_has_graphql_fields(allow); + if websocket_has_graphql_policy + && allow + .get("method") + .and_then(|m| m.as_str()) + .is_some_and(|method| method.eq_ignore_ascii_case("WEBSOCKET_TEXT")) + { + errors.push(format!( + "{rule_loc}: WebSocket endpoints with GraphQL operation policy must use operation_type/operation_name/fields rules for client messages instead of WEBSOCKET_TEXT" + )); + } + if protocol == "websocket" + && allow_has_graphql + && json_rule_has_transport_fields(allow) + { + errors.push(format!( + "{rule_loc}: WebSocket GraphQL allow rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + if protocol == "graphql" || (protocol == "websocket" && allow_has_graphql) { + validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); + } else if allow_has_graphql { + warnings.push(format!( + "{rule_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" + )); + } } } } @@ -921,6 +1092,13 @@ pub fn expand_access_presets(data: &mut serde_json::Value) { "full" => vec![graphql_rule_json("*")], _ => continue, } + } else if protocol == "websocket" { + match access.as_str() { + "read-only" => vec![rule_json("GET", "**")], + "read-write" => vec![rule_json("GET", "**"), rule_json("WEBSOCKET_TEXT", "**")], + "full" => vec![rule_json("*", "**")], + _ => continue, + } } else { match access.as_str() { "read-only" => vec![ @@ -957,6 +1135,15 @@ fn rule_json(method: &str, path: &str) -> serde_json::Value { }) } +fn valid_methods_for_protocol(protocol: &str) -> &'static [&'static str] { + match protocol { + "websocket" => &["GET", "WEBSOCKET_TEXT", "*"], + _ => &[ + "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", + ], + } +} + fn graphql_rule_json(operation_type: &str) -> serde_json::Value { serde_json::json!({ "allow": { @@ -994,6 +1181,16 @@ mod tests { assert_eq!(config.enforcement, EnforcementMode::Audit); } + #[test] + fn parse_l7_config_websocket_protocol() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert_eq!(config.protocol, L7Protocol::Websocket); + } + #[test] fn parse_l7_config_skip() { let val = regorus::Value::from_json_str( @@ -1031,6 +1228,242 @@ mod tests { assert!(config.allow_encoded_slash); } + #[test] + fn parse_l7_config_websocket_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_credential_rewrite); + } + + #[test] + fn parse_l7_config_websocket_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443, "websocket_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_credential_rewrite); + } + + #[test] + fn parse_l7_config_request_body_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.request_body_credential_rewrite); + } + + #[test] + fn parse_l7_config_request_body_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443, "request_body_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.request_body_credential_rewrite); + } + + #[test] + fn parse_l7_config_websocket_graphql_policy_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_graphql_policy); + } + + #[test] + fn parse_l7_config_websocket_graphql_policy_detects_operation_rules() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_graphql_policy); + } + + #[test] + fn validate_websocket_credential_rewrite_warns_unless_rest_or_websocket() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "websocket_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("websocket_credential_rewrite is ignored")), + "expected websocket_credential_rewrite warning: {warnings:?}" + ); + } + + #[test] + fn validate_request_body_credential_rewrite_warns_unless_rest() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "request_body_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("request_body_credential_rewrite is ignored")), + "expected request_body_credential_rewrite warning: {warnings:?}" + ); + } + + #[test] + fn expand_websocket_read_write_access_includes_text_messages() { + let mut data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "access": "read-write" + }], + "binaries": [] + } + } + }); + + expand_access_presets(&mut data); + let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] + .as_array() + .unwrap(); + let methods: Vec<&str> = rules + .iter() + .map(|r| r["allow"]["method"].as_str().unwrap()) + .collect(); + assert!(methods.contains(&"GET")); + assert!(methods.contains(&"WEBSOCKET_TEXT")); + } + + #[test] + fn validate_websocket_accepts_graphql_operation_rules() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!(errors.is_empty(), "expected no errors: {errors:?}"); + assert!(warnings.is_empty(), "expected no warnings: {warnings:?}"); + } + + #[test] + fn validate_websocket_graphql_rule_requires_operation_type() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("operation_type")), + "expected missing operation_type error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_rule_rejects_mixed_transport_fields() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql", "operation_type": "subscription"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("must not combine")), + "expected mixed-field error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_policy_rejects_raw_text_message_rule() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}, + {"allow": {"operation_type": "query"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("instead of WEBSOCKET_TEXT")), + "expected raw WEBSOCKET_TEXT rejection: {errors:?}" + ); + } + #[test] fn validate_rules_and_access_mutual_exclusion() { let data = serde_json::json!({ diff --git a/crates/openshell-sandbox/src/l7/provider.rs b/crates/openshell-sandbox/src/l7/provider.rs index 7516aa85c..864d94ad2 100644 --- a/crates/openshell-sandbox/src/l7/provider.rs +++ b/crates/openshell-sandbox/src/l7/provider.rs @@ -27,7 +27,10 @@ pub enum RelayOutcome { /// Contains any overflow bytes read from upstream past the 101 response /// headers that belong to the upgraded protocol. The 101 headers /// themselves have already been forwarded to the client. - Upgraded { overflow: Vec }, + Upgraded { + overflow: Vec, + websocket_permessage_deflate: bool, + }, } /// Body framing for HTTP requests/responses. diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index f099c3558..6d271af21 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -8,6 +8,7 @@ //! and either forwards or denies the request. use crate::l7::provider::{L7Provider, RelayOutcome}; +use crate::l7::rest::WebSocketExtensionMode; use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; use crate::opa::{PolicyGenerationGuard, TunnelPolicyEngine}; use crate::secrets::{self, SecretResolver}; @@ -38,6 +39,44 @@ pub struct L7EvalContext { pub(crate) secret_resolver: Option>, } +#[derive(Default)] +pub(crate) struct UpgradeRelayOptions<'a> { + pub(crate) websocket_request: bool, + pub(crate) websocket: WebSocketUpgradeBehavior, + pub(crate) secret_resolver: Option>, + pub(crate) engine: Option<&'a TunnelPolicyEngine>, + pub(crate) ctx: Option<&'a L7EvalContext>, + pub(crate) enforcement: EnforcementMode, + pub(crate) target: String, + pub(crate) query_params: std::collections::HashMap>, + pub(crate) policy_name: String, +} + +#[derive(Default)] +pub(crate) struct WebSocketUpgradeBehavior { + pub(crate) credential_rewrite: bool, + pub(crate) message_policy: WebSocketMessagePolicy, + pub(crate) permessage_deflate: bool, +} + +#[derive(Clone, Copy, Default, PartialEq, Eq)] +pub(crate) enum WebSocketMessagePolicy { + #[default] + None, + Transport, + Graphql, +} + +impl WebSocketMessagePolicy { + fn inspects_messages(self) -> bool { + self != Self::None + } + + fn is_graphql(self) -> bool { + self == Self::Graphql + } +} + #[derive(Debug, Clone, Copy)] enum ParseRejectionMode { L7Endpoint, @@ -101,7 +140,9 @@ where U: AsyncRead + AsyncWrite + Unpin + Send, { match config.protocol { - L7Protocol::Rest => relay_rest(config, &engine, client, upstream, ctx).await, + L7Protocol::Rest | L7Protocol::Websocket => { + relay_rest(config, &engine, client, upstream, ctx).await + } L7Protocol::Graphql => relay_graphql(config, &engine, client, upstream, ctx).await, L7Protocol::Sql => { if close_if_stale(engine.generation_guard(), ctx) { @@ -242,6 +283,24 @@ where query_params: req.query_params.clone(), graphql: graphql_info.clone(), }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } let parse_error_reason = graphql_info .as_ref() @@ -264,10 +323,10 @@ where (false, EnforcementMode::Audit) => "audit", (false, EnforcementMode::Enforce) => "deny", }; - let engine_type = if config.protocol == L7Protocol::Graphql { - "l7-graphql" - } else { - "l7" + let engine_type = match config.protocol { + L7Protocol::Graphql => "l7-graphql", + L7Protocol::Websocket => "l7-websocket", + L7Protocol::Rest | L7Protocol::Sql => "l7", }; emit_l7_request_log( ctx, @@ -282,19 +341,39 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, + }, ) .await?; match outcome { RelayOutcome::Reusable => {} RelayOutcome::Consumed => return Ok(()), - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req.query_params, + Some(&engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; } } } else { @@ -374,20 +453,29 @@ fn emit_l7_request_log( /// Handle an upgraded connection (101 Switching Protocols). /// /// Forwards any overflow bytes from the upgrade response to the client, then -/// switches to raw bidirectional TCP copy for the upgraded protocol (WebSocket, -/// HTTP/2, etc.). L7 policy enforcement does not apply after the upgrade — -/// the initial HTTP request was already evaluated. +/// either switches to a parsed WebSocket relay for opted-in message policy / +/// credential rewriting or to raw bidirectional TCP copy for other upgrades. pub(crate) async fn handle_upgrade( client: &mut C, upstream: &mut U, overflow: Vec, host: &str, port: u16, + options: UpgradeRelayOptions<'_>, ) -> Result<()> where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send, { + let use_websocket_relay = options.websocket_request + && (options.websocket.message_policy.inspects_messages() + || options.websocket.permessage_deflate + || (options.websocket.credential_rewrite && options.secret_resolver.is_some())); + let relay_mode = if use_websocket_relay { + "websocket parsed relay" + } else { + "raw bidirectional relay (L7 enforcement no longer active)" + }; ocsf_emit!( NetworkActivityBuilder::new(crate::ocsf_ctx()) .activity(ActivityId::Other) @@ -395,12 +483,56 @@ where .severity(SeverityId::Informational) .dst_endpoint(Endpoint::from_domain(host, port)) .message(format!( - "101 Switching Protocols — raw bidirectional relay (L7 enforcement no longer active) \ - [host:{host} port:{port} overflow_bytes:{}]", + "101 Switching Protocols — {relay_mode} [host:{host} port:{port} overflow_bytes:{}]", overflow.len() )) .build() ); + if use_websocket_relay { + let resolver = if options.websocket.credential_rewrite { + options.secret_resolver.as_deref() + } else { + None + }; + let inspector = if options.websocket.message_policy.inspects_messages() { + match (options.engine, options.ctx) { + (Some(engine), Some(ctx)) => Some(crate::l7::websocket::InspectionOptions { + engine, + ctx, + enforcement: options.enforcement, + target: options.target.clone(), + query_params: options.query_params.clone(), + graphql_policy: options.websocket.message_policy.is_graphql(), + }), + _ => { + return Err(miette!( + "websocket message inspection missing policy context" + )); + } + } + } else { + None + }; + let compression = if options.websocket.permessage_deflate { + crate::l7::websocket::WebSocketCompression::PermessageDeflate + } else { + crate::l7::websocket::WebSocketCompression::None + }; + return crate::l7::websocket::relay_with_options( + client, + upstream, + overflow, + host, + port, + crate::l7::websocket::RelayOptions { + policy_name: &options.policy_name, + resolver, + inspector, + compression, + }, + ) + .await; + } if !overflow.is_empty() { client.write_all(&overflow).await.into_diagnostic()?; client.flush().await.into_diagnostic()?; @@ -411,6 +543,57 @@ where Ok(()) } +pub(crate) fn upgrade_options<'a>( + config: &L7EndpointConfig, + ctx: &'a L7EvalContext, + websocket_request: bool, + target: &str, + query_params: &std::collections::HashMap>, + engine: Option<&'a TunnelPolicyEngine>, +) -> UpgradeRelayOptions<'a> { + let websocket_credential_rewrite = + matches!(config.protocol, L7Protocol::Rest | L7Protocol::Websocket) + && config.websocket_credential_rewrite; + let websocket_message_policy = if config.protocol == L7Protocol::Websocket { + if config.websocket_graphql_policy { + WebSocketMessagePolicy::Graphql + } else { + WebSocketMessagePolicy::Transport + } + } else { + WebSocketMessagePolicy::None + }; + UpgradeRelayOptions { + websocket_request, + websocket: WebSocketUpgradeBehavior { + credential_rewrite: websocket_credential_rewrite, + message_policy: websocket_message_policy, + permessage_deflate: false, + }, + secret_resolver: if websocket_credential_rewrite { + ctx.secret_resolver.clone() + } else { + None + }, + engine, + ctx: engine.map(|_| ctx), + enforcement: config.enforcement, + target: target.to_string(), + query_params: query_params.clone(), + policy_name: ctx.policy_name.clone(), + } +} + +pub(crate) fn websocket_extension_mode(config: &L7EndpointConfig) -> WebSocketExtensionMode { + if config.protocol == L7Protocol::Websocket + || (config.protocol == L7Protocol::Rest && config.websocket_credential_rewrite) + { + WebSocketExtensionMode::PermessageDeflate + } else { + WebSocketExtensionMode::Preserve + } +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -490,6 +673,24 @@ where query_params: req.query_params.clone(), graphql: None, }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + provider + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } // Evaluate L7 policy via Rego (using redacted target) let (allowed, reason) = evaluate_l7_request(engine, ctx, &request_info)?; @@ -558,12 +759,17 @@ where if allowed || config.enforcement == EnforcementMode::Audit { // Forward request to upstream and relay response - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, + }, ) .await?; match outcome { @@ -576,8 +782,23 @@ where ); return Ok(()); } - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req.query_params, + Some(engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; } } } else { @@ -787,8 +1008,21 @@ where ); return Ok(()); } - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let options = UpgradeRelayOptions { + websocket: WebSocketUpgradeBehavior { + permessage_deflate: websocket_permessage_deflate, + ..Default::default() + }, + ..Default::default() + }; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; } } } else { @@ -1016,20 +1250,31 @@ where // Forward request with credential rewriting and relay the response. // relay_http_request_with_resolver handles both directions: it sends // the request upstream and reads the response back to the client. - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - resolver, - Some(generation_guard), + crate::l7::rest::RelayRequestOptions { + resolver, + generation_guard: Some(generation_guard), + ..Default::default() + }, ) .await?; match outcome { RelayOutcome::Reusable => {} // continue loop RelayOutcome::Consumed => break, - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { overflow, .. } => { + return handle_upgrade( + client, + upstream, + overflow, + &ctx.host, + ctx.port, + UpgradeRelayOptions::default(), + ) + .await; } } } @@ -1049,7 +1294,7 @@ mod tests { use super::*; use crate::opa::{NetworkInput, OpaEngine}; use std::path::PathBuf; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); @@ -1086,6 +1331,436 @@ mod tests { ); } + #[test] + fn websocket_text_policy_requires_explicit_message_rule() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "gateway.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&input) + .unwrap() + .1; + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "ws_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let request = L7RequestInfo { + action: "WEBSOCKET_TEXT".into(), + target: "/ws".into(), + query_params: std::collections::HashMap::new(), + graphql: None, + }; + + let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); + + assert!(!allowed); + assert!(reason.contains("WEBSOCKET_TEXT /ws not permitted")); + } + + #[tokio::test] + async fn route_selected_websocket_upgrade_rejects_invalid_accept_without_forwarding_101() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Rest, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }]; + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + assert!(forwarded.contains("Connection: Upgrade\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", + ) + .await + .unwrap(); + + let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should fail closed on invalid accept") + .unwrap() + .expect_err("invalid accept must fail the route-selected relay"); + assert!(err.to_string().contains("Sec-WebSocket-Accept")); + + let mut response = [0u8; 1]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client side should close without 101") + .unwrap(); + assert_eq!(n, 0, "invalid response must not forward 101 headers"); + } + + #[tokio::test] + async fn route_selected_websocket_rewrites_text_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + - allow: + method: WEBSOCKET_TEXT + path: "/ws" + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten websocket text should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!(rewritten, r#"{"op":2,"d":{"token":"real-token"}}"#); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + + #[tokio::test] + async fn route_selected_graphql_websocket_rewrites_connection_init_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/graphql".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: true, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("T").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /graphql HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("GET /graphql HTTP/1.1")); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten GraphQL WebSocket control message should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + + fn masked_text_frame(payload: &[u8]) -> Vec { + let mask = [0x11, 0x22, 0x33, 0x44]; + assert!( + payload.len() <= 125, + "test helper only supports small frames" + ); + let payload_len = u8::try_from(payload.len()).expect("small frame length"); + let mut frame = vec![0x81, 0x80 | payload_len]; + frame.extend_from_slice(&mask); + frame.extend( + payload + .iter() + .enumerate() + .map(|(idx, byte)| byte ^ mask[idx % 4]), + ); + frame + } + + async fn read_text_frame( + reader: &mut R, + ) -> std::io::Result<(bool, String)> { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await?; + assert_eq!(header[0] & 0x0f, 0x1, "expected text frame"); + let masked = header[1] & 0x80 != 0; + let payload_len = usize::from(header[1] & 0x7f); + assert!(payload_len <= 125, "test helper only supports small frames"); + let mut mask = [0u8; 4]; + if masked { + reader.read_exact(&mut mask).await?; + } + let mut payload = vec![0u8; payload_len]; + reader.read_exact(&mut payload).await?; + if masked { + for (idx, byte) in payload.iter_mut().enumerate() { + *byte ^= mask[idx % 4]; + } + } + Ok((masked, String::from_utf8(payload).expect("text payload"))) + } + #[tokio::test] async fn l7_relay_closes_keep_alive_tunnel_after_policy_generation_change() { let initial_data = r#" diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 85ae01290..c513499f4 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -9,13 +9,18 @@ use crate::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; use crate::opa::PolicyGenerationGuard; -use crate::secrets::rewrite_http_header_block; +use crate::secrets::{ + SecretResolver, contains_reserved_credential_marker, rewrite_http_header_block, +}; +use base64::Engine as _; use miette::{IntoDiagnostic, Result, miette}; -use std::collections::HashMap; +use sha1::{Digest, Sha1}; +use std::collections::{HashMap, HashSet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::debug; const MAX_HEADER_BYTES: usize = 16384; // 16 KiB for HTTP headers +const MAX_REWRITE_BODY_BYTES: usize = 256 * 1024; const RELAY_BUF_SIZE: usize = 8192; /// Idle timeout for `relay_until_eof`. If no data arrives within this window /// the body is considered complete. Prevents blocking on servers that keep @@ -343,7 +348,7 @@ pub(crate) async fn relay_http_request_with_resolver( req: &L7Request, client: &mut C, upstream: &mut U, - resolver: Option<&crate::secrets::SecretResolver>, + resolver: Option<&SecretResolver>, ) -> Result where C: AsyncRead + AsyncWrite + Unpin, @@ -356,9 +361,48 @@ pub(crate) async fn relay_http_request_with_resolver_guarded( req: &L7Request, client: &mut C, upstream: &mut U, - resolver: Option<&crate::secrets::SecretResolver>, + resolver: Option<&SecretResolver>, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result +where + C: AsyncRead + AsyncWrite + Unpin, + U: AsyncRead + AsyncWrite + Unpin, +{ + relay_http_request_with_options_guarded( + req, + client, + upstream, + RelayRequestOptions { + resolver, + generation_guard, + websocket_extensions: WebSocketExtensionMode::Preserve, + request_body_credential_rewrite: false, + }, + ) + .await +} + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub(crate) enum WebSocketExtensionMode { + #[default] + Preserve, + PermessageDeflate, +} + +#[derive(Clone, Copy, Default)] +pub(crate) struct RelayRequestOptions<'a> { + pub(crate) resolver: Option<&'a SecretResolver>, + pub(crate) generation_guard: Option<&'a PolicyGenerationGuard>, + pub(crate) websocket_extensions: WebSocketExtensionMode, + pub(crate) request_body_credential_rewrite: bool, +} + +pub(crate) async fn relay_http_request_with_options_guarded( + req: &L7Request, + client: &mut C, + upstream: &mut U, + options: RelayRequestOptions<'_>, +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -368,88 +412,815 @@ where .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); + let header_str = std::str::from_utf8(&req.raw_header[..header_end]) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let client_requested_upgrade = client_requested_upgrade(header_str); + let websocket_request = if options.websocket_extensions == WebSocketExtensionMode::Preserve { + None + } else { + parse_websocket_upgrade_request(&req.raw_header[..header_end])? + }; - let rewrite_result = rewrite_http_header_block(&req.raw_header[..header_end], resolver) + let (header_bytes, expected_websocket_extension) = rewrite_websocket_extensions_for_mode( + &req.raw_header[..header_end], + options.websocket_extensions, + websocket_request.is_some(), + )?; + let websocket_response = + websocket_request + .as_ref() + .map(|request| WebSocketResponseValidation { + expected_accept: websocket_accept_for_key(&request.sec_key), + expected_extension: expected_websocket_extension.clone(), + offered_subprotocols: request.subprotocols.clone(), + }); + + let rewrite_result = rewrite_http_header_block(&header_bytes, options.resolver) .map_err(|e| miette!("credential injection failed: {e}"))?; - if let Some(guard) = generation_guard { + if let Some(guard) = options.generation_guard { guard.ensure_current()?; } - upstream - .write_all(&rewrite_result.rewritten) - .await - .into_diagnostic()?; + if options.request_body_credential_rewrite { + let body = collect_and_rewrite_request_body( + req, + client, + &rewrite_result.rewritten, + header_str, + &req.raw_header[header_end..], + options.resolver, + options.generation_guard, + ) + .await?; + upstream.write_all(&body.headers).await.into_diagnostic()?; + if !body.body.is_empty() { + upstream.write_all(&body.body).await.into_diagnostic()?; + } + } else { + upstream + .write_all(&rewrite_result.rewritten) + .await + .into_diagnostic()?; - let overflow = &req.raw_header[header_end..]; - if !overflow.is_empty() { - if let Some(guard) = generation_guard { - guard.ensure_current()?; + let overflow = &req.raw_header[header_end..]; + if !overflow.is_empty() { + if let Some(guard) = options.generation_guard { + guard.ensure_current()?; + } + upstream.write_all(overflow).await.into_diagnostic()?; + } + let overflow_len = overflow.len() as u64; + + match req.body_length { + BodyLength::ContentLength(len) => { + let remaining = len.saturating_sub(overflow_len); + if remaining > 0 { + relay_fixed(client, upstream, remaining, options.generation_guard).await?; + } + } + BodyLength::Chunked => { + relay_chunked( + client, + upstream, + &req.raw_header[header_end..], + options.generation_guard, + ) + .await?; + } + BodyLength::None => {} } - upstream.write_all(overflow).await.into_diagnostic()?; } - let overflow_len = overflow.len() as u64; + upstream.flush().await.into_diagnostic()?; + + let outcome = relay_response( + &req.action, + upstream, + client, + RelayResponseOptions { + websocket_extensions: options.websocket_extensions, + websocket: websocket_response, + client_requested_upgrade, + }, + ) + .await?; + Ok(outcome) +} + +struct PreparedRequestBody { + headers: Vec, + body: Vec, +} + +async fn collect_and_rewrite_request_body( + req: &L7Request, + client: &mut C, + rewritten_headers: &[u8], + original_header_str: &str, + already_read: &[u8], + resolver: Option<&SecretResolver>, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result { match req.body_length { + BodyLength::None => { + if body_bytes_contain_reserved_marker(already_read) { + return Err(miette!( + "request body credential rewrite cannot resolve placeholders without explicit body framing" + )); + } + Ok(PreparedRequestBody { + headers: rewritten_headers.to_vec(), + body: already_read.to_vec(), + }) + } BodyLength::ContentLength(len) => { - let remaining = len.saturating_sub(overflow_len); - if remaining > 0 { - relay_fixed(client, upstream, remaining, generation_guard).await?; + let len = usize::try_from(len) + .map_err(|_| miette!("request body is too large for credential rewrite"))?; + if len > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); } + let mut body = Vec::with_capacity(len); + let initial_len = already_read.len().min(len); + body.extend_from_slice(&already_read[..initial_len]); + let mut remaining = len.saturating_sub(initial_len); + let mut buf = [0u8; RELAY_BUF_SIZE]; + while remaining > 0 { + let to_read = remaining.min(buf.len()); + let n = client.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!( + "Connection closed with {remaining} body bytes remaining" + )); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + body.extend_from_slice(&buf[..n]); + remaining -= n; + } + let (headers, body) = + rewrite_buffered_body(rewritten_headers, original_header_str, body, resolver)?; + Ok(PreparedRequestBody { headers, body }) } BodyLength::Chunked => { - relay_chunked( - client, - upstream, - &req.raw_header[header_end..], - generation_guard, - ) - .await?; + let body = collect_chunked_body(client, already_read, generation_guard).await?; + if body_bytes_contain_reserved_marker(&body) { + return Err(miette!( + "request body credential rewrite does not support chunked bodies containing credential placeholders" + )); + } + Ok(PreparedRequestBody { + headers: rewritten_headers.to_vec(), + body, + }) } - BodyLength::None => {} } - upstream.flush().await.into_diagnostic()?; +} - let outcome = relay_response(&req.action, upstream, client).await?; - - // Validate that the client actually requested an upgrade before accepting - // a 101 from upstream. Per RFC 9110 Section 7.8, the server MUST NOT send - // 101 unless the client sent Upgrade + Connection: Upgrade headers. A - // non-compliant or malicious upstream could send an unsolicited 101 to - // bypass L7 inspection. - if matches!(outcome, RelayOutcome::Upgraded { .. }) { - let header_str = String::from_utf8_lossy(&req.raw_header[..header_end]); - if !client_requested_upgrade(&header_str) { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::DetectionFindingBuilder::new(crate::ocsf_ctx()) - .activity(openshell_ocsf::ActivityId::Open) - .action(openshell_ocsf::ActionId::Denied) - .disposition(openshell_ocsf::DispositionId::Blocked) - .severity(openshell_ocsf::SeverityId::High) - .confidence(openshell_ocsf::ConfidenceId::High) - .is_alert(true) - .finding_info( - openshell_ocsf::FindingInfo::new( - "unsolicited-101-upgrade", - "Unsolicited 101 Switching Protocols", - ) - .with_desc(&format!( - "Upstream sent 101 without client Upgrade request for {} {} — \ - possible L7 inspection bypass. Connection closed.", - req.action, req.target, - )), - ) - .message(format!( - "Unsolicited 101 upgrade blocked: {} {}", - req.action, req.target, - )) - .build() - ); - return Ok(RelayOutcome::Consumed); +fn rewrite_buffered_body( + headers: &[u8], + original_header_str: &str, + body: Vec, + resolver: Option<&SecretResolver>, +) -> Result<(Vec, Vec)> { + if body.is_empty() { + return Ok((headers.to_vec(), body)); + } + + let content_type = content_type(original_header_str); + if !is_rewritable_content_type(content_type.as_deref()) { + if body_bytes_contain_reserved_marker(&body) { + return Err(miette!( + "request body credential rewrite found placeholders in an unsupported content type" + )); } + return Ok((headers.to_vec(), body)); } - Ok(outcome) + let mut text = String::from_utf8(body) + .map_err(|_| miette!("request body credential rewrite requires UTF-8 text bodies"))?; + if !contains_reserved_credential_marker(&text) { + return Ok((headers.to_vec(), text.into_bytes())); + } + + let Some(resolver) = resolver else { + return Err(miette!( + "request body credential rewrite found placeholders but no resolver is available" + )); + }; + + let replacements = if content_type.as_deref() == Some("application/x-www-form-urlencoded") { + let (rewritten, replacements) = rewrite_form_urlencoded_body(&text, resolver)?; + text = rewritten; + replacements + } else { + resolver + .rewrite_text_placeholders(&mut text, "request_body") + .map_err(|e| miette!("credential injection failed: {e}"))? + }; + if replacements == 0 || contains_reserved_credential_marker(&text) { + return Err(miette!( + "request body credential rewrite left unresolved credential placeholders" + )); + } + + let body = text.into_bytes(); + let headers = set_content_length(headers, body.len())?; + Ok((headers, body)) +} + +fn rewrite_form_urlencoded_body(body: &str, resolver: &SecretResolver) -> Result<(String, usize)> { + let mut rewritten = String::with_capacity(body.len()); + let mut replacements = 0usize; + + for (idx, field) in body.split('&').enumerate() { + if idx > 0 { + rewritten.push('&'); + } + + let (name, value) = field + .split_once('=') + .map_or((field, None), |(name, value)| (name, Some(value))); + let decoded_name = form_url_decode(name)?; + if contains_reserved_credential_marker(&decoded_name) { + return Err(miette!( + "request body credential rewrite does not support placeholders in form field names" + )); + } + + rewritten.push_str(name); + let Some(value) = value else { + continue; + }; + + rewritten.push('='); + let decoded_value = form_url_decode(value)?; + if !contains_reserved_credential_marker(&decoded_value) { + rewritten.push_str(value); + continue; + } + + let mut rewritten_value = decoded_value; + let field_replacements = resolver + .rewrite_text_placeholders(&mut rewritten_value, "request_body") + .map_err(|e| miette!("credential injection failed: {e}"))?; + if field_replacements == 0 || contains_reserved_credential_marker(&rewritten_value) { + return Err(miette!( + "request body credential rewrite left unresolved credential placeholders" + )); + } + replacements += field_replacements; + rewritten.push_str(&form_url_encode(&rewritten_value)); + } + + Ok((rewritten, replacements)) +} + +fn form_url_decode(input: &str) -> Result { + let bytes = input.as_bytes(); + let mut decoded = Vec::with_capacity(bytes.len()); + let mut pos = 0usize; + + while pos < bytes.len() { + match bytes[pos] { + b'+' => { + decoded.push(b' '); + pos += 1; + } + b'%' if pos + 2 < bytes.len() => { + if let (Some(hi), Some(lo)) = (hex_value(bytes[pos + 1]), hex_value(bytes[pos + 2])) + { + decoded.push((hi << 4) | lo); + pos += 3; + } else { + decoded.push(bytes[pos]); + pos += 1; + } + } + byte => { + decoded.push(byte); + pos += 1; + } + } + } + + String::from_utf8(decoded).map_err(|_| { + miette!("request body credential rewrite requires UTF-8 form-url-encoded fields") + }) +} + +fn form_url_encode(input: &str) -> String { + let mut encoded = String::with_capacity(input.len()); + for byte in input.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'*' => { + encoded.push(byte as char); + } + b' ' => encoded.push('+'), + _ => { + use std::fmt::Write as _; + let _ = write!(encoded, "%{byte:02X}"); + } + } + } + encoded +} + +fn hex_value(byte: u8) -> Option { + match byte { + b'0'..=b'9' => Some(byte - b'0'), + b'a'..=b'f' => Some(byte - b'a' + 10), + b'A'..=b'F' => Some(byte - b'A' + 10), + _ => None, + } +} + +async fn collect_chunked_body( + client: &mut C, + already_read: &[u8], + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result> { + let mut read_buf = [0u8; RELAY_BUF_SIZE]; + let mut parse_buf = Vec::from(already_read); + let mut pos = 0usize; + + loop { + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + + let size_line_end = loop { + if let Some(end) = find_crlf(&parse_buf, pos) { + break end; + } + let n = client.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended before chunk-size line")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + }; + + let size_line = std::str::from_utf8(&parse_buf[pos..size_line_end]) + .into_diagnostic() + .map_err(|_| miette!("Invalid UTF-8 in chunk-size line"))?; + let size_token = size_line + .split(';') + .next() + .map(str::trim) + .unwrap_or_default(); + let chunk_size = usize::from_str_radix(size_token, 16) + .into_diagnostic() + .map_err(|_| miette!("Invalid chunk size token: {size_token:?}"))?; + pos = size_line_end + 2; + + if chunk_size == 0 { + loop { + let trailer_end = loop { + if let Some(end) = find_crlf(&parse_buf, pos) { + break end; + } + let n = client.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended before trailer terminator")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + }; + let trailer_line = &parse_buf[pos..trailer_end]; + pos = trailer_end + 2; + if trailer_line.is_empty() { + return Ok(parse_buf); + } + } + } + + let chunk_end = pos + .checked_add(chunk_size) + .ok_or_else(|| miette!("Chunk size overflow"))?; + let chunk_with_crlf_end = chunk_end + .checked_add(2) + .ok_or_else(|| miette!("Chunk size overflow"))?; + while parse_buf.len() < chunk_with_crlf_end { + let n = client.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended mid-chunk")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + } + if &parse_buf[chunk_end..chunk_with_crlf_end] != b"\r\n" { + return Err(miette!("Chunk missing terminating CRLF")); + } + pos = chunk_with_crlf_end; + } +} + +fn content_type(headers: &str) -> Option { + headers.lines().skip(1).find_map(|line| { + let (name, value) = line.split_once(':')?; + name.trim().eq_ignore_ascii_case("content-type").then(|| { + value + .split(';') + .next() + .unwrap_or("") + .trim() + .to_ascii_lowercase() + }) + }) +} + +fn is_rewritable_content_type(content_type: Option<&str>) -> bool { + let Some(content_type) = content_type else { + return false; + }; + content_type == "application/json" + || content_type == "application/x-www-form-urlencoded" + || content_type.starts_with("text/") +} + +fn body_bytes_contain_reserved_marker(body: &[u8]) -> bool { + if body.is_empty() { + return false; + } + String::from_utf8_lossy(body) + .split('\0') + .any(contains_reserved_credential_marker) +} + +fn set_content_length(headers: &[u8], len: usize) -> Result> { + use std::fmt::Write as _; + + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = String::with_capacity(header_str.len() + 32); + let mut inserted = false; + for line in header_str.split("\r\n") { + if line.is_empty() { + if !inserted { + let _ = write!(out, "Content-Length: {len}\r\n"); + } + out.push_str("\r\n"); + break; + } + if line + .split_once(':') + .is_some_and(|(name, _)| name.trim().eq_ignore_ascii_case("content-length")) + { + if !inserted { + let _ = write!(out, "Content-Length: {len}\r\n"); + inserted = true; + } + continue; + } + out.push_str(line); + out.push_str("\r\n"); + } + Ok(out.into_bytes()) +} + +pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { + let header_end = raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(raw_header.len(), |p| p + 4); + validate_websocket_upgrade_request(&raw_header[..header_end]).unwrap_or(false) +} + +fn rewrite_websocket_extensions_for_mode( + raw_header: &[u8], + mode: WebSocketExtensionMode, + websocket_request: bool, +) -> Result<(Vec, Option)> { + if !websocket_request || mode == WebSocketExtensionMode::Preserve { + return Ok((raw_header.to_vec(), None)); + } + match mode { + WebSocketExtensionMode::Preserve => Ok((raw_header.to_vec(), None)), + WebSocketExtensionMode::PermessageDeflate => { + rewrite_websocket_extensions_for_permessage_deflate(raw_header) + } + } +} + +fn rewrite_websocket_extensions_for_permessage_deflate( + raw_header: &[u8], +) -> Result<(Vec, Option)> { + let header_str = std::str::from_utf8(raw_header) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let safe_offer = supported_permessage_deflate_offer(header_str)?; + let mut out = Vec::with_capacity(raw_header.len()); + let mut inserted = false; + + for line in header_str.split_inclusive("\r\n") { + let bare = line.strip_suffix("\r\n").unwrap_or(line); + if bare + .to_ascii_lowercase() + .starts_with("sec-websocket-extensions:") + { + continue; + } + if bare.is_empty() && !inserted { + if let Some(offer) = safe_offer.as_deref() { + out.extend_from_slice(b"Sec-WebSocket-Extensions: "); + out.extend_from_slice(offer.as_bytes()); + out.extend_from_slice(b"\r\n"); + } + inserted = true; + } + out.extend_from_slice(line.as_bytes()); + } + Ok((out, safe_offer)) +} + +fn supported_permessage_deflate_offer(header_str: &str) -> Result> { + for offer in websocket_extension_offers(header_str)? { + if !offer.name.eq_ignore_ascii_case("permessage-deflate") { + continue; + } + let mut client_no_context_takeover = false; + let mut server_no_context_takeover = false; + let mut unsupported = false; + let mut seen = HashSet::new(); + for param in &offer.params { + let name = param.name.to_ascii_lowercase(); + if param.value.is_some() || !seen.insert(name.clone()) { + unsupported = true; + break; + } + if name == "client_no_context_takeover" { + client_no_context_takeover = true; + } else if name == "server_no_context_takeover" { + server_no_context_takeover = true; + } else { + unsupported = true; + break; + } + } + if client_no_context_takeover && !unsupported { + let mut offer = "permessage-deflate; client_no_context_takeover".to_string(); + if server_no_context_takeover { + offer.push_str("; server_no_context_takeover"); + } + return Ok(Some(offer)); + } + } + Ok(None) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketExtensionOffer { + name: String, + params: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketExtensionParam { + name: String, + value: Option, +} + +fn websocket_extension_offers(header_str: &str) -> Result> { + let mut offers = Vec::new(); + for line in header_str.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + if !name.trim().eq_ignore_ascii_case("sec-websocket-extensions") { + continue; + } + for extension in value.split(',') { + let mut parts = extension.split(';').map(str::trim); + let Some(extension_name) = parts.next().filter(|name| !name.is_empty()) else { + return Err(miette!("invalid WebSocket extension offer")); + }; + if !is_http_token(extension_name) { + return Err(miette!("invalid WebSocket extension token")); + } + let mut params = Vec::new(); + for param in parts { + if param.is_empty() { + return Err(miette!("invalid WebSocket extension parameter")); + } + let (param_name, param_value) = match param.split_once('=') { + Some((name, value)) => { + let value = value.trim(); + if value.is_empty() || value.starts_with('"') || !is_http_token(value) { + return Err(miette!("unsupported WebSocket extension parameter value")); + } + (name.trim(), Some(value.to_string())) + } + None => (param, None), + }; + if param_name.is_empty() || !is_http_token(param_name) { + return Err(miette!("invalid WebSocket extension parameter")); + } + params.push(WebSocketExtensionParam { + name: param_name.to_string(), + value: param_value, + }); + } + offers.push(WebSocketExtensionOffer { + name: extension_name.to_string(), + params, + }); + } + } + Ok(offers) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketUpgradeRequest { + sec_key: String, + subprotocols: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketResponseValidation { + expected_accept: String, + expected_extension: Option, + offered_subprotocols: Vec, +} + +fn validate_websocket_upgrade_request(raw_header: &[u8]) -> Result { + parse_websocket_upgrade_request(raw_header).map(|request| request.is_some()) +} + +fn parse_websocket_upgrade_request(raw_header: &[u8]) -> Result> { + let header_str = std::str::from_utf8(raw_header) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut lines = header_str.lines(); + let Some(request_line) = lines.next() else { + return Ok(None); + }; + let method = request_line.split_whitespace().next().unwrap_or_default(); + let mut headers = WebSocketUpgradeHeaders::default(); + + for line in lines { + if line.is_empty() { + break; + } + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + let value = value.trim(); + match name.as_str() { + "upgrade" if header_value_contains_token(value, "websocket") => { + headers.upgrade_websocket = true; + } + "connection" if header_value_contains_token(value, "upgrade") => { + headers.connection_upgrade = true; + } + "sec-websocket-key" => { + headers.sec_key_count += 1; + headers.sec_key = Some(value.to_string()); + } + "sec-websocket-version" => { + headers.version_count += 1; + headers.version = Some(value.to_string()); + } + "sec-websocket-protocol" => { + headers.subprotocols.extend(parse_http_token_list(value)?); + } + _ => {} + } + } + + if !headers.is_attempt() { + return Ok(None); + } + if !method.eq_ignore_ascii_case("GET") { + return Err(miette!("websocket upgrade request must use GET")); + } + if !headers.upgrade_websocket { + return Err(miette!( + "websocket upgrade request missing Upgrade: websocket" + )); + } + if !headers.connection_upgrade { + return Err(miette!( + "websocket upgrade request missing Connection: Upgrade" + )); + } + if headers.sec_key_count != 1 { + return Err(miette!( + "websocket upgrade request must include exactly one Sec-WebSocket-Key" + )); + } + let key = headers.sec_key.as_deref().unwrap_or_default(); + let decoded_key = base64::engine::general_purpose::STANDARD + .decode(key.as_bytes()) + .map_err(|_| miette!("websocket upgrade request has invalid Sec-WebSocket-Key"))?; + if decoded_key.len() != 16 { + return Err(miette!( + "websocket upgrade request has invalid Sec-WebSocket-Key length" + )); + } + if headers.version_count != 1 || headers.version.as_deref() != Some("13") { + return Err(miette!( + "websocket upgrade request must use Sec-WebSocket-Version: 13" + )); + } + Ok(Some(WebSocketUpgradeRequest { + sec_key: key.to_string(), + subprotocols: headers.subprotocols, + })) +} + +fn websocket_accept_for_key(sec_key: &str) -> String { + const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut hasher = Sha1::new(); + hasher.update(sec_key.as_bytes()); + hasher.update(WEBSOCKET_GUID.as_bytes()); + base64::engine::general_purpose::STANDARD.encode(hasher.finalize()) +} + +fn header_value_contains_token(value: &str, expected: &str) -> bool { + value + .split(',') + .any(|token| token.trim().eq_ignore_ascii_case(expected)) +} + +fn parse_http_token_list(value: &str) -> Result> { + let mut tokens = Vec::new(); + for token in value.split(',') { + let token = token.trim(); + if token.is_empty() || !is_http_token(token) { + return Err(miette!("invalid HTTP token list")); + } + tokens.push(token.to_string()); + } + Ok(tokens) +} + +fn is_http_token(value: &str) -> bool { + !value.is_empty() + && value.as_bytes().iter().all(|byte| { + matches!( + byte, + b'!' | b'#' + | b'$' + | b'%' + | b'&' + | b'\'' + | b'*' + | b'+' + | b'-' + | b'.' + | b'^' + | b'_' + | b'`' + | b'|' + | b'~' + | b'0'..=b'9' + | b'A'..=b'Z' + | b'a'..=b'z' + ) + }) +} + +#[derive(Default)] +struct WebSocketUpgradeHeaders { + upgrade_websocket: bool, + connection_upgrade: bool, + sec_key: Option, + sec_key_count: usize, + version: Option, + version_count: usize, + subprotocols: Vec, +} + +impl WebSocketUpgradeHeaders { + fn is_attempt(&self) -> bool { + self.upgrade_websocket || self.sec_key.is_some() || self.version.is_some() + } } /// Send a 403 Forbidden JSON deny response. @@ -768,10 +1539,28 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { .map(|offset| start + offset) } +#[derive(Clone)] +struct RelayResponseOptions { + websocket_extensions: WebSocketExtensionMode, + client_requested_upgrade: bool, + websocket: Option, +} + +impl Default for RelayResponseOptions { + fn default() -> Self { + Self { + websocket_extensions: WebSocketExtensionMode::Preserve, + client_requested_upgrade: true, + websocket: None, + } + } +} + async fn relay_response( request_method: &str, upstream: &mut U, client: &mut C, + options: RelayResponseOptions, ) -> Result where U: AsyncRead + Unpin, @@ -825,6 +1614,14 @@ where // from upstream beyond the headers are overflow that belong to the // upgraded protocol and must be forwarded before switching. if status_code == 101 { + if !options.client_requested_upgrade { + return Ok(RelayOutcome::Consumed); + } + let websocket_permessage_deflate = validate_websocket_response( + &header_str, + options.websocket_extensions, + options.websocket.as_ref(), + )?; client .write_all(&buf[..header_end]) .await @@ -836,7 +1633,10 @@ where overflow_bytes = overflow.len(), "101 Switching Protocols — signaling protocol upgrade" ); - return Ok(RelayOutcome::Upgraded { overflow }); + return Ok(RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + }); } // Bodiless responses (HEAD, 1xx, 204, 304): forward headers only, skip body @@ -938,6 +1738,159 @@ fn parse_connection_close(headers: &str) -> bool { false } +fn validate_websocket_response( + headers: &str, + mode: WebSocketExtensionMode, + websocket: Option<&WebSocketResponseValidation>, +) -> Result { + let Some(validation) = websocket else { + return validate_websocket_response_extensions_preserved(headers, mode); + }; + + let mut upgrade_websocket = false; + let mut connection_upgrade = false; + let mut accept_count = 0usize; + let mut accept_matches = false; + let mut subprotocol_count = 0usize; + let mut selected_subprotocol = None; + + for line in headers.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + let value = value.trim(); + match name.as_str() { + "upgrade" if header_value_contains_token(value, "websocket") => { + upgrade_websocket = true; + } + "connection" if header_value_contains_token(value, "upgrade") => { + connection_upgrade = true; + } + "sec-websocket-accept" => { + accept_count += 1; + accept_matches = value == validation.expected_accept; + } + "sec-websocket-protocol" => { + subprotocol_count += 1; + if !is_http_token(value) { + return Err(miette!( + "websocket upgrade response has invalid Sec-WebSocket-Protocol" + )); + } + selected_subprotocol = Some(value.to_string()); + } + _ => {} + } + } + + if !upgrade_websocket { + return Err(miette!( + "websocket upgrade response missing Upgrade: websocket" + )); + } + if !connection_upgrade { + return Err(miette!( + "websocket upgrade response missing Connection: Upgrade" + )); + } + if accept_count != 1 || !accept_matches { + return Err(miette!( + "websocket upgrade response has invalid Sec-WebSocket-Accept" + )); + } + if subprotocol_count > 1 { + return Err(miette!( + "websocket upgrade response has multiple Sec-WebSocket-Protocol headers" + )); + } + if let Some(protocol) = selected_subprotocol + && !validation + .offered_subprotocols + .iter() + .any(|offered| offered == &protocol) + { + return Err(miette!( + "upstream selected WebSocket subprotocol that was not offered" + )); + } + + let actual_extension = normalized_websocket_extension(headers)?; + match (&validation.expected_extension, actual_extension.as_deref()) { + (None, Some(_)) => Err(miette!( + "upstream negotiated WebSocket extension that was not offered" + )), + (None | Some(_), None) => Ok(false), + (Some(expected), Some(actual)) if expected.eq_ignore_ascii_case(actual) => Ok(true), + (Some(_), Some(_)) => Err(miette!( + "upstream negotiated WebSocket extension that does not match the safe offer" + )), + } +} + +fn validate_websocket_response_extensions_preserved( + headers: &str, + mode: WebSocketExtensionMode, +) -> Result { + match mode { + WebSocketExtensionMode::Preserve => Ok(false), + WebSocketExtensionMode::PermessageDeflate => { + let offers = websocket_extension_offers(headers)?; + if offers.is_empty() { + Ok(false) + } else { + Err(miette!( + "upstream negotiated WebSocket extension that was not offered" + )) + } + } + } +} + +fn normalized_websocket_extension(headers: &str) -> Result> { + let offers = websocket_extension_offers(headers)?; + if offers.is_empty() { + return Ok(None); + } + if offers.len() != 1 { + return Err(miette!("upstream negotiated multiple WebSocket extensions")); + } + let offer = &offers[0]; + if !offer.name.eq_ignore_ascii_case("permessage-deflate") { + return Err(miette!( + "upstream negotiated unsupported WebSocket extension" + )); + } + let mut client_no_context_takeover = false; + let mut server_no_context_takeover = false; + let mut seen = HashSet::new(); + for param in &offer.params { + let name = param.name.to_ascii_lowercase(); + if param.value.is_some() || !seen.insert(name.clone()) { + return Err(miette!( + "upstream negotiated unsupported permessage-deflate parameter" + )); + } + if name == "client_no_context_takeover" { + client_no_context_takeover = true; + } else if name == "server_no_context_takeover" { + server_no_context_takeover = true; + } else { + return Err(miette!( + "upstream negotiated unsupported permessage-deflate parameter" + )); + } + } + let mut normalized = String::from("permessage-deflate"); + if client_no_context_takeover { + normalized.push_str("; client_no_context_takeover"); + } + if server_no_context_takeover { + normalized.push_str("; server_no_context_takeover"); + } + Ok(Some(normalized)) +} + /// Check if the client request headers contain both `Upgrade` and /// `Connection: Upgrade` headers, indicating the client requested a /// protocol upgrade (e.g. WebSocket). @@ -1034,21 +1987,297 @@ fn is_benign_close(err: &std::io::Error) -> bool { ) } -#[cfg(test)] -#[allow( - clippy::iter_on_single_items, - clippy::manual_string_new, - clippy::collapsible_if, - clippy::cast_possible_truncation, - reason = "Test code: test fixtures and explicit value-shape assertions are idiomatic in tests." -)] -mod tests { - use super::*; - use crate::opa::OpaEngine; - use crate::secrets::SecretResolver; - use base64::Engine as _; +#[cfg(test)] +#[allow( + clippy::iter_on_single_items, + clippy::manual_string_new, + clippy::collapsible_if, + clippy::cast_possible_truncation, + reason = "Test code: test fixtures and explicit value-shape assertions are idiomatic in tests." +)] +mod tests { + use super::*; + use crate::opa::OpaEngine; + use crate::secrets::SecretResolver; + use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; + use std::sync::Arc; + + const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + const VALID_WS_KEY: &str = "dGhlIHNhbXBsZSBub25jZQ=="; + const VALID_WS_ACCEPT: &str = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="; + const TEXT_OPCODE: u8 = 0x1; + + #[derive(Debug)] + struct CapturedFrame { + fin_opcode: u8, + masked: bool, + payload: Vec, + } + + async fn read_http_header_block(reader: &mut R) -> Vec { + tokio::time::timeout(std::time::Duration::from_secs(2), async { + let mut header = Vec::new(); + let mut byte = [0u8; 1]; + loop { + reader.read_exact(&mut byte).await.unwrap(); + header.push(byte[0]); + if header.ends_with(b"\r\n\r\n") { + break; + } + } + header + }) + .await + .expect("HTTP header block should arrive") + } + + async fn read_websocket_frame(reader: &mut R) -> CapturedFrame { + tokio::time::timeout(std::time::Duration::from_secs(2), async { + let mut prefix = [0u8; 2]; + reader.read_exact(&mut prefix).await.unwrap(); + let masked = prefix[1] & 0x80 != 0; + let mut payload_len = u64::from(prefix[1] & 0x7f); + if payload_len == 126 { + let mut extended = [0u8; 2]; + reader.read_exact(&mut extended).await.unwrap(); + payload_len = u64::from(u16::from_be_bytes(extended)); + } else if payload_len == 127 { + let mut extended = [0u8; 8]; + reader.read_exact(&mut extended).await.unwrap(); + payload_len = u64::from_be_bytes(extended); + } + let mut mask_key = [0u8; 4]; + if masked { + reader.read_exact(&mut mask_key).await.unwrap(); + } + let payload_len = usize::try_from(payload_len).unwrap(); + let mut payload = vec![0u8; payload_len]; + reader.read_exact(&mut payload).await.unwrap(); + if masked { + apply_test_mask(&mut payload, mask_key); + } + CapturedFrame { + fin_opcode: prefix[0], + masked, + payload, + } + }) + .await + .expect("WebSocket frame should arrive") + } + + fn masked_frame_with_rsv(opcode: u8, rsv: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x80 | rsv | opcode); + write_test_payload_len(&mut frame, 0x80, payload.len()); + frame.extend_from_slice(&mask_key); + let mut masked = payload.to_vec(); + apply_test_mask(&mut masked, mask_key); + frame.extend_from_slice(&masked); + frame + } + + fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + write_test_payload_len(&mut frame, 0, payload.len()); + frame.extend_from_slice(payload); + frame + } + + fn write_test_payload_len(frame: &mut Vec, mask_bit: u8, payload_len: usize) { + if payload_len < 126 { + frame.push(mask_bit | payload_len as u8); + } else if u16::try_from(payload_len).is_ok() { + frame.push(mask_bit | 0x7e); + frame.extend_from_slice(&(payload_len as u16).to_be_bytes()); + } else { + frame.push(mask_bit | 0x7f); + frame.extend_from_slice(&(payload_len as u64).to_be_bytes()); + } + } + + fn apply_test_mask(payload: &mut [u8], mask_key: [u8; 4]) { + for (index, byte) in payload.iter_mut().enumerate() { + *byte ^= mask_key[index % 4]; + } + } + + fn compress_test_permessage_deflate(payload: &[u8]) -> Vec { + let mut compressor = Compress::new(Compression::fast(), false); + let mut out = Vec::with_capacity(payload.len().saturating_add(128)); + loop { + let consumed = usize::try_from(compressor.total_in()).unwrap(); + if consumed >= payload.len() { + break; + } + let before_in = compressor.total_in(); + let before_out = compressor.total_out(); + let status = compressor + .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) + .unwrap(); + if matches!(status, Status::BufError) + || (compressor.total_in() == before_in && compressor.total_out() == before_out) + { + out.reserve(out.capacity().max(1024)); + } + } + loop { + out.reserve(64); + let before_out = compressor.total_out(); + compressor + .compress_vec(&[], &mut out, FlushCompress::Sync) + .unwrap(); + if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + break; + } + if compressor.total_out() == before_out { + out.reserve(out.capacity().max(1024)); + } + } + out.truncate(out.len() - 4); + out + } + + fn decompress_test_permessage_deflate(payload: &[u8]) -> Vec { + let mut decoder = Decompress::new(false); + let mut input = Vec::with_capacity(payload.len() + 4); + input.extend_from_slice(payload); + input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); + let mut out = Vec::new(); + let mut input_pos = 0usize; + let mut scratch = [0u8; RELAY_BUF_SIZE]; + loop { + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let status = decoder + .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) + .unwrap(); + let read = usize::try_from(decoder.total_in() - before_in).unwrap(); + let written = usize::try_from(decoder.total_out() - before_out).unwrap(); + input_pos += read; + out.extend_from_slice(&scratch[..written]); + if matches!(status, Status::StreamEnd) { + break; + } + if input_pos >= input.len() && written < scratch.len() { + break; + } + assert!( + read != 0 || written != 0, + "test permessage-deflate decompression did not make progress" + ); + } + out + } + + fn websocket_request(extension: Option<&str>) -> L7Request { + let mut raw_header = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\n" + ); + if let Some(extension) = extension { + raw_header.push_str("Sec-WebSocket-Extensions: "); + raw_header.push_str(extension); + raw_header.push_str("\r\n"); + } + raw_header.push_str("Sec-WebSocket-Version: 13\r\n\r\n"); + L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: raw_header.into_bytes(), + body_length: BodyLength::None, + } + } + + async fn run_upgraded_websocket_case( + request_extension: Option<&'static str>, + response_extension: Option<&'static str>, + extension_mode: WebSocketExtensionMode, + resolver: Option>, + client_frame: Vec, + ) -> (String, CapturedFrame) { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(16384); + let (mut client_app, mut proxy_to_client) = tokio::io::duplex(16384); + let req = websocket_request(request_extension); + let resolver_for_header = resolver.clone(); + let resolver_for_upgrade = resolver.clone(); - const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + let upstream_task = tokio::spawn(async move { + let forwarded = read_http_header_block(&mut upstream_side).await; + let forwarded = String::from_utf8(forwarded).unwrap(); + let mut response = format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\n" + ); + if let Some(extension) = response_extension { + response.push_str("Sec-WebSocket-Extensions: "); + response.push_str(extension); + response.push_str("\r\n"); + } + response.push_str("\r\n"); + upstream_side.write_all(response.as_bytes()).await.unwrap(); + upstream_side.flush().await.unwrap(); + let frame = read_websocket_frame(&mut upstream_side).await; + (forwarded, frame) + }); + + let relay_task = tokio::spawn(async move { + let outcome = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: resolver_for_header.as_deref(), + websocket_extensions: extension_mode, + ..Default::default() + }, + ) + .await + .expect("handshake relay should succeed"); + let RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + else { + panic!("expected upgraded relay outcome"); + }; + let credential_rewrite = resolver_for_upgrade.is_some(); + crate::l7::relay::handle_upgrade( + &mut proxy_to_client, + &mut proxy_to_upstream, + overflow, + "example.com", + 443, + crate::l7::relay::UpgradeRelayOptions { + websocket_request: true, + websocket: crate::l7::relay::WebSocketUpgradeBehavior { + credential_rewrite, + permessage_deflate: websocket_permessage_deflate, + ..Default::default() + }, + secret_resolver: resolver_for_upgrade, + target: "/ws".to_string(), + policy_name: "test-policy".to_string(), + ..Default::default() + }, + ) + .await + }); + + let response = read_http_header_block(&mut client_app).await; + assert!( + String::from_utf8_lossy(&response).contains("101 Switching Protocols"), + "client must receive the upgrade before frame relay starts" + ); + client_app.write_all(&client_frame).await.unwrap(); + client_app.flush().await.unwrap(); + + let result = upstream_task.await.expect("upstream task should complete"); + drop(client_app); + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay_task).await; + result + } #[test] fn deny_response_body_is_agent_readable_and_redacted() { @@ -1711,7 +2940,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay_response should not deadlock"); @@ -1752,7 +2986,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("must not block when no Connection: close"); @@ -1788,7 +3027,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("HEAD", &mut upstream_read, &mut client_write), + relay_response( + "HEAD", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("HEAD relay must not deadlock waiting for body"); @@ -1803,281 +3047,706 @@ mod tests { let mut received = Vec::new(); client_read.read_to_end(&mut received).await.unwrap(); let received_str = String::from_utf8_lossy(&received); - assert!(received_str.contains("200 OK")); - // Should NOT contain body bytes - assert!(!received_str.contains('\0')); + assert!(received_str.contains("200 OK")); + // Should NOT contain body bytes + assert!(!received_str.contains('\0')); + } + + #[tokio::test] + async fn relay_response_204_no_body() { + let response = b"HTTP/1.1 204 No Content\r\nServer: test\r\n\r\n"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("204 relay must not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "204 response should be reusable" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + assert!(String::from_utf8_lossy(&received).contains("204 No Content")); + } + + #[tokio::test] + async fn relay_response_chunked_body_complete_in_overflow() { + // Entire chunked body (including terminal 0\r\n\r\n) arrives with + // headers in the same read. relay_chunked must NOT be called or it + // will block forever waiting for data that was already consumed. + let response = + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + // Do NOT close — if relay_chunked is called it will block forever + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("must not block when chunked body is complete in overflow"); + + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "connection should be reusable" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!( + received_str.contains("hello"), + "chunked body should be forwarded" + ); + } + + #[tokio::test] + async fn relay_response_chunked_with_trailers_does_not_wait_for_eof() { + // Last-chunk can be followed by trailers, so body terminator is not + // always literal "0\r\n\r\n". We must stop at final empty trailer + // line without waiting for upstream connection close. + let response = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\nx-checksum: abc123\r\n\r\n"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + // Keep stream open to ensure relay terminates by framing, not EOF. + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("must not block when chunked response has trailers"); + + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "chunked response should be reusable" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!( + received_str.contains("hello"), + "chunked body should be forwarded" + ); + assert!( + received_str.contains("x-checksum: abc123"), + "trailers should be forwarded" + ); + } + + #[tokio::test] + async fn relay_response_normal_content_length() { + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("normal relay must not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "Content-Length response should be reusable" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!(received_str.contains("hello")); + } + + #[tokio::test] + async fn relay_response_connection_close_with_content_length() { + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: close\r\n\r\nhello"; + + let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, mut client_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + upstream_write.write_all(response).await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("relay must not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + // With explicit framing, Connection: close is still reported as reusable + // so the relay loop continues. The *next* upstream write will fail and + // exit the loop via the normal error path. + assert!( + matches!(outcome, RelayOutcome::Reusable), + "explicit framing keeps loop alive despite Connection: close" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + assert!(String::from_utf8_lossy(&received).contains("hello")); + } + + #[tokio::test] + async fn relay_response_101_switching_protocols_returns_upgraded_with_overflow() { + // Build a 101 response followed by WebSocket frame data (overflow). + let mut response = Vec::new(); + response.extend_from_slice(b"HTTP/1.1 101 Switching Protocols\r\n"); + response.extend_from_slice(b"Upgrade: websocket\r\n"); + response.extend_from_slice(b"Connection: Upgrade\r\n"); + response.extend_from_slice(b"\r\n"); + response.extend_from_slice(b"\x81\x05hello"); // WebSocket frame + + let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, client_write) = tokio::io::duplex(4096); + + upstream_write.write_all(&response).await.unwrap(); + drop(upstream_write); + + let mut upstream_read = upstream_read; + let mut client_write = client_write; + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), + ) + .await + .expect("relay_response should not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + match outcome { + RelayOutcome::Upgraded { overflow, .. } => { + assert_eq!( + &overflow, b"\x81\x05hello", + "overflow should contain WebSocket frame data" + ); + } + other => panic!("Expected Upgraded, got {other:?}"), + } + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!( + received_str.contains("101 Switching Protocols"), + "client should receive the 101 response headers" + ); } #[tokio::test] - async fn relay_response_204_no_body() { - let response = b"HTTP/1.1 204 No Content\r\nServer: test\r\n\r\n"; + async fn relay_response_101_no_overflow() { + // 101 response with no trailing bytes — overflow should be empty. + let response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (_client_read, client_write) = tokio::io::duplex(4096); - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); - }); + upstream_write.write_all(response).await.unwrap(); + drop(upstream_write); + + let mut upstream_read = upstream_read; + let mut client_write = client_write; let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await - .expect("204 relay must not deadlock"); - - let outcome = result.expect("relay_response should succeed"); - assert!( - matches!(outcome, RelayOutcome::Reusable), - "204 response should be reusable" - ); + .expect("relay_response should not deadlock"); - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - assert!(String::from_utf8_lossy(&received).contains("204 No Content")); + match result.expect("should succeed") { + RelayOutcome::Upgraded { overflow, .. } => { + assert!(overflow.is_empty(), "no overflow expected"); + } + other => panic!("Expected Upgraded, got {other:?}"), + } } #[tokio::test] - async fn relay_response_chunked_body_complete_in_overflow() { - // Entire chunked body (including terminal 0\r\n\r\n) arrives with - // headers in the same read. relay_chunked must NOT be called or it - // will block forever waiting for data that was already consumed. - let response = - b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n"; + async fn relay_rejects_unsolicited_101_without_client_upgrade_header() { + // Client sends a normal GET without Upgrade headers. + // Upstream responds with 101 (non-compliant). The relay should + // reject the upgrade and return Consumed instead. + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let req = L7Request { + action: "GET".to_string(), + target: "/api".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); - // Do NOT close — if relay_chunked is called it will block forever + let upstream_task = tokio::spawn(async move { + // Read the request + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + // Send unsolicited 101 + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); }); let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + None, + ), ) .await - .expect("must not block when chunked body is complete in overflow"); + .expect("relay must not deadlock"); - let outcome = result.expect("relay_response should succeed"); + let outcome = result.expect("relay should succeed"); assert!( - matches!(outcome, RelayOutcome::Reusable), - "connection should be reusable" + matches!(outcome, RelayOutcome::Consumed), + "unsolicited 101 should be rejected as Consumed, got {outcome:?}" ); - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - let received_str = String::from_utf8_lossy(&received); - assert!( - received_str.contains("hello"), - "chunked body should be forwarded" - ); + upstream_task.await.expect("upstream task should complete"); } #[tokio::test] - async fn relay_response_chunked_with_trailers_does_not_wait_for_eof() { - // Last-chunk can be followed by trailers, so body terminator is not - // always literal "0\r\n\r\n". We must stop at final empty trailer - // line without waiting for upstream connection close. - let response = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\nx-checksum: abc123\r\n\r\n"; + async fn relay_accepts_101_with_client_upgrade_header() { + // Client sends a proper upgrade request with Upgrade + Connection headers. + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); - // Keep stream open to ensure relay terminates by framing, not EOF. - tokio::time::sleep(std::time::Duration::from_secs(10)).await; + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); }); let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + None, + ), ) .await - .expect("must not block when chunked response has trailers"); + .expect("relay must not deadlock"); - let outcome = result.expect("relay_response should succeed"); + let outcome = result.expect("relay should succeed"); assert!( - matches!(outcome, RelayOutcome::Reusable), - "chunked response should be reusable" + matches!(outcome, RelayOutcome::Upgraded { .. }), + "proper upgrade request should be accepted, got {outcome:?}" ); - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - let received_str = String::from_utf8_lossy(&received); + upstream_task.await.expect("upstream task should complete"); + } + + #[tokio::test] + async fn opted_in_websocket_relay_rejects_invalid_upgrade_before_upstream_write() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; + + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; + assert!( - received_str.contains("hello"), - "chunked body should be forwarded" + result.is_err(), + "missing Sec-WebSocket-Key must fail closed" ); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); assert!( - received_str.contains("x-checksum: abc123"), - "trailers should be forwarded" + forwarded.is_empty(), + "invalid opted-in upgrade must not reach upstream" ); } #[tokio::test] - async fn relay_response_normal_content_length() { - let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"; + async fn opted_in_websocket_relay_strips_request_extensions_and_rejects_response_extensions() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut app_side, mut proxy_to_client) = tokio::io::duplex(8192); - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), + body_length: BodyLength::None, + }; - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let forwarded = String::from_utf8_lossy(&buf[..total]); + assert!( + !forwarded + .to_ascii_lowercase() + .contains("sec-websocket-extensions"), + "opted-in request must strip extension negotiation" + ); + upstream_side + .write_all( + format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n" + ) + .as_bytes(), + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); }); - let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), - ) - .await - .expect("normal relay must not deadlock"); + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; - let outcome = result.expect("relay_response should succeed"); - assert!( - matches!(outcome, RelayOutcome::Reusable), - "Content-Length response should be reusable" - ); + let err = result.expect_err("upstream extension negotiation must fail closed"); + assert!(err.to_string().contains("not offered")); + upstream_task.await.expect("upstream task should complete"); - client_write.shutdown().await.unwrap(); + drop(proxy_to_client); let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - let received_str = String::from_utf8_lossy(&received); - assert!(received_str.contains("hello")); + app_side.read_to_end(&mut received).await.unwrap(); + assert!( + received.is_empty(), + "rejected extension negotiation must not forward 101 headers" + ); } #[tokio::test] - async fn relay_response_connection_close_with_content_length() { - let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: close\r\n\r\nhello"; + async fn permessage_deflate_mode_allows_supported_no_context_takeover() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - let (mut upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, mut client_write) = tokio::io::duplex(4096); + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), + body_length: BodyLength::None, + }; - tokio::spawn(async move { - upstream_write.write_all(response).await.unwrap(); + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let forwarded = String::from_utf8_lossy(&buf[..total]).to_ascii_lowercase(); + assert!(forwarded.contains( + "sec-websocket-extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover" + )); + upstream_side + .write_all( + format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\n\r\n" + ) + .as_bytes(), + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); }); - let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + let outcome = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, ) .await - .expect("relay must not deadlock"); + .expect("safe permessage-deflate negotiation should pass"); - let outcome = result.expect("relay_response should succeed"); - // With explicit framing, Connection: close is still reported as reusable - // so the relay loop continues. The *next* upstream write will fail and - // exit the loop via the normal error path. assert!( - matches!(outcome, RelayOutcome::Reusable), - "explicit framing keeps loop alive despite Connection: close" + matches!( + outcome, + RelayOutcome::Upgraded { + websocket_permessage_deflate: true, + .. + } + ), + "safe permessage-deflate must be marked negotiated" ); - - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - assert!(String::from_utf8_lossy(&received).contains("hello")); + upstream_task.await.expect("upstream task should complete"); } #[tokio::test] - async fn relay_response_101_switching_protocols_returns_upgraded_with_overflow() { - // Build a 101 response followed by WebSocket frame data (overflow). - let mut response = Vec::new(); - response.extend_from_slice(b"HTTP/1.1 101 Switching Protocols\r\n"); - response.extend_from_slice(b"Upgrade: websocket\r\n"); - response.extend_from_slice(b"Connection: Upgrade\r\n"); - response.extend_from_slice(b"\r\n"); - response.extend_from_slice(b"\x81\x05hello"); // WebSocket frame - - let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (mut client_read, client_write) = tokio::io::duplex(4096); - - upstream_write.write_all(&response).await.unwrap(); - drop(upstream_write); - - let mut upstream_read = upstream_read; - let mut client_write = client_write; - - let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + async fn websocket_conformance_preserve_mode_relays_raw_frames_without_validation() { + let (forwarded, frame) = run_upgraded_websocket_case( + None, + None, + WebSocketExtensionMode::Preserve, + None, + unmasked_frame(TEXT_OPCODE, b"raw-unmasked"), ) - .await - .expect("relay_response should not deadlock"); - - let outcome = result.expect("relay_response should succeed"); - match outcome { - RelayOutcome::Upgraded { overflow } => { - assert_eq!( - &overflow, b"\x81\x05hello", - "overflow should contain WebSocket frame data" - ); - } - other => panic!("Expected Upgraded, got {other:?}"), - } + .await; - client_write.shutdown().await.unwrap(); - let mut received = Vec::new(); - client_read.read_to_end(&mut received).await.unwrap(); - let received_str = String::from_utf8_lossy(&received); assert!( - received_str.contains("101 Switching Protocols"), - "client should receive the 101 response headers" + forwarded.contains("Upgrade: websocket"), + "raw preserve path should still forward the upgrade request" + ); + assert!( + !frame.masked, + "raw preserve path must not validate or rewrite client frame masking" ); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert_eq!(frame.payload, b"raw-unmasked"); } #[tokio::test] - async fn relay_response_101_no_overflow() { - // 101 response with no trailing bytes — overflow should be empty. - let response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; - - let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); - let (_client_read, client_write) = tokio::io::duplex(4096); + async fn websocket_conformance_rewrite_mode_rewrites_text_after_upgrade() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); - upstream_write.write_all(response).await.unwrap(); - drop(upstream_write); + let (forwarded, frame) = run_upgraded_websocket_case( + None, + None, + WebSocketExtensionMode::PermessageDeflate, + resolver.map(Arc::new), + masked_frame_with_rsv(TEXT_OPCODE, 0, payload.as_bytes()), + ) + .await; - let mut upstream_read = upstream_read; - let mut client_write = client_write; + assert!( + !forwarded + .to_ascii_lowercase() + .contains("sec-websocket-extensions"), + "plain rewrite path should not offer compression when the client did not offer a safe subset" + ); + assert!(frame.masked, "parsed relay must preserve client masking"); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert_eq!( + String::from_utf8(frame.payload).unwrap(), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } - let result = tokio::time::timeout( - std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + #[tokio::test] + async fn websocket_conformance_deflate_rewrites_compressed_text_after_upgrade() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + let compressed = compress_test_permessage_deflate(payload.as_bytes()); + + let (forwarded, frame) = run_upgraded_websocket_case( + Some("permessage-deflate; server_no_context_takeover; client_no_context_takeover"), + Some("permessage-deflate; server_no_context_takeover; client_no_context_takeover"), + WebSocketExtensionMode::PermessageDeflate, + resolver.map(Arc::new), + masked_frame_with_rsv(TEXT_OPCODE, 0x40, &compressed), ) - .await - .expect("relay_response should not deadlock"); + .await; - match result.expect("should succeed") { - RelayOutcome::Upgraded { overflow } => { - assert!(overflow.is_empty(), "no overflow expected"); - } - other => panic!("Expected Upgraded, got {other:?}"), - } + assert!( + forwarded.to_ascii_lowercase().contains( + "sec-websocket-extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover" + ), + "safe extension offer should be canonicalized before forwarding" + ); + assert!(frame.masked, "parsed relay must preserve client masking"); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert!( + frame.fin_opcode & 0x40 != 0, + "rewritten compressed text must retain RSV1" + ); + assert_eq!( + String::from_utf8(decompress_test_permessage_deflate(&frame.payload)).unwrap(), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); } #[tokio::test] - async fn relay_rejects_unsolicited_101_without_client_upgrade_header() { - // Client sends a normal GET without Upgrade headers. - // Upstream responds with 101 (non-compliant). The relay should - // reject the upgrade and return Consumed instead. + async fn opted_in_websocket_relay_rejects_invalid_accept_before_forwarding_101() { let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); - let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + let (mut app_side, mut proxy_to_client) = tokio::io::duplex(8192); let req = L7Request { action: "GET".to_string(), - target: "/api".to_string(), + target: "/ws".to_string(), query_params: HashMap::new(), - raw_header: b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), body_length: BodyLength::None, }; let upstream_task = tokio::spawn(async move { - // Read the request let mut buf = vec![0u8; 4096]; let mut total = 0; loop { @@ -2090,92 +3759,249 @@ mod tests { break; } } - // Send unsolicited 101 upstream_side .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", ) .await .unwrap(); upstream_side.flush().await.unwrap(); }); - let result = tokio::time::timeout( - std::time::Duration::from_secs(5), - relay_http_request_with_resolver( - &req, - &mut proxy_to_client, - &mut proxy_to_upstream, - None, + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; + + let err = result.expect_err("invalid Sec-WebSocket-Accept must fail closed"); + assert!(err.to_string().contains("Sec-WebSocket-Accept")); + upstream_task.await.expect("upstream task should complete"); + + drop(proxy_to_client); + let mut received = Vec::new(); + app_side.read_to_end(&mut received).await.unwrap(); + assert!( + received.is_empty(), + "invalid websocket response must not forward 101 headers" + ); + } + + #[test] + fn websocket_accept_matches_rfc_6455_sample() { + assert_eq!(websocket_accept_for_key(VALID_WS_KEY), VALID_WS_ACCEPT); + } + + #[test] + fn strict_response_validation_rejects_missing_upgrade_headers() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("missing Upgrade/Connection must fail"); + + assert!(err.to_string().contains("Upgrade: websocket")); + } + + #[test] + fn permessage_deflate_response_must_match_exact_safe_offer() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some( + "permessage-deflate; client_no_context_takeover; server_no_context_takeover" + .to_string(), + ), + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("extension response must exactly match the safe offer"); + + assert!(err.to_string().contains("safe offer")); + } + + #[test] + fn permessage_deflate_offer_requires_client_no_context_takeover() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!( + supported_permessage_deflate_offer(&raw) + .expect("valid unsupported extension offer should parse") + .is_none() + ); + } + + #[test] + fn permessage_deflate_offer_canonicalizes_safe_params() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + assert_eq!( + supported_permessage_deflate_offer(&raw) + .expect("safe extension offer should parse") + .as_deref(), + Some("permessage-deflate; client_no_context_takeover; server_no_context_takeover") + ); + } + + #[test] + fn permessage_deflate_offer_rejects_duplicate_safe_params() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; client_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + assert!( + supported_permessage_deflate_offer(&raw) + .expect("duplicate safe param should parse but not be supported") + .is_none() + ); + } + + #[test] + fn permessage_deflate_offer_rejects_quoted_values() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover=\"true\"\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + let err = supported_permessage_deflate_offer(&raw) + .expect_err("quoted permessage-deflate parameter values should fail closed"); + assert!(err.to_string().contains("parameter value")); + } + + #[test] + fn permessage_deflate_response_accepts_reordered_safe_params() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some( + "permessage-deflate; client_no_context_takeover; server_no_context_takeover" + .to_string(), ), + offered_subprotocols: Vec::new(), + }; + + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect("reordered safe extension params should canonicalize"); + + assert!(negotiated); + } + + #[test] + fn permessage_deflate_response_rejects_duplicate_safe_params() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some("permessage-deflate; client_no_context_takeover".to_string()), + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("duplicate extension params should fail closed"); + + assert!(err.to_string().contains("unsupported permessage-deflate")); + } + + #[test] + fn preserve_mode_leaves_malformed_extension_response_raw() { + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover=\"true\"\r\n\r\n", + WebSocketExtensionMode::Preserve, + None, ) - .await - .expect("relay must not deadlock"); + .expect("preserve mode should not parse or reject raw extension negotiation"); - let outcome = result.expect("relay should succeed"); - assert!( - matches!(outcome, RelayOutcome::Consumed), - "unsolicited 101 should be rejected as Consumed, got {outcome:?}" + assert!(!negotiated); + } + + #[test] + fn parse_websocket_upgrade_request_tracks_subprotocols() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Protocol: chat, superchat\r\nSec-WebSocket-Version: 13\r\n\r\n" ); - upstream_task.await.expect("upstream task should complete"); - } + let request = parse_websocket_upgrade_request(raw.as_bytes()) + .expect("request should parse") + .expect("request should be websocket"); - #[tokio::test] - async fn relay_accepts_101_with_client_upgrade_header() { - // Client sends a proper upgrade request with Upgrade + Connection headers. - let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); - let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + assert_eq!(request.subprotocols, ["chat", "superchat"]); + } - let req = L7Request { - action: "GET".to_string(), - target: "/ws".to_string(), - query_params: HashMap::new(), - raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n".to_vec(), - body_length: BodyLength::None, + #[test] + fn strict_response_validation_allows_offered_subprotocol() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string(), "superchat".to_string()], }; - let upstream_task = tokio::spawn(async move { - let mut buf = vec![0u8; 4096]; - let mut total = 0; - loop { - let n = upstream_side.read(&mut buf[total..]).await.unwrap(); - if n == 0 { - break; - } - total += n; - if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { - break; - } - } - upstream_side - .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", - ) - .await - .unwrap(); - upstream_side.flush().await.unwrap(); - }); + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: superchat\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect("offered subprotocol should validate"); - let result = tokio::time::timeout( - std::time::Duration::from_secs(5), - relay_http_request_with_resolver( - &req, - &mut proxy_to_client, - &mut proxy_to_upstream, - None, - ), + assert!(!negotiated); + } + + #[test] + fn strict_response_validation_rejects_unoffered_subprotocol() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string()], + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: admin\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), ) - .await - .expect("relay must not deadlock"); + .expect_err("unoffered subprotocol should fail closed"); - let outcome = result.expect("relay should succeed"); - assert!( - matches!(outcome, RelayOutcome::Upgraded { .. }), - "proper upgrade request should be accepted, got {outcome:?}" - ); + assert!(err.to_string().contains("subprotocol")); + } - upstream_task.await.expect("upstream task should complete"); + #[test] + fn strict_response_validation_rejects_multiple_subprotocol_headers() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string(), "superchat".to_string()], + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Protocol: superchat\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("multiple selected subprotocols should fail closed"); + + assert!(err.to_string().contains("Sec-WebSocket-Protocol")); } #[tokio::test] @@ -2243,6 +4069,94 @@ mod tests { assert!(client_requested_upgrade(headers)); } + #[test] + fn request_is_websocket_upgrade_detects_websocket_upgrade() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: keep-alive, Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!(request_is_websocket_upgrade(raw.as_bytes())); + } + + #[test] + fn request_is_websocket_upgrade_rejects_missing_key() { + let raw = b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(validate_websocket_upgrade_request(raw).is_err()); + } + + #[test] + fn request_is_websocket_upgrade_rejects_wrong_method() { + let raw = format!( + "POST /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!(!request_is_websocket_upgrade(raw.as_bytes())); + assert!(validate_websocket_upgrade_request(raw.as_bytes()).is_err()); + } + + #[test] + fn request_is_websocket_upgrade_rejects_wrong_version() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 12\r\n\r\n" + ); + assert!(!request_is_websocket_upgrade(raw.as_bytes())); + assert!(validate_websocket_upgrade_request(raw.as_bytes()).is_err()); + } + + #[test] + fn validate_websocket_upgrade_ignores_plain_rest_request() { + let raw = b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(!validate_websocket_upgrade_request(raw).expect("plain request should parse")); + } + + #[test] + fn validate_websocket_upgrade_ignores_non_websocket_upgrade() { + let raw = b"GET /h2c HTTP/1.1\r\nHost: example.com\r\nUpgrade: h2c\r\nConnection: Upgrade\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(!validate_websocket_upgrade_request(raw).expect("h2c request should parse")); + } + + #[test] + fn strip_websocket_extensions_removes_extension_negotiation() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + let (stripped, offered) = rewrite_websocket_extensions_for_mode( + raw.as_bytes(), + WebSocketExtensionMode::PermessageDeflate, + true, + ) + .expect("strip should succeed"); + assert!(offered.is_none()); + let stripped = String::from_utf8(stripped).unwrap(); + + assert!(stripped.contains("Upgrade: websocket\r\n")); + assert!(stripped.contains("Sec-WebSocket-Key: ")); + assert!(stripped.contains("Sec-WebSocket-Version: 13\r\n")); + assert!( + !stripped + .to_ascii_lowercase() + .contains("sec-websocket-extensions") + ); + assert!(stripped.ends_with("\r\n\r\n")); + } + + #[test] + fn strip_websocket_extensions_leaves_non_websocket_request_unchanged() { + let raw = b"GET /api HTTP/1.1\r\nHost: example.com\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n"; + + let (stripped, offered) = rewrite_websocket_extensions_for_mode( + raw, + WebSocketExtensionMode::PermessageDeflate, + false, + ) + .expect("strip should succeed"); + + assert!(offered.is_none()); + assert_eq!(stripped, raw); + } + #[test] fn rewrite_header_block_resolves_placeholder_auth_headers() { let (_, resolver) = SecretResolver::from_provider_env( @@ -2514,6 +4428,257 @@ mod tests { Ok(forwarded) } + async fn relay_and_capture_with_options( + raw_header: Vec, + body_length: BodyLength, + resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, + ) -> Result { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let header_str = String::from_utf8_lossy(&raw_header); + let first_line = header_str.lines().next().unwrap_or(""); + let parts: Vec<&str> = first_line.splitn(3, ' ').collect(); + let action = parts.first().unwrap_or(&"GET").to_string(); + let target = parts.get(1).unwrap_or(&"/").to_string(); + + let req = L7Request { + action, + target, + query_params: HashMap::new(), + raw_header, + body_length, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + let mut total = 0usize; + let mut header_end = None; + let mut expected_total = None; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if header_end.is_none() + && let Some(end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") + { + let end = end + 4; + let headers = String::from_utf8_lossy(&buf[..end]); + let len = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + header_end = Some(end); + expected_total = Some(end + len); + } + if expected_total.is_some_and(|expected| total >= expected) { + break; + } + } + upstream_side + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + String::from_utf8_lossy(&buf[..total]).to_string() + }); + + relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver, + request_body_credential_rewrite, + ..Default::default() + }, + ) + .await?; + + upstream_task + .await + .map_err(|e| miette!("upstream task failed: {e}")) + } + + #[tokio::test] + async fn relay_request_body_rewrites_provider_alias_header_and_urlencoded_token() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider.v1-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = format!("token={alias}&channel=C123"); + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::ContentLength(body.len() as u64), + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); + } + + #[tokio::test] + async fn relay_request_body_rewrites_percent_encoded_canonical_urlencoded_token() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=openshell%3Aresolve%3Aenv%3AAPI_TOKEN¬e=hello+world"; + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::ContentLength(body.len() as u64), + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + let expected_body = "token=provider-real-token¬e=hello+world"; + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("openshell%3Aresolve%3Aenv%3AAPI_TOKEN")); + assert!(!forwarded.contains("openshell:resolve:env:API_TOKEN")); + } + + #[tokio::test] + async fn relay_request_body_unresolved_alias_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=provider-OPENSHELL-RESOLVE-ENV-APP_TOKEN"; + let raw = format!( + "POST /api/connections.open HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let req = L7Request { + action: "POST".to_string(), + target: "/api/connections.open".to_string(), + query_params: HashMap::new(), + raw_header: raw.into_bytes(), + body_length: BodyLength::ContentLength(body.len() as u64), + }; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: Some(&resolver), + request_body_credential_rewrite: true, + ..Default::default() + }, + ) + .await + .expect_err("unknown body alias should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed body rewrite must not reach upstream" + ); + } + + #[tokio::test] + async fn relay_request_body_unresolved_encoded_canonical_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=openshell%3Aresolve%3Aenv%3AMISSING_TOKEN"; + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let req = L7Request { + action: "POST".to_string(), + target: "/api/messages".to_string(), + query_params: HashMap::new(), + raw_header: raw.into_bytes(), + body_length: BodyLength::ContentLength(body.len() as u64), + }; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: Some(&resolver), + request_body_credential_rewrite: true, + ..Default::default() + }, + ) + .await + .expect_err("unknown encoded body placeholder should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + assert!(!err.to_string().contains("MISSING_TOKEN")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed body rewrite must not reach upstream" + ); + } + #[tokio::test] async fn relay_injects_bearer_header_credential() { let (child_env, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs new file mode 100644 index 000000000..2dc1b25c3 --- /dev/null +++ b/crates/openshell-sandbox/src/l7/websocket.rs @@ -0,0 +1,1937 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! WebSocket relay for opt-in credential placeholder rewriting and message policy. +//! +//! The relay parses only client-to-server frames. Server-to-client bytes stay +//! raw passthrough so inspection and rewriting cannot expose response payloads. + +use crate::l7::relay::{L7EvalContext, evaluate_l7_request}; +use crate::l7::{EnforcementMode, L7RequestInfo}; +use crate::opa::TunnelPolicyEngine; +use crate::secrets::SecretResolver; +use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; +use miette::{IntoDiagnostic, Result, miette}; +use openshell_ocsf::{ + ActionId, ActivityId, DispositionId, Endpoint, NetworkActivityBuilder, SeverityId, StatusId, + ocsf_emit, +}; +use std::collections::HashMap; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +const MAX_TEXT_MESSAGE_BYTES: usize = 1024 * 1024; +const MAX_RAW_FRAME_PAYLOAD_BYTES: u64 = 16 * 1024 * 1024; +const COPY_BUF_SIZE: usize = 8192; +const OPCODE_CONTINUATION: u8 = 0x0; +const OPCODE_TEXT: u8 = 0x1; +const OPCODE_BINARY: u8 = 0x2; +const OPCODE_CLOSE: u8 = 0x8; +const OPCODE_PING: u8 = 0x9; +const OPCODE_PONG: u8 = 0xA; + +#[derive(Debug)] +struct FrameHeader { + fin: bool, + rsv: u8, + opcode: u8, + masked: bool, + payload_len: u64, + mask_key: Option<[u8; 4]>, + raw_header: Vec, +} + +#[derive(Debug)] +enum FragmentState { + None, + Text { payload: Vec, compressed: bool }, + Binary, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum WebSocketCompression { + None, + PermessageDeflate, +} + +pub(super) struct InspectionOptions<'a> { + pub(super) engine: &'a TunnelPolicyEngine, + pub(super) ctx: &'a L7EvalContext, + pub(super) enforcement: EnforcementMode, + pub(super) target: String, + pub(super) query_params: HashMap>, + pub(super) graphql_policy: bool, +} + +pub(super) struct RelayOptions<'a> { + pub(super) policy_name: &'a str, + pub(super) resolver: Option<&'a SecretResolver>, + pub(super) inspector: Option>, + pub(super) compression: WebSocketCompression, +} + +/// Relay an upgraded WebSocket connection with optional client text inspection, +/// credential rewriting, and strict permessage-deflate handling. +pub(super) async fn relay_with_options( + client: &mut C, + upstream: &mut U, + overflow: Vec, + host: &str, + port: u16, + options: RelayOptions<'_>, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + let (mut client_read, mut client_write) = tokio::io::split(client); + let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream); + + if !overflow.is_empty() { + client_write.write_all(&overflow).await.into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + } + + let client_to_server = + relay_client_to_server(&mut client_read, &mut upstream_write, host, port, &options); + let server_to_client = async { + tokio::io::copy(&mut upstream_read, &mut client_write) + .await + .into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + Ok::<(), miette::Report>(()) + }; + + let result = tokio::select! { + result = client_to_server => result, + result = server_to_client => result, + }; + let _ = upstream_write.shutdown().await; + let _ = client_write.shutdown().await; + result +} + +async fn relay_client_to_server( + reader: &mut R, + writer: &mut W, + host: &str, + port: u16, + options: &RelayOptions<'_>, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut fragments = FragmentState::None; + let mut close_seen = false; + + loop { + let Some(frame) = read_frame_header(reader).await.inspect_err(|e| { + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(e)); + })? + else { + writer.shutdown().await.into_diagnostic()?; + return Ok(()); + }; + + if close_seen { + let e = miette!("websocket frame received after close frame"); + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); + return Err(e); + } + + if let Err(e) = validate_frame_header(&frame, &fragments, options.compression) { + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); + return Err(e); + } + + match frame.opcode { + OPCODE_TEXT => { + let payload = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + let compressed = frame.rsv == 0x40; + if frame.fin { + relay_text_payload( + writer, &frame, payload, false, compressed, host, port, options, + ) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } else { + fragments = FragmentState::Text { + payload, + compressed, + }; + } + } + OPCODE_CONTINUATION => match &mut fragments { + FragmentState::Text { + payload, + compressed, + } => { + let next = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if let Err(e) = append_text_fragment(payload, next) { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); + return Err(e); + } + if frame.fin { + let complete = std::mem::take(payload); + let was_compressed = *compressed; + fragments = FragmentState::None; + relay_text_payload( + writer, + &frame, + complete, + true, + was_compressed, + host, + port, + options, + ) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } + } + FragmentState::Binary => { + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if frame.fin { + fragments = FragmentState::None; + } + } + FragmentState::None => { + let e = + miette!("websocket continuation frame without active fragmented message"); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); + return Err(e); + } + }, + OPCODE_BINARY => { + if !frame.fin { + fragments = FragmentState::Binary; + } + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } + OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG => { + relay_control_frame(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if frame.opcode == OPCODE_CLOSE { + close_seen = true; + } + } + _ => unreachable!("validated opcode"), + } + } +} + +async fn read_frame_header(reader: &mut R) -> Result> { + let first = match reader.read_u8().await { + Ok(byte) => byte, + Err(e) + if matches!( + e.kind(), + std::io::ErrorKind::UnexpectedEof + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::BrokenPipe + ) => + { + return Ok(None); + } + Err(e) => return Err(miette!("{e}")), + }; + let second = reader + .read_u8() + .await + .map_err(|e| miette!("malformed websocket frame header: {e}"))?; + + let mut raw_header = vec![first, second]; + let len_code = second & 0x7F; + let payload_len = match len_code { + 0..=125 => u64::from(len_code), + 126 => { + let mut bytes = [0u8; 2]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + raw_header.extend_from_slice(&bytes); + let len = u64::from(u16::from_be_bytes(bytes)); + if len < 126 { + return Err(miette!( + "websocket frame uses non-minimal 16-bit extended length" + )); + } + len + } + 127 => { + let mut bytes = [0u8; 8]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + if bytes[0] & 0x80 != 0 { + return Err(miette!("websocket frame uses non-canonical 64-bit length")); + } + raw_header.extend_from_slice(&bytes); + let len = u64::from_be_bytes(bytes); + if u16::try_from(len).is_ok() { + return Err(miette!( + "websocket frame uses non-minimal 64-bit extended length" + )); + } + len + } + _ => unreachable!("7-bit length code"), + }; + + let masked = second & 0x80 != 0; + let mask_key = if masked { + let mut key = [0u8; 4]; + reader + .read_exact(&mut key) + .await + .map_err(|e| miette!("malformed websocket mask key: {e}"))?; + raw_header.extend_from_slice(&key); + Some(key) + } else { + None + }; + + Ok(Some(FrameHeader { + fin: first & 0x80 != 0, + rsv: first & 0x70, + opcode: first & 0x0F, + masked, + payload_len, + mask_key, + raw_header, + })) +} + +fn validate_frame_header( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> Result<()> { + if !valid_rsv_bits(frame, fragments, compression) { + return Err(miette!( + "websocket frame has unsupported RSV bits or extension state" + )); + } + if !frame.masked { + return Err(miette!("websocket client frame is not masked")); + } + if !matches!( + frame.opcode, + OPCODE_CONTINUATION + | OPCODE_TEXT + | OPCODE_BINARY + | OPCODE_CLOSE + | OPCODE_PING + | OPCODE_PONG + ) { + return Err(miette!("websocket frame uses reserved opcode")); + } + if matches!(frame.opcode, OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG) { + if !frame.fin { + return Err(miette!("websocket control frame is fragmented")); + } + if frame.payload_len > 125 { + return Err(miette!("websocket control frame exceeds 125 bytes")); + } + } + if matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) + && !matches!(fragments, FragmentState::None) + { + return Err(miette!( + "websocket data frame started before previous fragmented message completed" + )); + } + if matches!(frame.opcode, OPCODE_CONTINUATION) && matches!(fragments, FragmentState::None) { + return Err(miette!( + "websocket continuation frame without active fragmented message" + )); + } + if (frame.opcode == OPCODE_BINARY + || (frame.opcode == OPCODE_CONTINUATION && matches!(fragments, FragmentState::Binary))) + && frame.payload_len > MAX_RAW_FRAME_PAYLOAD_BYTES + { + return Err(miette!( + "websocket binary frame exceeds {MAX_RAW_FRAME_PAYLOAD_BYTES} byte relay limit" + )); + } + Ok(()) +} + +fn valid_rsv_bits( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> bool { + if frame.rsv == 0 { + return true; + } + if compression != WebSocketCompression::PermessageDeflate || frame.rsv != 0x40 { + return false; + } + matches!(fragments, FragmentState::None) && matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) +} + +async fn read_masked_payload( + reader: &mut R, + frame: &FrameHeader, +) -> Result> { + let payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket text frame is too large to buffer"))?; + if payload_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + let mut payload = vec![0u8; payload_len]; + reader + .read_exact(&mut payload) + .await + .map_err(|e| miette!("malformed websocket payload: {e}"))?; + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + Ok(payload) +} + +fn append_text_fragment(buffer: &mut Vec, next: Vec) -> Result<()> { + let new_len = buffer + .len() + .checked_add(next.len()) + .ok_or_else(|| miette!("websocket text message length overflow"))?; + if new_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + buffer.extend_from_slice(&next); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +async fn relay_text_payload( + writer: &mut W, + frame: &FrameHeader, + payload: Vec, + force_reframe: bool, + compressed: bool, + host: &str, + port: u16, + options: &RelayOptions<'_>, +) -> Result<()> { + let message_payload = if compressed { + decompress_permessage_deflate(&payload)? + } else { + payload + }; + let mut text = String::from_utf8(message_payload) + .map_err(|_| miette!("websocket text message is not valid UTF-8"))?; + let replacements = if let Some(resolver) = options.resolver { + resolver + .rewrite_websocket_text_placeholders(&mut text) + .map_err(|_| miette!("websocket credential placeholder resolution failed"))? + } else { + 0 + }; + + if let Some(inspector) = options.inspector.as_ref() { + inspect_websocket_text_message(host, port, options.policy_name, inspector, &text)?; + } + + if replacements == 0 && !force_reframe && !compressed { + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut payload = text.into_bytes(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + writer.write_all(&payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + return Ok(()); + } + + if replacements > 0 { + emit_rewrite_event(host, port, options.policy_name, replacements); + } + if compressed { + let compressed_payload = compress_permessage_deflate(text.as_bytes())?; + return write_masked_frame_with_rsv(writer, OPCODE_TEXT, 0x40, &compressed_payload).await; + } + write_masked_frame(writer, OPCODE_TEXT, text.as_bytes()).await +} + +fn inspect_websocket_text_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, + text: &str, +) -> Result<()> { + if inspector.graphql_policy { + return inspect_graphql_websocket_message(host, port, policy_name, inspector, text); + } + + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + }; + let (allowed, reason) = evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)?; + let decision = match (allowed, inspector.enforcement) { + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + None, + ); + if !allowed && inspector.enforcement == EnforcementMode::Enforce { + return Err(miette!("websocket text message denied by policy")); + } + Ok(()) +} + +fn inspect_graphql_websocket_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, + text: &str, +) -> Result<()> { + match classify_graphql_websocket_message(text) { + GraphqlWebSocketMessage::Control { message_type } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_CONTROL".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + }; + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + "allow", + &format!("GraphQL WebSocket control message {message_type}"), + None, + ); + Ok(()) + } + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: Some(graphql.clone()), + }; + let parse_error_reason = graphql + .error + .as_deref() + .map(|error| format!("GraphQL WebSocket message rejected: {error}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason) = if let Some(reason) = parse_error_reason { + (false, reason) + } else { + evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)? + }; + let decision = match (allowed, inspector.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + let reason = format!("graphql_ws_type={message_type} {reason}"); + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + Some(&graphql), + ); + if (!allowed && inspector.enforcement == EnforcementMode::Enforce) || force_deny { + return Err(miette!("websocket GraphQL message denied by policy")); + } + Ok(()) + } + } +} + +#[derive(Debug)] +enum GraphqlWebSocketMessage { + Control { + message_type: String, + }, + Operation { + message_type: String, + graphql: crate::l7::graphql::GraphqlRequestInfo, + }, +} + +fn classify_graphql_websocket_message(text: &str) -> GraphqlWebSocketMessage { + let value = match serde_json::from_str::(text) { + Ok(value) => value, + Err(err) => { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error(format!( + "GraphQL WebSocket message is not valid JSON: {err}" + )), + }; + } + }; + let Some(obj) = value.as_object() else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message must be a JSON object"), + }; + }; + let Some(message_type) = obj.get("type").and_then(serde_json::Value::as_str) else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message missing string type"), + }; + }; + + match message_type { + "subscribe" | "start" => { + if obj + .get("id") + .and_then(serde_json::Value::as_str) + .is_none_or(str::is_empty) + { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing non-empty id", + ), + }; + } + let Some(payload) = obj.get("payload").filter(|value| value.is_object()) else { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing object payload", + ), + }; + }; + GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: crate::l7::graphql::classify_json_envelope_value(payload), + } + } + "connection_init" | "connection_terminate" | "ping" | "pong" | "complete" | "stop" => { + GraphqlWebSocketMessage::Control { + message_type: message_type.to_string(), + } + } + _ => GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error(format!( + "unsupported GraphQL WebSocket client message type {message_type:?}" + )), + }, + } +} + +fn graphql_error(message: impl Into) -> crate::l7::graphql::GraphqlRequestInfo { + crate::l7::graphql::GraphqlRequestInfo { + operations: Vec::new(), + error: Some(message.into()), + } +} + +async fn relay_control_frame( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let raw_payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket control frame payload length overflow"))?; + let mut raw_payload = vec![0u8; raw_payload_len]; + reader + .read_exact(&mut raw_payload) + .await + .map_err(|e| miette!("malformed websocket control payload: {e}"))?; + + if frame.opcode == OPCODE_CLOSE { + let mut payload = raw_payload.clone(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + validate_close_payload(&payload)?; + } + + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + writer.write_all(&raw_payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn validate_close_payload(payload: &[u8]) -> Result<()> { + if payload.len() == 1 { + return Err(miette!( + "websocket close frame payload cannot be exactly one byte" + )); + } + if payload.len() < 2 { + return Ok(()); + } + + let code = u16::from_be_bytes([payload[0], payload[1]]); + if !valid_close_code(code) { + return Err(miette!("websocket close frame uses invalid close code")); + } + if std::str::from_utf8(&payload[2..]).is_err() { + return Err(miette!("websocket close frame reason is not valid UTF-8")); + } + Ok(()) +} + +fn valid_close_code(code: u16) -> bool { + (matches!(code, 1000..=1014) && !matches!(code, 1004..=1006)) || (3000..=4999).contains(&code) +} + +async fn copy_raw_frame_payload( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut remaining = frame.payload_len; + let mut buf = [0u8; COPY_BUF_SIZE]; + while remaining > 0 { + let to_read = usize::try_from(remaining) + .unwrap_or(buf.len()) + .min(buf.len()); + let n = reader.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("websocket payload ended before declared length")); + } + writer.write_all(&buf[..n]).await.into_diagnostic()?; + remaining -= n as u64; + } + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +async fn write_masked_frame( + writer: &mut W, + opcode: u8, + payload: &[u8], +) -> Result<()> { + write_masked_frame_with_rsv(writer, opcode, 0, payload).await +} + +async fn write_masked_frame_with_rsv( + writer: &mut W, + opcode: u8, + rsv: u8, + payload: &[u8], +) -> Result<()> { + let mut header = Vec::with_capacity(14); + header.push(0x80 | rsv | opcode); + match payload.len() { + 0..=125 => header.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + header.push(0x80 | 0x7e); + header.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + header.push(0x80 | 127); + header.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + let mask_key = new_mask_key(); + header.extend_from_slice(&mask_key); + + let mut masked = payload.to_vec(); + apply_mask(&mut masked, mask_key); + writer.write_all(&header).await.into_diagnostic()?; + writer.write_all(&masked).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn decompress_permessage_deflate(payload: &[u8]) -> Result> { + let mut decoder = Decompress::new(false); + let mut input = Vec::with_capacity(payload.len() + 4); + input.extend_from_slice(payload); + input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); + let mut out = Vec::with_capacity(payload.len().saturating_mul(2).min(MAX_TEXT_MESSAGE_BYTES)); + let mut input_pos = 0usize; + let mut scratch = [0u8; COPY_BUF_SIZE]; + loop { + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let status = decoder + .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate decompression failed: {e}"))?; + let read = usize::try_from(decoder.total_in() - before_in) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + let written = usize::try_from(decoder.total_out() - before_out) + .map_err(|_| miette!("websocket permessage-deflate output length overflow"))?; + input_pos = input_pos + .checked_add(read) + .ok_or_else(|| miette!("websocket permessage-deflate input length overflow"))?; + if out.len().saturating_add(written) > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + out.extend_from_slice(&scratch[..written]); + if matches!(status, Status::StreamEnd) { + break; + } + if input_pos >= input.len() && written < scratch.len() { + break; + } + if read == 0 && written == 0 { + return Err(miette!( + "websocket permessage-deflate decompression did not make progress" + )); + } + } + Ok(out) +} + +fn compress_permessage_deflate(payload: &[u8]) -> Result> { + let mut compressor = Compress::new(Compression::fast(), false); + let expansion = payload.len() / 16; + let mut out = Vec::with_capacity(payload.len().saturating_add(expansion).saturating_add(128)); + loop { + let consumed = usize::try_from(compressor.total_in()) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + if consumed >= payload.len() { + break; + } + let before_in = compressor.total_in(); + let before_out = compressor.total_out(); + let status = compressor + .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if matches!(status, Status::BufError) + || (compressor.total_in() == before_in && compressor.total_out() == before_out) + { + out.reserve(out.capacity().max(1024)); + } + } + loop { + out.reserve(64); + let before_out = compressor.total_out(); + compressor + .compress_vec(&[], &mut out, FlushCompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + break; + } + if compressor.total_out() == before_out { + out.reserve(out.capacity().max(1024)); + } + } + if !out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + return Err(miette!( + "websocket permessage-deflate compression missing sync marker" + )); + } + out.truncate(out.len() - 4); + Ok(out) +} + +fn new_mask_key() -> [u8; 4] { + let bytes = uuid::Uuid::new_v4().into_bytes(); + [bytes[0], bytes[1], bytes[2], bytes[3]] +} + +fn apply_mask(payload: &mut [u8], mask_key: [u8; 4]) { + for (i, byte) in payload.iter_mut().enumerate() { + *byte ^= mask_key[i % 4]; + } +} + +fn emit_rewrite_event(host: &str, port: u16, policy_name: &str, replacements: usize) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(rewrite_event_message(host, port, replacements)) + .build(); + ocsf_emit!(event); +} + +fn rewrite_event_message(host: &str, port: u16, replacements: usize) -> String { + format!( + "WEBSOCKET_CREDENTIAL_REWRITE rewrote client text message [host:{host} port:{port} replacements:{replacements}]" + ) +} + +fn emit_websocket_l7_event( + host: &str, + port: u16, + policy_name: &str, + request_info: &L7RequestInfo, + decision: &str, + reason: &str, + graphql: Option<&crate::l7::graphql::GraphqlRequestInfo>, +) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let (action_id, disposition_id, severity) = match decision { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "allow" | "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let summary = graphql.map(graphql_log_summary).unwrap_or_default(); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(format!( + "WEBSOCKET_L7_REQUEST {decision} {} {host}:{port}{}{} reason={reason}", + request_info.action, request_info.target, summary + )) + .build(); + ocsf_emit!(event); +} + +fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String { + if let Some(error) = info.error.as_deref() { + return format!(" graphql_error={error:?}"); + } + let ops: Vec = info + .operations + .iter() + .map(|op| { + let name = op.operation_name.as_deref().unwrap_or("-"); + let fields = if op.fields.is_empty() { + "-".to_string() + } else { + op.fields.join(",") + }; + let persisted = op + .persisted_query_hash + .as_deref() + .or(op.persisted_query_id.as_deref()) + .unwrap_or("-"); + format!( + "type={} name={} fields={} persisted={}", + op.operation_type, name, fields, persisted + ) + }) + .collect(); + format!(" graphql_ops={}", ops.join(";")) +} + +fn protocol_failure_class(error: &miette::Report) -> &'static str { + let msg = error.to_string().to_ascii_lowercase(); + if msg.contains("credential") { + "credential_resolution_failed" + } else if msg.contains("utf-8") { + "invalid_utf8" + } else if msg.contains("close frame") || msg.contains("after close") { + "invalid_close_frame" + } else if msg.contains("control frame") { + "invalid_control_frame" + } else if msg.contains("length") + || msg.contains("too large") + || msg.contains("exceeds") + || msg.contains("overflow") + { + "invalid_length" + } else if msg.contains("continuation") || msg.contains("fragmented") { + "invalid_fragmentation" + } else if msg.contains("reserved opcode") { + "reserved_opcode" + } else if msg.contains("not masked") { + "unmasked_client_frame" + } else if msg.contains("rsv") { + "rsv_bits" + } else if msg.contains("malformed") { + "malformed_frame" + } else { + "protocol_error" + } +} + +fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, failure_class: &str) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(protocol_failure_message(host, port)) + .status_detail(failure_class) + .build(); + ocsf_emit!(event); +} + +fn protocol_failure_message(host: &str, port: u16) -> String { + format!("WEBSOCKET_CREDENTIAL_REWRITE closed ambiguous client frame [host:{host} port:{port}]") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::l7::relay::L7EvalContext; + use crate::opa::{NetworkInput, OpaEngine}; + use crate::secrets::SecretResolver; + use std::path::PathBuf; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + const GRAPHQL_WS_POLICY: &str = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + binaries: + - { path: /usr/bin/node } +"#; + + fn resolver() -> (HashMap, SecretResolver) { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + (child_env, resolver.expect("resolver")) + } + + fn masked_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec { + masked_frame_with_rsv(fin, opcode, 0, payload) + } + + fn masked_frame_with_rsv(fin: bool, opcode: u8, rsv: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push((if fin { 0x80 } else { 0 }) | rsv | opcode); + match payload.len() { + 0..=125 => frame.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + frame.push(0x80 | 127); + frame.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(u8::try_from(payload.len()).expect("test payload fits in one byte")); + frame.extend_from_slice(payload); + frame + } + + fn masked_frame_with_declared_len(opcode: u8, declared_len: u64) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 127); + frame.extend_from_slice(&declared_len.to_be_bytes()); + frame.extend_from_slice(&[0x37, 0xfa, 0x21, 0x3d]); + frame + } + + fn masked_frame_with_non_minimal_16_bit_len(opcode: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("test payload fits u16") + .to_be_bytes(), + ); + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn close_payload(code: u16, reason: &[u8]) -> Vec { + let mut payload = Vec::with_capacity(2 + reason.len()); + payload.extend_from_slice(&code.to_be_bytes()); + payload.extend_from_slice(reason); + payload + } + + async fn run_client_to_server(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + async fn run_client_to_server_with_graphql_policy( + input: Vec, + resolver: Option<&SecretResolver>, + ) -> Result> { + let engine = OpaEngine::from_strings(TEST_POLICY, GRAPHQL_WS_POLICY) + .expect("GraphQL WebSocket policy should load"); + let network_input = NetworkInput { + host: "realtime.graphql.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&network_input) + .expect("network action should evaluate") + .1; + let tunnel_engine = engine + .clone_engine_for_tunnel(generation) + .expect("tunnel engine"); + let ctx = L7EvalContext { + host: "realtime.graphql.test".into(), + port: 443, + policy_name: "graphql_ws".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "graphql_ws", + resolver, + inspector: Some(InspectionOptions { + engine: &tunnel_engine, + ctx: &ctx, + enforcement: EnforcementMode::Enforce, + target: "/graphql".to_string(), + query_params: HashMap::new(), + graphql_policy: true, + }), + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "realtime.graphql.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + async fn run_client_to_server_compressed(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::PermessageDeflate, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + fn decode_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_ne!(frame[1] & 0x80, 0); + String::from_utf8(decode_masked_payload(frame)).unwrap() + } + + fn decode_masked_payload(frame: &[u8]) -> Vec { + assert_ne!(frame[1] & 0x80, 0); + let len_code = frame[1] & 0x7F; + let (payload_len, mask_offset) = match len_code { + 0..=125 => (usize::from(len_code), 2), + 126 => (usize::from(u16::from_be_bytes([frame[2], frame[3]])), 4), + 127 => { + let len = u64::from_be_bytes(frame[2..10].try_into().unwrap()); + (usize::try_from(len).unwrap(), 10) + } + _ => unreachable!(), + }; + let mask_key: [u8; 4] = frame[mask_offset..mask_offset + 4].try_into().unwrap(); + let mut payload = frame[mask_offset + 4..mask_offset + 4 + payload_len].to_vec(); + apply_mask(&mut payload, mask_key); + payload + } + + fn decode_compressed_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_eq!(frame[0] & 0x40, 0x40); + let payload = decode_masked_payload(frame); + String::from_utf8(decompress_permessage_deflate(&payload).unwrap()).unwrap() + } + + async fn read_one_frame(reader: &mut R) -> Vec { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await.unwrap(); + let len_code = header[1] & 0x7F; + let extended_len = match len_code { + 0..=125 => Vec::new(), + 126 => { + let mut bytes = vec![0u8; 2]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + 127 => { + let mut bytes = vec![0u8; 8]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + _ => unreachable!(), + }; + let payload_len = match len_code { + 0..=125 => usize::from(len_code), + 126 => usize::from(u16::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )), + 127 => usize::try_from(u64::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )) + .unwrap(), + _ => unreachable!(), + }; + let mask_len = if header[1] & 0x80 != 0 { 4 } else { 0 }; + let mut rest = vec![0u8; extended_len.len() + mask_len + payload_len]; + rest[..extended_len.len()].copy_from_slice(&extended_len); + reader + .read_exact(&mut rest[extended_len.len()..]) + .await + .unwrap(); + + let mut frame = header.to_vec(); + frame.extend_from_slice(&rest); + frame + } + + #[test] + fn classifies_graphql_transport_ws_subscribe_operation() { + let message = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "subscribe"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations.len(), 1); + assert_eq!(graphql.operations[0].operation_type, "subscription"); + assert_eq!( + graphql.operations[0].operation_name.as_deref(), + Some("NewMessages") + ); + assert_eq!(graphql.operations[0].fields, vec!["messageAdded"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_legacy_graphql_ws_start_operation() { + let message = r#"{"type":"start","id":"1","payload":{"query":"query Viewer { viewer }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "start"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations[0].operation_type, "query"); + assert_eq!(graphql.operations[0].fields, vec!["viewer"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_graphql_websocket_control_message_without_payload_logging() { + match classify_graphql_websocket_message( + r#"{"type":"connection_init","payload":{"authorization":"secret"}}"#, + ) { + GraphqlWebSocketMessage::Control { message_type } => { + assert_eq!(message_type, "connection_init"); + } + other @ GraphqlWebSocketMessage::Operation { .. } => { + panic!("expected control message, got {other:?}") + } + } + } + + #[test] + fn unsupported_graphql_websocket_message_type_fails_closed() { + match classify_graphql_websocket_message(r#"{"type":"next","id":"1"}"#) { + GraphqlWebSocketMessage::Operation { graphql, .. } => { + assert!( + graphql + .error + .as_deref() + .is_some_and(|error| error.contains("unsupported")) + ); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation error, got {other:?}") + } + } + } + + #[test] + fn graphql_websocket_log_summary_excludes_payload_variables_and_secrets() { + let placeholder = "openshell:resolve:env:T"; + let message = format!( + r#"{{"type":"subscribe","id":"1","payload":{{"query":"query Viewer {{ viewer }}","variables":{{"token":"{placeholder}"}}}}}}"# + ); + let graphql = match classify_graphql_websocket_message(&message) { + GraphqlWebSocketMessage::Operation { graphql, .. } => graphql, + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + }; + let summary = graphql_log_summary(&graphql); + + assert!(summary.contains("type=query")); + assert!(summary.contains("fields=viewer")); + assert!(!summary.contains(placeholder)); + assert!(!summary.contains("real-token")); + assert!(!summary.contains("variables")); + assert!(!summary.contains("token")); + assert!(!summary.contains("secret_len")); + } + + #[tokio::test] + async fn rewrites_discord_like_identify_text_payload() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let output = run_client_to_server(masked_frame(true, OPCODE_TEXT, payload.as_bytes())) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } + + #[tokio::test] + async fn upgraded_relay_rewrites_client_text_before_upstream_receives_it() { + let (child_env, resolver) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + let client_frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + assert!( + !String::from_utf8_lossy(&client_frame).contains("real-token"), + "client-side fixture must not contain the real token" + ); + + let (mut client_app, mut relay_client) = tokio::io::duplex(4096); + let (mut relay_upstream, mut upstream_app) = tokio::io::duplex(4096); + let relay = tokio::spawn(async move { + relay_with_options( + &mut relay_client, + &mut relay_upstream, + Vec::new(), + "gateway.example.test", + 443, + RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }, + ) + .await + }); + + client_app.write_all(&client_frame).await.unwrap(); + client_app.flush().await.unwrap(); + + let upstream_frame = tokio::time::timeout( + std::time::Duration::from_secs(2), + read_one_frame(&mut upstream_app), + ) + .await + .expect("upstream should receive rewritten frame"); + assert_eq!( + decode_masked_text_frame(&upstream_frame), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + + drop(client_app); + drop(upstream_app); + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay).await; + } + + #[tokio::test] + async fn graphql_websocket_policy_allows_subscription_operation() { + let payload = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame.clone(), None) + .await + .expect("allowed subscription should relay"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), payload); + } + + #[tokio::test] + async fn graphql_websocket_policy_denies_unlisted_operation_field() { + let payload = + r#"{"type":"subscribe","id":"1","payload":{"query":"query Admin { adminAuditLog }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let err = run_client_to_server_with_graphql_policy(frame, None) + .await + .expect_err("unlisted field should be denied"); + + assert!(err.to_string().contains("websocket GraphQL message denied")); + } + + #[tokio::test] + async fn graphql_websocket_control_message_rewrites_credentials_before_relay() { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("T").expect("placeholder env"); + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame, Some(&resolver)) + .await + .expect("control message should relay after credential rewrite"); + + let rewritten = decode_masked_text_frame(&output); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + } + + #[tokio::test] + async fn text_without_placeholder_passes_semantically_unchanged() { + let frame = masked_frame(true, OPCODE_TEXT, br#"{"op":1,"d":42}"#); + let output = run_client_to_server(frame.clone()) + .await + .expect("relay should succeed"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), r#"{"op":1,"d":42}"#); + } + + #[tokio::test] + async fn unknown_placeholder_fails_closed() { + let frame = masked_frame( + true, + OPCODE_TEXT, + br#"{"token":"openshell:resolve:env:UNKNOWN"}"#, + ); + + let err = run_client_to_server(frame) + .await + .expect_err("unknown placeholder should fail"); + + assert!( + err.to_string() + .contains("credential placeholder resolution") + ); + } + + #[tokio::test] + async fn fragmented_text_rewrites_after_final_continuation() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let second = r#""}"#; + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend(masked_frame(true, OPCODE_CONTINUATION, second.as_bytes())); + + let output = run_client_to_server(input) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn rejects_rsv_bits() { + let mut frame = masked_frame(true, OPCODE_TEXT, b"hello"); + frame[0] |= 0x40; + + let err = run_client_to_server(frame) + .await + .expect_err("RSV frame should fail"); + + assert!(err.to_string().contains("RSV bits")); + } + + #[tokio::test] + async fn rejects_unmasked_client_frame() { + let err = run_client_to_server(unmasked_frame(OPCODE_TEXT, b"hello")) + .await + .expect_err("unmasked frame should fail"); + + assert!(err.to_string().contains("not masked")); + } + + #[tokio::test] + async fn rejects_invalid_utf8_text() { + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &[0xff])) + .await + .expect_err("invalid UTF-8 should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } + + #[tokio::test] + async fn rejects_oversize_text_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &payload)) + .await + .expect_err("oversize text should fail"); + + assert!(err.to_string().contains("exceeds")); + } + + #[tokio::test] + async fn fragmented_text_allows_interleaved_ping_pong_and_rewrites_at_completion() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let first_control_frame = masked_frame(true, OPCODE_PING, b"p"); + let second_control_frame = masked_frame(true, OPCODE_PONG, b"q"); + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend_from_slice(&first_control_frame); + input.extend_from_slice(&second_control_frame); + input.extend(masked_frame(true, OPCODE_CONTINUATION, br#""}"#)); + + let output = run_client_to_server(input) + .await + .expect("relay should allow interleaved control frames"); + + assert!(output.starts_with(&first_control_frame)); + assert_eq!( + &output + [first_control_frame.len()..first_control_frame.len() + second_control_frame.len()], + second_control_frame.as_slice() + ); + assert_eq!( + decode_masked_text_frame( + &output[first_control_frame.len() + second_control_frame.len()..] + ), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn compressed_text_rewrites_with_permessage_deflate() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"token":"{placeholder}"}}"#); + let compressed = compress_permessage_deflate(payload.as_bytes()).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let output = run_client_to_server_compressed(input) + .await + .expect("compressed text should relay"); + + assert_eq!( + decode_compressed_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn compressed_text_rejects_decompressed_oversize_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let compressed = compress_permessage_deflate(&payload).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let err = run_client_to_server_compressed(input) + .await + .expect_err("oversize decompressed text should fail"); + + assert!(err.to_string().contains("exceeds")); + } + + #[tokio::test] + async fn binary_frame_passes_through_unchanged() { + let frame = masked_frame(true, OPCODE_BINARY, &[0, 1, 2, 3, 255]); + + let output = run_client_to_server(frame.clone()) + .await + .expect("binary frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_reserved_opcode() { + let err = run_client_to_server(masked_frame(true, 0x3, b"reserved")) + .await + .expect_err("reserved opcode should fail"); + + assert!(err.to_string().contains("reserved opcode")); + } + + #[tokio::test] + async fn rejects_continuation_without_active_message() { + let err = run_client_to_server(masked_frame(true, OPCODE_CONTINUATION, b"orphan")) + .await + .expect_err("orphan continuation should fail"); + + assert!(err.to_string().contains("continuation")); + } + + #[tokio::test] + async fn rejects_new_data_frame_before_fragment_completion() { + let mut input = masked_frame(false, OPCODE_TEXT, b"partial"); + input.extend(masked_frame(true, OPCODE_TEXT, b"second")); + + let err = run_client_to_server(input) + .await + .expect_err("new data frame during fragmentation should fail"); + + assert!(err.to_string().contains("previous fragmented message")); + } + + #[tokio::test] + async fn rejects_fragmented_control_frame() { + let err = run_client_to_server(masked_frame(false, OPCODE_PING, b"ping")) + .await + .expect_err("fragmented control frame should fail"); + + assert!(err.to_string().contains("control frame is fragmented")); + } + + #[tokio::test] + async fn rejects_control_frame_over_125_bytes() { + let payload = vec![b'a'; 126]; + let err = run_client_to_server(masked_frame(true, OPCODE_PING, &payload)) + .await + .expect_err("oversize control frame should fail"); + + assert!(err.to_string().contains("control frame exceeds")); + } + + #[tokio::test] + async fn rejects_non_minimal_extended_length() { + let err = run_client_to_server(masked_frame_with_non_minimal_16_bit_len( + OPCODE_TEXT, + b"hello", + )) + .await + .expect_err("non-minimal length should fail"); + + assert!(err.to_string().contains("non-minimal")); + } + + #[tokio::test] + async fn rejects_oversize_binary_frame_before_payload_buffering() { + let err = run_client_to_server(masked_frame_with_declared_len( + OPCODE_BINARY, + MAX_RAW_FRAME_PAYLOAD_BYTES + 1, + )) + .await + .expect_err("oversize binary frame should fail"); + + assert!(err.to_string().contains("binary frame exceeds")); + } + + #[tokio::test] + async fn validates_close_frame_payloads() { + let frame = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + + let output = run_client_to_server(frame.clone()) + .await + .expect("valid close frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_close_frame_with_one_byte_payload() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &[0x03])) + .await + .expect_err("one-byte close frame should fail"); + + assert!(err.to_string().contains("exactly one byte")); + } + + #[tokio::test] + async fn rejects_reserved_close_code() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &close_payload(1005, b""))) + .await + .expect_err("reserved close code should fail"); + + assert!(err.to_string().contains("invalid close code")); + } + + #[tokio::test] + async fn rejects_close_reason_with_invalid_utf8() { + let err = run_client_to_server(masked_frame( + true, + OPCODE_CLOSE, + &close_payload(1000, &[0xff]), + )) + .await + .expect_err("invalid close reason should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } + + #[tokio::test] + async fn rejects_frames_after_client_close_frame() { + let mut input = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + input.extend(masked_frame(true, OPCODE_TEXT, b"late")); + + let err = run_client_to_server(input) + .await + .expect_err("frames after close should fail"); + + assert!(err.to_string().contains("after close")); + } + + #[test] + fn websocket_ocsf_messages_do_not_include_payload_or_secret_material() { + let placeholder = "openshell:resolve:env:DISCORD_BOT_TOKEN"; + let secret = "real-token"; + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let rewrite = rewrite_event_message("gateway.example.test", 443, 1); + let failure = protocol_failure_message("gateway.example.test", 443); + let messages = [rewrite, failure]; + + for message in messages { + assert!(!message.contains(placeholder)); + assert!(!message.contains(secret)); + assert!(!message.contains(&payload)); + assert!(!message.contains("secret_len")); + assert!(!message.contains("payload_len")); + } + } +} diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index 5897679a0..a9ab94a2b 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -1061,6 +1061,12 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St if e.allow_encoded_slash { ep["allow_encoded_slash"] = true.into(); } + if e.websocket_credential_rewrite { + ep["websocket_credential_rewrite"] = true.into(); + } + if e.request_body_credential_rewrite { + ep["request_body_credential_rewrite"] = true.into(); + } if !e.persisted_queries.is_empty() { ep["persisted_queries"] = e.persisted_queries.clone().into(); } @@ -1811,6 +1817,28 @@ network_policies: access: read-only binaries: - { path: /usr/bin/curl } + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + deny_rules: + - operation_type: mutation + binaries: + - { path: /usr/bin/curl } l4_only: name: l4_only endpoints: @@ -1897,6 +1925,25 @@ process: }) } + fn l7_websocket_graphql_input(host: &str, operations: serde_json::Value) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": 443 }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "WEBSOCKET_TEXT", + "path": "/graphql", + "query_params": {}, + "graphql": { + "operations": operations + } + } + }) + } + fn eval_l7(engine: &OpaEngine, input: &serde_json::Value) -> bool { let mut eng = engine.engine.lock().unwrap(); eng.set_input_json(&input.to_string()).unwrap(); @@ -2134,6 +2181,97 @@ process: assert!(!eval_l7(&engine, &mutation)); } + #[test] + fn l7_websocket_graphql_subscription_allowed_by_field_rule() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "subscription", + "operation_name": "NewMessages", + "fields": ["messageAdded"], + "persisted_query": false + }]), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_unlisted_field_denied() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_deny_rule_takes_precedence() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "mutation", + "operation_name": "DeleteRepo", + "fields": ["deleteRepository"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_not_bypassed_by_generic_text_rule() { + let data = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + method: WEBSOCKET_TEXT + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + binaries: + - { path: /usr/bin/curl } +"#; + let data_json: serde_json::Value = + serde_yml::from_str(data).expect("fixture should parse as YAML"); + let mut rego = regorus::Engine::new(); + rego.add_policy("policy.rego".into(), TEST_POLICY.into()) + .expect("policy should load"); + rego.add_data_json(&data_json.to_string()) + .expect("data should load"); + let engine = OpaEngine { + engine: Mutex::new(rego), + generation: Arc::new(AtomicU64::new(0)), + }; + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + #[test] fn l7_endpoint_path_scopes_rest_and_graphql_on_same_host() { let data = r#" @@ -2463,6 +2601,120 @@ network_policies: assert!(l7.allow_encoded_slash); } + #[test] + fn l7_endpoint_config_preserves_proto_websocket_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "gateway".to_string(), + NetworkPolicyRule { + name: "gateway".to_string(), + endpoints: vec![NetworkEndpoint { + host: "gateway.example.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "full".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "gateway.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.websocket_credential_rewrite); + } + + #[test] + fn l7_endpoint_config_preserves_proto_request_body_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "slack".to_string(), + NetworkPolicyRule { + name: "slack".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "read-write".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "slack.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.request_body_credential_rewrite); + } + #[test] fn l7_endpoint_config_none_for_l4_only() { let engine = l7_engine(); diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs index 21556ec6a..165b0c1bd 100644 --- a/crates/openshell-sandbox/src/policy_local.rs +++ b/crates/openshell-sandbox/src/policy_local.rs @@ -619,6 +619,8 @@ fn network_endpoint_from_json( ports, deny_rules, allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, // GraphQL persisted-query knobs and path scoping default empty — // agent proposals don't author them today. persisted_queries: String::new(), diff --git a/crates/openshell-sandbox/src/provider_credentials.rs b/crates/openshell-sandbox/src/provider_credentials.rs index bd28824ae..829e1b226 100644 --- a/crates/openshell-sandbox/src/provider_credentials.rs +++ b/crates/openshell-sandbox/src/provider_credentials.rs @@ -19,6 +19,7 @@ pub struct ProviderCredentialSnapshot { struct ProviderCredentialStateInner { current: Arc, generations: VecDeque>, + current_resolver: Option>, combined_resolver: Option>, } @@ -29,19 +30,21 @@ pub struct ProviderCredentialState { impl ProviderCredentialState { pub fn from_environment(revision: u64, env: HashMap) -> Self { - let (child_env, resolver) = SecretResolver::from_provider_env_for_revision(env, revision); + let (child_env, generation_resolver, current_resolver) = + SecretResolver::from_provider_env_for_current_revision(env, revision); let snapshot = Arc::new(ProviderCredentialSnapshot { revision, child_env, }); - let generations: VecDeque<_> = resolver.map(Arc::new).into_iter().collect(); - let combined_resolver = - SecretResolver::merge(generations.iter().map(Arc::as_ref)).map(Arc::new); + let generations: VecDeque<_> = generation_resolver.map(Arc::new).into_iter().collect(); + let current_resolver = current_resolver.map(Arc::new); + let combined_resolver = merge_resolvers(&generations, current_resolver.as_ref()); Self { inner: Arc::new(RwLock::new(ProviderCredentialStateInner { current: snapshot, generations, + current_resolver, combined_resolver, })), } @@ -64,7 +67,8 @@ impl ProviderCredentialState { } pub fn install_environment(&self, revision: u64, env: HashMap) -> usize { - let (child_env, resolver) = SecretResolver::from_provider_env_for_revision(env, revision); + let (child_env, generation_resolver, current_resolver) = + SecretResolver::from_provider_env_for_current_revision(env, revision); let mut inner = self .inner .write() @@ -74,19 +78,33 @@ impl ProviderCredentialState { revision, child_env, }); + inner.current_resolver = current_resolver.map(Arc::new); - if let Some(resolver) = resolver { + if let Some(resolver) = generation_resolver { inner.generations.push_back(Arc::new(resolver)); while inner.generations.len() > MAX_RETAINED_CREDENTIAL_GENERATIONS { inner.generations.pop_front(); } } inner.combined_resolver = - SecretResolver::merge(inner.generations.iter().map(Arc::as_ref)).map(Arc::new); + merge_resolvers(&inner.generations, inner.current_resolver.as_ref()); inner.current.child_env.len() } } +fn merge_resolvers( + generations: &VecDeque>, + current_resolver: Option<&Arc>, +) -> Option> { + SecretResolver::merge( + generations + .iter() + .map(Arc::as_ref) + .chain(current_resolver.into_iter().map(Arc::as_ref)), + ) + .map(Arc::new) +} + #[cfg(test)] mod tests { use super::*; @@ -122,10 +140,18 @@ mod tests { resolver.resolve_placeholder("openshell:resolve:env:v11_GITHUB_TOKEN"), Some("new") ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:GITHUB_TOKEN"), + Some("new") + ); + assert_eq!( + resolver.resolve_placeholder("provider-OPENSHELL-RESOLVE-ENV-GITHUB_TOKEN"), + Some("new") + ); } #[test] - fn empty_refresh_removes_env_from_new_snapshots_but_retains_old_resolver() { + fn empty_refresh_removes_current_aliases_but_retains_revisioned_resolver() { let state = ProviderCredentialState::from_environment( 10, HashMap::from([("GITHUB_TOKEN".to_string(), "old".to_string())]), @@ -139,5 +165,13 @@ mod tests { resolver.resolve_placeholder("openshell:resolve:env:v10_GITHUB_TOKEN"), Some("old") ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:GITHUB_TOKEN"), + None + ); + assert_eq!( + resolver.resolve_placeholder("provider-OPENSHELL-RESOLVE-ENV-GITHUB_TOKEN"), + None + ); } } diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index f20e51655..3012930e2 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -10,7 +10,7 @@ use crate::opa::{NetworkAction, OpaEngine, PolicyGenerationGuard}; use crate::policy::ProxyPolicy; use crate::policy_local::{POLICY_LOCAL_HOST, PolicyLocalContext}; use crate::provider_credentials::ProviderCredentialState; -use crate::secrets::{SecretResolver, rewrite_header_line}; +use crate::secrets::{SecretResolver, rewrite_header_line_checked}; use miette::{IntoDiagnostic, Result}; use openshell_core::net::{is_always_blocked_ip, is_internal_ip}; use openshell_ocsf::{ @@ -2277,11 +2277,17 @@ fn rewrite_forward_request( used: usize, path: &str, secret_resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, ) -> Result, crate::secrets::UnresolvedPlaceholderError> { let header_end = raw[..used] .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(used, |p| p + 4); + let websocket_upgrade = crate::l7::rest::request_is_websocket_upgrade(&raw[..header_end]); + let upstream_path = match secret_resolver { + Some(resolver) => crate::secrets::rewrite_target_for_eval(path, resolver)?.resolved, + None => path.to_string(), + }; let header_str = String::from_utf8_lossy(&raw[..header_end]); let lines = header_str.split("\r\n").collect::>(); @@ -2298,7 +2304,7 @@ fn rewrite_forward_request( if parts.len() == 3 { output.extend_from_slice(parts[0].as_bytes()); output.push(b' '); - output.extend_from_slice(path.as_bytes()); + output.extend_from_slice(upstream_path.as_bytes()); output.push(b' '); output.extend_from_slice(parts[2].as_bytes()); } else { @@ -2325,14 +2331,19 @@ fn rewrite_forward_request( // Replace Connection header if lower.starts_with("connection:") { has_connection = true; + if websocket_upgrade { + output.extend_from_slice(line.as_bytes()); + output.extend_from_slice(b"\r\n"); + continue; + } output.extend_from_slice(b"Connection: close\r\n"); continue; } - let rewritten_line = secret_resolver.map_or_else( - || line.to_string(), - |resolver| rewrite_header_line(line, resolver), - ); + let rewritten_line = match secret_resolver { + Some(resolver) => rewrite_header_line_checked(line, resolver)?, + None => line.to_string(), + }; output.extend_from_slice(rewritten_line.as_bytes()); output.extend_from_slice(b"\r\n"); @@ -2343,7 +2354,7 @@ fn rewrite_forward_request( } // Inject missing headers - if !has_connection { + if !has_connection && !websocket_upgrade { output.extend_from_slice(b"Connection: close\r\n"); } if !has_via { @@ -2352,6 +2363,7 @@ fn rewrite_forward_request( // End of headers output.extend_from_slice(b"\r\n"); + let rewritten_header_end = output.len(); // Append any overflow body bytes from the original buffer if header_end < used { @@ -2360,8 +2372,15 @@ fn rewrite_forward_request( // Fail-closed: scan for any remaining unresolved placeholders if secret_resolver.is_some() { - let output_str = String::from_utf8_lossy(&output); - if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) { + let scan_end = if request_body_credential_rewrite { + rewritten_header_end + } else { + output.len() + }; + let output_str = String::from_utf8_lossy(&output[..scan_end]); + if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) + || output_str.contains(crate::secrets::PROVIDER_ALIAS_MARKER_PUBLIC) + { return Err(crate::secrets::UnresolvedPlaceholderError { location: "header" }); } } @@ -2369,13 +2388,20 @@ fn rewrite_forward_request( Ok(output) } +struct ForwardRelayOptions<'a> { + generation_guard: &'a PolicyGenerationGuard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode, + secret_resolver: Option<&'a SecretResolver>, + request_body_credential_rewrite: bool, +} + async fn relay_rewritten_forward_request( method: &str, path: &str, rewritten: Vec, client: &mut C, upstream: &mut U, - generation_guard: &PolicyGenerationGuard, + options: ForwardRelayOptions<'_>, ) -> Result where C: TokioAsyncRead + TokioAsyncWrite + Unpin, @@ -2396,12 +2422,16 @@ where body_length, }; - crate::l7::rest::relay_http_request_with_resolver_guarded( + crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - None, - Some(generation_guard), + crate::l7::rest::RelayRequestOptions { + resolver: options.secret_resolver, + generation_guard: Some(options.generation_guard), + websocket_extensions: options.websocket_extensions, + request_body_credential_rewrite: options.request_body_credential_rewrite, + }, ) .await } @@ -2623,6 +2653,35 @@ async fn handle_forward_proxy( }; let mut forward_request_bytes = buf[..used].to_vec(); let mut upstream_target = path.clone(); + let mut websocket_extensions = crate::l7::rest::WebSocketExtensionMode::Preserve; + let mut forward_tunnel_engine: Option = None; + let mut forward_upgrade_config: Option = None; + let mut forward_upgrade_target = String::new(); + let mut forward_upgrade_query_params = std::collections::HashMap::new(); + let mut forward_websocket_request = + crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); + let mut request_body_credential_rewrite = false; + let l7_ctx = crate::l7::relay::L7EvalContext { + host: host_lc.clone(), + port, + policy_name: matched_policy.clone().unwrap_or_default(), + binary_path: decision + .binary + .as_ref() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_default(), + ancestors: decision + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + cmdline_paths: decision + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + secret_resolver: secret_resolver.clone(), + }; // 4b. If the endpoint has L7 config, evaluate the request against // L7 policy. The forward proxy handles exactly one request per @@ -2670,28 +2729,6 @@ async fn handle_forward_proxy( } }; - let l7_ctx = crate::l7::relay::L7EvalContext { - host: host_lc.clone(), - port, - policy_name: matched_policy.clone().unwrap_or_default(), - binary_path: decision - .binary - .as_ref() - .map(|p| p.to_string_lossy().into_owned()) - .unwrap_or_default(), - ancestors: decision - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - cmdline_paths: decision - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - secret_resolver: secret_resolver.clone(), - }; - // Canonicalize the request-target. The canonical form is fed to OPA // AND reassigned to the outer `path` variable so the later call to // `rewrite_forward_request` writes canonical bytes to the upstream. @@ -2760,6 +2797,14 @@ async fn handle_forward_proxy( .await?; return Ok(()); }; + forward_websocket_request = + crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); + websocket_extensions = crate::l7::relay::websocket_extension_mode(&l7_config.config); + request_body_credential_rewrite = l7_config.config.protocol == crate::l7::L7Protocol::Rest + && l7_config.config.request_body_credential_rewrite; + forward_upgrade_config = Some(l7_config.config.clone()); + forward_upgrade_target = path.clone(); + forward_upgrade_query_params = query_params.clone(); let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { let header_end = forward_request_bytes .windows(4) @@ -2920,6 +2965,7 @@ async fn handle_forward_proxy( .await?; return Ok(()); } + forward_tunnel_engine = Some(tunnel_engine); } // 5. DNS resolution + SSRF defence (mirrors the CONNECT path logic). @@ -3180,6 +3226,7 @@ async fn handle_forward_proxy( forward_request_bytes.len(), &upstream_target, secret_resolver.as_deref(), + request_body_credential_rewrite, ) { Ok(bytes) => bytes, Err(e) => { @@ -3222,11 +3269,47 @@ async fn handle_forward_proxy( rewritten, client, &mut upstream, - &forward_generation_guard, + ForwardRelayOptions { + generation_guard: &forward_generation_guard, + websocket_extensions, + secret_resolver: secret_resolver.as_deref(), + request_body_credential_rewrite, + }, ) .await?; - if let crate::l7::provider::RelayOutcome::Upgraded { overflow } = outcome { - crate::l7::relay::handle_upgrade(client, &mut upstream, overflow, &host_lc, port).await?; + if let crate::l7::provider::RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + { + let mut upgrade_options = if let (Some(config), Some(engine)) = ( + forward_upgrade_config.as_ref(), + forward_tunnel_engine.as_ref(), + ) { + crate::l7::relay::upgrade_options( + config, + &l7_ctx, + forward_websocket_request, + &forward_upgrade_target, + &forward_upgrade_query_params, + Some(engine), + ) + } else { + crate::l7::relay::UpgradeRelayOptions { + websocket_request: forward_websocket_request, + ..Default::default() + } + }; + upgrade_options.websocket.permessage_deflate = websocket_permessage_deflate; + crate::l7::relay::handle_upgrade( + client, + &mut upstream, + overflow, + &host_lc, + port, + upgrade_options, + ) + .await?; } Ok(()) @@ -3298,6 +3381,473 @@ fn is_benign_relay_error(err: &miette::Report) -> bool { mod tests { use super::*; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use std::sync::Arc; + + fn websocket_l7_config( + protocol: crate::l7::L7Protocol, + websocket_credential_rewrite: bool, + ) -> crate::l7::L7EndpointConfig { + crate::l7::L7EndpointConfig { + protocol, + path: "/**".to_string(), + tls: crate::l7::TlsMode::Auto, + enforcement: crate::l7::EnforcementMode::Enforce, + graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + allow_encoded_slash: false, + websocket_credential_rewrite, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + } + } + + fn forward_test_guard() -> PolicyGenerationGuard { + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + engine + .generation_guard(engine.current_generation()) + .unwrap() + } + + async fn relay_forward_request_and_capture( + method: &str, + path: &str, + raw: &[u8], + resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, + ) -> Result { + let guard = forward_test_guard(); + let rewritten = rewrite_forward_request( + raw, + raw.len(), + path, + resolver, + request_body_credential_rewrite, + ) + .map_err(|e| miette::miette!("{e}"))?; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + let mut total = 0usize; + let mut expected_total = None; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if expected_total.is_none() + && let Some(end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") + { + let header_end = end + 4; + let headers = String::from_utf8_lossy(&buf[..header_end]); + let len = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + expected_total = Some(header_end + len); + } + if expected_total.is_some_and(|expected| total >= expected) { + break; + } + } + upstream_side + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + String::from_utf8_lossy(&buf[..total]).to_string() + }); + + relay_rewritten_forward_request( + method, + path, + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: resolver, + request_body_credential_rewrite, + }, + ) + .await?; + + upstream_task + .await + .map_err(|e| miette::miette!("upstream task failed: {e}")) + } + + fn forward_websocket_policy_parts( + data: &str, + host: &str, + port: u16, + path: &str, + policy_name: &str, + ) -> ( + crate::l7::L7EndpointConfig, + crate::opa::TunnelPolicyEngine, + crate::l7::relay::L7EvalContext, + ) { + let policy = include_str!("../data/sandbox-policy.rego"); + let engine = OpaEngine::from_strings(policy, data).unwrap(); + let decision = ConnectDecision { + action: NetworkAction::Allow { + matched_policy: Some(policy_name.to_string()), + }, + generation: engine.current_generation(), + binary: Some(PathBuf::from("/usr/bin/node")), + binary_pid: None, + ancestors: vec![], + cmdline_paths: vec![], + }; + let route = + query_l7_route_snapshot(&engine, &decision, host, port).expect("L7 route should match"); + let config = select_l7_config_for_path(&route.configs, path) + .expect("path-specific L7 config should match") + .config + .clone(); + let tunnel_engine = engine + .clone_engine_for_tunnel(route.generation) + .expect("tunnel engine"); + let ctx = crate::l7::relay::L7EvalContext { + host: host.to_string(), + port, + policy_name: policy_name.to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + (config, tunnel_engine, ctx) + } + + async fn read_http_headers(reader: &mut R) -> Vec { + let mut bytes = Vec::new(); + let mut chunk = [0u8; 256]; + loop { + let n = + tokio::time::timeout(std::time::Duration::from_secs(1), reader.read(&mut chunk)) + .await + .expect("HTTP headers should arrive") + .expect("header read should succeed"); + assert!(n > 0, "stream closed before HTTP headers"); + bytes.extend_from_slice(&chunk[..n]); + if bytes.windows(4).any(|w| w == b"\r\n\r\n") { + return bytes; + } + } + } + + fn masked_text_frame(payload: &[u8]) -> Vec { + let mask = [0x11, 0x22, 0x33, 0x44]; + assert!( + payload.len() <= 125, + "test helper only supports small frames" + ); + let payload_len = u8::try_from(payload.len()).expect("small frame length"); + let mut frame = vec![0x81, 0x80 | payload_len]; + frame.extend_from_slice(&mask); + frame.extend( + payload + .iter() + .enumerate() + .map(|(idx, byte)| byte ^ mask[idx % 4]), + ); + frame + } + + async fn forward_websocket_denied_after_upgrade( + config: crate::l7::L7EndpointConfig, + tunnel_engine: crate::opa::TunnelPolicyEngine, + ctx: crate::l7::relay::L7EvalContext, + path: &str, + payload: &str, + ) -> (miette::Report, Vec) { + let host = ctx.host.clone(); + let port = ctx.port; + let raw = format!( + "GET http://{host}{path} HTTP/1.1\r\n\ + Host: {host}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n\r\n" + ); + let rewritten = rewrite_forward_request(raw.as_bytes(), raw.len(), path, None, false) + .expect("forward websocket request should rewrite to origin form"); + let websocket_extensions = crate::l7::relay::websocket_extension_mode(&config); + let target = path.to_string(); + let query_params = std::collections::HashMap::new(); + let (mut proxy_to_upstream, mut upstream) = tokio::io::duplex(8192); + let (mut app, mut proxy_to_client) = tokio::io::duplex(8192); + + let relay = tokio::spawn(async move { + let guard = tunnel_engine.generation_guard(); + let outcome = relay_rewritten_forward_request( + "GET", + &target, + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: guard, + websocket_extensions, + secret_resolver: None, + request_body_credential_rewrite: false, + }, + ) + .await?; + if let crate::l7::provider::RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + { + let mut options = crate::l7::relay::upgrade_options( + &config, + &ctx, + true, + &target, + &query_params, + Some(&tunnel_engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + crate::l7::relay::handle_upgrade( + &mut proxy_to_client, + &mut proxy_to_upstream, + overflow, + &host, + port, + options, + ) + .await?; + } + Ok::<(), miette::Report>(()) + }); + + let forwarded_headers = read_http_headers(&mut upstream).await; + let forwarded_headers = String::from_utf8_lossy(&forwarded_headers); + assert!(forwarded_headers.starts_with(&format!("GET {path} HTTP/1.1\r\n"))); + assert!(forwarded_headers.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let response = read_http_headers(&mut app).await; + assert!(String::from_utf8_lossy(&response).contains("101 Switching Protocols")); + + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("websocket relay should fail closed after denied frame") + .expect("relay task should not panic") + .expect_err("denied websocket frame should fail the forward relay"); + + let mut leaked = Vec::new(); + tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read_to_end(&mut leaked), + ) + .await + .expect("upstream side should close") + .expect("upstream read should succeed"); + (err, leaked) + } + + #[test] + fn forward_websocket_upgrade_options_enable_native_policy_context() { + let (_, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "discord-real".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.map(Arc::new); + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let ctx = crate::l7::relay::L7EvalContext { + host: "gateway.example.test".to_string(), + port: 80, + policy_name: "ws_api".to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver, + }; + let query_params = std::collections::HashMap::new(); + + let extensions = crate::l7::relay::websocket_extension_mode(&websocket_l7_config( + crate::l7::L7Protocol::Websocket, + true, + )); + let options = crate::l7::relay::upgrade_options( + &websocket_l7_config(crate::l7::L7Protocol::Websocket, true), + &ctx, + true, + "/ws", + &query_params, + Some(&tunnel_engine), + ); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::PermessageDeflate + ); + assert!(options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_some()); + assert!(options.engine.is_some()); + assert!(options.ctx.is_some()); + assert!(matches!( + options.websocket.message_policy, + crate::l7::relay::WebSocketMessagePolicy::Transport + )); + } + + #[test] + fn forward_websocket_upgrade_options_preserve_rest_without_rewrite() { + let ctx = crate::l7::relay::L7EvalContext { + host: "gateway.example.test".to_string(), + port: 80, + policy_name: "rest_api".to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let query_params = std::collections::HashMap::new(); + let config = websocket_l7_config(crate::l7::L7Protocol::Rest, false); + let extensions = crate::l7::relay::websocket_extension_mode(&config); + let options = + crate::l7::relay::upgrade_options(&config, &ctx, true, "/ws", &query_params, None); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::Preserve + ); + assert!(!options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_none()); + assert!(options.engine.is_none()); + assert!(options.ctx.is_none()); + assert!(matches!( + options.websocket.message_policy, + crate::l7::relay::WebSocketMessagePolicy::None + )); + } + + #[tokio::test] + async fn forward_websocket_upgrade_blocks_text_frame_by_policy() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 80 + path: "/ws" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + - allow: + method: WEBSOCKET_TEXT + path: "/ws" + deny_rules: + - method: WEBSOCKET_TEXT + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let (config, tunnel_engine, ctx) = + forward_websocket_policy_parts(data, "gateway.example.test", 80, "/ws", "ws_api"); + + let (err, leaked) = forward_websocket_denied_after_upgrade( + config, + tunnel_engine, + ctx, + "/ws", + r#"{"type":"unsafe"}"#, + ) + .await; + + assert!(err.to_string().contains("websocket text message denied")); + assert!( + leaked.is_empty(), + "denied forward-proxy WebSocket text frames must not reach upstream" + ); + } + + #[tokio::test] + async fn forward_graphql_websocket_upgrade_blocks_unallowed_operation() { + let data = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: gateway.example.test + port: 80 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + deny_rules: + - operation_type: query + fields: [admin] + binaries: + - { path: /usr/bin/node } +"#; + let (config, tunnel_engine, ctx) = forward_websocket_policy_parts( + data, + "gateway.example.test", + 80, + "/graphql", + "graphql_ws", + ); + assert!( + config.websocket_graphql_policy, + "operation rules should enable GraphQL-over-WebSocket inspection" + ); + + let (err, leaked) = forward_websocket_denied_after_upgrade( + config, + tunnel_engine, + ctx, + "/graphql", + r#"{"id":"1","type":"subscribe","payload":{"query":"query { admin }"}}"#, + ) + .await; + + assert!(err.to_string().contains("websocket GraphQL message denied")); + assert!( + leaked.is_empty(), + "denied forward-proxy GraphQL WebSocket operations must not reach upstream" + ); + } #[test] fn l7_route_selection_prefers_path_specific_graphql_endpoint() { @@ -3310,6 +3860,9 @@ mod tests { enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, }, }, L7ConfigSnapshot { @@ -3320,6 +3873,9 @@ mod tests { enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, }, }, ]; @@ -4246,7 +4802,8 @@ mod tests { fn test_rewrite_get_request() { let raw = b"GET http://10.0.0.1:8000/api HTTP/1.1\r\nHost: 10.0.0.1:8000\r\nAccept: */*\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/api", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.starts_with("GET /api HTTP/1.1\r\n")); assert!(result_str.contains("Host: 10.0.0.1:8000")); @@ -4257,7 +4814,8 @@ mod tests { #[test] fn test_rewrite_strips_proxy_headers() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nProxy-Authorization: Basic abc\r\nProxy-Connection: keep-alive\r\nAccept: */*\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!( !result_str @@ -4271,7 +4829,8 @@ mod tests { #[test] fn test_rewrite_replaces_connection_header() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nConnection: keep-alive\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Connection: close")); assert!(!result_str.contains("keep-alive")); @@ -4280,7 +4839,8 @@ mod tests { #[test] fn test_rewrite_preserves_body_overflow() { let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 13\r\n\r\n{\"key\":\"val\"}"; - let result = rewrite_forward_request(raw, raw.len(), "/api", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("{\"key\":\"val\"}")); assert!(result_str.contains("POST /api HTTP/1.1")); @@ -4289,7 +4849,8 @@ mod tests { #[test] fn test_rewrite_preserves_existing_via() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nVia: 1.0 upstream\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Via: 1.0 upstream")); // Should not add a second Via header @@ -4312,7 +4873,7 @@ mod tests { .expect("canonicalization should succeed for the attack payload"); assert_eq!(canon.path, "/secret"); - let rewritten = rewrite_forward_request(raw, raw.len(), &canon.path, None) + let rewritten = rewrite_forward_request(raw, raw.len(), &canon.path, None, false) .expect("rewrite_forward_request should succeed"); let rewritten_str = String::from_utf8_lossy(&rewritten); assert!( @@ -4338,7 +4899,7 @@ mod tests { _ => canon.path, }; - let rewritten = rewrite_forward_request(raw, raw.len(), &upstream_target, None) + let rewritten = rewrite_forward_request(raw, raw.len(), &upstream_target, None, false) .expect("rewrite_forward_request should succeed"); let rewritten_str = String::from_utf8_lossy(&rewritten); assert!( @@ -4357,13 +4918,169 @@ mod tests { .collect(), ); let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nAuthorization: Bearer openshell:resolve:env:ANTHROPIC_API_KEY\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref()) + let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref(), false) .expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Authorization: Bearer sk-test")); assert!(!result_str.contains("openshell:resolve:env:ANTHROPIC_API_KEY")); } + #[tokio::test] + async fn forward_relay_rewrites_urlencoded_body_alias_from_initial_read() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = format!("token={alias}&channel=C123"); + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_forward_request_and_capture( + "POST", + "/api/messages", + raw.as_bytes(), + Some(&resolver), + true, + ) + .await + .expect("forward relay should rewrite credentials"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.starts_with("POST /api/messages HTTP/1.1\r\n")); + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); + } + + #[tokio::test] + async fn forward_relay_rewrites_urlencoded_canonical_body_from_initial_read() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = "token=openshell%3Aresolve%3Aenv%3AAPI_TOKEN&channel=C123"; + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_forward_request_and_capture( + "POST", + "/api/messages", + raw.as_bytes(), + Some(&resolver), + true, + ) + .await + .expect("forward relay should rewrite credentials"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("openshell%3Aresolve%3Aenv%3AAPI_TOKEN")); + assert!(!forwarded.contains("openshell:resolve:env:API_TOKEN")); + } + + #[tokio::test] + async fn forward_relay_unresolved_body_placeholder_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = "token=provider-OPENSHELL-RESOLVE-ENV-MISSING_TOKEN"; + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let guard = forward_test_guard(); + let rewritten = rewrite_forward_request( + raw.as_bytes(), + raw.len(), + "/api/messages", + Some(&resolver), + true, + ) + .expect("header rewrite should defer body overflow to body rewriter"); + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_rewritten_forward_request( + "POST", + "/api/messages", + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: Some(&resolver), + request_body_credential_rewrite: true, + }, + ) + .await + .expect_err("unresolved body placeholder should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + assert!(!err.to_string().contains("MISSING_TOKEN")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed forward body rewrite must not reach upstream" + ); + } + + #[test] + fn test_forward_rewrite_preserves_websocket_upgrade_connection_header() { + let raw = "GET http://gateway.example.test/ws HTTP/1.1\r\n\ + Host: gateway.example.test\r\n\ + Upgrade: websocket\r\n\ + Connection: keep-alive, Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\ + Sec-WebSocket-Version: 13\r\n\r\n"; + + let result = rewrite_forward_request(raw.as_bytes(), raw.len(), "/ws", None, false) + .expect("websocket forward rewrite should succeed"); + let result_str = String::from_utf8_lossy(&result); + + assert!(result_str.starts_with("GET /ws HTTP/1.1\r\n")); + assert!(result_str.contains("Connection: keep-alive, Upgrade\r\n")); + assert!( + !result_str.contains("Connection: close\r\n"), + "websocket forward proxy must not strip the upgrade token" + ); + } + #[tokio::test] async fn test_forward_relay_guard_blocks_stale_generation_before_upstream_write() { let policy = include_str!("../data/sandbox-policy.rego"); @@ -4375,8 +5092,8 @@ mod tests { engine.reload(policy, policy_data).unwrap(); let raw = b"GET http://host/api HTTP/1.1\r\nHost: host\r\n\r\n"; - let rewritten = - rewrite_forward_request(raw, raw.len(), "/api", None).expect("rewrite should succeed"); + let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) + .expect("rewrite should succeed"); let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); @@ -4386,7 +5103,12 @@ mod tests { rewritten, &mut proxy_to_client, &mut proxy_to_upstream, - &guard, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, ) .await; assert!( @@ -4413,8 +5135,8 @@ mod tests { .unwrap(); let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 4\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"; - let rewritten = - rewrite_forward_request(raw, raw.len(), "/api", None).expect("rewrite should succeed"); + let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) + .expect("rewrite should succeed"); let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); @@ -4424,7 +5146,12 @@ mod tests { rewritten, &mut proxy_to_client, &mut proxy_to_upstream, - &guard, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, ) .await; assert!(result.is_err(), "forward relay must reject CL/TE ambiguity"); diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index d645e1482..6dbd34dcb 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -6,9 +6,11 @@ use std::collections::HashMap; use std::fmt; const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; +const PROVIDER_ALIAS_MARKER: &str = "OPENSHELL-RESOLVE-ENV-"; /// Public access to the placeholder prefix for fail-closed scanning in other modules. pub const PLACEHOLDER_PREFIX_PUBLIC: &str = PLACEHOLDER_PREFIX; +pub const PROVIDER_ALIAS_MARKER_PUBLIC: &str = PROVIDER_ALIAS_MARKER; /// Characters that are valid in an env var key name (used to extract /// placeholder boundaries within concatenated strings like path segments). @@ -16,6 +18,22 @@ fn is_env_key_char(b: u8) -> bool { b.is_ascii_alphanumeric() || b == b'_' } +fn is_alias_token_char(b: u8) -> bool { + b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b'~') +} + +fn contains_raw_reserved_marker(value: &str) -> bool { + value.contains(PLACEHOLDER_PREFIX) || value.contains(PROVIDER_ALIAS_MARKER) +} + +pub fn contains_reserved_credential_marker(value: &str) -> bool { + if contains_raw_reserved_marker(value) { + return true; + } + let decoded = percent_decode(value); + contains_raw_reserved_marker(&decoded) +} + // --------------------------------------------------------------------------- // Error and result types // --------------------------------------------------------------------------- @@ -31,7 +49,7 @@ impl fmt::Display for UnresolvedPlaceholderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "unresolved credential placeholder in {}: detected openshell:resolve:env:* token that could not be resolved", + "unresolved credential placeholder in {}: detected reserved credential token that could not be resolved", self.location ) } @@ -80,6 +98,38 @@ impl SecretResolver { pub(crate) fn from_provider_env_for_revision( provider_env: HashMap, revision: u64, + ) -> (HashMap, Option) { + Self::from_provider_env_for_revision_with_current_aliases(provider_env, revision, false) + } + + pub(crate) fn from_provider_env_for_current_revision( + provider_env: HashMap, + revision: u64, + ) -> (HashMap, Option, Option) { + if revision == 0 { + let (child_env, current_resolver) = + Self::from_provider_env_for_revision_with_current_aliases(provider_env, 0, true); + return (child_env, None, current_resolver); + } + let provider_env_for_current = provider_env.clone(); + let (child_env, revision_resolver) = + Self::from_provider_env_for_revision_with_current_aliases( + provider_env, + revision, + false, + ); + let (_, current_resolver) = Self::from_provider_env_for_revision_with_current_aliases( + provider_env_for_current, + revision, + true, + ); + (child_env, revision_resolver, current_resolver) + } + + fn from_provider_env_for_revision_with_current_aliases( + provider_env: HashMap, + revision: u64, + include_current_aliases: bool, ) -> (HashMap, Option) { if provider_env.is_empty() { return (HashMap::new(), None); @@ -90,8 +140,11 @@ impl SecretResolver { for (key, value) in provider_env { let placeholder = placeholder_for_env_key_for_revision(&key, revision); - child_env.insert(key, placeholder.clone()); - by_placeholder.insert(placeholder, value); + child_env.insert(key.clone(), placeholder.clone()); + by_placeholder.insert(placeholder, value.clone()); + if include_current_aliases && revision != 0 { + by_placeholder.insert(placeholder_for_env_key(&key), value.clone()); + } } (child_env, Some(Self { by_placeholder })) @@ -114,7 +167,13 @@ impl SecretResolver { /// Returns `None` if the placeholder is unknown or the resolved value /// contains prohibited control characters (CRLF, null byte). pub(crate) fn resolve_placeholder(&self, value: &str) -> Option<&str> { - let secret = self.by_placeholder.get(value).map(String::as_str)?; + let secret = if let Some(secret) = self.by_placeholder.get(value) { + secret.as_str() + } else { + let key = alias_env_key(value)?; + let canonical = placeholder_for_env_key(key); + self.by_placeholder.get(&canonical).map(String::as_str)? + }; match validate_resolved_secret(secret) { Ok(s) => Some(s), Err(reason) => { @@ -128,10 +187,13 @@ impl SecretResolver { } } - pub(crate) fn rewrite_header_value(&self, value: &str) -> Option { + pub(crate) fn rewrite_header_value( + &self, + value: &str, + ) -> Result, UnresolvedPlaceholderError> { // Direct placeholder match: `x-api-key: openshell:resolve:env:KEY` if let Some(secret) = self.resolve_placeholder(value.trim()) { - return Some(secret.to_string()); + return Ok(Some(secret.to_string())); } let trimmed = value.trim(); @@ -142,56 +204,228 @@ impl SecretResolver { .strip_prefix("Basic ") .or_else(|| trimmed.strip_prefix("basic ")) .map(str::trim) - && let Some(rewritten) = self.rewrite_basic_auth_token(encoded) + && let Some(rewritten) = self.rewrite_basic_auth_token(encoded)? { - return Some(format!("Basic {rewritten}")); + return Ok(Some(format!("Basic {rewritten}"))); } // Prefixed placeholder: `Bearer openshell:resolve:env:KEY` - let split_at = trimmed.find(char::is_whitespace)?; + let Some(split_at) = trimmed.find(char::is_whitespace) else { + if contains_reserved_credential_marker(trimmed) { + return Err(UnresolvedPlaceholderError { location: "header" }); + } + return Ok(None); + }; let prefix = &trimmed[..split_at]; let candidate = trimmed[split_at..].trim(); - let secret = self.resolve_placeholder(candidate)?; - Some(format!("{prefix} {secret}")) + if let Some(secret) = self.resolve_placeholder(candidate) { + return Ok(Some(format!("{prefix} {secret}"))); + } + + if contains_reserved_credential_marker(candidate) { + return Err(UnresolvedPlaceholderError { location: "header" }); + } + + Ok(None) + } + + pub(crate) fn rewrite_text_placeholders( + &self, + text: &mut String, + location: &'static str, + ) -> Result { + if !contains_raw_reserved_marker(text) { + return Ok(0); + } + + let mut rewritten = String::with_capacity(text.len()); + let mut pos = 0; + let mut replacements = 0; + + while pos < text.len() { + let next_canonical = text[pos..].find(PLACEHOLDER_PREFIX).map(|p| pos + p); + let next_alias = text[pos..].find(PROVIDER_ALIAS_MARKER).map(|marker_pos| { + let marker_abs = pos + marker_pos; + alias_start_for_marker(text, marker_abs) + }); + let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() else { + rewritten.push_str(&text[pos..]); + break; + }; + + rewritten.push_str(&text[pos..abs_start]); + + if text[abs_start..].starts_with(PLACEHOLDER_PREFIX) { + let Some((token_end, token)) = self.credential_token_at(text, abs_start) else { + return Err(UnresolvedPlaceholderError { location }); + }; + let Some(secret) = self.resolve_placeholder(token) else { + return Err(UnresolvedPlaceholderError { location }); + }; + rewritten.push_str(secret); + replacements += 1; + pos = token_end; + continue; + } + + if let Some((token_end, token)) = alias_token_at(text, abs_start) { + let Some(secret) = self.resolve_placeholder(token) else { + return Err(UnresolvedPlaceholderError { location }); + }; + rewritten.push_str(secret); + replacements += 1; + pos = token_end; + continue; + } + + return Err(UnresolvedPlaceholderError { location }); + } + + if contains_raw_reserved_marker(&rewritten) { + return Err(UnresolvedPlaceholderError { location }); + } + + *text = rewritten; + Ok(replacements) + } + + /// Rewrite credential placeholders inside a WebSocket text message. + /// + /// The message is mutated only after all placeholders resolve + /// successfully. The return value is the number of replacements; callers + /// must not log the rewritten text. + pub(crate) fn rewrite_websocket_text_placeholders( + &self, + text: &mut String, + ) -> Result { + self.rewrite_text_placeholders(text, "websocket") + } + + fn credential_token_at<'a>( + &'a self, + text: &'a str, + abs_start: usize, + ) -> Option<(usize, &'a str)> { + self.longest_known_token_match(text, abs_start) + .or_else(|| canonical_token_at(text, abs_start)) + .or_else(|| alias_token_at(text, abs_start)) + } + + fn longest_known_token_match<'a>( + &'a self, + text: &str, + abs_start: usize, + ) -> Option<(usize, &'a str)> { + let suffix = &text[abs_start..]; + self.by_placeholder + .keys() + .filter_map(|placeholder| { + if !suffix.starts_with(placeholder) { + return None; + } + let key_end = abs_start + placeholder.len(); + let boundary_ok = token_boundary_ok(text, abs_start, key_end, placeholder); + boundary_ok.then_some((key_end, placeholder.as_str())) + }) + .max_by_key(|(_, placeholder)| placeholder.len()) } /// Decode a Base64-encoded Basic auth token, resolve any placeholders in /// the decoded `username:password` string, and re-encode. /// /// Returns `None` if decoding fails or no placeholders are found. - fn rewrite_basic_auth_token(&self, encoded: &str) -> Option { + fn rewrite_basic_auth_token( + &self, + encoded: &str, + ) -> Result, UnresolvedPlaceholderError> { let b64 = base64::engine::general_purpose::STANDARD; - let decoded_bytes = b64.decode(encoded.trim()).ok()?; - let decoded = std::str::from_utf8(&decoded_bytes).ok()?; - - // Check if the decoded string contains any placeholder - if !decoded.contains(PLACEHOLDER_PREFIX) { - return None; + let Some(decoded_bytes) = b64.decode(encoded.trim()).ok() else { + return Ok(None); + }; + let Some(decoded) = std::str::from_utf8(&decoded_bytes).ok() else { + return Ok(None); + }; + + if !contains_raw_reserved_marker(decoded) { + return Ok(None); } - // Rewrite all placeholder occurrences in the decoded string let mut rewritten = decoded.to_string(); - for (placeholder, secret) in &self.by_placeholder { - if rewritten.contains(placeholder.as_str()) { - // Validate the resolved secret for control characters - if validate_resolved_secret(secret).is_err() { - tracing::warn!( - location = "basic_auth", - "credential resolution rejected: resolved value contains prohibited characters" - ); - return None; - } - rewritten = rewritten.replace(placeholder.as_str(), secret); - } - } + let replacements = self.rewrite_text_placeholders(&mut rewritten, "header")?; - // Only return if we actually changed something - if rewritten == decoded { - return None; + if replacements == 0 { + return Ok(None); } - Some(b64.encode(rewritten.as_bytes())) + Ok(Some(b64.encode(rewritten.as_bytes()))) + } +} + +fn alias_start_for_marker(text: &str, marker_abs: usize) -> usize { + let mut start = marker_abs; + let bytes = text.as_bytes(); + while start > 0 && is_alias_token_char(bytes[start - 1]) { + start -= 1; } + start +} + +fn canonical_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { + if !text[abs_start..].starts_with(PLACEHOLDER_PREFIX) { + return None; + } + let key_start = abs_start + PLACEHOLDER_PREFIX.len(); + let key_end = text[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(text.len(), |p| key_start + p); + (key_end > key_start).then_some((key_end, &text[abs_start..key_end])) +} + +fn alias_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { + let suffix = &text[abs_start..]; + let marker_rel = suffix.find(PROVIDER_ALIAS_MARKER)?; + if marker_rel == 0 { + return None; + } + let key_start = abs_start + marker_rel + PROVIDER_ALIAS_MARKER.len(); + let key_end = text[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(text.len(), |p| key_start + p); + if key_end == key_start { + return None; + } + let before_ok = abs_start == 0 || !is_alias_token_char(text.as_bytes()[abs_start - 1]); + let after_ok = key_end == text.len() || !is_alias_token_char(text.as_bytes()[key_end]); + (before_ok && after_ok).then_some((key_end, &text[abs_start..key_end])) +} + +fn alias_env_key(token: &str) -> Option<&str> { + let marker_start = token.find(PROVIDER_ALIAS_MARKER)?; + if marker_start == 0 { + return None; + } + if !token[..marker_start].bytes().all(is_alias_token_char) { + return None; + } + let key_start = marker_start + PROVIDER_ALIAS_MARKER.len(); + let key_end = token[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(token.len(), |p| key_start + p); + (key_end == token.len() && key_end > key_start).then_some(&token[key_start..key_end]) +} + +fn token_boundary_ok(text: &str, abs_start: usize, token_end: usize, token: &str) -> bool { + if token.starts_with(PLACEHOLDER_PREFIX) { + return token_end == text.len() + || !is_env_key_char(text.as_bytes()[token_end]) + || text[token_end..].starts_with(PLACEHOLDER_PREFIX); + } + let before_ok = abs_start == 0 || !is_alias_token_char(text.as_bytes()[abs_start - 1]); + let after_ok = token_end == text.len() || !is_alias_token_char(text.as_bytes()[token_end]); + before_ok && after_ok } pub fn placeholder_for_env_key(key: &str) -> String { @@ -387,8 +621,9 @@ fn rewrite_request_line( return unchanged(); }; - // Only rewrite if the URI contains a placeholder - if !uri.contains(PLACEHOLDER_PREFIX) { + // Only rewrite if the URI contains a placeholder or a provider-shaped + // credential alias, including percent-encoded canonical placeholders. + if !contains_reserved_credential_marker(uri) { return unchanged(); } @@ -444,10 +679,6 @@ fn rewrite_uri_path( path: &str, resolver: &SecretResolver, ) -> Result, UnresolvedPlaceholderError> { - if !path.contains(PLACEHOLDER_PREFIX) { - return Ok(None); - } - let segments: Vec<&str> = path.split('/').collect(); let mut resolved_segments = Vec::with_capacity(segments.len()); let mut redacted_segments = Vec::with_capacity(segments.len()); @@ -455,7 +686,7 @@ fn rewrite_uri_path( for segment in &segments { let decoded = percent_decode(segment); - if !decoded.contains(PLACEHOLDER_PREFIX) { + if !contains_raw_reserved_marker(&decoded) { resolved_segments.push(segment.to_string()); redacted_segments.push(segment.to_string()); continue; @@ -495,28 +726,23 @@ fn rewrite_path_segment( let bytes = segment.as_bytes(); while pos < bytes.len() { - if let Some(start) = segment[pos..].find(PLACEHOLDER_PREFIX) { - let abs_start = pos + start; + let next_canonical = segment[pos..].find(PLACEHOLDER_PREFIX).map(|p| pos + p); + let next_alias = segment[pos..] + .find(PROVIDER_ALIAS_MARKER) + .map(|marker_pos| { + let marker_abs = pos + marker_pos; + alias_start_for_marker(segment, marker_abs) + }); + if let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() { // Copy literal prefix before the placeholder resolved.push_str(&segment[pos..abs_start]); redacted.push_str(&segment[pos..abs_start]); - // Extract the key name using the env var grammar: [A-Za-z_][A-Za-z0-9_]* - let key_start = abs_start + PLACEHOLDER_PREFIX.len(); - let key_end = segment[key_start..] - .bytes() - .position(|b| !is_env_key_char(b)) - .map_or(segment.len(), |p| key_start + p); - - if key_end == key_start { - // Empty key — not a valid placeholder, copy literally - resolved.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); - redacted.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); - pos = abs_start + PLACEHOLDER_PREFIX.len(); - continue; - } - - let full_placeholder = &segment[abs_start..key_end]; + let Some((token_end, full_placeholder)) = canonical_token_at(segment, abs_start) + .or_else(|| alias_token_at(segment, abs_start)) + else { + return Err(UnresolvedPlaceholderError { location: "path" }); + }; if let Some(secret) = resolver.resolve_placeholder(full_placeholder) { validate_credential_for_path(secret).map_err(|reason| { tracing::warn!( @@ -531,7 +757,7 @@ fn rewrite_path_segment( } else { return Err(UnresolvedPlaceholderError { location: "path" }); } - pos = key_end; + pos = token_end; } else { // No more placeholders in remainder resolved.push_str(&segment[pos..]); @@ -550,7 +776,7 @@ fn rewrite_uri_query_params( query: &str, resolver: &SecretResolver, ) -> Result, UnresolvedPlaceholderError> { - if !query.contains(PLACEHOLDER_PREFIX) { + if !contains_reserved_credential_marker(query) { return Ok(None); } @@ -561,15 +787,18 @@ fn rewrite_uri_query_params( for param in query.split('&') { if let Some((key, value)) = param.split_once('=') { let decoded_value = percent_decode(value); - if let Some(secret) = resolver.resolve_placeholder(&decoded_value) { - resolved_params.push(format!("{key}={}", percent_encode_query(secret))); + if contains_raw_reserved_marker(&decoded_value) { + let mut rewritten = decoded_value.clone(); + let replacements = + resolver.rewrite_text_placeholders(&mut rewritten, "query_param")?; + if replacements == 0 || contains_raw_reserved_marker(&rewritten) { + return Err(UnresolvedPlaceholderError { + location: "query_param", + }); + } + resolved_params.push(format!("{key}={}", percent_encode_query(&rewritten))); redacted_params.push(format!("{key}=[CREDENTIAL]")); any_rewritten = true; - } else if decoded_value.contains(PLACEHOLDER_PREFIX) { - // Placeholder detected but not resolved - return Err(UnresolvedPlaceholderError { - location: "query_param", - }); } else { resolved_params.push(param.to_string()); redacted_params.push(param.to_string()); @@ -639,41 +868,42 @@ pub fn rewrite_http_header_block( break; } - output.extend_from_slice(rewrite_header_line(line, resolver).as_bytes()); + output.extend_from_slice(rewrite_header_line_checked(line, resolver)?.as_bytes()); output.extend_from_slice(b"\r\n"); } output.extend_from_slice(b"\r\n"); output.extend_from_slice(&raw[header_end..]); - // Fail-closed scan: check for any remaining unresolved placeholders - // in both raw form and percent-decoded form of the output header block. + // Fail-closed scan: check for any remaining unresolved placeholders or + // provider-shaped aliases in both raw and percent-decoded header bytes. let output_header = String::from_utf8_lossy(&output[..output.len().min(header_end + 256)]); - if output_header.contains(PLACEHOLDER_PREFIX) { + if contains_reserved_credential_marker(&output_header) { return Err(UnresolvedPlaceholderError { location: "header" }); } - // Also check percent-decoded form of the request line (F5 — encoded placeholder bypass) - let rewritten_rl = output_header.split("\r\n").next().unwrap_or(""); - let decoded_rl = percent_decode(rewritten_rl); - if decoded_rl.contains(PLACEHOLDER_PREFIX) { - return Err(UnresolvedPlaceholderError { location: "path" }); - } - Ok(RewriteResult { rewritten: output, redacted_target: rl_result.redacted_target, }) } +#[cfg_attr(not(test), allow(dead_code))] pub fn rewrite_header_line(line: &str, resolver: &SecretResolver) -> String { + rewrite_header_line_checked(line, resolver).unwrap_or_else(|_| line.to_string()) +} + +pub fn rewrite_header_line_checked( + line: &str, + resolver: &SecretResolver, +) -> Result { let Some((name, value)) = line.split_once(':') else { - return line.to_string(); + return Ok(line.to_string()); }; - resolver.rewrite_header_value(value.trim()).map_or_else( - || line.to_string(), - |rewritten| format!("{name}: {rewritten}"), + resolver.rewrite_header_value(value.trim())?.map_or_else( + || Ok(line.to_string()), + |rewritten| Ok(format!("{name}: {rewritten}")), ) } @@ -688,12 +918,7 @@ pub fn rewrite_target_for_eval( target: &str, resolver: &SecretResolver, ) -> Result { - if !target.contains(PLACEHOLDER_PREFIX) { - // Also check percent-decoded form - let decoded = percent_decode(target); - if decoded.contains(PLACEHOLDER_PREFIX) { - return Err(UnresolvedPlaceholderError { location: "path" }); - } + if !contains_reserved_credential_marker(target) { return Ok(RewriteTargetResult { resolved: target.to_string(), redacted: target.to_string(), @@ -800,6 +1025,50 @@ mod tests { ); } + #[test] + fn rewrites_provider_shaped_alias_header_values() { + let (_, resolver) = SecretResolver::from_provider_env( + [ + ("API_TOKEN".to_string(), "provider-real-token".to_string()), + ("CHAT_APP_TOKEN".to_string(), "app-real-token".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + assert_eq!( + rewrite_header_line( + "Authorization: Bearer vendor-OPENSHELL-RESOLVE-ENV-API_TOKEN", + &resolver, + ), + "Authorization: Bearer provider-real-token" + ); + assert_eq!( + rewrite_header_line( + "x-app-token: token.v1-OPENSHELL-RESOLVE-ENV-CHAT_APP_TOKEN", + &resolver, + ), + "x-app-token: app-real-token" + ); + } + + #[test] + fn unresolved_provider_shaped_alias_fails_closed() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let raw = b"GET / HTTP/1.1\r\nAuthorization: Bearer vendor-OPENSHELL-RESOLVE-ENV-UNKNOWN_TOKEN\r\n\r\n"; + + let err = rewrite_http_header_block(raw, Some(&resolver)) + .expect_err("unknown alias should fail closed"); + + assert_eq!(err.location, "header"); + } + #[test] fn rewrites_http_header_blocks_and_preserves_body() { let (_, resolver) = SecretResolver::from_provider_env( @@ -1410,6 +1679,29 @@ mod tests { ); } + #[test] + fn percent_encoded_canonical_placeholder_in_query_rewrites() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let encoded = "openshell%3Aresolve%3Aenv%3AAPI_TOKEN"; + let raw = format!("GET /api?token={encoded} HTTP/1.1\r\nHost: x\r\n\r\n"); + + let result = + rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should rewrite"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!(rewritten.starts_with("GET /api?token=provider-real-token HTTP/1.1")); + assert!(!rewritten.contains("openshell")); + assert_eq!( + result.redacted_target.as_deref(), + Some("/api?token=[CREDENTIAL]") + ); + } + #[test] fn all_resolved_succeeds() { let (child_env, resolver) = SecretResolver::from_provider_env( @@ -1444,6 +1736,129 @@ mod tests { assert_eq!(raw.as_slice(), result.rewritten.as_slice()); } + #[test] + fn rewrite_websocket_text_replaces_placeholders_and_returns_count() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string()), + ("APP_ID".to_string(), "app-123".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let token = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let app_id = child_env.get("APP_ID").unwrap(); + let mut payload = + format!(r#"{{"op":2,"d":{{"token":"{token}","properties":{{"app":"{app_id}"}}}}}}"#); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 2); + assert!(payload.contains(r#""token":"real-token""#)); + assert!(payload.contains(r#""app":"app-123""#)); + assert!(!payload.contains(PLACEHOLDER_PREFIX)); + } + + #[test] + fn rewrite_websocket_text_replaces_provider_shaped_alias() { + let (_, resolver) = SecretResolver::from_provider_env( + [("APP_TOKEN".to_string(), "app-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let mut payload = r#"{"token":"provider-OPENSHELL-RESOLVE-ENV-APP_TOKEN"}"#.to_string(); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("alias should rewrite"); + + assert_eq!(count, 1); + assert_eq!(payload, r#"{"token":"app-real-token"}"#); + } + + #[test] + fn rewrite_websocket_text_without_placeholder_is_unchanged() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let mut payload = r#"{"op":1,"d":42}"#.to_string(); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 0); + assert_eq!(payload, r#"{"op":1,"d":42}"#); + } + + #[test] + fn rewrite_websocket_text_unknown_placeholder_fails_closed_without_mutating() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let original = r#"{"token":"openshell:resolve:env:UNKNOWN"}"#.to_string(); + let mut payload = original.clone(); + + let err = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect_err("unknown placeholder should fail"); + + assert_eq!(err.location, "websocket"); + assert_eq!(payload, original); + } + + #[test] + fn rewrite_websocket_text_handles_repeated_adjacent_and_unicode_placeholders() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("TOKEN".to_string(), "tok".to_string()), + ("APP".to_string(), "app".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let token = child_env.get("TOKEN").unwrap(); + let app = child_env.get("APP").unwrap(); + let mut payload = format!("prefix-☃-{token}{app}-{token}-suffix"); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 3); + assert_eq!(payload, "prefix-☃-tokapp-tok-suffix"); + } + + #[test] + fn rewrite_websocket_text_placeholder_like_prefix_fails_without_mutating() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let original = "openshell:resolve:env:-not-a-key".to_string(); + let mut payload = original.clone(); + + let err = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect_err("placeholder-like prefix should fail closed"); + + assert_eq!(err.location, "websocket"); + assert_eq!(payload, original); + } + // === Redaction tests === #[test] diff --git a/crates/openshell-sandbox/tests/websocket_upgrade.rs b/crates/openshell-sandbox/tests/websocket_upgrade.rs index e4cd232ce..b35076a9a 100644 --- a/crates/openshell-sandbox/tests/websocket_upgrade.rs +++ b/crates/openshell-sandbox/tests/websocket_upgrade.rs @@ -124,7 +124,7 @@ async fn websocket_upgrade_through_l7_relay_exchanges_message() { .expect("relay should succeed"); match outcome { - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { overflow, .. } => { // This is what handle_upgrade() does in relay.rs if !overflow.is_empty() { client_proxy.write_all(&overflow).await.unwrap(); diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index d5a47bcba..885dbc9ad 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -216,6 +216,12 @@ fn summarize_endpoint(endpoint: &NetworkEndpoint) -> String { if !endpoint.tls.is_empty() { parts.push(format!("tls={}", endpoint.tls)); } + if endpoint.websocket_credential_rewrite { + parts.push("websocket_credential_rewrite=true".to_string()); + } + if endpoint.request_body_credential_rewrite { + parts.push("request_body_credential_rewrite=true".to_string()); + } if !endpoint.allowed_ips.is_empty() { parts.push(format!("allowed_ips={}", endpoint.allowed_ips.len())); } @@ -4318,6 +4324,62 @@ mod tests { ); } + #[test] + fn summarize_cli_policy_merge_op_formats_websocket_credential_rewrite() { + let operation = PolicyMergeOp::AddRule { + rule_name: "realtime_api".to_string(), + rule: NetworkPolicyRule { + name: "realtime_api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + protocol: "websocket".to_string(), + access: "read-write".to_string(), + enforcement: "enforce".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + }; + + assert_eq!( + summarize_cli_policy_merge_op(&operation), + "add-endpoint realtime_api endpoints=[realtime.example.com:443 protocol=websocket access=read-write enforcement=enforce websocket_credential_rewrite=true] binaries=[/usr/bin/node]" + ); + } + + #[test] + fn summarize_cli_policy_merge_op_formats_request_body_credential_rewrite() { + let operation = PolicyMergeOp::AddRule { + rule_name: "slack_api".to_string(), + rule: NetworkPolicyRule { + name: "slack_api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + protocol: "rest".to_string(), + access: "read-write".to_string(), + enforcement: "enforce".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + }; + + assert_eq!( + summarize_cli_policy_merge_op(&operation), + "add-endpoint slack_api endpoints=[slack.com:443 protocol=rest access=read-write enforcement=enforce request_body_credential_rewrite=true] binaries=[/usr/bin/node]" + ); + } + // ---- merge_chunk_into_policy ---- #[tokio::test] diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index a98e8087c..295f850df 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -155,7 +155,7 @@ Each endpoint defines a reachable destination and optional inspection rules. | `host` | string | Yes | Hostname or IP address. Supports wildcards: `*.example.com` matches any subdomain. | | `port` | integer | Yes | TCP port number. | | `path` | string | No | Optional HTTP path glob used to select between L7 endpoints that share the same host and port. Empty means all paths. Use this when REST and GraphQL live under the same host, such as `/repos/**` and `/graphql`. | -| `protocol` | string | No | Set to `rest` for HTTP method/path inspection or `graphql` for GraphQL operation inspection. Omit for TCP passthrough. | +| `protocol` | string | No | Set to `rest` for HTTP method/path inspection, `websocket` for RFC 6455 upgrade and client text-message inspection, or `graphql` for GraphQL-over-HTTP operation inspection. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket traffic. Omit for TCP passthrough. | | `tls` | string | No | TLS handling mode. The proxy auto-detects TLS by peeking the first bytes of each connection and terminates it for inspected HTTPS traffic, so this field is optional in most cases. Set to `skip` to disable auto-detection for edge cases such as client-certificate mTLS or non-standard protocols. The values `terminate` and `passthrough` are deprecated and log a warning; they are still accepted for backward compatibility but have no effect on behavior. | | `enforcement` | string | No | `enforce` actively blocks disallowed requests. `audit` logs violations but allows traffic through. | | `access` | string | No | Access preset. One of `read-only`, `read-write`, or `full`. Mutually exclusive with `rules`. | @@ -163,19 +163,23 @@ Each endpoint defines a reachable destination and optional inspection rules. | `deny_rules` | list of deny rule objects | No | L7 deny rules that block specific requests even when allowed by `access` or `rules`. Deny rules take precedence over allow rules. | | `allowed_ips` | list of string | No | CIDR or IP allowlist for SSRF override. Entries overlapping loopback (`127.0.0.0/8`), link-local (`169.254.0.0/16`), or unspecified (`0.0.0.0`) are rejected at load time. | | `allow_encoded_slash` | bool | No | When `true`, L7 request parsing preserves `%2F` inside path segments instead of rejecting it. Use this for registries and APIs such as npm scoped packages (`/@scope%2Fname`). Defaults to `false`. | -| `persisted_queries` | string | No | GraphQL hash-only behavior. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | +| `websocket_credential_rewrite` | bool | No | When `true` on a `protocol: rest` or `protocol: websocket` endpoint, OpenShell rewrites credential placeholders in client-to-server WebSocket text messages after an allowed HTTP `101` upgrade. Binary frames are relayed but not rewritten. Defaults to `false`. | +| `request_body_credential_rewrite` | bool | No | When `true` on a `protocol: rest` endpoint, OpenShell rewrites credential placeholders in UTF-8 `application/json`, `application/x-www-form-urlencoded`, and `text/*` request bodies before forwarding upstream. The proxy buffers at most 256 KiB and updates `Content-Length` after rewriting. Defaults to `false`. | +| `persisted_queries` | string | No | GraphQL hash-only behavior for `protocol: graphql` and GraphQL-over-WebSocket operation policy. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | | `graphql_persisted_queries` | map | No | Trusted GraphQL persisted-query registry keyed by hash or saved-query ID. Values contain `operation_type`, optional `operation_name`, and optional root `fields`. | -| `graphql_max_body_bytes` | integer | No | Maximum GraphQL request body bytes buffered for inspection. Defaults to `65536`. | +| `graphql_max_body_bytes` | integer | No | Maximum GraphQL-over-HTTP request body bytes buffered for inspection. Defaults to `65536`. | + +Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token provider-shaped aliases such as `provider-OPENSHELL-RESOLVE-ENV-API_TOKEN` when the referenced environment key exists in the configured provider credentials. #### Access Levels The `access` field accepts one of the following values: -| Value | REST expansion | GraphQL expansion | -|---|---|---| -| `full` | All methods and paths. | All operation types. | -| `read-only` | `GET`, `HEAD`, `OPTIONS`. | `query` operations. | -| `read-write` | `GET`, `HEAD`, `OPTIONS`, `POST`, `PUT`, `PATCH`. | `query` and `mutation` operations. | +| Value | REST expansion | WebSocket expansion | GraphQL expansion | +|---|---|---|---| +| `full` | All methods and paths. | WebSocket upgrade and all inspected client text-message paths. | All operation types. | +| `read-only` | `GET`, `HEAD`, `OPTIONS`. | WebSocket upgrade handshake only. | `query` operations. | +| `read-write` | `GET`, `HEAD`, `OPTIONS`, `POST`, `PUT`, `PATCH`. | WebSocket upgrade handshake and client text messages. | `query` and `mutation` operations. | #### Allow Rule Objects @@ -208,9 +212,31 @@ rules: any: ["v1.*", "v2.*"] ``` -##### GraphQL Allow Rule (`protocol: graphql`) +##### WebSocket Allow Rule (`protocol: websocket`) + +WebSocket allow rules match the RFC 6455 HTTP upgrade by path and match client-to-server text messages on the same upgraded connection with the synthetic `WEBSOCKET_TEXT` method. Binary frames are relayed but are not rewritten. + +| Field | Type | Required | Description | +|---|---|---|---| +| `method` | string | Yes | `GET` allows the upgrade handshake, `WEBSOCKET_TEXT` allows client text messages after upgrade, and `*` matches both inspected actions. | +| `path` | string | Yes | URL path pattern from the original upgrade request. Supports `*` and `**` glob syntax. | +| `query` | map | No | Query parameter matchers from the original upgrade request. Matcher syntax is the same as REST allow rules. | + +Example WebSocket allow rules: + +```yaml showLineNumbers={false} +rules: + - allow: + method: GET + path: /v1/realtime/** + - allow: + method: WEBSOCKET_TEXT + path: /v1/realtime/** +``` + +##### GraphQL Allow Rule (`protocol: graphql` or GraphQL-over-WebSocket) -GraphQL allow rules match parsed GraphQL operations by operation type, optional operation name, and optional root fields. +GraphQL allow rules match parsed GraphQL operations by operation type, optional operation name, and optional root fields. On `protocol: graphql`, they apply to GraphQL-over-HTTP `GET` and `POST` requests. On `protocol: websocket`, include a separate `GET` allow rule for the RFC 6455 upgrade, then use GraphQL allow rules for client operation messages using the `graphql-transport-ws` `subscribe` message type or the legacy `graphql-ws` `start` message type. | Field | Type | Required | Description | |---|---|---|---| @@ -231,6 +257,23 @@ rules: fields: [createIssue] ``` +Example GraphQL-over-WebSocket allow rules: + +```yaml showLineNumbers={false} +rules: + - allow: + method: GET + path: /graphql + - allow: + operation_type: subscription + fields: [messageAdded] + - allow: + operation_type: query + fields: [viewer] +``` + +Do not combine `method`, `path`, or `query` with `operation_type`, `operation_name`, or `fields` inside the same WebSocket rule. When a WebSocket endpoint has GraphQL operation policy, use GraphQL rules for client messages instead of a raw `WEBSOCKET_TEXT` allow rule. + #### Deny Rule Objects Blocks specific operations on endpoints that otherwise have broad access. Deny rules are evaluated after allow rules and take precedence: if a request matches any deny rule, it is blocked regardless of what the allow rules or access preset permit. @@ -263,9 +306,33 @@ endpoints: path: "/repos/*/rulesets" ``` -##### GraphQL Deny Rule (`protocol: graphql`) +##### WebSocket Deny Rule (`protocol: websocket`) + +WebSocket deny rules use the same field names as WebSocket allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. + +| Field | Type | Required | Description | +|---|---|---|---| +| `method` | string | Yes | `GET` denies matching upgrade handshakes, `WEBSOCKET_TEXT` denies matching client text messages after upgrade, and `*` matches both inspected actions. | +| `path` | string | Yes | URL path pattern from the original upgrade request. Same glob syntax as allow rules. | +| `query` | map | No | Query parameter matchers from the original upgrade request. Same syntax as allow rule `query`. | + +Example WebSocket deny rules: + +```yaml showLineNumbers={false} +endpoints: + - host: realtime.example.com + port: 443 + protocol: websocket + enforcement: enforce + access: read-write + deny_rules: + - method: WEBSOCKET_TEXT + path: "/v1/admin/**" +``` + +##### GraphQL Deny Rule (`protocol: graphql` or GraphQL-over-WebSocket) -GraphQL deny rules use the same field names as GraphQL allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. +GraphQL deny rules use the same field names as GraphQL allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. On WebSocket GraphQL endpoints, they apply only to classified GraphQL operation messages; protocol lifecycle messages such as `connection_init`, `ping`, `pong`, and `complete` are allowed as WebSocket control-plane messages and are not payload-logged. | Field | Type | Required | Description | |---|---|---|---| diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 4e0aa4357..fb0b04cfe 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -49,14 +49,14 @@ network_policies: Static sections are locked at sandbox creation. Changing them requires destroying and recreating the sandbox. Dynamic sections can be updated on a running sandbox with `openshell policy update` for incremental merges or `openshell policy set` for full replacement, and take effect without restarting. When a hot reload changes rules on an active HTTP L7 endpoint, existing keep-alive tunnels are closed before forwarding another parsed request. Credential-injection-only HTTP passthrough tunnels use the same reload boundary. Most HTTP clients reconnect automatically, and the next request is evaluated against the current policy. -Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. +Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. Use `protocol: websocket` when policy should stay attached to the RFC 6455 upgrade and client text messages after the allowed upgrade. Add `websocket_credential_rewrite: true` only when the relay should rewrite credential placeholders in client-to-server WebSocket text messages. Add `request_body_credential_rewrite: true` only on inspected REST endpoints that need OpenShell to rewrite placeholders in supported text request bodies. | Section | Type | Description | |---|---|---| | `filesystem_policy` | Static | Controls which directories the agent can access on disk. Paths are split into `read_only` and `read_write` lists. Any path not listed in either list is inaccessible. Set `include_workdir: true` to automatically add the agent's working directory to `read_write`. [Landlock LSM](https://docs.kernel.org/security/landlock.html) enforces these restrictions at the kernel level. | | `landlock` | Static | Configures Landlock LSM enforcement behavior. Set `compatibility` to `best_effort` (skip individual inaccessible paths while applying remaining rules) or `hard_requirement` (fail if any path is inaccessible or the required kernel ABI is unavailable). See the [Policy Schema Reference](/reference/policy-schema#landlock) for the full behavior table. | | `process` | Static | Sets the OS-level identity for the agent process. `run_as_user` and `run_as_group` default to `sandbox`. Root (`root` or `0`) is rejected. The agent also runs with seccomp filters that block dangerous system calls. | -| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path).
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | +| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). For endpoints with `protocol: websocket`, the proxy validates the RFC 6455 upgrade and evaluates `GET` rules for the handshake plus either `WEBSOCKET_TEXT` rules for raw client text messages or GraphQL operation rules for GraphQL-over-WebSocket messages. Set `websocket_credential_rewrite: true` only when a WebSocket or REST compatibility endpoint must keep placeholder credentials in sandbox-owned text frames and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | ## Baseline Filesystem Paths @@ -123,7 +123,7 @@ The following steps outline the hot-reload policy update workflow. openshell logs --tail --source sandbox ``` -3. For additive network changes, use `openshell policy update`. This is the fastest path for adding endpoints, binaries, or REST allow/deny rules without replacing the full policy. The full option and format reference is in [Incremental Policy Updates](#incremental-policy-updates). +3. For additive network changes, use `openshell policy update`. This is the fastest path for adding endpoints, binaries, or REST and WebSocket allow/deny rules without replacing the full policy. The full option and format reference is in [Incremental Policy Updates](#incremental-policy-updates). ```shell openshell policy update \ @@ -136,7 +136,7 @@ The following steps outline the hot-reload policy update workflow. --wait ``` - `--add-allow` and `--add-deny` currently target existing `protocol: rest` endpoints only. If you pass multiple update flags in one command, OpenShell applies them as one atomic merge batch and persists at most one new revision. + `--add-allow` and `--add-deny` target existing `protocol: rest` or `protocol: websocket` endpoints. If you pass multiple update flags in one command, OpenShell applies them as one atomic merge batch and persists at most one new revision. 4. For larger edits, pull the current policy and edit the YAML directly. Strip the metadata header (Version, Hash, Status) before reusing the file. @@ -165,7 +165,7 @@ Use `openshell policy update` when you want to merge network policy changes into `openshell policy update` is useful when you want to: - add a new endpoint for an existing binary without touching other policy sections. -- add a few REST allow or deny rules after you see a blocked request in the logs. +- add a few REST or WebSocket allow/deny rules after you see a blocked request in the logs. - remove one endpoint or one named rule without rewriting the rest of the file. - preview a merged result locally with `--dry-run` before you send it to the gateway. @@ -173,15 +173,15 @@ Use `openshell policy set` instead when you want to replace the full policy, upd ### Update Commands -The incremental update surface is split into endpoint-level operations and REST rule-level operations. +The incremental update surface is split into endpoint-level operations and method/path rule-level operations for REST and WebSocket endpoints. | Flag | What it changes | Typical use | |---|---|---| -| `--add-endpoint ` | Creates or merges a network rule and endpoint. | Allow a new host and port, optionally with `access`, `protocol`, `enforcement`, and binaries. | +| `--add-endpoint ` | Creates or merges a network rule and endpoint. | Allow a new host and port, optionally with `access`, `protocol`, `enforcement`, endpoint options, and binaries. | | `--remove-endpoint ` | Removes one host and port match from the current policy. | Drop a stale endpoint or remove one port from a multi-port endpoint. | | `--remove-rule ` | Deletes a named `network_policies` entry. | Remove a whole rule by name when you no longer need it. | -| `--add-allow ` | Appends REST allow rules to an existing endpoint. | Permit one additional method and path on a REST API that is already configured. | -| `--add-deny ` | Appends REST deny rules to an existing endpoint. | Block a sensitive REST path under an endpoint that is otherwise allowed. | +| `--add-allow ` | Appends method/path allow rules to an existing REST or WebSocket endpoint. | Permit one additional REST method/path or WebSocket `WEBSOCKET_TEXT` path on an API that is already configured. | +| `--add-deny ` | Appends method/path deny rules to an existing REST or WebSocket endpoint. | Block a sensitive REST path or WebSocket text-message path under an endpoint that is otherwise allowed. | | `--binary ` | Adds binaries to every `--add-endpoint` rule in the same command. | Bind a new endpoint to one or more executables. | | `--rule-name ` | Overrides the generated rule name. | Keep a stable human-chosen rule name when adding exactly one endpoint. | | `--dry-run` | Shows the merged policy locally and does not call the gateway. | Review the result before persisting it. | @@ -194,17 +194,18 @@ The incremental update surface is split into endpoint-level operations and REST `--add-endpoint` works at the endpoint and rule level. It creates a new `network_policies` entry when needed, or merges into an existing rule that already covers the same host and port. Use it when you are defining where traffic may go and which binaries may send it. -`--add-allow` and `--add-deny` work at the REST request level. They do not create binaries, and they do not create a new endpoint. They modify an existing endpoint that already has `protocol: rest`. +`--add-allow` and `--add-deny` work at the method/path rule level. They do not create binaries, and they do not create a new endpoint. They modify an existing endpoint that already has `protocol: rest` or `protocol: websocket`. This is the practical difference: - Use `--add-endpoint` to say "allow this binary to reach `api.github.com:443`." - Use `--add-allow` to say "for that existing REST endpoint, also allow `POST /repos/*/issues`." - Use `--add-deny` to say "for that existing REST endpoint, explicitly deny `POST /admin/**`." +- Use `--add-allow` to say "for that existing WebSocket endpoint, also allow client text messages on `/v1/realtime/**`." -In the first pass of this feature: +Current constraints: -- `--add-allow` and `--add-deny` only work on `protocol: rest` endpoints. +- `--add-allow` and `--add-deny` work on `protocol: rest` and `protocol: websocket` endpoints. - `--add-deny` requires the endpoint to already have an allow base, either an `access` preset or explicit allow `rules`. - `protocol: sql` is not a practical incremental workflow today. OpenShell does not do full SQL parsing, and SQL enforcement is not meaningfully supported yet. @@ -213,7 +214,7 @@ In the first pass of this feature: `--add-endpoint` uses this format: ```text -host:port[:access[:protocol[:enforcement]]] +host:port[:access[:protocol[:enforcement[:options]]]] ``` Each segment has a fixed meaning: @@ -222,9 +223,10 @@ Each segment has a fixed meaning: |---|---|---| | `host` | Yes | Destination hostname. | | `port` | Yes | Destination port, `1` through `65535`. | -| `access` | No | Access preset for L7 endpoints: `read-only`, `read-write`, or `full`. Incremental updates currently expand presets for REST-shaped access. | -| `protocol` | No | L7 inspection mode: `rest` or `sql`. `sql` is audit-only and not a recommended workflow today. | +| `access` | No | Access preset for L7 endpoints: `read-only`, `read-write`, or `full`. Incremental updates expand presets into protocol-specific method/path rules for REST and WebSocket endpoints. | +| `protocol` | No | L7 inspection mode: `rest`, `websocket`, or `sql`. `sql` is audit-only and not a recommended workflow today. | | `enforcement` | No | Enforcement mode for inspected traffic: `enforce` or `audit`. | +| `options` | No | Comma-separated endpoint options. Use `websocket-credential-rewrite` with `protocol: websocket` or REST compatibility endpoints that perform a WebSocket upgrade. Use `request-body-credential-rewrite` only with `protocol: rest`. | Examples: @@ -232,19 +234,29 @@ Examples: |---|---| | `pypi.org:443` | Add a plain L4 endpoint. The proxy allows the TCP stream and does not inspect HTTP requests. | | `api.github.com:443:read-only:rest:enforce` | Add a REST endpoint with the `read-only` preset expanded by the policy engine into GET, HEAD, and OPTIONS access. | +| `api.example.com:443:read-write:rest:enforce:request-body-credential-rewrite` | Add a REST endpoint that rewrites credential placeholders in supported text request bodies. | +| `realtime.example.com:443:read-write:websocket:enforce` | Add a WebSocket endpoint with the `read-write` preset expanded by the policy engine into the upgrade `GET` and client `WEBSOCKET_TEXT` access. | +| `realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite` | Add a WebSocket endpoint that rewrites `openshell:resolve:env:*` placeholders in client text frames after an allowed upgrade. | -If you set `protocol: rest`, you also need an allow shape. With incremental updates, that means you should provide an `access` preset on `--add-endpoint`, then use `--add-allow` or `--add-deny` to refine REST endpoints later. +If you set `protocol: rest` or `protocol: websocket`, you also need an allow shape. With incremental updates, that means you should provide an `access` preset on `--add-endpoint`, then use `--add-allow` or `--add-deny` to refine method/path rules later. + +Use the `websocket-credential-rewrite` endpoint option with `protocol: websocket` when the sandbox should send credential placeholders in client text frames and have OpenShell resolve them after the allowed upgrade. The option can also be used with `protocol: rest` compatibility endpoints that perform a WebSocket upgrade. It is rejected for plain L4 or `protocol: sql` endpoints. + +Use the `request-body-credential-rewrite` endpoint option with `protocol: rest` when an API expects OpenShell-managed credentials in UTF-8 JSON, form, or text request bodies. OpenShell buffers up to 256 KiB, rewrites recognized credential placeholders, updates `Content-Length`, and rejects unresolved placeholders instead of forwarding them. The option is rejected for WebSocket, GraphQL, SQL, and plain L4 endpoints. + +Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token provider-shaped aliases such as `provider-OPENSHELL-RESOLVE-ENV-API_TOKEN` when the referenced environment key exists in the configured provider credentials. For example: - `api.github.com:443:read-only:rest` is valid. +- `realtime.example.com:443:read-write:websocket` is valid. - `api.github.com:443::rest` is invalid. It does not mean "allow all traffic." An L7 endpoint with `protocol` but no `access` or `rules` is rejected when the policy loads. -When you pass multiple `--add-endpoint` flags in one command, every `--binary` value applies to every added endpoint in that command. If different endpoints need different binaries, use separate `policy update` commands. +Endpoint options belong to the individual `--add-endpoint` spec. When you pass multiple `--add-endpoint` flags in one command, every `--binary` value applies to every added endpoint in that command. If different endpoints need different binaries, use separate `policy update` commands. If you do not pass `--rule-name`, OpenShell generates one from the host and port, such as `allow_api_github_com_443`. -### REST Rule Specs +### Method/Path Rule Specs `--add-allow` and `--add-deny` use this format: @@ -252,7 +264,7 @@ If you do not pass `--rule-name`, OpenShell generates one from the host and port host:port:METHOD:path_glob ``` -This string identifies an existing REST endpoint and the request pattern you want to add. +This string identifies an existing REST or WebSocket endpoint and the request pattern you want to add. In shell commands, quote the full `SPEC` when it contains `*` or `**` so your shell passes it literally instead of expanding it as a local file glob. @@ -260,8 +272,8 @@ In shell commands, quote the full `SPEC` when it contains `*` or `**` so your sh |---|---| | `host` | Existing endpoint host. | | `port` | Existing endpoint port. | -| `METHOD` | HTTP method. The CLI normalizes it to uppercase. | -| `path_glob` | URL path glob. It must start with `/`, or be `**`, or start with `**/`. | +| `METHOD` | HTTP method for REST endpoints, or `GET` / `WEBSOCKET_TEXT` for WebSocket endpoints. The CLI normalizes it to uppercase. | +| `path_glob` | URL path glob. For WebSocket text messages, this still matches the upgraded request path, not message payload content. It must start with `/`, or be `**`, or start with `**/`. | This example: @@ -283,11 +295,11 @@ Path globs follow the same semantics as YAML allow and deny rules: - `/repos/*/issues` matches one repository owner or name segment in the middle. - `/repos/**` matches everything under `/repos/`. -The rule-level commands only modify method and path constraints. They do not change binaries, hostnames, ports, or protocol settings. +The rule-level commands only modify method and path constraints. They do not change binaries, hostnames, ports, protocol settings, or WebSocket message payload matching. ### Common Workflows -Use these patterns as starting points when you decide whether to update an endpoint or append REST rules. +Use these patterns as starting points when you decide whether to update an endpoint or append REST/WebSocket rules. #### Add a new L4 endpoint @@ -302,7 +314,7 @@ openshell policy update demo \ --wait ``` -This creates or merges endpoint entries and binds them to the listed binaries. It does not create per-path REST rules. +This creates or merges endpoint entries and binds them to the listed binaries. It does not create inspected method/path rules. #### Create a REST endpoint with a base allow set @@ -341,6 +353,31 @@ openshell policy update demo \ This adds a deny rule to the existing REST endpoint. The endpoint must already have an allow base. +#### Create a WebSocket endpoint with a base allow set + +Use `--add-endpoint` with `protocol: websocket` when the destination is an RFC 6455 WebSocket API. + +```shell +openshell policy update demo \ + --add-endpoint realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite \ + --binary /usr/bin/node \ + --wait +``` + +This creates a WebSocket endpoint and sets its base allow behavior through the `read-write` access preset. For WebSocket endpoints, `read-write` expands to the upgrade `GET` and client `WEBSOCKET_TEXT` messages on the upgraded request path. The rewrite option lets the sandbox send `openshell:resolve:env:*` placeholders in client text frames; OpenShell resolves them before forwarding to the upstream service. + +#### Add a WebSocket text-message deny rule + +Use `WEBSOCKET_TEXT` when you want to refine client-to-server text-frame policy without matching message payload content. + +```shell +openshell policy update demo \ + --add-deny 'realtime.example.com:443:WEBSOCKET_TEXT:/v1/admin/**' \ + --wait +``` + +This adds a deny rule to the existing WebSocket endpoint. The path glob matches the WebSocket upgrade path. + #### Remove one endpoint or rule Use `--remove-endpoint` to remove one host and port pair, or `--remove-rule` to delete the whole named rule. @@ -379,7 +416,7 @@ The CLI validates the argument shapes before it sends the request. The gateway t - a required segment is missing. - a port is outside `1` through `65535`. - `--add-allow` or `--add-deny` points at an endpoint that does not exist. -- `--add-allow` or `--add-deny` targets a non-REST endpoint. +- `--add-allow` or `--add-deny` targets an endpoint that is neither REST nor WebSocket. - `--add-deny` targets an endpoint that has no base allow set. ## Global Policy Override @@ -415,7 +452,7 @@ When triaging denied requests, check: - Destination host and port to confirm which endpoint is missing. - Calling binary path to confirm which `binaries` entry needs to be added or adjusted. -- HTTP method and path (for REST endpoints) to confirm which `rules` entry needs to be added or adjusted. +- HTTP method and path for REST endpoints, or `GET` / `WEBSOCKET_TEXT` and the upgraded request path for WebSocket endpoints, to confirm which `rules` entry needs to be added or adjusted. Then push the updated policy as described above. @@ -427,7 +464,7 @@ openshell policy update --add-allow 'api.github.com:443:GET:/repos/**' -- ## Examples -Add these blocks to the `network_policies` section of your sandbox policy. Apply simple endpoints and REST rule additions with `openshell policy update`, or apply any complete YAML block with `openshell policy set --policy --wait`. +Add these blocks to the `network_policies` section of your sandbox policy. Apply simple endpoints and REST/WebSocket rule additions with `openshell policy update`, or apply any complete YAML block with `openshell policy set --policy --wait`. Use **Simple endpoint** for host-level allowlists and **Granular rules** for method/path control. @@ -447,7 +484,7 @@ Allow `pip install` and `uv pip install` to reach PyPI: - { path: /usr/local/bin/uv } ``` -Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. +Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. WebSocket text-frame policy requires an explicit `protocol: websocket` endpoint. WebSocket payload credential rewrite can also be enabled on a `protocol: rest` compatibility endpoint with `websocket_credential_rewrite: true`. REST request body credential rewrite requires an inspected `protocol: rest` endpoint with `request_body_credential_rewrite: true`. @@ -505,7 +542,7 @@ For an end-to-end walkthrough that combines this policy with a GitHub credential - { path: /usr/bin/gh } ``` -Endpoints with `protocol: rest` enable HTTP request inspection. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets both protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. +Endpoints with `protocol: rest` enable HTTP request inspection and can opt in to supported text request body credential rewrite. Endpoints with `protocol: websocket` validate WebSocket upgrades and inspect client text messages on the upgraded request path. WebSocket endpoints can also classify GraphQL-over-WebSocket operation messages with the same operation rules used by GraphQL-over-HTTP. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets these protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. @@ -570,6 +607,36 @@ For allow rules, every selected root field in an operation must match one of the Hash-only persisted queries cannot be classified from the request alone. OpenShell denies them unless the endpoint uses `persisted_queries: allow_registered` and provides a trusted `graphql_persisted_queries` entry keyed by hash or saved-query ID. +### GraphQL-over-WebSocket matching + +Some APIs carry GraphQL operations over RFC 6455 WebSockets, commonly for subscriptions and realtime updates. Configure these as `protocol: websocket`, allow the upgrade with a normal `GET` rule, then add GraphQL operation rules for client operation messages. OpenShell recognizes modern `graphql-transport-ws` `subscribe` messages and legacy `graphql-ws` `start` messages. + +```yaml showLineNumbers={false} + realtime_graphql: + name: realtime_graphql + endpoints: + - host: realtime.example.com + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: subscription + fields: [messageAdded] + - allow: + operation_type: query + fields: [viewer] + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +``` + +When a WebSocket endpoint has GraphQL operation policy, client operation messages are fail-closed on malformed JSON, unsupported message types, parse errors, unregistered hash-only persisted queries, or unallowed operations. Use GraphQL operation rules for client messages rather than a raw `WEBSOCKET_TEXT` allow rule. Protocol lifecycle messages such as `connection_init`, `ping`, `pong`, and `complete` are allowed without payload logging; if `websocket_credential_rewrite: true` is set, placeholders inside those text messages are resolved before forwarding. + ### GraphQL service policy shapes GraphQL field names are application-specific, so treat these as starting shapes to review against the actual app schema: diff --git a/docs/security/best-practices.mdx b/docs/security/best-practices.mdx index 227bb4567..f51781144 100644 --- a/docs/security/best-practices.mdx +++ b/docs/security/best-practices.mdx @@ -97,9 +97,9 @@ The `protocol` field on an endpoint controls whether the proxy inspects individu | Aspect | Detail | |---|---| | Default | Endpoints without a `protocol` field use L4-only enforcement: the proxy checks host, port, and binary, then relays the TCP stream without inspecting payloads. | -| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, or `protocol: graphql` to inspect GraphQL operation type, operation name, and root fields. Pair either protocol with `rules` or access presets (`full`, `read-only`, `read-write`). | +| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, `protocol: websocket` to inspect RFC 6455 upgrade handshakes and client text messages, or `protocol: graphql` to inspect GraphQL-over-HTTP operation type, operation name, and root fields. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket messages. Pair inspected protocols with `rules` or access presets (`full`, `read-only`, `read-write`). REST endpoints that need credential placeholders in supported text request bodies can set `request_body_credential_rewrite: true`. | | Risk if relaxed | L4-only endpoints allow the agent to send any data through the tunnel after the initial connection is permitted. The proxy cannot see HTTP methods, paths, or GraphQL operations. Adding `access: full` with L7 inspection enables observability but permits all inspected actions. | -| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Use `protocol: graphql` for GraphQL APIs where destructive operations are body-encoded. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols (WebSocket, gRPC streaming). | +| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Add `request_body_credential_rewrite: true` only for REST APIs that require OpenShell-managed credentials in UTF-8 JSON, form, or text request bodies. Use `protocol: graphql` for GraphQL-over-HTTP APIs where destructive operations are body-encoded. Use `protocol: websocket` for RFC 6455 endpoints, with explicit `GET` and `WEBSOCKET_TEXT` rules for raw text protocols or explicit GraphQL operation rules for GraphQL-over-WebSocket. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols. For WebSocket endpoints that must carry placeholder credentials in client text frames, add `websocket_credential_rewrite: true`. | ### Enforcement Mode (`audit` vs `enforce`) @@ -283,8 +283,8 @@ The following patterns weaken security without providing meaningful benefit. | Mistake | Why it matters | What to do instead | |---------|---------------|-------------------| -| Omitting `protocol: rest` on REST API endpoints | Without `protocol: rest`, the proxy uses L4-only enforcement. It allows the TCP stream through after checking host, port, and binary, but cannot inspect individual HTTP requests. | Add `protocol: rest` with specific `rules` to enable per-request method and path control. | -| Using `access: full` when finer rules would suffice | `access: full` with `protocol: rest` enables inspection but allows all HTTP methods and paths. | Use `access: read-only` or explicit `rules` to restrict what the agent can do at the HTTP level. | +| Omitting an inspected protocol on REST or WebSocket API endpoints | Without `protocol: rest` or `protocol: websocket`, the proxy uses L4-only enforcement. It allows the TCP stream through after checking host, port, and binary, but cannot inspect individual HTTP requests or WebSocket text messages. | Add `protocol: rest` or `protocol: websocket` with specific `rules` to enable method and path control. | +| Using `access: full` when finer rules would suffice | `access: full` with `protocol: rest` or `protocol: websocket` enables inspection but allows all methods and paths for that protocol. | Use `access: read-only` or explicit `rules` to restrict what the agent can do at the L7 level. | | Adding endpoints permanently when operator approval would suffice | Adding endpoints to the policy YAML makes them permanently reachable across all instances. | Use operator approval. Approved endpoints persist within the sandbox instance but reset on re-creation. | | Using broad binary globs | A glob like `/**` allows any binary to reach the endpoint, defeating binary-scoped enforcement. | Scope globs to specific directories (for example, `/sandbox/.vscode-server/**`). | | Skipping TLS termination on HTTPS APIs | Setting `tls: skip` disables credential injection and L7 inspection. | Use the default auto-detect behavior unless the upstream requires client-certificate mTLS. | diff --git a/e2e/rust/Cargo.lock b/e2e/rust/Cargo.lock index dcf1f72af..61f15866b 100644 --- a/e2e/rust/Cargo.lock +++ b/e2e/rust/Cargo.lock @@ -8,6 +8,12 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "2.11.0" @@ -238,9 +244,11 @@ checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" name = "openshell-e2e" version = "0.1.0" dependencies = [ + "base64", "hex", "rand", "serde_json", + "sha1", "sha2", "tempfile", "tokio", @@ -429,6 +437,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" diff --git a/e2e/rust/Cargo.toml b/e2e/rust/Cargo.toml index 072727b8a..89a75967a 100644 --- a/e2e/rust/Cargo.toml +++ b/e2e/rust/Cargo.toml @@ -45,6 +45,11 @@ name = "gateway_resume" path = "tests/gateway_resume.rs" required-features = ["e2e-docker"] +[[test]] +name = "websocket_conformance" +path = "tests/websocket_conformance.rs" +required-features = ["e2e-docker"] + [[test]] name = "user_namespaces" path = "tests/user_namespaces.rs" @@ -71,8 +76,10 @@ path = "tests/gpu_device_selection.rs" required-features = ["e2e-gpu"] [dependencies] +base64 = "0.22" tokio = { version = "1.43", features = ["full"] } tempfile = "3" +sha1 = "0.10" sha2 = "0.10" hex = "0.4" rand = "0.9" diff --git a/e2e/rust/tests/websocket_conformance.rs b/e2e/rust/tests/websocket_conformance.rs new file mode 100644 index 000000000..d87c9b9dd --- /dev/null +++ b/e2e/rust/tests/websocket_conformance.rs @@ -0,0 +1,482 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![cfg(feature = "e2e")] + +//! E2E regression: WebSocket credential placeholders are resolved on the real +//! Docker-backed sandbox path after an RFC 6455 upgrade. +//! +//! The sandbox process sends its provider-managed placeholder in a masked text +//! frame. The local upstream only reports whether it saw the real secret and +//! whether any placeholder survived; it never echoes payload bytes, placeholder +//! text, or secret material back into test output. + +use std::io::{self, Error, ErrorKind, Write}; +use std::process::Stdio; +use std::sync::Mutex; + +use base64::Engine as _; +use openshell_e2e::harness::binary::openshell_cmd; +use openshell_e2e::harness::sandbox::SandboxGuard; +use sha1::{Digest, Sha1}; +use tempfile::NamedTempFile; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::task::JoinHandle; + +const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +const PROVIDER_NAME: &str = "e2e-websocket-conformance"; +const TEST_SERVER_HOST: &str = "host.openshell.internal"; +const TEST_SECRET: &str = "sk-e2e-websocket-conformance-secret"; +const TOKEN_ENV: &str = "WS_E2E_TOKEN"; +const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; +static PROVIDER_LOCK: Mutex<()> = Mutex::new(()); + +async fn run_cli(args: &[&str]) -> Result { + let mut cmd = openshell_cmd(); + cmd.args(args).stdout(Stdio::piped()).stderr(Stdio::piped()); + + let output = cmd + .output() + .await + .map_err(|e| format!("failed to spawn openshell {}: {e}", args.join(" ")))?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let combined = format!("{stdout}{stderr}"); + + if !output.status.success() { + return Err(format!( + "openshell {} failed (exit {:?}):\n{combined}", + args.join(" "), + output.status.code() + )); + } + + Ok(combined) +} + +async fn delete_provider(name: &str) { + let mut cmd = openshell_cmd(); + cmd.arg("provider") + .arg("delete") + .arg(name) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + let _ = cmd.status().await; +} + +async fn create_generic_provider(name: &str) -> Result { + let credential = format!("{TOKEN_ENV}={TEST_SECRET}"); + run_cli(&[ + "provider", + "create", + "--name", + name, + "--type", + "generic", + "--credential", + &credential, + ]) + .await +} + +struct WebSocketProbeServer { + port: u16, + task: JoinHandle<()>, +} + +impl WebSocketProbeServer { + async fn start() -> Result { + let listener = TcpListener::bind(("0.0.0.0", 0)) + .await + .map_err(|e| format!("bind websocket probe server: {e}"))?; + let port = listener + .local_addr() + .map_err(|e| format!("read websocket probe server address: {e}"))? + .port(); + let task = tokio::spawn(async move { + loop { + let Ok((stream, _)) = listener.accept().await else { + break; + }; + tokio::spawn(async move { + let _ = handle_websocket_probe_connection(stream).await; + }); + } + }); + + Ok(Self { port, task }) + } +} + +impl Drop for WebSocketProbeServer { + fn drop(&mut self) { + self.task.abort(); + } +} + +async fn recv_until(stream: &mut TcpStream, marker: &[u8]) -> io::Result> { + let mut data = Vec::new(); + let mut buf = [0_u8; 1024]; + loop { + let read = stream.read(&mut buf).await?; + if read == 0 { + return Ok(data); + } + data.extend_from_slice(&buf[..read]); + if data.windows(marker.len()).any(|window| window == marker) { + return Ok(data); + } + } +} + +async fn read_websocket_text(stream: &mut TcpStream) -> io::Result { + let mut header = [0_u8; 2]; + stream.read_exact(&mut header).await?; + let length = match header[1] & 0x7F { + len @ 0..=125 => usize::from(len), + 126 => { + let mut bytes = [0_u8; 2]; + stream.read_exact(&mut bytes).await?; + usize::from(u16::from_be_bytes(bytes)) + } + 127 => { + let mut bytes = [0_u8; 8]; + stream.read_exact(&mut bytes).await?; + usize::try_from(u64::from_be_bytes(bytes)) + .map_err(|_| Error::new(ErrorKind::InvalidData, "websocket frame too large"))? + } + _ => unreachable!(), + }; + + let mut mask = [0_u8; 4]; + if header[1] & 0x80 != 0 { + stream.read_exact(&mut mask).await?; + } else { + mask = [0, 0, 0, 0]; + } + + let mut payload = vec![0_u8; length]; + stream.read_exact(&mut payload).await?; + if header[1] & 0x80 != 0 { + for (index, byte) in payload.iter_mut().enumerate() { + *byte ^= mask[index % mask.len()]; + } + } + + String::from_utf8(payload).map_err(|e| { + Error::new( + ErrorKind::InvalidData, + format!("invalid websocket text: {e}"), + ) + }) +} + +async fn send_websocket_text(stream: &mut TcpStream, payload: &str) -> io::Result<()> { + let data = payload.as_bytes(); + let mut frame = Vec::with_capacity(data.len() + 10); + frame.push(0x81); + if data.len() < 126 { + frame.push( + u8::try_from(data.len()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "websocket frame too large"))?, + ); + } else if data.len() <= usize::from(u16::MAX) { + frame.push(126); + frame.extend_from_slice( + &u16::try_from(data.len()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "websocket frame too large"))? + .to_be_bytes(), + ); + } else { + frame.push(127); + frame.extend_from_slice( + &u64::try_from(data.len()) + .map_err(|_| Error::new(ErrorKind::InvalidData, "websocket frame too large"))? + .to_be_bytes(), + ); + } + frame.extend_from_slice(data); + stream.write_all(&frame).await +} + +fn header_value(request: &str, name: &str) -> Option { + request.lines().find_map(|line| { + let (header, value) = line.split_once(':')?; + if header.trim().eq_ignore_ascii_case(name) { + Some(value.trim().to_string()) + } else { + None + } + }) +} + +fn websocket_accept_for_key(key: &str) -> String { + let mut hasher = Sha1::new(); + hasher.update(key.as_bytes()); + hasher.update(WEBSOCKET_GUID.as_bytes()); + base64::engine::general_purpose::STANDARD.encode(hasher.finalize()) +} + +async fn handle_websocket_probe_connection(mut stream: TcpStream) -> io::Result<()> { + let request_bytes = recv_until(&mut stream, b"\r\n\r\n").await?; + let request = String::from_utf8_lossy(&request_bytes); + if !request.to_ascii_lowercase().contains("upgrade: websocket") { + stream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok") + .await?; + return Ok(()); + } + + let accept = header_value(&request, "Sec-WebSocket-Key") + .map(|key| websocket_accept_for_key(&key)) + .ok_or_else(|| Error::new(ErrorKind::InvalidData, "missing Sec-WebSocket-Key"))?; + let response = format!( + "HTTP/1.1 101 Switching Protocols\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Accept: {accept}\r\n\ + \r\n" + ); + stream.write_all(response.as_bytes()).await?; + + let text = read_websocket_text(&mut stream).await?; + let response = format!( + r#"{{"saw_placeholder": {}, "saw_secret": {}}}"#, + text.contains(PLACEHOLDER_PREFIX), + text.contains(TEST_SECRET) + ); + send_websocket_text(&mut stream, &response).await +} + +fn write_websocket_policy(host: &str, port: u16) -> Result { + let mut file = NamedTempFile::new().map_err(|e| format!("create temp policy file: {e}"))?; + let policy = format!( + r#"version: 1 + +filesystem_policy: + include_workdir: true + read_only: + - /usr + - /lib + - /proc + - /dev/urandom + - /app + - /etc + - /var/log + read_write: + - /sandbox + - /tmp + - /dev/null + +landlock: + compatibility: best_effort + +process: + run_as_user: sandbox + run_as_group: sandbox + +network_policies: + websocket_conformance: + name: websocket_conformance + endpoints: + - host: {host} + port: {port} + protocol: websocket + enforcement: enforce + access: read-write + websocket_credential_rewrite: true + allowed_ips: + - "10.0.0.0/8" + - "172.0.0.0/8" + - "192.168.0.0/16" + - "fc00::/7" + binaries: + - path: /usr/bin/python* + - path: /usr/local/bin/python* + - path: /sandbox/.uv/python/*/bin/python* +"# + ); + file.write_all(policy.as_bytes()) + .map_err(|e| format!("write temp policy file: {e}"))?; + file.flush() + .map_err(|e| format!("flush temp policy file: {e}"))?; + Ok(file) +} + +fn websocket_client_script(host: &str, port: u16) -> String { + format!( + r#" +import base64 +import json +import os +import socket +import struct +import time +import urllib.parse + +HOST = {host:?} +PORT = {port} +TOKEN_ENV = {token_env:?} + +def recv_until(sock, marker): + data = b"" + while marker not in data: + chunk = sock.recv(4096) + if not chunk: + break + data += chunk + return data + +def read_exact(sock, size): + data = b"" + while len(data) < size: + chunk = sock.recv(size - len(data)) + if not chunk: + raise EOFError("unexpected end of websocket frame") + data += chunk + return data + +def masked_text_frame(payload): + data = payload.encode("utf-8") + mask = os.urandom(4) + if len(data) < 126: + header = bytes([0x81, 0x80 | len(data)]) + elif len(data) <= 0xFFFF: + header = bytes([0x81, 0x80 | 126]) + struct.pack("!H", len(data)) + else: + header = bytes([0x81, 0x80 | 127]) + struct.pack("!Q", len(data)) + masked = bytes(byte ^ mask[index % 4] for index, byte in enumerate(data)) + return header + mask + masked + +def read_frame(sock): + first, second = read_exact(sock, 2) + length = second & 0x7F + if length == 126: + length = struct.unpack("!H", read_exact(sock, 2))[0] + elif length == 127: + length = struct.unpack("!Q", read_exact(sock, 8))[0] + mask = read_exact(sock, 4) if second & 0x80 else b"" + payload = read_exact(sock, length) + if mask: + payload = bytes(byte ^ mask[index % 4] for index, byte in enumerate(payload)) + return first, payload + +def proxy_parts(): + names = ("HTTP_PROXY", "http_proxy", "HTTPS_PROXY", "https_proxy", "ALL_PROXY", "all_proxy") + proxy_url = next((os.environ.get(name) for name in names if os.environ.get(name)), None) + if not proxy_url: + raise RuntimeError("proxy environment is not configured") + parsed = urllib.parse.urlparse(proxy_url) + if not parsed.hostname: + raise RuntimeError(f"invalid proxy URL: {{proxy_url!r}}") + return parsed.hostname, parsed.port or 80 + +def connect_with_retry(host, port, timeout_seconds=20): + proxy_host, proxy_port = proxy_parts() + target = f"{{host}}:{{port}}" + deadline = time.monotonic() + timeout_seconds + last_error = None + while time.monotonic() < deadline: + sock = None + try: + sock = socket.create_connection((proxy_host, proxy_port), timeout=5) + request = f"CONNECT {{target}} HTTP/1.1\r\nHost: {{target}}\r\n\r\n" + sock.sendall(request.encode("ascii")) + response = recv_until(sock, b"\r\n\r\n").decode("iso-8859-1", "replace") + if response.startswith("HTTP/1.1 200") or response.startswith("HTTP/1.0 200"): + return sock + first_line = response.splitlines()[0] if response else "" + raise RuntimeError(f"proxy CONNECT failed: {{first_line}}") + except (OSError, RuntimeError) as error: + if sock is not None: + sock.close() + last_error = error + time.sleep(0.25) + raise last_error + +token = os.environ[TOKEN_ENV] +payload = json.dumps({{"authorization": "Bearer " + token}}, sort_keys=True) +key = base64.b64encode(os.urandom(16)).decode("ascii") + +with connect_with_retry(HOST, PORT) as sock: + request = ( + f"GET /ws HTTP/1.1\r\n" + f"Host: {{HOST}}:{{PORT}}\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {{key}}\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n" + ) + sock.sendall(request.encode("ascii")) + response = recv_until(sock, b"\r\n\r\n").decode("iso-8859-1", "replace") + if not response.startswith("HTTP/1.1 101"): + raise RuntimeError("websocket upgrade failed") + sock.sendall(masked_text_frame(payload)) + _, response_payload = read_frame(sock) + print(response_payload.decode("utf-8")) +"#, + host = host, + port = port, + token_env = TOKEN_ENV, + ) +} + +#[tokio::test] +async fn websocket_text_placeholder_is_rewritten_in_docker_sandbox() { + let _provider_lock = PROVIDER_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + + delete_provider(PROVIDER_NAME).await; + create_generic_provider(PROVIDER_NAME) + .await + .expect("create generic provider"); + + let result = async { + let server = WebSocketProbeServer::start().await?; + let policy = write_websocket_policy(TEST_SERVER_HOST, server.port)?; + let policy_path = policy + .path() + .to_str() + .ok_or_else(|| "temp policy path should be utf-8".to_string())? + .to_string(); + let script = websocket_client_script(TEST_SERVER_HOST, server.port); + + SandboxGuard::create(&[ + "--policy", + &policy_path, + "--provider", + PROVIDER_NAME, + "--", + "python3", + "-c", + &script, + ]) + .await + } + .await; + + delete_provider(PROVIDER_NAME).await; + + let guard = result.expect("sandbox create"); + assert!( + guard + .create_output + .contains(r#"{"saw_placeholder": false, "saw_secret": true}"#), + "expected upstream to see only the resolved secret marker:\n{}", + guard.create_output + ); + assert!( + !guard.create_output.contains(TEST_SECRET), + "test output should not expose the raw WebSocket credential:\n{}", + guard.create_output + ); + assert!( + !guard.create_output.contains(PLACEHOLDER_PREFIX), + "test output should not expose unresolved credential placeholders:\n{}", + guard.create_output + ); +} diff --git a/proto/sandbox.proto b/proto/sandbox.proto index f7df5945e..b40d95cb1 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -70,7 +70,7 @@ message NetworkEndpoint { // Single port (backwards compat). Use `ports` for multiple ports. // Mutually exclusive with `ports` — if both are set, `ports` takes precedence. uint32 port = 2; - // Application protocol for L7 inspection: "rest", "graphql", "sql", or "" (L4-only). + // Application protocol for L7 inspection: "rest", "websocket", "graphql", "sql", or "" (L4-only). string protocol = 3; // TLS handling: "terminate" or "passthrough" (default). string tls = 4; @@ -116,6 +116,14 @@ message NetworkEndpoint { // protocol "rest" when both surfaces live under api.example.com:443. // Empty means all paths. string path = 15; + // When true on a "rest" endpoint, OpenShell rewrites credential placeholders + // inside client-to-server WebSocket text messages after an allowed HTTP 101 + // upgrade. Defaults to false. + bool websocket_credential_rewrite = 16; + // When true on a "rest" endpoint, OpenShell rewrites credential placeholders + // inside supported textual HTTP request bodies before forwarding upstream. + // Defaults to false. + bool request_body_credential_rewrite = 17; } // Trusted GraphQL operation classification. diff --git a/tasks/test.toml b/tasks/test.toml index 8e26170d5..91d2c44f6 100644 --- a/tasks/test.toml +++ b/tasks/test.toml @@ -38,6 +38,13 @@ run = [ "e2e/with-docker-gateway.sh cargo test --manifest-path e2e/rust/Cargo.toml --features e2e-docker", ] +["e2e:websocket-conformance"] +description = "Run focused WebSocket conformance e2e tests against a Docker-backed gateway" +run = [ + "cargo build -p openshell-cli --features openshell-core/dev-settings", + "e2e/with-docker-gateway.sh cargo test --manifest-path e2e/rust/Cargo.toml --features e2e-docker --test websocket_conformance", +] + ["e2e:python"] description = "Run Python e2e tests against a Docker-backed gateway (E2E_PARALLEL=N or 'auto'; default 5)" depends = ["python:proto"] From df5a8b943f50aa5bbfbdd3505429276154fbccc8 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Tue, 12 May 2026 15:43:23 +0100 Subject: [PATCH 039/157] fix(providers): read opencode config file during credential discovery (#1290) Supplement env-var discovery with API keys stored in the opencode config file at $XDG_CONFIG_HOME/opencode/opencode.json. Keys under provider..options.apiKey are surfaced as _API_KEY env vars; existing env vars take priority over file-sourced values. --- .../src/providers/opencode.rs | 136 +++++++++++++++++- 1 file changed, 132 insertions(+), 4 deletions(-) diff --git a/crates/openshell-providers/src/providers/opencode.rs b/crates/openshell-providers/src/providers/opencode.rs index 417bdb6c2..feb707fd0 100644 --- a/crates/openshell-providers/src/providers/opencode.rs +++ b/crates/openshell-providers/src/providers/opencode.rs @@ -1,8 +1,12 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + use crate::{ - ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, + DiscoveredProvider, ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, + discover_with_spec, }; pub struct OpencodeProvider; @@ -12,13 +16,80 @@ pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { credential_env_vars: &["OPENCODE_API_KEY", "OPENROUTER_API_KEY", "OPENAI_API_KEY"], }; +/// Return the path to the opencode config file, respecting `XDG_CONFIG_HOME`. +fn opencode_config_path() -> Option { + let config_home = std::env::var("XDG_CONFIG_HOME") + .ok() + .map(PathBuf::from) + .or_else(|| { + std::env::var("HOME") + .ok() + .map(|h| PathBuf::from(h).join(".config")) + })?; + Some(config_home.join("opencode").join("opencode.json")) +} + +/// Extract API key credentials from the contents of an opencode config file. +/// +/// opencode stores per-provider API keys at `provider..options.apiKey`. +/// Each key is surfaced as `_API_KEY` so that it can be injected +/// as an environment variable into the sandbox and picked up by opencode at runtime. +fn extract_credentials_from_opencode_config(content: &str) -> HashMap { + let Ok(json) = serde_json::from_str::(content) else { + return HashMap::new(); + }; + let Some(providers) = json.get("provider").and_then(|p| p.as_object()) else { + return HashMap::new(); + }; + + let mut creds = HashMap::new(); + for (provider_name, provider_cfg) in providers { + if let Some(api_key) = provider_cfg + .get("options") + .and_then(|o| o.get("apiKey")) + .and_then(|k| k.as_str()) + .filter(|k| !k.trim().is_empty()) + { + let env_var = format!("{}_API_KEY", provider_name.to_ascii_uppercase()); + creds.insert(env_var, api_key.to_string()); + } + } + creds +} + +/// Read opencode credentials from `path`, returning `None` if the file is absent or unreadable. +fn read_opencode_config_file(path: &Path) -> Option> { + let content = std::fs::read_to_string(path).ok()?; + let creds = extract_credentials_from_opencode_config(&content); + if creds.is_empty() { None } else { Some(creds) } +} + impl ProviderPlugin for OpencodeProvider { fn id(&self) -> &'static str { SPEC.id } - fn discover_existing(&self) -> Result, ProviderError> { - discover_with_spec(&SPEC, &RealDiscoveryContext) + fn discover_existing(&self) -> Result, ProviderError> { + let mut discovered = discover_with_spec(&SPEC, &RealDiscoveryContext)?.unwrap_or_default(); + + // Supplement env-var discovery with credentials stored in the opencode config file. + // opencode's native config lives at $XDG_CONFIG_HOME/opencode/opencode.json and stores + // API keys under `provider..options.apiKey`. If the user configured opencode + // normally (i.e. no env vars set), this is the only place the keys exist. + if let Some(path) = opencode_config_path() + && let Some(file_creds) = read_opencode_config_file(&path) + { + for (key, value) in file_creds { + // Env vars already set take priority; config file fills the gaps. + discovered.credentials.entry(key).or_insert(value); + } + } + + if discovered.is_empty() { + Ok(None) + } else { + Ok(Some(discovered)) + } } fn credential_env_vars(&self) -> &'static [&'static str] { @@ -28,7 +99,7 @@ impl ProviderPlugin for OpencodeProvider { #[cfg(test)] mod tests { - use super::SPEC; + use super::{SPEC, extract_credentials_from_opencode_config}; use crate::discover_with_spec; use crate::test_helpers::MockDiscoveryContext; @@ -43,4 +114,61 @@ mod tests { Some(&"op-key".to_string()) ); } + + #[test] + fn extracts_credentials_from_config_file() { + let config = r#"{ + "provider": { + "anthropic": { "options": { "apiKey": "sk-ant-key" } }, + "openai": { "options": { "apiKey": "sk-openai-key" } } + } + }"#; + let creds = extract_credentials_from_opencode_config(config); + assert_eq!( + creds.get("ANTHROPIC_API_KEY"), + Some(&"sk-ant-key".to_string()) + ); + assert_eq!( + creds.get("OPENAI_API_KEY"), + Some(&"sk-openai-key".to_string()) + ); + } + + #[test] + fn skips_providers_without_api_key() { + let config = r#"{ + "provider": { + "ollama": { "options": { "baseUrl": "http://localhost:11434" } } + } + }"#; + let creds = extract_credentials_from_opencode_config(config); + assert!( + creds.is_empty(), + "no credentials expected for keyless provider" + ); + } + + #[test] + fn skips_empty_api_keys() { + let config = r#"{ + "provider": { + "anthropic": { "options": { "apiKey": "" } } + } + }"#; + let creds = extract_credentials_from_opencode_config(config); + assert!(creds.is_empty()); + } + + #[test] + fn tolerates_malformed_json() { + let creds = extract_credentials_from_opencode_config("not json at all"); + assert!(creds.is_empty()); + } + + #[test] + fn tolerates_missing_provider_section() { + let config = r#"{ "theme": "dark" }"#; + let creds = extract_credentials_from_opencode_config(config); + assert!(creds.is_empty()); + } } From 3b61c9cdcf84186a0e0cedb213c57944ce2ff2df Mon Sep 17 00:00:00 2001 From: Arnon Rotem-Gal-Oz Date: Tue, 12 May 2026 22:44:38 +0300 Subject: [PATCH 040/157] feat(k8s): support nodeSelector and tolerations from platform_config (#1327) Extract node_selector and tolerations from the sandbox template's platform_config and apply them to the pod spec. This allows callers to schedule sandbox pods on specific node pools and tolerate taints without requiring first-class proto fields. --- .../openshell-driver-kubernetes/src/driver.rs | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index cde1f4b22..56b73447a 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -1078,6 +1078,12 @@ fn sandbox_template_to_k8s( serde_json::json!(runtime_class), ); } + if let Some(node_selector) = platform_config_struct(template, "node_selector") { + spec.insert("nodeSelector".to_string(), node_selector); + } + if let Some(tolerations) = platform_config_struct(template, "tolerations") { + spec.insert("tolerations".to_string(), tolerations); + } // Per-sandbox platform_config.host_users overrides the cluster-wide default. let use_user_namespaces = platform_config_bool(template, "host_users") @@ -2322,4 +2328,108 @@ mod tests { ); assert!(cr["spec"].get("logLevel").is_none()); } + + #[test] + fn node_selector_from_platform_config() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "node_selector".to_string(), + Value { + kind: Some(Kind::StructValue(Struct { + fields: std::iter::once(( + "gpu-pool".to_string(), + Value { + kind: Some(Kind::StringValue("true".to_string())), + }, + )) + .collect(), + })), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + false, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["nodeSelector"]["gpu-pool"], + serde_json::json!("true") + ); + } + + #[test] + fn tolerations_from_platform_config() { + let toleration = Struct { + fields: [ + ( + "key".to_string(), + Value { + kind: Some(Kind::StringValue("nvidia.com/gpu".to_string())), + }, + ), + ( + "operator".to_string(), + Value { + kind: Some(Kind::StringValue("Exists".to_string())), + }, + ), + ( + "effect".to_string(), + Value { + kind: Some(Kind::StringValue("NoSchedule".to_string())), + }, + ), + ] + .into_iter() + .collect(), + }; + + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "tolerations".to_string(), + Value { + kind: Some(Kind::ListValue(prost_types::ListValue { + values: vec![Value { + kind: Some(Kind::StructValue(toleration)), + }], + })), + }, + )) + .collect(), + }), + ..SandboxTemplate::default() + }; + + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + false, + ¶ms, + ) + }; + + let tolerations = pod_template["spec"]["tolerations"] + .as_array() + .expect("tolerations should be an array"); + assert_eq!(tolerations.len(), 1); + assert_eq!(tolerations[0]["key"], "nvidia.com/gpu"); + assert_eq!(tolerations[0]["operator"], "Exists"); + assert_eq!(tolerations[0]["effect"], "NoSchedule"); + } } From ba77967fc4c4aff4248b161fbb544395d754629c Mon Sep 17 00:00:00 2001 From: Taylor Mutch Date: Tue, 12 May 2026 16:05:49 -0700 Subject: [PATCH 041/157] refactor(docker): split gateway/supervisor Dockerfiles and use native rust builds (#1316) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * wip * refactor(docker): use native rust builds for split gateway/supervisor images Drop the in-Docker BUILD_FROM_SOURCE path so both images consume only prebuilt binaries staged natively via tasks/scripts/stage-prebuilt-binaries.sh. This mirrors what CI does and reuses the host's cargo target cache and sccache across rebuilds. - Dockerfile.gateway: nvcr.io/nvidia/distroless/cc:v4.0.4 base (the 4.0.0 tag does not exist on nvcr.io; the registry uses a v prefix). GNU-linked binary copied to /usr/local/bin. - Dockerfile.supervisor: scratch base, static musl binary. Static linkage lets the image stay scratch while still being executable as a Kubernetes init container. - skaffold.yaml: each artifact invokes tasks/scripts/docker-build-image.sh, which stages the binary natively (cargo / cargo-zigbuild) and then builds the image. Drops the cross-build.sh dependency from the supervisor build. - seccomp.rs: add a local SYS_kexec_file_load constant for musl/aarch64. libc 0.2.185 omits the symbol from its musl/aarch64 bindings, so the supervisor's seccomp filter previously failed to compile for that target. - architecture/build.md: describe the native-first pipeline and per-image runtime choices. Local validation: gateway image 101MB (was 194MB), supervisor image 21.7MB. helm:skaffold:run deploys cleanly; the static musl supervisor binary runs correctly in a non-glibc agent container. * refactor(docker): tighten binary perms via --chown + 0550 Replace `COPY --chmod=755` with `COPY --chown= --chmod=0550` in the gateway and supervisor Dockerfiles. The binary is no longer world-readable or world-executable; ownership is pinned to the runtime user. - Gateway uses `--chown=nvs:nvs` + `USER nvs:nvs`, matching the only non-root user defined in `nvcr.io/nvidia/distroless/cc` (UID 1000) and the Helm chart's `securityContext.runAsUser: 1000`, which overrides the Dockerfile USER at runtime. - Supervisor uses numeric `--chown=65534:65534` because the scratch base has no `/etc/passwd` for name resolution. The supervisor image is only consumed by the init-container copy-self path; the destination pod's runAsUser governs execute access. Validated by deploying to a local k3d cluster via `helm:skaffold:run` and confirming the gateway StatefulSet reaches 1/1 Running. * ci(gpu): repoint GPU probe image lookup at Dockerfile.gateway The previous awk parsed `FROM AS gateway` from the now-deleted `Dockerfile.images`. The new `Dockerfile.gateway` uses an ARG with a default (`ARG GATEWAY_BASE_IMAGE=nvcr.io/nvidia/distroless/cc:v4.0.4`) and `FROM ${GATEWAY_BASE_IMAGE} AS gateway`, so the old script returns nothing. Parse the ARG default value directly so the GPU prerequisites check keeps using the gateway base image as a `nvidia-smi` probe target. * ci(gpu): pin GPU probe to nvcr.io/nvidia/base/ubuntu:noble The previous probe parsed the gateway base image out of the Dockerfile, relying on the fact that the gateway ran on `nvcr.io/nvidia/base/ubuntu` and that NVIDIA Container Toolkit CDI injection would populate `nvidia-smi` and the supporting libs at runtime. The new gateway base (`nvcr.io/nvidia/distroless/cc`) lacks `ldconfig`, a populated `/usr/bin`, and the broader filesystem layout CDI injection assumes, so it cannot serve as a GPU probe. Pin the probe image explicitly to the NVIDIA-managed Ubuntu base. The probe is independent of the gateway runtime and survives future base swaps. * fix(e2e-gpu): pass GPU probe image via env, drop Dockerfile.images parse The Rust e2e test `gpu_request_for_each_discovered_device_matches_plain_container` was still parsing the deleted `Dockerfile.images` to derive its GPU probe image, panicking with `No such file or directory` after this PR's Dockerfile split. Move the probe image to a single source of truth in the workflow (`OPENSHELL_E2E_GPU_PROBE_IMAGE` env at the job level) and require the e2e test to read it from there. No silent codebase default — the test panics with a pointer at the workflow if the env is missing, so it fails loudly rather than drifting from CI. The prereq probe step in the workflow now consumes the same env, so the probe image is declared exactly once. * fix(docker): keep supervisor binary root-owned for rootless Podman The Podman driver mounts the supervisor image read-only into the sandbox container at /opt/openshell/bin and runs that container as UID 0, but deliberately drops DAC_OVERRIDE for hardening (container.rs:419). With --chown=65534:65534 --chmod=0550 the binary was r-xr-x--- owned by UID 65534, so the container's UID 0 fell into "other" with no read or exec access and the supervisor crashed on start (ContainerExited code 1). Docker and Kubernetes both retain DAC_OVERRIDE, so root could still exec the file — which is why this regression only surfaced in the Podman e2e job. Drop --chown from the supervisor COPY so the binary stays root-owned. Keep --chmod=0550: the security win was dropping world-execute, not changing the owner. The chown bought nothing here because the container is always UID 0 regardless of driver, but it actively broke the only driver that drops DAC_OVERRIDE. --------- Co-authored-by: Drew Newberry --- .../skills/debug-openshell-cluster/SKILL.md | 3 +- .github/workflows/e2e-gpu-test.yaml | 9 +- .github/workflows/rust-native-build.yml | 53 +++++++- architecture/build.md | 30 ++++- crates/openshell-driver-podman/README.md | 6 +- .../src/sandbox/linux/seccomp.rs | 14 +- deploy/docker/Dockerfile.ci | 1 + deploy/docker/Dockerfile.gateway | 41 ++++++ deploy/docker/Dockerfile.images | 122 ------------------ deploy/docker/Dockerfile.supervisor | 35 +++++ deploy/helm/openshell/skaffold.yaml | 40 +++--- deploy/helm/openshell/templates/certgen.yaml | 2 +- e2e/rust/tests/gpu_device_selection.rs | 54 ++------ tasks/scripts/docker-build-image.sh | 15 ++- tasks/scripts/stage-prebuilt-binaries.sh | 27 +++- 15 files changed, 230 insertions(+), 222 deletions(-) create mode 100644 deploy/docker/Dockerfile.gateway delete mode 100644 deploy/docker/Dockerfile.images create mode 100644 deploy/docker/Dockerfile.supervisor diff --git a/.agents/skills/debug-openshell-cluster/SKILL.md b/.agents/skills/debug-openshell-cluster/SKILL.md index b7b2c898c..48729e89a 100644 --- a/.agents/skills/debug-openshell-cluster/SKILL.md +++ b/.agents/skills/debug-openshell-cluster/SKILL.md @@ -74,6 +74,7 @@ Common findings: - Sandbox image missing or pull denied: verify image reference and registry credentials. - Docker driver cannot initialize because it cannot find `openshell-sandbox`: verify `OPENSHELL_DOCKER_SUPERVISOR_BIN`, the sibling binary next to `openshell-gateway`, or the configured supervisor image contains `/openshell-sandbox`. - Sandbox never registers: check gateway logs and supervisor callback endpoint. +- Supervisor image exits before printing `openshell-sandbox --version`: the image should be the scratch supervisor image from `deploy/docker/Dockerfile.supervisor` and must contain a static executable at `/openshell-sandbox`. For source checkout development, restart the local gateway with: @@ -126,7 +127,7 @@ kubectl -n openshell get statefulset openshell -o jsonpath="{.spec.template.spec helm -n openshell get values openshell | grep -E 'repository|tag|supervisorImage' ``` -The gateway image and `server.supervisorImage` should use the same build tag in branch and E2E deploys. A stale supervisor image can make sandbox behavior lag behind gateway policy or proto changes. +The gateway image built from `deploy/docker/Dockerfile.gateway` and the scratch supervisor image built from `deploy/docker/Dockerfile.supervisor` should use the same build tag in branch and E2E deploys. A stale supervisor image can make sandbox behavior lag behind gateway policy or proto changes. For local/external pull mode (the default local path via `mise run cluster`), local images are tagged to the configured local registry base, pushed to that registry, and pulled by k3s via the `registries.yaml` mirror endpoint. The `cluster` task pushes prebuilt local tags (`openshell/*:dev`, falling back to `localhost:5000/openshell/*:dev` or `127.0.0.1:5000/openshell/*:dev`). diff --git a/.github/workflows/e2e-gpu-test.yaml b/.github/workflows/e2e-gpu-test.yaml index 429a82524..0004bcbe2 100644 --- a/.github/workflows/e2e-gpu-test.yaml +++ b/.github/workflows/e2e-gpu-test.yaml @@ -49,6 +49,11 @@ jobs: OPENSHELL_REGISTRY_USERNAME: ${{ github.actor }} OPENSHELL_REGISTRY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} OPENSHELL_E2E_DOCKER_GPU: "1" + # NVIDIA-managed Ubuntu base used as the GPU probe target: it has the + # filesystem layout CDI injection expects (ldconfig, populated /usr/bin) + # which the distroless gateway runtime lacks. Consumed by the prereq + # probe below and by the e2e tests in e2e/rust/tests/gpu_device_selection.rs. + OPENSHELL_E2E_GPU_PROBE_IMAGE: "nvcr.io/nvidia/base/ubuntu:noble-20251013" steps: - uses: actions/checkout@v6 @@ -58,9 +63,7 @@ jobs: - name: Check Docker GPU prerequisites run: | docker info --format '{{json .CDISpecDirs}}' - GPU_PROBE_IMAGE="$(awk '$1 == "FROM" && $3 == "AS" && $4 == "gateway" { print $2; exit }' deploy/docker/Dockerfile.images)" - test -n "${GPU_PROBE_IMAGE}" - docker run --rm --device nvidia.com/gpu=all "${GPU_PROBE_IMAGE}" nvidia-smi -L + docker run --rm --device nvidia.com/gpu=all "${OPENSHELL_E2E_GPU_PROBE_IMAGE}" nvidia-smi -L - name: Run tests run: mise run --no-deps --skip-deps e2e:docker:gpu diff --git a/.github/workflows/rust-native-build.yml b/.github/workflows/rust-native-build.yml index 40feb1dd0..edb1bfb7a 100644 --- a/.github/workflows/rust-native-build.yml +++ b/.github/workflows/rust-native-build.yml @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -name: Rust Native Build (openshell-gateway / openshell-sandbox) +name: Rust Image Binary Build (openshell-gateway / openshell-sandbox) -# Build Rust binaries natively per Linux architecture before the Docker image -# build consumes them as prebuilt artifacts. +# Build Rust binaries per Linux architecture before the Docker image build +# consumes them as prebuilt artifacts. Gateway images use GNU-linked binaries +# for the NVIDIA distroless C/C++ runtime; supervisor images use musl/static +# binaries so the final image can remain scratch. on: workflow_call: @@ -105,10 +107,12 @@ jobs: gateway) crate=openshell-server binary=openshell-gateway + zig_target= ;; sandbox) crate=openshell-sandbox binary=openshell-sandbox + zig_target= ;; *) echo "unsupported component: $COMPONENT" >&2 @@ -118,10 +122,20 @@ jobs: case "$ARCH" in amd64) - target=x86_64-unknown-linux-gnu + if [[ "$COMPONENT" == "sandbox" ]]; then + target=x86_64-unknown-linux-musl + zig_target=x86_64-linux-musl + else + target=x86_64-unknown-linux-gnu + fi ;; arm64) - target=aarch64-unknown-linux-gnu + if [[ "$COMPONENT" == "sandbox" ]]; then + target=aarch64-unknown-linux-musl + zig_target=aarch64-linux-musl + else + target=aarch64-unknown-linux-gnu + fi ;; *) echo "unsupported arch: $ARCH" >&2 @@ -133,6 +147,7 @@ jobs: echo "crate=$crate" echo "binary=$binary" echo "target=$target" + echo "zig_target=$zig_target" } >> "$GITHUB_OUTPUT" - name: Configure GHA sccache backend @@ -163,6 +178,30 @@ jobs: set -euo pipefail sed -i -E '/^\[workspace\.package\]/,/^\[/{s/^version[[:space:]]*=[[:space:]]*".*"/version = "'"${{ steps.version.outputs.cargo_version }}"'"/}' Cargo.toml + - name: Set up zig musl wrappers + if: contains(steps.target.outputs.target, 'musl') + run: | + set -euo pipefail + ZIG="$(mise which zig)" + ZIG_TARGET="${{ steps.target.outputs.zig_target }}" + mkdir -p /tmp/zig-musl + + # cc-rs injects --target=, which zig does not parse. + # Strip caller-provided --target and use the wrapper's zig target. + for tool in cc c++; do + printf '#!/bin/bash\nargs=()\nfor arg in "$@"; do\n case "$arg" in\n --target=*) ;;\n *) args+=("$arg") ;;\n esac\ndone\nexec "%s" %s --target=%s "${args[@]}"\n' \ + "$ZIG" "$tool" "$ZIG_TARGET" > "/tmp/zig-musl/${tool}" + chmod +x "/tmp/zig-musl/${tool}" + done + + TARGET_ENV=$(echo "${{ steps.target.outputs.target }}" | tr '-' '_') + TARGET_ENV_UPPER=${TARGET_ENV^^} + + echo "CC_${TARGET_ENV}=/tmp/zig-musl/cc" >> "$GITHUB_ENV" + echo "CXX_${TARGET_ENV}=/tmp/zig-musl/c++" >> "$GITHUB_ENV" + echo "CARGO_TARGET_${TARGET_ENV_UPPER}_LINKER=/tmp/zig-musl/cc" >> "$GITHUB_ENV" + echo "CARGO_TARGET_${TARGET_ENV_UPPER}_RUSTFLAGS=-Clink-self-contained=no" >> "$GITHUB_ENV" + - name: Build ${{ steps.target.outputs.binary }} (${{ steps.target.outputs.target }}) env: # Preserve the release-codegen setting used by the old Dockerfile @@ -171,6 +210,7 @@ jobs: OPENSHELL_IMAGE_TAG: ${{ inputs['image-tag'] }} run: | set -euo pipefail + mise x -- rustup target add "${{ steps.target.outputs.target }}" args=( --release --target "${{ steps.target.outputs.target }}" @@ -192,8 +232,7 @@ jobs: OUTPUT="$("$BIN" --version)" echo "$OUTPUT" grep -q "^${{ steps.target.outputs.binary }} " <<<"$OUTPUT" - # Record glibc linkage so drift from the Ubuntu noble runtime base - # image is visible in logs. + # Record linkage so image runtime drift is visible in logs. ldd --version ldd "$BIN" || true diff --git a/architecture/build.md b/architecture/build.md index cfe13c4b1..62d21cdf4 100644 --- a/architecture/build.md +++ b/architecture/build.md @@ -12,7 +12,8 @@ OpenShell builds these main artifacts: |---|---| | Gateway binary | `crates/openshell-server` | | CLI package and Python SDK | `python/openshell` plus Rust binaries where packaged | -| Gateway and supervisor container images | `deploy/docker/Dockerfile.images` | +| Gateway container image | `deploy/docker/Dockerfile.gateway` | +| Supervisor container image | `deploy/docker/Dockerfile.supervisor` | | Helm chart | `deploy/helm/openshell` | | VM driver/runtime assets | `crates/openshell-driver-vm` | | Published docs site | `docs/` rendered by Fern config in `fern/` | @@ -21,10 +22,29 @@ Sandbox community images are built outside this repository. ## Container Builds -The Docker image pipeline stages prebuilt Rust binaries, then builds container -images from `deploy/docker/Dockerfile.images`. CI builds native artifacts on the -target architecture, stages them under `deploy/docker/.build/`, and then uses -Buildx to publish per-architecture images and multi-architecture tags. +The Docker image pipeline is a two-step flow: build the Rust binary natively +for the target architecture, then assemble the container image from the +prebuilt binary. The gateway image is built from `deploy/docker/Dockerfile.gateway` +and the supervisor image from `deploy/docker/Dockerfile.supervisor`. Neither +Dockerfile compiles Rust — both copy a staged binary out of +`deploy/docker/.build/prebuilt-binaries//` into the final image. + +Binary staging is driven by `tasks/scripts/stage-prebuilt-binaries.sh`, which +runs `cargo build` natively on a matching host or `cargo zigbuild` when +cross-compiling. CI invokes the same staging step via the +`rust-native-build.yml` workflow (per-architecture, per-component) and uploads +the result as an artifact that the image build job downloads back into the +staging directory before running Buildx. + +Runtime layout: + +- **Gateway**: `nvcr.io/nvidia/distroless/cc` base, GNU-linked binary at + `/usr/local/bin/openshell-gateway`, runs as UID/GID `65532:65532`. +- **Supervisor**: `scratch` base, static musl binary at `/openshell-sandbox`. + Static linkage is required because the image is mounted/extracted into + sandbox environments (Docker extraction, Podman image volumes, Kubernetes + init-container copy-self) and cannot rely on a dynamic loader. + Gateway image builds bake the corresponding supervisor image tag into the gateway binary so Docker sandboxes do not depend on `:latest` by default. Package formulas also pin Docker supervisor extraction to the matching release diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index 5b88010e4..6b60613ed 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -86,8 +86,8 @@ sequenceDiagram C->>C: entrypoint: /opt/openshell/bin/openshell-sandbox ``` -The `supervisor` target in `deploy/docker/Dockerfile.images` copies the -`openshell-sandbox` binary to `/openshell-sandbox` in the supervisor image. +The supervisor image from `deploy/docker/Dockerfile.supervisor` copies the static +`openshell-sandbox` binary to `/openshell-sandbox`. Mounting that image at `/opt/openshell/bin` makes the binary available as `/opt/openshell/bin/openshell-sandbox`. @@ -352,4 +352,4 @@ matter compared to cluster or rootful runtimes: netns, proxy, and relay behavior shared by all drivers. - Container engine abstraction: `tasks/scripts/container-engine.sh` for build/deploy support across Docker and Podman. -- Supervisor image build: `deploy/docker/Dockerfile.images`. +- Supervisor image build: `deploy/docker/Dockerfile.supervisor`. diff --git a/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs b/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs index f61464023..1044623f5 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs @@ -25,6 +25,16 @@ use tracing::debug; /// Value of `SECCOMP_SET_MODE_FILTER` (linux/seccomp.h). const SECCOMP_SET_MODE_FILTER: u64 = 1; +// libc 0.2.185 omits `SYS_kexec_file_load` from the musl/aarch64 bindings even +// though the kernel exposes syscall 294. Fall back to the literal so the +// supervisor's seccomp filter still blocks fileless kernel-image loads when +// built statically against musl on aarch64. +#[cfg(all(target_arch = "aarch64", target_env = "musl"))] +#[allow(non_upper_case_globals)] +const SYS_kexec_file_load: libc::c_long = 294; +#[cfg(not(all(target_arch = "aarch64", target_env = "musl")))] +use libc::SYS_kexec_file_load; + /// Apply the supervisor seccomp filter across the running process. /// /// This runs after privileged startup helpers complete and synchronizes the @@ -81,7 +91,7 @@ fn build_supervisor_prelude_rules() -> BTreeMap> { libc::SYS_finit_module, libc::SYS_delete_module, libc::SYS_kexec_load, - libc::SYS_kexec_file_load, + SYS_kexec_file_load, ] { rules.entry(syscall).or_default(); } @@ -423,7 +433,7 @@ mod tests { libc::SYS_finit_module, libc::SYS_delete_module, libc::SYS_kexec_load, - libc::SYS_kexec_file_load, + SYS_kexec_file_load, ] { assert!( filter_rules.contains_key(&syscall), diff --git a/deploy/docker/Dockerfile.ci b/deploy/docker/Dockerfile.ci index 3c669a96f..ee67f97b2 100644 --- a/deploy/docker/Dockerfile.ci +++ b/deploy/docker/Dockerfile.ci @@ -29,6 +29,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ libz3-dev \ pkg-config \ libssl-dev \ + musl-tools \ openssh-client \ python3 \ python3-venv \ diff --git a/deploy/docker/Dockerfile.gateway b/deploy/docker/Dockerfile.gateway new file mode 100644 index 000000000..30a45e8c1 --- /dev/null +++ b/deploy/docker/Dockerfile.gateway @@ -0,0 +1,41 @@ +# syntax=docker/dockerfile:1.4 + +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Gateway image build. +# +# The Rust binary is built natively before this image build runs and staged at: +# deploy/docker/.build/prebuilt-binaries//openshell-gateway +# +# Use tasks/scripts/docker-build-image.sh gateway (or `mise run build:docker:gateway`) +# to stage the binary and build the image in one step. CI builds the binary +# per-architecture via the `rust-native-build.yml` workflow and uploads it as +# an artifact, which is downloaded into the same staging directory before the +# image build job runs. +# +# The runtime is `nvcr.io/nvidia/distroless/cc:4.0.0`, which provides glibc and +# the dynamic loader needed by the GNU-linked gateway binary while keeping the +# attack surface small. + +ARG GATEWAY_BASE_IMAGE=nvcr.io/nvidia/distroless/cc:v4.0.4 + +FROM ${GATEWAY_BASE_IMAGE} AS gateway + +ARG TARGETARCH + +WORKDIR /app + +# --chmod=0550 preserves the executable bit through actions/upload-artifact + +# download-artifact (which strip exec perms during the roundtrip) without +# granting world-execute. --chown=nvs:nvs matches the image's only defined +# non-root user (`nvs:1000`, the NVIDIA distroless convention) and aligns +# with the Helm chart's `securityContext.runAsUser: 1000`, which overrides +# the Dockerfile's USER at runtime. +COPY --chown=nvs:nvs --chmod=0550 deploy/docker/.build/prebuilt-binaries/${TARGETARCH}/openshell-gateway /usr/local/bin/openshell-gateway + +USER nvs:nvs +EXPOSE 8080 + +ENTRYPOINT ["/usr/local/bin/openshell-gateway"] +CMD ["--bind-address", "0.0.0.0", "--port", "8080"] diff --git a/deploy/docker/Dockerfile.images b/deploy/docker/Dockerfile.images deleted file mode 100644 index 62662fa93..000000000 --- a/deploy/docker/Dockerfile.images +++ /dev/null @@ -1,122 +0,0 @@ -# syntax=docker/dockerfile:1.4 - -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -# Shared OpenShell image build graph. -# -# Targets: -# gateway Final gateway image -# supervisor Final supervisor image (Ubuntu base, supervisor binary) -# -# Rust binaries are built natively before the image build and staged at: -# deploy/docker/.build/prebuilt-binaries//openshell-{gateway,sandbox} -# -# For local dev (Skaffold), pass --build-arg BUILD_FROM_SOURCE=1 to compile -# binaries inside Docker instead. BuildKit only executes the selected binary -# staging stage, so missing prebuilt files do not cause a build failure. - -# Controls binary source: 0 = prebuilt (release), 1 = compile in Docker (local dev). -# Must be declared here (global scope) so it can be used in FROM instructions below. -ARG BUILD_FROM_SOURCE=0 - -# --------------------------------------------------------------------------- -# Optional in-Docker Rust build (BUILD_FROM_SOURCE=1, local dev only) -# --------------------------------------------------------------------------- -FROM rust:1.95.0-slim-bookworm AS rust-builder - -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - cmake \ - pkg-config \ - libssl-dev \ - ca-certificates \ - && rm -rf /var/lib/apt/lists/* - -WORKDIR /build - -COPY Cargo.toml Cargo.lock ./ -COPY crates/ crates/ -COPY proto/ proto/ -COPY providers/ providers/ - -RUN --mount=type=cache,target=/usr/local/cargo/registry \ - --mount=type=cache,target=/build/target \ - cargo build --release \ - --features "openshell-core/dev-settings" \ - --bin openshell-gateway \ - --bin openshell-sandbox \ - && mkdir -p /build/out \ - && install -m 0755 target/release/openshell-gateway /build/out/openshell-gateway \ - && install -m 0755 target/release/openshell-sandbox /build/out/openshell-sandbox - -# --------------------------------------------------------------------------- -# Per-arch binary stages -# --------------------------------------------------------------------------- - -# Prebuilt path (release default, BUILD_FROM_SOURCE=0) -FROM scratch AS gateway-binary-0 -ARG TARGETARCH -# --chmod=755 preserves the executable bit through actions/upload-artifact + -# download-artifact, which strip exec perms during the roundtrip. -COPY --chmod=755 deploy/docker/.build/prebuilt-binaries/${TARGETARCH}/openshell-gateway /build/out/openshell-gateway - -# Source-built path (local dev, BUILD_FROM_SOURCE=1) -FROM rust-builder AS gateway-binary-1 - -FROM gateway-binary-${BUILD_FROM_SOURCE} AS gateway-binary - -# Prebuilt path (release default, BUILD_FROM_SOURCE=0) -FROM scratch AS supervisor-binary-0 -ARG TARGETARCH -# --chmod=755 preserves the executable bit through actions/upload-artifact + -# download-artifact, which strip exec perms during the roundtrip. -COPY --chmod=755 deploy/docker/.build/prebuilt-binaries/${TARGETARCH}/openshell-sandbox /build/out/openshell-sandbox - -# Source-built path (local dev, BUILD_FROM_SOURCE=1) -FROM rust-builder AS supervisor-binary-1 - -FROM supervisor-binary-${BUILD_FROM_SOURCE} AS supervisor-binary - -# --------------------------------------------------------------------------- -# Final gateway image -# --------------------------------------------------------------------------- -FROM nvcr.io/nvidia/base/ubuntu:noble-20251013 AS gateway - -RUN apt-get update && apt-get install -y --no-install-recommends \ - ca-certificates && \ - apt-get install -y --only-upgrade gpgv && \ - rm -rf /var/lib/apt/lists/* - -RUN useradd --create-home --user-group openshell - -WORKDIR /app - -COPY --from=gateway-binary /build/out/openshell-gateway /usr/local/bin/ - -RUN mkdir -p /build/crates/openshell-server -COPY --chmod=755 crates/openshell-server/migrations /build/crates/openshell-server/migrations - -USER openshell -EXPOSE 8080 - -ENTRYPOINT ["openshell-gateway"] -CMD ["--bind-address", "0.0.0.0", "--port", "8080"] - -# --------------------------------------------------------------------------- -# Final supervisor image -# --------------------------------------------------------------------------- -# Supervisor image based on the same NVIDIA Ubuntu base used by the gateway. -# -# Used by: -# - Docker driver: binary is extracted from the image and run inside the -# agent container. -# - Podman driver: image is mounted as an OCI volume at /opt/openshell/bin. -# - Kubernetes driver: image runs as an init container that invokes the -# binary's `copy-self` subcommand to seed an emptyDir volume. -# -# An Ubuntu base provides glibc and the dynamic loader needed to exec the -# dynamically linked binary. `FROM scratch` would be smaller but cannot run -# the binary, breaking the Kubernetes init-container path. -FROM nvcr.io/nvidia/base/ubuntu:noble-20251013 AS supervisor -COPY --from=supervisor-binary /build/out/openshell-sandbox /openshell-sandbox diff --git a/deploy/docker/Dockerfile.supervisor b/deploy/docker/Dockerfile.supervisor new file mode 100644 index 000000000..c84cc70e9 --- /dev/null +++ b/deploy/docker/Dockerfile.supervisor @@ -0,0 +1,35 @@ +# syntax=docker/dockerfile:1.4 + +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Supervisor image build. +# +# The final image is `scratch`: it only carries the static `openshell-sandbox` +# binary used by Docker extraction, Podman image volumes, and the Kubernetes +# init container copy-self path. A static musl binary lets the image stay +# `scratch` while still being executable as an init container. +# +# The Rust binary is built natively before this image build runs and staged at: +# deploy/docker/.build/prebuilt-binaries//openshell-sandbox +# +# Use tasks/scripts/docker-build-image.sh supervisor (or `mise run build:docker:supervisor`) +# to stage the binary and build the image in one step. CI builds the binary +# per-architecture via the `rust-native-build.yml` workflow (with the musl +# target) and uploads it as an artifact, which is downloaded into the same +# staging directory before the image build job runs. + +FROM scratch AS supervisor + +ARG TARGETARCH + +# --chmod=0550 drops world-execute and survives the actions/upload-artifact +# + download-artifact roundtrip (which strips exec perms). Ownership is left +# at root (0:0) deliberately: the Podman driver mounts this image as a +# read-only image volume into the sandbox container and drops DAC_OVERRIDE, +# so the container's UID 0 must own the binary to read+exec it. Mode 0550 +# (r-xr-x---) is the security win; the chown to a non-root UID was breaking +# Podman without buying anything since the container is always UID 0. +COPY --chmod=0550 deploy/docker/.build/prebuilt-binaries/${TARGETARCH}/openshell-sandbox /openshell-sandbox + +ENTRYPOINT ["/openshell-sandbox"] diff --git a/deploy/helm/openshell/skaffold.yaml b/deploy/helm/openshell/skaffold.yaml index 2de9ee4e6..779211877 100644 --- a/deploy/helm/openshell/skaffold.yaml +++ b/deploy/helm/openshell/skaffold.yaml @@ -1,12 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# Local dev: builds gateway + supervisor images using Dockerfile.images with -# BUILD_FROM_SOURCE=1, which compiles Rust binaries inside Docker without -# requiring pre-staged artifacts. +# Local dev: builds gateway + supervisor images via tasks/scripts/docker-build-image.sh, +# which first stages Rust binaries natively on the host (using cargo / cargo-zigbuild +# when cross-compiling) and then builds the image from the prebuilt binary. This +# mirrors CI and is faster than compiling inside Docker on every rebuild because +# the host's cargo target cache and sccache are reused across iterations. # # Run from repo root: -# skaffold dev -f deploy/helm/openshell/skaffold.yaml +# mise run helm:skaffold:dev +# mise run helm:skaffold:run # # See https://skaffold.dev/docs/deployers/helm/ (setValueTemplates, IMAGE_* fields). apiVersion: skaffold/v4beta14 @@ -23,39 +26,34 @@ build: context: ../../.. custom: buildCommand: | - docker buildx build \ - --build-arg BUILD_FROM_SOURCE=1 \ - --target gateway \ - --tag "$IMAGE" \ - --load \ - --file deploy/docker/Dockerfile.images \ - . + IMAGE_NAME="${IMAGE%:*}" \ + IMAGE_TAG="${IMAGE##*:}" \ + tasks/scripts/docker-build-image.sh gateway dependencies: paths: - Cargo.toml - Cargo.lock - crates/** - proto/** - - deploy/docker/Dockerfile.images - - crates/openshell-server/migrations/** + - deploy/docker/Dockerfile.gateway + - tasks/scripts/docker-build-image.sh + - tasks/scripts/stage-prebuilt-binaries.sh - image: openshell/supervisor context: ../../.. custom: buildCommand: | - docker buildx build \ - --build-arg BUILD_FROM_SOURCE=1 \ - --target supervisor \ - --tag "$IMAGE" \ - --load \ - --file deploy/docker/Dockerfile.images \ - . + IMAGE_NAME="${IMAGE%:*}" \ + IMAGE_TAG="${IMAGE##*:}" \ + tasks/scripts/docker-build-image.sh supervisor dependencies: paths: - Cargo.toml - Cargo.lock - crates/** - proto/** - - deploy/docker/Dockerfile.images + - deploy/docker/Dockerfile.supervisor + - tasks/scripts/docker-build-image.sh + - tasks/scripts/stage-prebuilt-binaries.sh deploy: helm: releases: diff --git a/deploy/helm/openshell/templates/certgen.yaml b/deploy/helm/openshell/templates/certgen.yaml index d8136d581..ef4500db6 100644 --- a/deploy/helm/openshell/templates/certgen.yaml +++ b/deploy/helm/openshell/templates/certgen.yaml @@ -95,7 +95,7 @@ spec: valueFrom: fieldRef: fieldPath: metadata.namespace - command: ["openshell-gateway"] + command: ["/usr/local/bin/openshell-gateway"] args: - generate-certs - --server-secret-name={{ .Values.server.tls.certSecretName }} diff --git a/e2e/rust/tests/gpu_device_selection.rs b/e2e/rust/tests/gpu_device_selection.rs index 930ae73e1..73db9a6d8 100644 --- a/e2e/rust/tests/gpu_device_selection.rs +++ b/e2e/rust/tests/gpu_device_selection.rs @@ -7,7 +7,6 @@ //! //! Requires a GPU-backed gateway and a sandbox image containing `nvidia-smi`. -use std::path::{Path, PathBuf}; use std::process::Stdio; use std::time::Duration; @@ -19,9 +18,9 @@ use serde_json::{Map, Value}; use tokio::time::timeout; const SANDBOX_CREATE_TIMEOUT: Duration = Duration::from_secs(600); -const GPU_PROBE_DOCKERFILE_STAGE: &str = "gateway"; const CDI_GPU_DEVICE_ALL: &str = "nvidia.com/gpu=all"; const CDI_GPU_DEVICE_PREFIX: &str = "nvidia.com/gpu="; +const GPU_PROBE_IMAGE_ENV: &str = "OPENSHELL_E2E_GPU_PROBE_IMAGE"; fn gpu_lines(output: &str) -> Vec { strip_ansi(output) @@ -32,53 +31,18 @@ fn gpu_lines(output: &str) -> Vec { .collect() } -fn workspace_root() -> PathBuf { - Path::new(env!("CARGO_MANIFEST_DIR")) - .ancestors() - .nth(2) - .expect("failed to resolve workspace root from CARGO_MANIFEST_DIR") - .to_path_buf() -} - -fn dockerfile_images_gpu_probe_image() -> String { - let dockerfile = workspace_root().join("deploy/docker/Dockerfile.images"); - let contents = std::fs::read_to_string(&dockerfile) - .unwrap_or_else(|err| panic!("failed to read {}: {err}", dockerfile.display())); - - contents - .lines() - .map(str::trim) - .find_map(|line| { - let mut parts = line.split_whitespace(); - let instruction = parts.next()?; - let image = parts.next()?; - let as_keyword = parts.next()?; - let stage = parts.next()?; - - if instruction.eq_ignore_ascii_case("FROM") - && as_keyword.eq_ignore_ascii_case("AS") - && stage == GPU_PROBE_DOCKERFILE_STAGE - { - Some(image) - } else { - None - } - }) - .unwrap_or_else(|| { - panic!( - "failed to find a FROM AS {GPU_PROBE_DOCKERFILE_STAGE} stage in {}", - dockerfile.display() - ) - }) - .to_string() -} - fn gpu_probe_image() -> String { - std::env::var("OPENSHELL_E2E_GPU_PROBE_IMAGE") + std::env::var(GPU_PROBE_IMAGE_ENV) .ok() .map(|value| value.trim().to_string()) .filter(|value| !value.is_empty()) - .unwrap_or_else(dockerfile_images_gpu_probe_image) + .unwrap_or_else(|| { + panic!( + "{GPU_PROBE_IMAGE_ENV} must be set to a container image that supports \ + NVIDIA Container Toolkit CDI injection (set by \ + .github/workflows/e2e-gpu-test.yaml in CI)" + ) + }) } fn object_string<'a>(object: &'a Map, key: &str) -> Option<&'a str> { diff --git a/tasks/scripts/docker-build-image.sh b/tasks/scripts/docker-build-image.sh index 2fb86bc5e..b733ec4a1 100755 --- a/tasks/scripts/docker-build-image.sh +++ b/tasks/scripts/docker-build-image.sh @@ -93,31 +93,29 @@ ensure_prebuilt_binaries() { TARGET=${1:?"Usage: docker-build-image.sh [extra-args...]"} shift -DOCKERFILE="deploy/docker/Dockerfile.images" -if [[ ! -f "${DOCKERFILE}" ]]; then - echo "Error: Dockerfile not found: ${DOCKERFILE}" >&2 - exit 1 -fi - IS_FINAL_IMAGE=0 IMAGE_NAME="" DOCKER_TARGET="" +DOCKERFILE="" case "${TARGET}" in gateway) IS_FINAL_IMAGE=1 IMAGE_NAME="openshell/gateway" DOCKER_TARGET="gateway" + DOCKERFILE="deploy/docker/Dockerfile.gateway" ;; supervisor) IS_FINAL_IMAGE=1 IMAGE_NAME="openshell/supervisor" DOCKER_TARGET="supervisor" + DOCKERFILE="deploy/docker/Dockerfile.supervisor" ;; supervisor-output) # Backward-compat alias: same as "supervisor". IS_FINAL_IMAGE=1 IMAGE_NAME="openshell/supervisor" DOCKER_TARGET="supervisor" + DOCKERFILE="deploy/docker/Dockerfile.supervisor" ;; *) echo "Error: unsupported target '${TARGET}'" >&2 @@ -125,6 +123,11 @@ case "${TARGET}" in ;; esac +if [[ ! -f "${DOCKERFILE}" ]]; then + echo "Error: Dockerfile not found: ${DOCKERFILE}" >&2 + exit 1 +fi + if [[ -n "${IMAGE_REGISTRY:-}" && "${IS_FINAL_IMAGE}" == "1" ]]; then IMAGE_NAME="${IMAGE_REGISTRY}/${IMAGE_NAME#openshell/}" fi diff --git a/tasks/scripts/stage-prebuilt-binaries.sh b/tasks/scripts/stage-prebuilt-binaries.sh index 21d97c472..d05c0cc75 100755 --- a/tasks/scripts/stage-prebuilt-binaries.sh +++ b/tasks/scripts/stage-prebuilt-binaries.sh @@ -21,9 +21,22 @@ normalize_arch() { } target_triple() { + local libc=${2:-gnu} case "$1" in - amd64) echo "x86_64-unknown-linux-gnu" ;; - arm64) echo "aarch64-unknown-linux-gnu" ;; + amd64) + if [[ "$libc" == "musl" ]]; then + echo "x86_64-unknown-linux-musl" + else + echo "x86_64-unknown-linux-gnu" + fi + ;; + arm64) + if [[ "$libc" == "musl" ]]; then + echo "aarch64-unknown-linux-musl" + else + echo "aarch64-unknown-linux-gnu" + fi + ;; *) echo "unsupported architecture: $1" >&2 exit 1 @@ -71,10 +84,10 @@ components_for_target() { echo "gateway" ;; sandbox|supervisor|supervisor-output) - echo "sandbox" + echo "supervisor" ;; all) - echo "gateway sandbox" + echo "gateway supervisor" ;; *) usage @@ -88,10 +101,12 @@ resolve_component() { gateway) crate=openshell-server binary=openshell-gateway + target_libc=gnu ;; - sandbox) + supervisor) crate=openshell-sandbox binary=openshell-sandbox + target_libc=musl ;; *) echo "unsupported binary component: $1" >&2 @@ -130,7 +145,7 @@ build_component_for_arch() { local current_host_arch resolve_component "$component" - target="$(target_triple "$arch")" + target="$(target_triple "$arch" "$target_libc")" stage="${ROOT}/deploy/docker/.build/prebuilt-binaries/${arch}" features="${EXTRA_CARGO_FEATURES:-openshell-core/dev-settings}" current_host_os="$(host_os)" From 8322e4fd001cb9c792ed40be4d4a4d50ca39c6db Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Tue, 12 May 2026 16:36:15 -0700 Subject: [PATCH 042/157] docs: style fixes (#1341) * docs: style fixes * docs: drop observability section overview page and rename a section title * docs: title updates --- docs/get-started/quickstart.mdx | 2 +- .../tutorials/first-network-policy.mdx | 4 ++-- docs/get-started/tutorials/github-sandbox.mdx | 6 ++--- docs/get-started/tutorials/index.mdx | 2 +- .../tutorials/inference-ollama.mdx | 18 +++++++------- .../tutorials/local-inference-lmstudio.mdx | 8 +++---- docs/index.mdx | 4 ++-- docs/index.yml | 2 +- docs/kubernetes/access-control.mdx | 6 ++--- docs/kubernetes/ingress.mdx | 2 +- docs/kubernetes/managing-certificates.mdx | 4 ++-- docs/kubernetes/openshift.mdx | 6 ++--- docs/kubernetes/setup.mdx | 24 +++++++++---------- docs/observability/accessing-logs.mdx | 2 +- docs/observability/logging.mdx | 4 ++-- docs/observability/ocsf-json-export.mdx | 14 ++++++----- docs/observability/overview.mdx | 17 ------------- docs/reference/default-policy.mdx | 2 +- docs/reference/gateway-auth.mdx | 2 +- docs/reference/policy-schema.mdx | 6 ++--- docs/reference/sandbox-compute-drivers.mdx | 2 +- docs/sandboxes/inference-routing.mdx | 2 +- docs/sandboxes/manage-gateways.mdx | 2 +- docs/sandboxes/manage-providers.mdx | 2 +- docs/sandboxes/policies.mdx | 4 ++-- docs/security/best-practices.mdx | 12 +++++----- fern/fern.config.json | 2 +- 27 files changed, 73 insertions(+), 88 deletions(-) delete mode 100644 docs/observability/overview.mdx diff --git a/docs/get-started/quickstart.mdx b/docs/get-started/quickstart.mdx index fcfbc945b..9c40eb024 100644 --- a/docs/get-started/quickstart.mdx +++ b/docs/get-started/quickstart.mdx @@ -36,7 +36,7 @@ If you prefer [uv](https://docs.astral.sh/uv/): uv tool install -U openshell ``` -After installing the CLI, run `openshell --help` in your terminal to see the full CLI reference. +After installing the CLI, run `openshell --help` in your terminal to view the full CLI reference. You can also clone the [NVIDIA OpenShell GitHub repository](https://github.com/NVIDIA/OpenShell) and use the `/openshell-cli` skill to load the CLI reference into your agent. diff --git a/docs/get-started/tutorials/first-network-policy.mdx b/docs/get-started/tutorials/first-network-policy.mdx index effe7c445..3e4593308 100644 --- a/docs/get-started/tutorials/first-network-policy.mdx +++ b/docs/get-started/tutorials/first-network-policy.mdx @@ -4,7 +4,7 @@ title: "Write Your First Sandbox Network Policy" sidebar-title: "First Network Policy" slug: "get-started/tutorials/first-network-policy" -description: "See how OpenShell network policies work by creating a sandbox, observing default-deny in action, and applying a fine-grained L7 read-only rule." +description: "Learn how OpenShell network policies work by creating a sandbox, observing default-deny in action, and applying a fine-grained L7 read-only rule." keywords: "Generative AI, Cybersecurity, Tutorial, Policy, Network Policy, Sandbox, Security" --- @@ -117,7 +117,7 @@ network_policies: - { path: /usr/bin/curl } ``` -The `filesystem_policy`, `landlock`, and `process` sections preserve the default sandbox settings. This is required because `policy set` replaces the entire policy. The `network_policies` section is the key part: `curl` may make GET, HEAD, and OPTIONS requests to `api.github.com` over HTTPS. Everything else is denied. The proxy auto-detects TLS on HTTPS endpoints and terminates it to inspect each HTTP request and enforce the `read-only` access preset at the method level. +The `filesystem_policy`, `landlock`, and `process` sections preserve the default sandbox settings. This is required because `policy set` replaces the entire policy. The `network_policies` section is the key part: `curl` can make GET, HEAD, and OPTIONS requests to `api.github.com` over HTTPS. Everything else is denied. The proxy auto-detects TLS on HTTPS endpoints and terminates it to inspect each HTTP request and enforce the `read-only` access preset at the method level. Apply it: diff --git a/docs/get-started/tutorials/github-sandbox.mdx b/docs/get-started/tutorials/github-sandbox.mdx index 1f225fb60..7c76d4e41 100644 --- a/docs/get-started/tutorials/github-sandbox.mdx +++ b/docs/get-started/tutorials/github-sandbox.mdx @@ -11,7 +11,7 @@ keywords: "Generative AI, Cybersecurity, Tutorial, GitHub, Sandbox, Policy, Clau This tutorial walks through an iterative sandbox policy workflow. You launch a sandbox, ask Claude Code to push code to GitHub, and observe the default network policy denying the request. You then diagnose the denial from your machine and from inside the sandbox, apply a policy update, and verify that the policy update to the sandbox takes effect. -After completing this tutorial, you will have: +After completing this tutorial, you have: - A running sandbox with Claude Code that can push to a GitHub repository. - A custom network policy that grants GitHub access for a specific repository. @@ -131,7 +131,7 @@ The sandbox runs a proxy that enforces policies on outbound traffic. The `github_rest_api` policy allows GET requests (used to read the file) but blocks PUT/write requests to GitHub. This is a sandbox-level restriction, not a token issue. No matter what token you provide, pushes through the API -will be blocked until the policy is updated. +are blocked until you update the policy. Both perspectives confirm the same thing: the proxy is doing its job. The default policy is designed to be restrictive. To allow GitHub pushes, you need to update the network policy. @@ -162,7 +162,7 @@ Refer to the following policy example to compare with the generated policy befor The following YAML shows a complete policy that extends the [default policy](/reference/default-policy) with GitHub access for a single repository. Replace `` with your GitHub organization or username and `` with your repository name. -The `filesystem_policy`, `landlock`, and `process` sections are static. They are read once at sandbox creation and cannot be changed by a hot-reload. They are included here for completeness so the file is self-contained, but only the `network_policies` section takes effect when you apply this to a running sandbox. +The `filesystem_policy`, `landlock`, and `process` sections are static. OpenShell reads them at sandbox creation, and a hot reload cannot change them. They are included here for completeness so the file is self-contained, but only the `network_policies` section takes effect when you apply this to a running sandbox. ```yaml version: 1 diff --git a/docs/get-started/tutorials/index.mdx b/docs/get-started/tutorials/index.mdx index f5a79b543..c03e924f7 100644 --- a/docs/get-started/tutorials/index.mdx +++ b/docs/get-started/tutorials/index.mdx @@ -29,6 +29,6 @@ Route inference through Ollama using cloud-hosted or local models, and verify it -Route inference to a local LM Studio server via the OpenAI or Anthropic compatible APIs. +Route inference to a local LM Studio server using the OpenAI-compatible or Anthropic-compatible APIs. diff --git a/docs/get-started/tutorials/inference-ollama.mdx b/docs/get-started/tutorials/inference-ollama.mdx index 4a16eee01..4f46b847e 100644 --- a/docs/get-started/tutorials/inference-ollama.mdx +++ b/docs/get-started/tutorials/inference-ollama.mdx @@ -8,12 +8,12 @@ description: "Run local and cloud models inside an OpenShell sandbox using the O keywords: "Generative AI, Cybersecurity, Tutorial, Inference Routing, Ollama, Local Inference, Sandbox" --- -This tutorial covers two ways to use Ollama with OpenShell: +This tutorial covers two ways of running Ollama with OpenShell: -1. **Ollama sandbox (recommended)** — a self-contained sandbox with Ollama, Claude Code, and Codex pre-installed. One command to start. -2. **Host-level Ollama** — run Ollama on the gateway host and route sandbox inference to it. Useful when you want a single Ollama instance shared across multiple sandboxes. +1. Ollama sandbox. This is the recommended way to run Ollama. A self-contained sandbox with Ollama, Claude Code, and Codex pre-installed. One command starts it. +2. Host-level Ollama. This is an alternative way to run Ollama. Run Ollama on the gateway host and route sandbox inference to it. Use this option when you want a single Ollama instance shared across multiple sandboxes. -After completing this tutorial, you will know how to: +After completing this tutorial, you know how to: - Launch the Ollama community sandbox for a batteries-included experience. - Use `ollama launch` to start coding agents inside a sandbox. @@ -190,11 +190,11 @@ The response should be JSON from the model. Common issues and fixes: -- **Ollama not reachable from sandbox** — Ollama must be bound to `0.0.0.0`, not `127.0.0.1`. This applies to host-level Ollama only; the community sandbox handles this automatically. -- **`OPENAI_BASE_URL` wrong** — Use `http://host.openshell.internal:11434/v1`, not `localhost` or `127.0.0.1`. -- **Model not found** — Run `ollama ps` to confirm the model is loaded. Run `ollama pull ` if needed. -- **HTTPS vs HTTP** — Code inside sandboxes must call `https://inference.local`, not `http://`. -- **AMD GPU driver issues** — Ollama v0.18+ requires ROCm 7 drivers for AMD GPUs. Update your drivers if you see GPU detection failures. +- **Ollama not reachable from sandbox:** Ollama must be bound to `0.0.0.0`, not `127.0.0.1`. This applies to host-level Ollama only; the community sandbox handles this automatically. +- **`OPENAI_BASE_URL` wrong:** Use `http://host.openshell.internal:11434/v1`, not `localhost` or `127.0.0.1`. +- **Model not found:** Run `ollama ps` to confirm the model is loaded. Run `ollama pull ` if needed. +- **HTTPS instead of HTTP:** Code inside sandboxes must call `https://inference.local`, not `http://`. +- **AMD GPU driver issues:** Ollama v0.18+ requires ROCm 7 drivers for AMD GPUs. Update your drivers if you see GPU detection failures. Useful commands: diff --git a/docs/get-started/tutorials/local-inference-lmstudio.mdx b/docs/get-started/tutorials/local-inference-lmstudio.mdx index 2d1246459..7ce604009 100644 --- a/docs/get-started/tutorials/local-inference-lmstudio.mdx +++ b/docs/get-started/tutorials/local-inference-lmstudio.mdx @@ -15,7 +15,7 @@ The LM Studio server provides easy setup with both OpenAI and Anthropic compatib -This tutorial will cover: +This tutorial covers: - Expose a local inference server to OpenShell sandboxes. - Verify end-to-end inference from inside a sandbox. @@ -54,11 +54,11 @@ lms daemon up Start the LM Studio local server from the Developer tab, and verify the OpenAI-compatible endpoint is enabled. -LM Studio will listen to `127.0.0.1:1234` by default. For use with OpenShell, you'll need to configure LM Studio to listen on all interfaces (`0.0.0.0`). +LM Studio listens to `127.0.0.1:1234` by default. For use with OpenShell, configure LM Studio to listen on all interfaces (`0.0.0.0`). -If you're using the GUI, go to the Developer Tab, select Server Settings, then enable Serve on Local Network. +If you use the GUI, go to the Developer Tab, select Server Settings, then enable Serve on Local Network. -If you're using llmster in headless mode, run `lms server start --bind 0.0.0.0`. +If you use llmster in headless mode, run `lms server start --bind 0.0.0.0`. ## Test with a small model diff --git a/docs/index.mdx b/docs/index.mdx index 4d358a827..91aba803d 100644 --- a/docs/index.mdx +++ b/docs/index.mdx @@ -38,7 +38,7 @@ uncontrolled network activity. Install OpenShell and create your first sandbox in two commands. -{/*Terminal demo styles live in fern/main.css — inline