From 887aaa80e6e89c4e0326574bad0b3474ea97570f Mon Sep 17 00:00:00 2001 From: Max Dubrinsky Date: Thu, 28 May 2026 10:59:16 -0400 Subject: [PATCH 1/6] docs(rfc): add RFC 0005 for shared SDK core and TypeScript binding Captures the design behind extracting the shared client core out of openshell-cli into a standalone openshell-sdk crate, plus the napi-rs TypeScript binding (openshell-sdk-node, published as @openshell/sdk). Covers motivation (CLI/TUI/embedders sharing one transport, OIDC, and edge-tunnel implementation), surface area, error model, and the path for future language bindings. --- .../README.md | 179 ++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 rfc/0005-shared-sdk-core-and-ts-binding/README.md diff --git a/rfc/0005-shared-sdk-core-and-ts-binding/README.md b/rfc/0005-shared-sdk-core-and-ts-binding/README.md new file mode 100644 index 000000000..09d9b82ff --- /dev/null +++ b/rfc/0005-shared-sdk-core-and-ts-binding/README.md @@ -0,0 +1,179 @@ +--- +authors: + - "@mdubrinsky" +state: review +links: + - https://linear.app/nvidia/issue/OSGH-110/python-and-typescript-sdk-support +--- + +# RFC 0005 - Shared Rust SDK core and TypeScript binding + +## Summary + +Extract a new `openshell-sdk` Rust crate from gRPC client plumbing that today lives in `openshell-cli`, and ship a TypeScript SDK (`@openshell/sdk`) as a [napi-rs](https://napi.rs) wrapper over that crate. Refactor `openshell-cli` to consume `openshell-sdk` so the CLI and the TS SDK share a single transport, auth, and error implementation. The pure-Python SDK at `python/openshell/` stays as-is for this RFC; migrating it onto the shared core is deferred to a follow-up. + +## Motivation + +OpenShell currently has one programmable surface (Python) and a CLI. The Python SDK is a hand-written gRPC client in `python/openshell/sandbox.py` that duplicates concerns already implemented in Rust (`openshell-cli/src/tls.rs` and `openshell-cli/src/oidc_auth.rs`): + +- TLS material loading, mTLS channel setup +- Edge-auth bearer token attachment +- OIDC token refresh +- Plaintext vs TLS transport selection + +Adding a TypeScript SDK by hand-writing a third gRPC client would extend the duplication. Three reasons to share a Rust core instead: + +1. **Multi-language support without re-implementing the transport per language.** We expect TS/Node users (TS-authored agents, web tooling). A shared transport layer keeps retry, auth refresh, and streaming consistent across bindings. +2. **The Rust transport already runs in production.** The CLI exercises every auth mode today. +3. **Establishes the pattern for other-language bindings.** If this works for TS, the same crate can later back a PyO3 binding and replace the pure-Python SDK. + +## What exists today + +- **Python SDK.** `python/openshell/sandbox.py`, hand-written gRPC against the existing protos. +- **CLI transport stack.** Full set of transport/auth modes implemented in `openshell-cli/src/tls.rs`, `openshell-cli/src/oidc_auth.rs`, and `openshell-cli/src/edge_tunnel.rs`. Runs in production today. + +## Non-goals + +- **Replacing the pure-Python SDK.** That migration is a separate, larger decision (API parity, deprecation window, packaging). This RFC keeps Python on its current pure-Python stack and only ensures the shared core is shaped so a future PyO3 wrapper is feasible. +- **gRPC contract changes.** The SDK is a client of the existing `proto/openshell.proto`, `proto/sandbox.proto`, `proto/inference.proto`. No service or message changes. +- **Browser / WebAssembly support.** napi-rs targets Node only. A browser SDK is a separate future RFC. +- **Bundling the `openshell` CLI binary inside the npm package.** Unlike the Python wheel (which uses maturin's `bindings = "bin"` to bundle the CLI), the TS SDK is gRPC-only. CLI installation stays a separate concern. +- **Streaming `exec` in the initial slice.** Tracked separately. + +## Proposal + +### New and changed crates + +``` +crates/ + openshell-sdk/ NEW. Pure Rust async client library. No FFI, no CLI deps. + openshell-sdk-node/ NEW. napi-rs wrapper over openshell-sdk. Ships as @openshell/sdk. + openshell-cli/ REFACTORED. Channel/auth code moves out; CLI consumes openshell-sdk. + openshell-core/ UNCHANGED. Still owns proto codegen; openshell-sdk depends on it. +``` + +### `openshell-sdk` surface + +```rust +pub struct ClientConfig { + pub gateway: String, // "https://..." or "http://..." + pub tls: Option, // required for mTLS, ignored for plaintext + pub auth: Option, // bearer token or OIDC refresh closure + pub timeout: Option, // default: None (no client-side timeout) +} + +pub enum AuthConfig { + Bearer(String), + Oidc { token: String, refresh: Arc }, +} + +pub struct OpenShellClient { /* tonic Channel + interceptor */ } + +impl OpenShellClient { + pub async fn connect(config: ClientConfig) -> Result; + + pub async fn health(&self) -> Result; + pub async fn create_sandbox(&self, spec: SandboxSpec) -> Result; + pub async fn get_sandbox(&self, name: &str) -> Result; + pub async fn list_sandboxes(&self, opts: ListOptions) -> Result, SdkError>; + pub async fn delete_sandbox(&self, name: &str) -> Result; + pub async fn wait_ready(&self, name: &str, timeout: Duration) -> Result; + pub async fn wait_deleted(&self, name: &str, timeout: Duration) -> Result<(), SdkError>; + pub async fn exec(&self, name: &str, cmd: &[String], opts: ExecOptions) -> Result; +} +``` + +### `openshell-sdk-node` surface + +A thin napi-rs wrapper exposing the same surface as JS classes / objects. Idiomatic camelCase (`createSandbox`, `waitReady`) is generated automatically from snake_case Rust by napi-derive. + +### CLI refactor + +Transport mechanics move out of `openshell-cli` and into `openshell-sdk`: gRPC channel construction, TLS material handling, request interceptors, and the Cloudflare Access tunnel. The CLI keeps everything user-facing — gateway-name resolution, default-path lookups, and the OIDC browser flow. The SDK never sees a browser; it consumes a `Refresh` trait that the CLI implements. + +### Transport and auth modes + +MVP must support the same five transport/auth modes the CLI exercises today, so a CLI user can move to the SDK without losing connectivity options: + +- Plaintext (local development) +- mTLS (self-deployed gateways with client certs) +- OIDC bearer over HTTPS (gateways behind an OAuth2/OIDC IdP) +- Cloudflare Access tunnel (hosted gateways) +- Insecure TLS (development/debug; certificate verification disabled) + +### Current leanings + +| Decision | Choice | Rationale | +| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| **Binding tool for TS** | napi-rs v3 | Required for `ThreadsafeFunction>` + `call_async`. Considered: UniFFI (no stable TS target), Diplomat (smaller community, JS support nascent), wit-bindgen (Node packaging not yet ergonomic). | +| **Tokio runtime ownership at FFI boundary** | napi's ambient tokio runtime is available only inside `async fn` entry points. Every user-facing napi function that needs the runtime must be `async`. No `Handle` plumbing in `ClientConfig`. | Sync `#[napi] fn` runs on the JS thread with no reactor; `tokio::spawn` from sync napi context panics with "no reactor running". | +| **API shape** | Async-only, no blocking facade | Tonic is async-native; a blocking facade would require `block_on` plumbing and confuse the JS Promise contract. Callers needing sync can wrap with `tokio::runtime::Runtime::block_on` themselves. | +| **Error type** | `thiserror` enum in `openshell-sdk`, mapped to napi `Error` with a `code` field for TS discriminated-union ergonomics | Better than the Python SDK's single `SandboxError(RuntimeError)`. Lets TS consumers `switch` on error kind. | +| **Retry policy** | Per-call configurable; default = no retry | Matches the Python SDK. Advanced users opt in. | +| **OIDC refresh trait** | SDK accepts a `Refresh` trait with a domain error type, not `napi::Error`. CLI provides the browser-flow impl. Node binding wraps a JS callback as a `ThreadsafeFunction<(), Promise>`. | Keeps `openshell-sdk` napi-free. | +| **Single-flight refresh coalescing** | Lives in `openshell-sdk` core, not in the binding. | napi does not provide it; standard OIDC pattern needs it (one refresh in flight, all waiters share the result). | +| **OIDC refresh cancellation** | Rust-side future drop does not propagate to JS. In-flight JS refresher promise runs to completion; SDK ignores late-arriving values. | Trait does not need cooperative cancellation. | +| **Streaming pattern** | Iterator-style: a napi class with an async `next()` method, drop-based cancellation, and a thin TS shim layering `for await` on top. Not native AsyncGenerator. | napi-rs v3.8 has no native AsyncGenerator return type. | +| **Auth token file loading** | NOT in `openshell-sdk` directly. Callers pass an explicit token. A separate convenience helper in `openshell-cli` (or a thin helper crate) handles the `~/.config/openshell/gateways//edge_token` lookup. | Keeps `openshell-sdk` free of filesystem access. Usable as a library without CLI assumptions. | +| **SDK layering scope (MVP)** | Sandbox-focused. High-level methods cover health, sandbox CRUD, waits, and non-streaming exec. A `raw` module re-exports generated tonic clients as an escape hatch. Inference, providers, policy, logs, settings, SSH, forwarding, and completions are out of MVP and deferred. | The `raw` escape hatch lets callers reach RPCs the high-level surface doesn't yet cover. | +| **TypeScript API model** | Curated SDK types, not raw proto shapes. Enum-valued fields use string literals (e.g., `"Pending"`), not numeric proto enums. Captured high-level types: `SandboxSpec`, `SandboxRef`, `Health`, `ListOptions`, `ExecOptions`. | TS DX is better with discriminated string unions than with numeric proto enums. | + +## Implementation plan + +Phases ordered by dependency. No time estimates. This RFC establishes direction; detailed contracts (the `Refresh` trait shape, error codes, exec semantics) settle at implementation time. + +### Phase 1 — Refactor and extract `openshell-sdk` + +- Create `crates/openshell-sdk/` with transport, auth, error, and edge-tunnel modules. +- Execute the CLI refactor described above. +- Exit criteria: all existing `mise run test` and `mise run e2e` paths pass. No new SDK consumers yet. + +### Phase 2 — High-level SDK methods + +- Implement `health`, `create_sandbox`, `get_sandbox`, `list_sandboxes`, `delete_sandbox`, `wait_ready`, `wait_deleted`, non-streaming `exec`. +- Unit tests with a mock tonic server. +- Settle the `Refresh` trait contract: single-flight semantics, proactive vs reactive trigger, deadline, retry-after-refresh-failure, terminal-failure signalling. + +### Phase 3 — `openshell-sdk-node` napi binding + +- Build the JS-facing client surface over `openshell-sdk`. +- Wire the OIDC refresh callback path between Rust and JS. +- Map SDK errors to JS errors with a discriminable `code` field. +- Resolve the tunnel-vs-refresh interaction with one targeted test (does the CF tunnel re-handshake on bearer rotation, swap headers in place, or tear down and rebuild?). +- Smoke test against a plaintext local gateway. + +## Migration and compatibility + +- **CLI surface preserved.** The phase 1 refactor does not change `openshell-cli` flags, behavior, or output. Existing scripts continue to work. +- **gRPC contract unchanged** (see Non-goals). +- **Python SDK frozen.** The pure-Python SDK is unaffected by this RFC. +- **Alpha contract.** `@openshell/sdk` ships under `0.0.0-alpha.x` until the surface stabilizes; no semver guarantee before 1.0. + +## Risks + +- **CLI regression during phase 1.** Mitigation: extraction PR ships first with no SDK consumers, with the existing CLI tests as the regression surface. +- **napi-rs prebuilt binary CI complexity.** Six-target build matrices break in interesting ways (musl static linking, macOS codesigning, cross-compilation for aarch64). The v3 toolchain has only been exercised on darwin-arm64 so far; the full cross-platform matrix is unproven. Mitigation: lean on napi-rs's published workflow template; treat the first publish as the trigger for completing the build matrix. +- **Python/TS SDK behavior drift.** While Python stays pure-Python over gRPC, behavior (timeouts, retry, error mapping) may drift from the Rust SDK. Mitigation: keep the Python SDK frozen during this RFC's implementation; track parity as a precondition for a future Python-on-shared-core RFC. +- **Refresh contract details.** The FFI mechanism is settled. Still unspecified: proactive vs reactive trigger, deadline, retry-after-refresh-failure, terminal-failure signalling. Mitigation: design alongside phase 2. +- **Tunnel-vs-refresh interaction.** The CF tunnel captures bearer headers at connection time; bearer rotation mid-session is not yet specified. Mitigation: settle in phase 3 with one targeted test before npm alpha publish. + +## Alternatives + +- **Pure-TS gRPC client (e.g., `@connectrpc/connect`, `ts-proto`).** Cheaper and faster initially, no shared runtime. Loses all the shared-core benefits (auth refresh, retry, error taxonomy) and locks us into duplicating logic per language. Reasonable if the project decided the shared-core direction is overkill; this RFC argues it's not. +- **TS calls Python via subprocess or IPC.** Rejected — terrible DX, forces a Python runtime on Node consumers. +- **UniFFI for both Python and TS.** UniFFI's TS target is not yet stable. Re-evaluate once it lands. +- **Diplomat (Rust → JS/Dart/Kotlin).** Smaller community, JS support less proven. +- **`wit-bindgen` + WebAssembly component model.** The likely long-term target once Node packaging of wasm components matures. +- **Do nothing; tell TS users to use the gRPC stubs directly.** Possible, but leaves every TS consumer to roll their own wrapper. + +## Prior art + +- **Polars** — Rust core, PyO3 for Python, napi-rs for Node. Same pattern. +- **swc** and **Turbopack** — large napi-rs projects in the JS tooling ecosystem, demonstrate the publishing/CI patterns. +- **Bitwarden SDK** — Rust core with UniFFI bindings; useful reference for Refresh-trait-style auth design even though we're not using UniFFI. +- **1Password Connect SDK** — multi-language SDK over a shared gRPC contract, same design choice in a different domain. + +## Open questions + +- **Retry policy shape.** Builder on `ClientConfig` (declarative) or `tower::Layer` (composable)? Composable is more flexible; declarative is friendlier for napi/PyO3 consumers who can't construct a `Layer`. +- **Should `OpenShellClient::from_gateway_name(name)` exist in `openshell-sdk` at all,** or only in a CLI-config helper crate? Tradeoff between ergonomics and keeping `openshell-sdk` filesystem-free. From 4a2caba7be7efea5131910e7cdabf7e77c04033b Mon Sep 17 00:00:00 2001 From: Max Dubrinsky Date: Thu, 28 May 2026 11:00:45 -0400 Subject: [PATCH 2/6] feat(sdk): extract openshell-sdk crate Per RFC 0005, lift the gRPC client, TLS, OIDC, edge-tunnel, and refresh plumbing out of openshell-cli into a new openshell-sdk crate. CLI and TUI now consume the SDK; openshell-cli/src/{tls.rs,oidc_auth.rs} shrink to thin wrappers over the SDK's transport and OIDC modules. - New crate openshell-sdk exposes a typed gRPC client, TLS resolver, OidcRefresher with single-flight semantics, edge-tunnel dialer, and a Sandbox-API surface that mirrors the existing CLI behavior. - crates/openshell-core/src/auth.rs moves into the SDK as auth.rs. - crates/openshell-cli/src/edge_tunnel.rs moves into the SDK as edge_tunnel.rs. Tests: 3 unit + 10 mock-gateway integration tests in openshell-sdk. --- Cargo.lock | 30 +- crates/openshell-cli/Cargo.toml | 6 +- crates/openshell-cli/src/completers.rs | 2 +- crates/openshell-cli/src/lib.rs | 1 - crates/openshell-cli/src/main.rs | 2 +- crates/openshell-cli/src/oidc_auth.rs | 87 +- crates/openshell-cli/src/tls.rs | 306 +++---- crates/openshell-core/src/lib.rs | 1 - crates/openshell-sdk/Cargo.toml | 38 + .../src/auth.rs | 20 +- crates/openshell-sdk/src/client.rs | 331 ++++++++ crates/openshell-sdk/src/config.rs | 81 ++ .../src/edge_tunnel.rs | 19 +- crates/openshell-sdk/src/error.rs | 138 ++++ crates/openshell-sdk/src/lib.rs | 51 ++ crates/openshell-sdk/src/oidc.rs | 146 ++++ crates/openshell-sdk/src/raw.rs | 45 ++ crates/openshell-sdk/src/refresh.rs | 301 +++++++ crates/openshell-sdk/src/transport.rs | 253 ++++++ crates/openshell-sdk/src/types.rs | 169 ++++ crates/openshell-sdk/tests/client_mock.rs | 764 ++++++++++++++++++ crates/openshell-tui/Cargo.toml | 1 + crates/openshell-tui/src/app.rs | 2 +- crates/openshell-tui/src/lib.rs | 2 +- 24 files changed, 2493 insertions(+), 303 deletions(-) create mode 100644 crates/openshell-sdk/Cargo.toml rename crates/{openshell-core => openshell-sdk}/src/auth.rs (78%) create mode 100644 crates/openshell-sdk/src/client.rs create mode 100644 crates/openshell-sdk/src/config.rs rename crates/{openshell-cli => openshell-sdk}/src/edge_tunnel.rs (92%) create mode 100644 crates/openshell-sdk/src/error.rs create mode 100644 crates/openshell-sdk/src/lib.rs create mode 100644 crates/openshell-sdk/src/oidc.rs create mode 100644 crates/openshell-sdk/src/raw.rs create mode 100644 crates/openshell-sdk/src/refresh.rs create mode 100644 crates/openshell-sdk/src/transport.rs create mode 100644 crates/openshell-sdk/src/types.rs create mode 100644 crates/openshell-sdk/tests/client_mock.rs diff --git a/Cargo.lock b/Cargo.lock index 92bc18499..739f47568 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3428,6 +3428,7 @@ dependencies = [ "openshell-policy", "openshell-prover", "openshell-providers", + "openshell-sdk", "openshell-tui", "owo-colors", "prost-types", @@ -3443,11 +3444,8 @@ dependencies = [ "tempfile", "thiserror 2.0.18", "tokio", - "tokio-rustls", "tokio-stream", - "tokio-tungstenite 0.26.2", "tonic", - "tower 0.5.3", "tracing", "tracing-subscriber", "url", @@ -3684,6 +3682,31 @@ dependencies = [ "webpki-roots 1.0.7", ] +[[package]] +name = "openshell-sdk" +version = "0.0.0" +dependencies = [ + "async-trait", + "futures", + "hyper", + "hyper-util", + "miette", + "oauth2", + "openshell-core", + "reqwest 0.12.28", + "rustls", + "rustls-pemfile", + "serde", + "thiserror 2.0.18", + "tokio", + "tokio-rustls", + "tokio-stream", + "tokio-tungstenite 0.26.2", + "tonic", + "tower 0.5.3", + "tracing", +] + [[package]] name = "openshell-server" version = "0.0.0" @@ -3773,6 +3796,7 @@ dependencies = [ "openshell-core", "openshell-policy", "openshell-providers", + "openshell-sdk", "owo-colors", "ratatui", "serde", diff --git a/crates/openshell-cli/Cargo.toml b/crates/openshell-cli/Cargo.toml index b69a9629b..7df60b6dd 100644 --- a/crates/openshell-cli/Cargo.toml +++ b/crates/openshell-cli/Cargo.toml @@ -20,6 +20,7 @@ openshell-core = { path = "../openshell-core" } openshell-policy = { path = "../openshell-policy" } openshell-providers = { path = "../openshell-providers" } openshell-prover = { path = "../openshell-prover" } +openshell-sdk = { path = "../openshell-sdk" } openshell-tui = { path = "../openshell-tui" } serde = { workspace = true } serde_json = { workspace = true } @@ -49,8 +50,6 @@ hyper-util = { workspace = true } hyper-rustls = { version = "0.27", default-features = false, features = ["native-tokio", "http1", "http2", "tls12", "logging", "ring", "webpki-tokio"] } rustls = { workspace = true } rustls-pemfile = { workspace = true } -tokio-rustls = { workspace = true } -tower = { workspace = true } reqwest = { workspace = true } # Error handling @@ -66,9 +65,6 @@ tempfile = "3" oauth2 = "5" base64 = { workspace = true } -# WebSocket (Cloudflare tunnel proxy) -tokio-tungstenite = { workspace = true } - # Streams futures = { workspace = true } tokio-stream = { workspace = true } diff --git a/crates/openshell-cli/src/completers.rs b/crates/openshell-cli/src/completers.rs index a421b418a..ff8713dcb 100644 --- a/crates/openshell-cli/src/completers.rs +++ b/crates/openshell-cli/src/completers.rs @@ -9,9 +9,9 @@ use openshell_bootstrap::edge_token::load_edge_token; use openshell_bootstrap::oidc_token::{is_token_expired, load_oidc_token, store_oidc_token}; use openshell_bootstrap::{list_gateways, load_active_gateway, load_gateway_metadata}; use openshell_core::ObjectName; -use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::proto::open_shell_client::OpenShellClient; use openshell_core::proto::{ListProvidersRequest, ListSandboxesRequest}; +use openshell_sdk::EdgeAuthInterceptor; use tonic::service::interceptor::InterceptedService; use tonic::transport::Channel; diff --git a/crates/openshell-cli/src/lib.rs b/crates/openshell-cli/src/lib.rs index 84a87acd2..bdf3fa092 100644 --- a/crates/openshell-cli/src/lib.rs +++ b/crates/openshell-cli/src/lib.rs @@ -10,7 +10,6 @@ pub(crate) static TEST_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(() pub mod auth; pub mod completers; -pub mod edge_tunnel; pub mod oidc_auth; pub(crate) mod policy_update; pub mod run; diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 917c8faa1..e4ea3a0b4 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -2899,7 +2899,7 @@ async fn main() -> Result<()> { let mut tls = tls.with_gateway_name(&ctx.name); apply_auth(&mut tls, &ctx.name); let channel = openshell_cli::tls::build_channel(&ctx.endpoint, &tls).await?; - let interceptor = openshell_core::auth::EdgeAuthInterceptor::new( + let interceptor = openshell_sdk::EdgeAuthInterceptor::new( tls.oidc_token.as_deref(), tls.edge_token.as_deref(), )?; diff --git a/crates/openshell-cli/src/oidc_auth.rs b/crates/openshell-cli/src/oidc_auth.rs index 379a53112..bdc30e902 100644 --- a/crates/openshell-cli/src/oidc_auth.rs +++ b/crates/openshell-cli/src/oidc_auth.rs @@ -17,10 +17,10 @@ use miette::{IntoDiagnostic, Result}; use oauth2::basic::BasicClient; use oauth2::{ AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, - RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl, + RedirectUrl, Scope, TokenResponse, TokenUrl, }; use openshell_bootstrap::oidc_token::OidcTokenBundle; -use serde::Deserialize; +use openshell_sdk::oidc::{RefreshTokenInput, discover, http_client, refresh_token}; use std::convert::Infallible; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -30,50 +30,6 @@ use tracing::debug; const AUTH_TIMEOUT: Duration = Duration::from_secs(120); -/// OIDC discovery document (subset of fields we need). -#[derive(Debug, Deserialize)] -struct OidcDiscovery { - issuer: String, - authorization_endpoint: String, - token_endpoint: String, -} - -/// Discover OIDC endpoints from the issuer's well-known configuration. -/// -/// Validates that the discovery document's `issuer` field matches the -/// configured issuer URL to prevent SSRF or misdirection. -async fn discover(issuer: &str, insecure: bool) -> Result { - let normalized_issuer = issuer.trim_end_matches('/'); - let url = format!("{normalized_issuer}/.well-known/openid-configuration"); - let client = http_client(insecure); - let resp: OidcDiscovery = client - .get(&url) - .send() - .await - .into_diagnostic()? - .json() - .await - .into_diagnostic()?; - - let discovered_issuer = resp.issuer.trim_end_matches('/'); - if discovered_issuer != normalized_issuer { - return Err(miette::miette!( - "OIDC discovery issuer mismatch: expected '{}', got '{}'", - normalized_issuer, - discovered_issuer - )); - } - Ok(resp) -} - -fn http_client(insecure: bool) -> reqwest::Client { - let mut builder = reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()); - if insecure { - builder = builder.danger_accept_invalid_certs(true); - } - builder.build().expect("failed to build HTTP client") -} - fn build_scopes(scopes: Option<&str>) -> Vec { let mut result = vec![Scope::new("openid".to_string())]; if let Some(s) = scopes { @@ -227,36 +183,33 @@ pub async fn oidc_client_credentials_flow( /// Refresh an OIDC token using the `refresh_token` grant. /// -/// Preserves the existing refresh token if the server does not return a new -/// one (per OAuth 2.0 spec, the refresh response may omit `refresh_token`). +/// Wraps [`openshell_sdk::oidc::refresh_token`] with the CLI's +/// [`OidcTokenBundle`] storage shape. Preserves the existing refresh +/// token when the server omits one (per OAuth 2.0 the refresh response +/// is allowed to leave `refresh_token` out). pub async fn oidc_refresh_token( bundle: &OidcTokenBundle, insecure: bool, ) -> Result { - let refresh_token = bundle.refresh_token.as_deref().ok_or_else(|| { + let refresh = bundle.refresh_token.as_deref().ok_or_else(|| { miette::miette!( "no refresh token available — re-authenticate with: openshell gateway login" ) })?; - let discovery = discover(&bundle.issuer, insecure).await?; - - let client = BasicClient::new(ClientId::new(bundle.client_id.clone())) - .set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?); - - let http = http_client(insecure); - let token_response = client - .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string())) - .request_async(&http) - .await - .map_err(|e| miette::miette!("token refresh failed: {e}"))?; - - let mut refreshed = - bundle_from_oauth2_response(&token_response, &bundle.issuer, &bundle.client_id); - if refreshed.refresh_token.is_none() { - refreshed.refresh_token.clone_from(&bundle.refresh_token); - } - Ok(refreshed) + let input = + RefreshTokenInput::new(refresh, &bundle.issuer, &bundle.client_id).with_insecure(insecure); + let output = refresh_token(&input).await.into_diagnostic()?; + + Ok(OidcTokenBundle { + access_token: output.access_token, + refresh_token: output + .refresh_token + .or_else(|| bundle.refresh_token.clone()), + expires_at: output.expires_at, + issuer: bundle.issuer.clone(), + client_id: bundle.client_id.clone(), + }) } /// Ensure we have a valid OIDC token for the given gateway, refreshing if needed. diff --git a/crates/openshell-cli/src/tls.rs b/crates/openshell-cli/src/tls.rs index 10df401a5..89097b0de 100644 --- a/crates/openshell-cli/src/tls.rs +++ b/crates/openshell-cli/src/tls.rs @@ -2,25 +2,26 @@ // SPDX-License-Identifier: Apache-2.0 use miette::{IntoDiagnostic, Result, WrapErr}; -use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::proto::inference_client::InferenceClient; use openshell_core::proto::open_shell_client::OpenShellClient; +use openshell_sdk::EdgeAuthInterceptor; use rustls::{ RootCertStore, - client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, - pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}, + pki_types::{CertificateDer, PrivateKeyDer}, }; -use std::collections::HashMap; -use std::future::Future; use std::io::Cursor; -use std::net::SocketAddr; use std::path::PathBuf; -use std::sync::OnceLock; use std::time::Duration; -use tokio::sync::Mutex; use tonic::service::interceptor::InterceptedService; use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity}; -use tracing::debug; + +// `build_insecure_rustls_config` lives in the SDK (used by the SDK's +// transport stack and by CLI's HTTP health check). The other former +// `tls.rs` helpers (`build_rustls_config`, `build_tonic_tls_config`, +// `load_private_key`, `TlsMaterials`) were tied to mTLS and now live +// below as CLI-private legacy code — they will go away when mTLS is +// retired as an auth method. +pub use openshell_sdk::transport::build_insecure_rustls_config; /// Concrete gRPC client type used by all commands. pub type GrpcClient = OpenShellClient>; @@ -104,12 +105,44 @@ impl TlsOptions { pub fn is_bearer_auth(&self) -> bool { self.edge_token.is_some() || self.oidc_token.is_some() } -} -pub struct TlsMaterials { - ca: Vec, - cert: Vec, - key: Vec, + /// Returns `true` when this `TlsOptions` carries a full mTLS client + /// identity (cert + key on disk). Used by [`build_channel`] to route + /// mTLS-authenticated gateways through the legacy inline path. + pub fn has_mtls_identity(&self, server: &str) -> bool { + let resolved = self.with_default_paths(server); + resolved.cert.as_ref().is_some_and(|p| p.exists()) + && resolved.key.as_ref().is_some_and(|p| p.exists()) + } + + /// Convert this CLI-side `TlsOptions` into an SDK [`openshell_sdk::ClientConfig`] + /// for non-mTLS gateways. + /// + /// Reads the CA cert from disk if a path resolves; a missing file is + /// non-fatal and falls back to system roots (matches today's OIDC + /// fallback behavior). Maps tokens to [`openshell_sdk::AuthConfig`] + /// with OIDC taking precedence over `EdgeJwt` when both are set. + /// + /// mTLS materials are intentionally not carried through; gateways + /// requiring client certificates are dispatched to the legacy inline + /// path in [`build_channel`] before this conversion is reached. + pub fn to_client_config(&self, server: &str) -> openshell_sdk::ClientConfig { + let resolved = self.with_default_paths(server); + let ca_cert = resolved + .ca + .as_ref() + .and_then(|ca_path| std::fs::read(ca_path).ok()); + let auth = match (&resolved.oidc_token, &resolved.edge_token) { + (Some(token), _) => Some(openshell_sdk::AuthConfig::Oidc(token.clone())), + (None, Some(token)) => Some(openshell_sdk::AuthConfig::EdgeJwt(token.clone())), + (None, None) => None, + }; + let mut config = openshell_sdk::ClientConfig::new(server); + config.ca_cert = ca_cert; + config.auth = auth; + config.insecure_skip_verify = resolved.gateway_insecure; + config + } } /// Resolve the TLS cert directory for a known gateway name. @@ -163,6 +196,20 @@ fn xdg_config_dir() -> Result { openshell_core::paths::xdg_config_dir() } +// ── Legacy mTLS path ───────────────────────────────────────────────── +// Everything in this section supports gateways that authenticate clients +// with an mTLS certificate. mTLS is being retired as an auth method, and +// the SDK does not speak it. Until product removes mTLS support, these +// helpers stay in CLI for the `else { full mTLS }` branch of +// `build_channel` and the matching branch of `http_health_check`. + +/// In-memory mTLS materials read from disk by [`require_tls_materials`]. +pub struct TlsMaterials { + pub ca: Vec, + pub cert: Vec, + pub key: Vec, +} + pub fn require_tls_materials(server: &str, tls: &TlsOptions) -> Result { let resolved = tls.with_default_paths(server); let default_hint = default_tls_dir(server).map_or_else(String::new, |dir| { @@ -192,6 +239,7 @@ pub fn require_tls_materials(server: &str, tls: &TlsOptions) -> Result Result> { let mut cursor = Cursor::new(pem); let key = rustls_pemfile::private_key(&mut cursor) @@ -200,11 +248,12 @@ fn load_private_key(pem: &[u8]) -> Result> { Ok(key) } +/// Build a `rustls` mTLS client config (used by `http_health_check`). pub fn build_rustls_config(materials: &TlsMaterials) -> Result { let mut roots = RootCertStore::empty(); let mut ca_cursor = Cursor::new(&materials.ca); let ca_certs = rustls_pemfile::certs(&mut ca_cursor) - .collect::>, _>>() + .collect::>, _>>() .into_diagnostic()?; for cert in ca_certs { roots.add(cert).into_diagnostic()?; @@ -212,7 +261,7 @@ pub fn build_rustls_config(materials: &TlsMaterials) -> Result>, _>>() + .collect::>, _>>() .into_diagnostic()?; let key = load_private_key(&materials.key)?; @@ -222,6 +271,8 @@ pub fn build_rustls_config(materials: &TlsMaterials) -> Result ClientTlsConfig { let ca_cert = Certificate::from_pem(materials.ca.clone()); let identity = Identity::from_pem(materials.cert.clone(), materials.key.clone()); @@ -230,202 +281,48 @@ pub fn build_tonic_tls_config(materials: &TlsMaterials) -> ClientTlsConfig { .identity(identity) } -#[derive(Debug)] -struct InsecureServerCertVerifier; - -impl ServerCertVerifier for InsecureServerCertVerifier { - fn verify_server_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &ServerName<'_>, - _ocsp_response: &[u8], - _now: UnixTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &rustls::DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &rustls::DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - rustls::crypto::ring::default_provider() - .signature_verification_algorithms - .supported_schemes() - } -} - -#[derive(Clone)] -struct InsecureTlsConnector { - tls_connector: tokio_rustls::TlsConnector, -} - -impl tower::Service for InsecureTlsConnector { - type Response = hyper_util::rt::TokioIo>; - type Error = Box; - type Future = - std::pin::Pin> + Send>>; - - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - std::task::Poll::Ready(Ok(())) - } +// ── Channel construction (legacy mTLS dispatcher) ──────────────────── +// `build_channel` is a thin dispatcher: gateways that authenticate +// clients with mTLS take the inline `build_legacy_mtls_channel` path +// below; everything else converts to a `ClientConfig` and delegates to +// `openshell_sdk::transport::build_channel`. When mTLS retires as an +// auth method, `needs_legacy_mtls` and `build_legacy_mtls_channel` go +// with it. - fn call(&mut self, uri: hyper::Uri) -> Self::Future { - let tls_connector = self.tls_connector.clone(); - Box::pin(async move { - let host = uri.host().unwrap_or("localhost").to_string(); - let port = uri.port_u16().unwrap_or(443); - let addr = format!("{host}:{port}"); - let tcp = tokio::net::TcpStream::connect(addr).await?; - let server_name = ServerName::try_from(host)?; - let tls_stream = tls_connector.connect(server_name, tcp).await?; - Ok(hyper_util::rt::TokioIo::new(tls_stream)) - }) +pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { + if needs_legacy_mtls(tls, server) { + return build_legacy_mtls_channel(server, tls).await; } + let config = tls.to_client_config(server); + Ok(openshell_sdk::transport::build_channel(&config).await?) } -pub fn build_insecure_rustls_config() -> Result { - let config = rustls::ClientConfig::builder() - .dangerous() - .with_custom_certificate_verifier(std::sync::Arc::new(InsecureServerCertVerifier)) - .with_no_client_auth(); - Ok(config) -} - -/// Tunnel proxy addresses keyed by upstream endpoint + token. -/// -/// Each distinct edge-authenticated gateway gets its own local proxy instead of -/// reusing the first gateway touched in the current process. -static EDGE_TUNNEL_ADDRS: OnceLock>> = OnceLock::new(); - -async fn edge_tunnel_addr(server: &str, token: &str) -> Result { - let key = (server.to_string(), token.to_string()); - let registry = EDGE_TUNNEL_ADDRS.get_or_init(|| Mutex::new(HashMap::new())); - - { - let addrs = registry.lock().await; - if let Some(addr) = addrs.get(&key).copied() { - return Ok(addr); - } - } - - let proxy = crate::edge_tunnel::start_tunnel_proxy(server, token).await?; - debug!( - local_addr = %proxy.local_addr, - server, - "edge tunnel proxy started, routing gRPC through local proxy" - ); - - let mut addrs = registry.lock().await; - Ok(*addrs.entry(key).or_insert(proxy.local_addr)) +/// Returns `true` when this connection should run through the CLI's +/// inline mTLS path: HTTPS, no insecure-skip, no edge tunnel, and either +/// no OIDC token or OIDC paired with mTLS materials on disk. The combined +/// mTLS+OIDC case preserves the documented "mTLS as transport trust +/// boundary, Bearer for full scope" deployment model. +fn needs_legacy_mtls(tls: &TlsOptions, server: &str) -> bool { + server.starts_with("https://") + && !tls.gateway_insecure + && tls.edge_token.is_none() + && (tls.oidc_token.is_none() || tls.has_mtls_identity(server)) } -pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { - if server.starts_with("http://") { - let endpoint = Endpoint::from_shared(server.to_string()) - .into_diagnostic()? - .connect_timeout(Duration::from_secs(10)) - .http2_adaptive_window(true) - .http2_keep_alive_interval(Duration::from_secs(10)) - .keep_alive_while_idle(true); - return endpoint.connect().await.into_diagnostic(); - } - - // When Cloudflare edge bearer auth is active and the server is HTTPS, - // route traffic through a local WebSocket tunnel proxy instead. - // OIDC tokens bypass the tunnel — they connect directly. - if tls.edge_token.is_some() && server.starts_with("https://") { - let token = tls - .edge_token - .as_deref() - .ok_or_else(|| miette::miette!("edge token required for tunnel"))?; - let local_addr = edge_tunnel_addr(server, token).await?; - - // Connect to the local tunnel proxy over plaintext HTTP/2. - let local_url = format!("http://{local_addr}"); - let endpoint = Endpoint::from_shared(local_url) - .into_diagnostic()? - .connect_timeout(Duration::from_secs(10)) - .http2_adaptive_window(true) - .http2_keep_alive_interval(Duration::from_secs(10)) - .keep_alive_while_idle(true); - return endpoint.connect().await.into_diagnostic(); - } - - if tls.gateway_insecure && server.starts_with("https://") { - tracing::warn!("TLS certificate verification is disabled — do not use in production"); - let rustls_config = build_insecure_rustls_config()?; - let tls_connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(rustls_config)); - let connector = InsecureTlsConnector { tls_connector }; - // Use http:// so tonic does not layer its own TLS on top — our - // connector performs TLS with the insecure config. - let http_uri = server.replacen("https://", "http://", 1); - let endpoint = Endpoint::from_shared(http_uri) - .into_diagnostic()? - .connect_timeout(Duration::from_secs(10)) - .http2_keep_alive_interval(Duration::from_secs(10)) - .keep_alive_while_idle(true); - return endpoint - .connect_with_connector(connector) - .await - .into_diagnostic(); - } - - let mut endpoint = Endpoint::from_shared(server.to_string()) +/// Inline mTLS channel construction for gateways that require client +/// certificates as the transport-level trust boundary. Goes away when +/// mTLS is retired as an auth method. +async fn build_legacy_mtls_channel(server: &str, tls: &TlsOptions) -> Result { + let materials = require_tls_materials(server, tls)?; + let tls_config = build_tonic_tls_config(&materials); + let endpoint = Endpoint::from_shared(server.to_string()) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) - .keep_alive_while_idle(true); - - let tls_config = if tls.oidc_token.is_some() { - // Bearer auth over HTTPS: use mTLS certs for the transport layer when - // available (server may still require client certs), and layer the - // Bearer token on top via the interceptor. - require_tls_materials(server, tls).map_or_else( - |_| { - let resolved = tls.with_default_paths(server); - resolved - .ca - .as_ref() - .and_then(|ca_path| std::fs::read(ca_path).ok()) - .map_or_else( - || ClientTlsConfig::new().with_enabled_roots(), - |ca_pem| { - ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca_pem)) - }, - ) - }, - |materials| build_tonic_tls_config(&materials), - ) - } else if tls.edge_token.is_some() { - // Edge bearer mode — routed through tunnel above; if we reach here - // the server is not HTTPS so connect plaintext. - return endpoint.connect().await.into_diagnostic(); - } else { - // Standard mTLS: private CA + client cert. - let materials = require_tls_materials(server, tls)?; - build_tonic_tls_config(&materials) - }; - endpoint = endpoint.tls_config(tls_config).into_diagnostic()?; + .keep_alive_while_idle(true) + .tls_config(tls_config) + .into_diagnostic()?; endpoint.connect().await.into_diagnostic() } @@ -441,7 +338,10 @@ pub async fn grpc_client(server: &str, tls: &TlsOptions) -> Result { } fn interceptor_from_tls(tls: &TlsOptions) -> Result { - EdgeAuthInterceptor::new(tls.oidc_token.as_deref(), tls.edge_token.as_deref()) + Ok(EdgeAuthInterceptor::new( + tls.oidc_token.as_deref(), + tls.edge_token.as_deref(), + )?) } pub async fn grpc_inference_client(server: &str, tls: &TlsOptions) -> Result { diff --git a/crates/openshell-core/src/lib.rs b/crates/openshell-core/src/lib.rs index 17548ad1a..09da7fd64 100644 --- a/crates/openshell-core/src/lib.rs +++ b/crates/openshell-core/src/lib.rs @@ -9,7 +9,6 @@ //! - Common error types //! - Build version metadata -pub mod auth; pub mod config; pub mod driver_utils; pub mod error; diff --git a/crates/openshell-sdk/Cargo.toml b/crates/openshell-sdk/Cargo.toml new file mode 100644 index 000000000..9bab09fab --- /dev/null +++ b/crates/openshell-sdk/Cargo.toml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-sdk" +description = "Shared async Rust client for OpenShell gateways" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +openshell-core = { path = "../openshell-core" } +async-trait = "0.1" +futures = { workspace = true } +hyper = { workspace = true } +hyper-util = { workspace = true } +miette = { workspace = true } +oauth2 = "5" +reqwest = { workspace = true } +rustls = { workspace = true } +rustls-pemfile = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } +tokio-rustls = { workspace = true } +tokio-stream = { workspace = true } +tokio-tungstenite = { workspace = true } +tonic = { workspace = true, features = ["tls", "tls-native-roots"] } +tower = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } + +[lints] +workspace = true diff --git a/crates/openshell-core/src/auth.rs b/crates/openshell-sdk/src/auth.rs similarity index 78% rename from crates/openshell-core/src/auth.rs rename to crates/openshell-sdk/src/auth.rs index 16d513346..79e6a1fc0 100644 --- a/crates/openshell-core/src/auth.rs +++ b/crates/openshell-sdk/src/auth.rs @@ -1,15 +1,15 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! gRPC authentication interceptor shared by CLI and TUI. +//! Bearer-token authentication interceptor for outgoing gRPC requests. -use miette::Result; +use crate::error::{Result, SdkError}; /// Interceptor that injects authentication headers into every outgoing gRPC request. /// -/// Supports application-layer Bearer tokens (standard `authorization` -/// header) and Cloudflare Access tokens (custom headers). When no token is -/// set, acts as a no-op. OIDC takes precedence over edge tokens. +/// Supports OIDC Bearer tokens (standard `authorization` header) and +/// Cloudflare Access tokens (custom headers). When no token is set, acts +/// as a no-op. OIDC takes precedence over edge tokens. #[derive(Clone)] #[allow(clippy::struct_field_names)] pub struct EdgeAuthInterceptor { @@ -21,14 +21,14 @@ pub struct EdgeAuthInterceptor { impl EdgeAuthInterceptor { /// Create an interceptor from optional token strings. /// - /// OIDC bearer tokens take precedence over edge tokens. Returns a no-op - /// interceptor when no token is provided. + /// OIDC bearer token takes precedence over edge token. Returns a no-op + /// interceptor when neither token is provided. pub fn new(oidc_token: Option<&str>, edge_token: Option<&str>) -> Result { if let Some(token) = oidc_token { let bearer: tonic::metadata::MetadataValue = format!("Bearer {token}") .parse() - .map_err(|_| miette::miette!("invalid bearer token value"))?; + .map_err(|_| SdkError::auth("invalid OIDC token value"))?; return Ok(Self { bearer_value: Some(bearer), header_value: None, @@ -40,11 +40,11 @@ impl EdgeAuthInterceptor { Some(t) => { let hv: tonic::metadata::MetadataValue = t .parse() - .map_err(|_| miette::miette!("invalid edge token value"))?; + .map_err(|_| SdkError::auth("invalid edge token value"))?; let cv: tonic::metadata::MetadataValue = format!("CF_Authorization={t}") .parse() - .map_err(|_| miette::miette!("invalid edge token value for cookie"))?; + .map_err(|_| SdkError::auth("invalid edge token value for cookie"))?; (Some(hv), Some(cv)) } None => (None, None), diff --git a/crates/openshell-sdk/src/client.rs b/crates/openshell-sdk/src/client.rs new file mode 100644 index 000000000..fb3209cc9 --- /dev/null +++ b/crates/openshell-sdk/src/client.rs @@ -0,0 +1,331 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! High-level async client over the gateway gRPC surface. +//! +//! Covers the sandbox-focused MVP slice: health, sandbox CRUD, readiness / +//! deletion waits, and non-streaming exec. Other RPCs (inference, providers, +//! policy, logs, settings, SSH, forwarding) are reachable via +//! [`OpenShellClient::raw_grpc`] / [`OpenShellClient::raw_inference`]. + +use crate::auth::EdgeAuthInterceptor; +use crate::config::{AuthConfig, ClientConfig}; +use crate::error::{Result, SdkError}; +use crate::raw::{AuthedGrpcClient, AuthedInferenceClient}; +use crate::transport; +use crate::types::{ + ExecOptions, ExecResult, Health, ListOptions, SandboxPhase, SandboxRef, SandboxSpec, +}; +use futures::StreamExt; +use openshell_core::proto; +use std::time::{Duration, Instant}; +use tonic::transport::Channel; + +/// Async client for a single `OpenShell` gateway. +/// +/// Cheap to clone — the underlying tonic [`Channel`] multiplexes RPCs over a +/// shared HTTP/2 connection. Construct one per logical gateway and share it +/// across tasks; do not call [`OpenShellClient::connect`] per request. +#[derive(Clone)] +pub struct OpenShellClient { + channel: Channel, + interceptor: EdgeAuthInterceptor, +} + +impl OpenShellClient { + /// Open a connection to the gateway described by `config`. + /// + /// Performs the gRPC channel handshake immediately; subsequent RPCs reuse + /// the connection. + pub async fn connect(config: ClientConfig) -> Result { + let channel = transport::build_channel(&config).await?; + let interceptor = interceptor_from_config(&config)?; + Ok(Self { + channel, + interceptor, + }) + } + + /// Construct from an already-built [`Channel`] and interceptor. + /// + /// Use when the caller needs to customize channel construction beyond + /// what [`ClientConfig`] exposes. + pub fn from_parts(channel: Channel, interceptor: EdgeAuthInterceptor) -> Self { + Self { + channel, + interceptor, + } + } + + /// Underlying tonic [`Channel`]. + pub fn channel(&self) -> Channel { + self.channel.clone() + } + + /// Authenticated gRPC client for the main `OpenShell` service. + /// + /// Use this when the curated surface below doesn't expose the RPC or + /// field you need. + pub fn raw_grpc(&self) -> AuthedGrpcClient { + proto::open_shell_client::OpenShellClient::with_interceptor( + self.channel.clone(), + self.interceptor.clone(), + ) + } + + /// Authenticated gRPC client for the inference service. + pub fn raw_inference(&self) -> AuthedInferenceClient { + proto::inference_client::InferenceClient::with_interceptor( + self.channel.clone(), + self.interceptor.clone(), + ) + } + + /// Gateway health snapshot. + pub async fn health(&self) -> Result { + let mut grpc = self.raw_grpc(); + let resp = grpc + .health(proto::HealthRequest {}) + .await + .map_err(map_status)? + .into_inner(); + Ok(Health { + status: resp.status.into(), + version: resp.version, + }) + } + + /// Create a new sandbox from a curated [`SandboxSpec`]. + pub async fn create_sandbox(&self, spec: SandboxSpec) -> Result { + let request = create_sandbox_request(spec); + let mut grpc = self.raw_grpc(); + let response = grpc + .create_sandbox(request) + .await + .map_err(map_status)? + .into_inner(); + sandbox_from_response(response.sandbox) + } + + /// Fetch a sandbox by name. + pub async fn get_sandbox(&self, name: &str) -> Result { + let mut grpc = self.raw_grpc(); + let response = grpc + .get_sandbox(proto::GetSandboxRequest { + name: name.to_string(), + }) + .await + .map_err(map_status)? + .into_inner(); + sandbox_from_response(response.sandbox) + } + + /// List sandboxes. + pub async fn list_sandboxes(&self, opts: ListOptions) -> Result> { + let mut grpc = self.raw_grpc(); + let response = grpc + .list_sandboxes(proto::ListSandboxesRequest { + limit: opts.limit, + offset: opts.offset, + label_selector: opts.label_selector.unwrap_or_default(), + }) + .await + .map_err(map_status)? + .into_inner(); + Ok(response + .sandboxes + .into_iter() + .map(SandboxRef::from_proto) + .collect()) + } + + /// Delete a sandbox by name. + /// + /// Returns `true` when the gateway acknowledges the deletion, `false` + /// when it was already absent. The sandbox may still be in + /// [`SandboxPhase::Deleting`] when this returns — pair with + /// [`OpenShellClient::wait_deleted`] when you need a terminal guarantee. + pub async fn delete_sandbox(&self, name: &str) -> Result { + let mut grpc = self.raw_grpc(); + let response = grpc + .delete_sandbox(proto::DeleteSandboxRequest { + name: name.to_string(), + }) + .await + .map_err(map_status)? + .into_inner(); + Ok(response.deleted) + } + + /// Poll [`OpenShellClient::get_sandbox`] until the sandbox reaches + /// [`SandboxPhase::Ready`] or the `timeout` elapses. + /// + /// Returns the terminal sandbox snapshot on success. Returns an + /// [`SdkError::Connect`] when the timeout expires, or whatever error + /// the gateway returns if the sandbox transitions into + /// [`SandboxPhase::Error`]. + pub async fn wait_ready(&self, name: &str, timeout: Duration) -> Result { + self.wait_for(name, timeout, |phase| match phase { + SandboxPhase::Ready => Some(Ok(())), + SandboxPhase::Error => Some(Err(SdkError::connect(format!( + "sandbox '{name}' entered error phase" + )))), + _ => None, + }) + .await + } + + /// Poll until the sandbox is gone (gRPC `NotFound`) or the `timeout` + /// elapses. + pub async fn wait_deleted(&self, name: &str, timeout: Duration) -> Result<()> { + let deadline = Instant::now() + timeout; + let mut delay = Duration::from_millis(250); + loop { + match self.get_sandbox(name).await { + Err(SdkError::NotFound { .. }) => return Ok(()), + Err(other) => return Err(other), + Ok(snapshot) if snapshot.phase == SandboxPhase::Deleting => {} + Ok(_) => {} + } + if Instant::now() >= deadline { + return Err(SdkError::connect(format!( + "timed out waiting for sandbox '{name}' to delete" + ))); + } + tokio::time::sleep(delay).await; + delay = (delay * 2).min(Duration::from_secs(2)); + } + } + + /// Run a command inside a sandbox and buffer stdout/stderr to the end. + /// + /// For streaming output, drop down to [`OpenShellClient::raw_grpc`] and + /// call `exec_sandbox` directly. + pub async fn exec(&self, name: &str, cmd: &[String], opts: ExecOptions) -> Result { + let sandbox = self.get_sandbox(name).await?; + let request = proto::ExecSandboxRequest { + sandbox_id: sandbox.id, + command: cmd.to_vec(), + workdir: opts.workdir.unwrap_or_default(), + environment: opts.environment, + timeout_seconds: opts + .timeout + .map_or(0, |d| u32::try_from(d.as_secs()).unwrap_or(u32::MAX)), + stdin: opts.stdin.unwrap_or_default(), + tty: false, + cols: 0, + rows: 0, + }; + + let mut grpc = self.raw_grpc(); + let mut stream = grpc + .exec_sandbox(request) + .await + .map_err(map_status)? + .into_inner(); + + let mut stdout = Vec::new(); + let mut stderr = Vec::new(); + let mut exit_code: Option = None; + + while let Some(event) = stream.next().await { + let event = event.map_err(map_status)?; + match event.payload { + Some(proto::exec_sandbox_event::Payload::Stdout(chunk)) => { + stdout.extend_from_slice(&chunk.data); + } + Some(proto::exec_sandbox_event::Payload::Stderr(chunk)) => { + stderr.extend_from_slice(&chunk.data); + } + Some(proto::exec_sandbox_event::Payload::Exit(exit)) => { + exit_code = Some(exit.exit_code); + } + None => {} + } + } + + Ok(ExecResult { + exit_code: exit_code.unwrap_or(-1), + stdout, + stderr, + }) + } + + async fn wait_for(&self, name: &str, timeout: Duration, mut decide: F) -> Result + where + F: FnMut(SandboxPhase) -> Option>, + { + let deadline = Instant::now() + timeout; + let mut delay = Duration::from_millis(250); + loop { + let snapshot = self.get_sandbox(name).await?; + if let Some(verdict) = decide(snapshot.phase) { + verdict?; + return Ok(snapshot); + } + if Instant::now() >= deadline { + return Err(SdkError::connect(format!( + "timed out waiting for sandbox '{name}'" + ))); + } + tokio::time::sleep(delay).await; + delay = (delay * 2).min(Duration::from_secs(2)); + } + } +} + +fn interceptor_from_config(config: &ClientConfig) -> Result { + match &config.auth { + None => Ok(EdgeAuthInterceptor::noop()), + Some(AuthConfig::Oidc(token)) => EdgeAuthInterceptor::new(Some(token), None), + Some(AuthConfig::EdgeJwt(token)) => EdgeAuthInterceptor::new(None, Some(token)), + } +} + +fn create_sandbox_request(spec: SandboxSpec) -> proto::CreateSandboxRequest { + let SandboxSpec { + name, + image, + labels, + environment, + providers, + gpu, + gpu_device, + } = spec; + let template = image.map(|image| proto::SandboxTemplate { + image, + ..proto::SandboxTemplate::default() + }); + proto::CreateSandboxRequest { + spec: Some(proto::SandboxSpec { + environment, + template, + providers, + gpu, + gpu_device: gpu_device.unwrap_or_default(), + ..proto::SandboxSpec::default() + }), + name: name.unwrap_or_default(), + labels, + } +} + +fn sandbox_from_response(sandbox: Option) -> Result { + sandbox + .map(SandboxRef::from_proto) + .ok_or_else(|| SdkError::invalid_config("sandbox missing from gateway response")) +} + +fn map_status(status: tonic::Status) -> SdkError { + let message = status.message().to_string(); + match status.code() { + tonic::Code::NotFound => SdkError::NotFound { message }, + tonic::Code::AlreadyExists => SdkError::AlreadyExists { message }, + tonic::Code::InvalidArgument => SdkError::invalid_config(message), + tonic::Code::Unauthenticated | tonic::Code::PermissionDenied => SdkError::auth(message), + _ => SdkError::Rpc { + code: status.code() as i32, + message, + }, + } +} diff --git a/crates/openshell-sdk/src/config.rs b/crates/openshell-sdk/src/config.rs new file mode 100644 index 000000000..f54cac5e7 --- /dev/null +++ b/crates/openshell-sdk/src/config.rs @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Public input types for the SDK: how callers describe a gateway and the +//! credentials used to talk to it. +//! +//! The CLI keeps its own filesystem-aware `TlsOptions` for plumbing; it +//! converts to a `ClientConfig` at the moment of dialing the gateway. + +/// Authentication mode for outgoing gRPC requests. +/// +/// The two variants are functionally distinct in the transport layer: +/// `EdgeJwt` routes through a local WebSocket tunnel (the only way to get a +/// browser-flow JWT past Cloudflare Access on POST/HTTP2), while `Oidc` +/// connects directly over HTTPS and adds an `authorization: Bearer ...` +/// header. +// +// `#[non_exhaustive]` keeps phase 2 additive: when we promote `Oidc(String)` +// to `Oidc { token, refresh: Option> }` or add a third +// variant, downstream `match` arms aren't forced to break. +#[derive(Clone)] +#[non_exhaustive] +pub enum AuthConfig { + /// Cloudflare Access JWT — routes through the edge WebSocket tunnel. + EdgeJwt(String), + /// OIDC bearer token — direct HTTPS, `authorization` header. + Oidc(String), +} + +/// Configuration for opening a gRPC channel to an `OpenShell` gateway. +/// +/// Consumed by `openshell_sdk::transport::grpc_client` and the +/// inference-client equivalent. One `ClientConfig` per logical connection; +/// callers that want connection pooling cache the resulting `tonic::Channel`. +// +// NOTE: +// - `gateway` is a full URL (`http://...` or `https://...`) so the scheme +// tells the transport layer whether to use plaintext or TLS. Matches +// today's CLI convention; matches the RFC's `pub gateway: String`. +// - `ca_cert` pins a private-CA certificate (PEM-encoded). `None` falls +// back to the platform's system roots. +// - This SDK does not speak mTLS. Gateways requiring client certificates +// are handled by `openshell-cli`'s legacy mTLS path until product +// retires that auth method. +// - `insecure_skip_verify` is a separate flag rather than a third +// `AuthConfig` variant because it's a transport concern (cert +// verification) that's orthogonal to auth. +// - No `timeout` field yet. The RFC mentions one but today's behavior is +// `connect_timeout(10s)` hard-coded; introducing a configurable timeout +// here would be a behavior change. Phase 2 territory. +// - No `Debug` derive: `auth` carries secrets; `ca_cert` is fine but we +// redact the whole struct for safety. If callers want ergonomic printing +// we can implement `Debug` manually with a redacted token field. +// - `#[non_exhaustive]` + `Default` lets phase 2 add fields (timeout, retry +// policy, `Refresh` trait) without breaking literal-construct callers. +// Idiom is `ClientConfig { gateway: g, ..Default::default() }`. +#[derive(Clone, Default)] +#[non_exhaustive] +pub struct ClientConfig { + /// Gateway URL, e.g. `http://127.0.0.1:8080` or `https://gw.example.com`. + pub gateway: String, + /// CA certificate (PEM) for private-CA gateways. `None` uses system + /// roots. Ignored for plaintext gateways and when + /// `insecure_skip_verify` is enabled. + pub ca_cert: Option>, + /// Bearer-token auth mode. `None` = anonymous TLS over HTTPS, or + /// plaintext when `gateway` is `http://`. + pub auth: Option, + /// Disable TLS certificate verification (development/debug only). + /// Ignored for plaintext gateways. **Do not enable in production.** + pub insecure_skip_verify: bool, +} + +impl ClientConfig { + pub fn new(gateway: impl Into) -> Self { + Self { + gateway: gateway.into(), + ..Default::default() + } + } +} diff --git a/crates/openshell-cli/src/edge_tunnel.rs b/crates/openshell-sdk/src/edge_tunnel.rs similarity index 92% rename from crates/openshell-cli/src/edge_tunnel.rs rename to crates/openshell-sdk/src/edge_tunnel.rs index 814e245f3..5ced5fc35 100644 --- a/crates/openshell-cli/src/edge_tunnel.rs +++ b/crates/openshell-sdk/src/edge_tunnel.rs @@ -19,13 +19,13 @@ //! 3. Bidirectionally pipe bytes between the local TCP stream and the //! WebSocket. //! -//! The gRPC [`Channel`] then connects to `http://127.0.0.1:` +//! The gRPC `Channel` then connects to `http://127.0.0.1:` //! (plaintext) — the edge handles TLS, and the WebSocket carries the raw //! bytes to the origin. +use crate::error::{Result, SdkError}; use futures::stream::{SplitSink, SplitStream}; use futures::{SinkExt, StreamExt}; -use miette::{IntoDiagnostic, Result}; use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -63,8 +63,8 @@ pub async fn start_tunnel_proxy( gateway_endpoint: &str, edge_token: &str, ) -> Result { - let listener = TcpListener::bind("127.0.0.1:0").await.into_diagnostic()?; - let local_addr = listener.local_addr().into_diagnostic()?; + let listener = TcpListener::bind("127.0.0.1:0").await?; + let local_addr = listener.local_addr()?; // Convert the gateway endpoint to a WebSocket URL. // https://foo.com -> wss://foo.com @@ -88,7 +88,6 @@ pub async fn start_tunnel_proxy( "starting edge tunnel proxy" ); - // Spawn the accept loop. tokio::spawn(accept_loop(listener, config)); Ok(EdgeTunnelProxy { local_addr }) @@ -149,13 +148,15 @@ async fn handle_connection(tcp_stream: TcpStream, config: &TunnelConfig) -> Resu /// Open a WebSocket connection to the edge proxy. async fn open_ws(config: &TunnelConfig) -> Result>> { - let mut request = (&config.ws_url).into_client_request().into_diagnostic()?; + let mut request = (&config.ws_url) + .into_client_request() + .map_err(|e| SdkError::invalid_config(format!("invalid tunnel URL: {e}")))?; // Inject the bearer token via multiple headers for compatibility with // Cloudflare Access (which checks `Cf-Access-Token`, the // `CF_Authorization` cookie, and the `Cf-Access-Jwt-Assertion` header). let token_val = HeaderValue::from_str(&config.edge_token) - .map_err(|e| miette::miette!("invalid edge token header value: {e}"))?; + .map_err(|e| SdkError::auth(format!("invalid edge token header value: {e}")))?; request .headers_mut() .insert("Cf-Access-Token", token_val.clone()); @@ -165,14 +166,14 @@ async fn open_ws(config: &TunnelConfig) -> Result = std::result::Result; + +/// Errors produced by `openshell-sdk`. +/// +/// CLI consumers convert these to `miette::Report` at the call boundary; +/// future TS/Python bindings will map them to language-native exceptions +/// via the [`SdkError::code`] accessor. +#[derive(Debug, Error, Diagnostic)] +pub enum SdkError { + /// Caller-supplied configuration is invalid (URL parse, missing field, + /// illegal token characters). + #[error("invalid configuration: {message}")] + #[diagnostic(code(openshell::sdk::invalid_config))] + InvalidConfig { + /// Error message. + message: String, + }, + + /// TLS material parse or rustls config build failure. + #[error("TLS error: {message}")] + #[diagnostic(code(openshell::sdk::tls))] + Tls { + /// Error message. + message: String, + }, + + /// Failed to establish a connection to the gateway (TCP, TLS handshake, + /// HTTP/2, WebSocket upgrade). + #[error("connect error: {message}")] + #[diagnostic(code(openshell::sdk::connect))] + Connect { + /// Error message. + message: String, + }, + + /// Auth-related failure: OIDC discovery / refresh, token format invalid + /// for header injection. + #[error("auth error: {message}")] + #[diagnostic(code(openshell::sdk::auth))] + Auth { + /// Error message. + message: String, + }, + + /// Local IO failure (file read, listener bind, socket). + #[error("I/O error: {source}")] + #[diagnostic(code(openshell::sdk::io))] + Io { + /// Underlying I/O error. + #[from] + source: std::io::Error, + }, + + /// Gateway reported the requested object does not exist (gRPC `NotFound`). + #[error("not found: {message}")] + #[diagnostic(code(openshell::sdk::not_found))] + NotFound { + /// Error message. + message: String, + }, + + /// Gateway reported the requested object already exists (gRPC `AlreadyExists`). + #[error("already exists: {message}")] + #[diagnostic(code(openshell::sdk::already_exists))] + AlreadyExists { + /// Error message. + message: String, + }, + + /// Catch-all for gRPC errors not mapped to a more specific variant. + #[error("gateway error ({code}): {message}")] + #[diagnostic(code(openshell::sdk::rpc))] + Rpc { + /// Numeric gRPC status code (see [`tonic::Code`]). + code: i32, + /// Error message. + message: String, + }, +} + +impl SdkError { + /// Create an `InvalidConfig` error. + pub fn invalid_config(message: impl Into) -> Self { + Self::InvalidConfig { + message: message.into(), + } + } + + /// Create a `Tls` error. + pub fn tls(message: impl Into) -> Self { + Self::Tls { + message: message.into(), + } + } + + /// Create a `Connect` error. + pub fn connect(message: impl Into) -> Self { + Self::Connect { + message: message.into(), + } + } + + /// Create an `Auth` error. + pub fn auth(message: impl Into) -> Self { + Self::Auth { + message: message.into(), + } + } + + /// Stable string code for cross-language binding consumers. + /// + /// Returns one of: `invalid_config`, `tls`, `connect`, `auth`, `io`, + /// `not_found`, `already_exists`, `rpc`. Phase 3 (napi binding) will + /// surface this as the JS error's `code` field for discriminated-union + /// ergonomics. + pub const fn code(&self) -> &'static str { + match self { + Self::InvalidConfig { .. } => "invalid_config", + Self::Tls { .. } => "tls", + Self::Connect { .. } => "connect", + Self::Auth { .. } => "auth", + Self::Io { .. } => "io", + Self::NotFound { .. } => "not_found", + Self::AlreadyExists { .. } => "already_exists", + Self::Rpc { .. } => "rpc", + } + } +} diff --git a/crates/openshell-sdk/src/lib.rs b/crates/openshell-sdk/src/lib.rs new file mode 100644 index 000000000..53bcb336a --- /dev/null +++ b/crates/openshell-sdk/src/lib.rs @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared async Rust client for `OpenShell` gateways. +//! +//! Two layers: +//! +//! - [`OpenShellClient`] — the high-level sandbox-focused MVP surface: +//! health, sandbox CRUD, readiness/deletion waits, non-streaming exec. +//! - [`raw`] — direct access to the generated tonic clients for RPCs the +//! curated surface doesn't yet cover (inference, providers, policy, logs, +//! settings, SSH, forwarding). +//! +//! Owns the gRPC transport stack — channel construction, TLS material +//! handling, request interceptors, OIDC token refresh, and the Cloudflare +//! Access tunnel proxy. Consumed by `openshell-cli`, `openshell-tui`, and +//! the napi-rs wrapper that ships as `@openshell/sdk`. +//! +//! # Quick start +//! +//! ```ignore +//! use openshell_sdk::{ClientConfig, ListOptions, OpenShellClient}; +//! +//! # async fn run() -> Result<(), openshell_sdk::SdkError> { +//! let client = OpenShellClient::connect(ClientConfig::new("http://127.0.0.1:8080")).await?; +//! let health = client.health().await?; +//! let sandboxes = client.list_sandboxes(ListOptions::default()).await?; +//! # Ok(()) +//! # } +//! ``` + +pub mod auth; +pub mod client; +pub mod config; +pub mod edge_tunnel; +pub mod error; +pub mod oidc; +pub mod raw; +pub mod refresh; +pub mod transport; +pub mod types; + +pub use auth::EdgeAuthInterceptor; +pub use client::OpenShellClient; +pub use config::{AuthConfig, ClientConfig}; +pub use error::SdkError; +pub use refresh::{Refresh, RefreshError, RefreshedToken, TokenSource}; +pub use types::{ + ExecOptions, ExecResult, Health, ListOptions, SandboxPhase, SandboxRef, SandboxSpec, + ServiceStatus, +}; diff --git a/crates/openshell-sdk/src/oidc.rs b/crates/openshell-sdk/src/oidc.rs new file mode 100644 index 000000000..6a26678bb --- /dev/null +++ b/crates/openshell-sdk/src/oidc.rs @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! OIDC discovery and refresh-token flow (non-interactive). +//! +//! Browser-based authorization flows live in `openshell-cli` since they +//! require a local callback HTTP server and an OS browser launcher. + +use crate::error::{Result, SdkError}; +use oauth2::basic::BasicClient; +use oauth2::{ClientId, RefreshToken, TokenResponse, TokenUrl}; +use serde::Deserialize; + +/// OIDC discovery document (subset of fields callers consume). +#[derive(Debug, Deserialize)] +#[non_exhaustive] +pub struct OidcDiscovery { + pub issuer: String, + pub authorization_endpoint: String, + pub token_endpoint: String, +} + +/// Input to [`refresh_token`]. +/// +/// Constructed by the caller from whatever bundle / storage shape they +/// use — the SDK does not assume any particular persistence model. +#[derive(Clone)] +#[non_exhaustive] +pub struct RefreshTokenInput { + pub refresh_token: String, + pub issuer: String, + pub client_id: String, + pub insecure: bool, +} + +impl RefreshTokenInput { + pub fn new( + refresh_token: impl Into, + issuer: impl Into, + client_id: impl Into, + ) -> Self { + Self { + refresh_token: refresh_token.into(), + issuer: issuer.into(), + client_id: client_id.into(), + insecure: false, + } + } + + #[must_use] + pub fn with_insecure(mut self, insecure: bool) -> Self { + self.insecure = insecure; + self + } +} + +/// Output from [`refresh_token`]. +/// +/// `refresh_token` is `None` when the OIDC server did not return a new +/// refresh token; per OAuth 2.0, callers should preserve the previous +/// refresh token in that case. `expires_at` is a Unix timestamp (seconds +/// since epoch); `None` when the server omits `expires_in`. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct RefreshTokenOutput { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, +} + +/// Discover OIDC endpoints from the issuer's well-known configuration. +/// +/// Validates that the discovery document's `issuer` field matches the +/// configured issuer URL to prevent SSRF or misdirection. When `insecure` +/// is true, TLS certificate verification is disabled (intended for +/// development against self-signed gateways). +pub async fn discover(issuer: &str, insecure: bool) -> Result { + let normalized_issuer = issuer.trim_end_matches('/'); + let url = format!("{normalized_issuer}/.well-known/openid-configuration"); + let client = http_client(insecure); + let resp: OidcDiscovery = client + .get(&url) + .send() + .await + .map_err(|e| SdkError::auth(format!("OIDC discovery request failed: {e}")))? + .json() + .await + .map_err(|e| SdkError::auth(format!("OIDC discovery JSON parse failed: {e}")))?; + + let discovered_issuer = resp.issuer.trim_end_matches('/'); + if discovered_issuer != normalized_issuer { + return Err(SdkError::auth(format!( + "OIDC discovery issuer mismatch: expected '{normalized_issuer}', got '{discovered_issuer}'" + ))); + } + Ok(resp) +} + +/// Build an HTTP client suitable for OIDC token-endpoint requests. +/// +/// Disables redirects so token-endpoint responses aren't accidentally +/// followed; OIDC providers should not redirect on the token endpoint. +/// When `insecure` is true, TLS certificate verification is disabled. +pub fn http_client(insecure: bool) -> reqwest::Client { + let mut builder = reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()); + if insecure { + builder = builder.danger_accept_invalid_certs(true); + } + builder.build().expect("failed to build HTTP client") +} + +/// Refresh an OIDC access token using the `refresh_token` grant. +/// +/// The caller is responsible for preserving the prior refresh token when +/// the output's `refresh_token` is `None` — per OAuth 2.0 the server may +/// omit it from the refresh response. +pub async fn refresh_token(input: &RefreshTokenInput) -> Result { + let discovery = discover(&input.issuer, input.insecure).await?; + + let client = BasicClient::new(ClientId::new(input.client_id.clone())).set_token_uri( + TokenUrl::new(discovery.token_endpoint) + .map_err(|e| SdkError::auth(format!("invalid token endpoint URL: {e}")))?, + ); + + let http = http_client(input.insecure); + let token_response = client + .exchange_refresh_token(&RefreshToken::new(input.refresh_token.clone())) + .request_async(&http) + .await + .map_err(|e| SdkError::auth(format!("token refresh failed: {e}")))?; + + Ok(output_from_oauth2_response(&token_response)) +} + +fn output_from_oauth2_response(resp: &oauth2::basic::BasicTokenResponse) -> RefreshTokenOutput { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + RefreshTokenOutput { + access_token: resp.access_token().secret().clone(), + refresh_token: resp.refresh_token().map(|rt| rt.secret().clone()), + expires_at: resp.expires_in().map(|ei| now + ei.as_secs()), + } +} diff --git a/crates/openshell-sdk/src/raw.rs b/crates/openshell-sdk/src/raw.rs new file mode 100644 index 000000000..0b3b18a04 --- /dev/null +++ b/crates/openshell-sdk/src/raw.rs @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Escape hatch — direct access to the generated tonic clients and protobuf +//! types. +//! +//! Use this module when the curated high-level surface in +//! [`crate::client::OpenShellClient`] doesn't expose the RPC or field you +//! need. The high-level surface is sandbox-focused for MVP; inference, +//! providers, policy, logs, settings, SSH, and forwarding all live here. +//! +//! ```ignore +//! use openshell_sdk::{ClientConfig, OpenShellClient}; +//! use openshell_sdk::raw::ListProvidersRequest; +//! +//! let client = OpenShellClient::connect(ClientConfig::new("http://127.0.0.1:8080")).await?; +//! let mut grpc = client.raw_grpc(); +//! let providers = grpc.list_providers(ListProvidersRequest::default()).await?; +//! ``` + +pub use openshell_core::proto; +pub use openshell_core::proto::inference_client::InferenceClient; +pub use openshell_core::proto::open_shell_client::OpenShellClient as GrpcClient; +pub use openshell_core::proto::{ + CreateSandboxRequest, DeleteSandboxRequest, ExecSandboxRequest, GetSandboxRequest, + HealthRequest, ListProvidersRequest, ListSandboxesRequest, Sandbox, + SandboxPhase as ProtoSandboxPhase, SandboxSpec as ProtoSandboxSpec, SandboxTemplate, + ServiceStatus as ProtoServiceStatus, +}; + +/// Type alias for the gRPC client wrapped in the SDK's auth interceptor. +pub type AuthedGrpcClient = GrpcClient< + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + crate::EdgeAuthInterceptor, + >, +>; + +/// Type alias for the inference client wrapped in the SDK's auth interceptor. +pub type AuthedInferenceClient = InferenceClient< + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + crate::EdgeAuthInterceptor, + >, +>; diff --git a/crates/openshell-sdk/src/refresh.rs b/crates/openshell-sdk/src/refresh.rs new file mode 100644 index 000000000..46809700d --- /dev/null +++ b/crates/openshell-sdk/src/refresh.rs @@ -0,0 +1,301 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! OIDC bearer-token refresh contract. +//! +//! The SDK never talks to a browser or any specific `IdP`. Callers that need +//! the SDK to rotate an OIDC bearer mid-session implement [`Refresh`] and +//! construct a [`TokenSource`] around it. Implementations live where the +//! browser flow / token store / FFI callback belongs — in `openshell-cli` +//! for the desktop browser flow, in `openshell-sdk-node` for a JS callback. +//! +//! The trait is intentionally minimal. Single-flight coalescing (one refresh +//! in flight at a time, with all waiters sharing the result) is the SDK's +//! responsibility, not the implementer's; see [`TokenSource`]. +//! +//! TODO(rfc-0004): plumb [`TokenSource`] into the gRPC auth interceptor so +//! refreshes happen automatically before each request. Today the napi +//! binding exposes [`TokenSource::refresh_now`] / [`TokenSource::current`] +//! directly to JS callers, which can rotate the token by calling +//! `set_oidc_token` on a future iteration of the SDK client. + +use crate::error::{Result, SdkError}; +use std::fmt; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::sync::{Mutex, RwLock}; + +/// Errors a refresher can return. +/// +/// Domain-specific, deliberately not coupled to `tonic`, `napi`, or any +/// FFI-facing error type. The SDK maps these into [`SdkError::Auth`] before +/// surfacing to callers. +#[derive(Debug)] +#[non_exhaustive] +pub enum RefreshError { + /// Refresh failed but a retry might succeed (network blip, transient + /// `IdP` error). + Transient(String), + /// Refresh cannot succeed without user interaction (refresh token + /// expired, `IdP` revoked the session). Callers should not retry; they + /// should re-authenticate. + Terminal(String), +} + +impl fmt::Display for RefreshError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Transient(msg) => write!(f, "transient refresh error: {msg}"), + Self::Terminal(msg) => write!(f, "terminal refresh error: {msg}"), + } + } +} + +impl std::error::Error for RefreshError {} + +impl From for SdkError { + fn from(value: RefreshError) -> Self { + Self::auth(value.to_string()) + } +} + +/// A freshly minted access token + its absolute expiry. +/// +/// `expires_at` is seconds since the Unix epoch. `None` means the token's +/// expiry was not advertised — the SDK will not refresh it proactively but +/// may refresh on demand if [`Refresh::refresh`] is called. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct RefreshedToken { + pub access_token: String, + pub expires_at: Option, +} + +impl RefreshedToken { + pub fn new(access_token: impl Into) -> Self { + Self { + access_token: access_token.into(), + expires_at: None, + } + } + + #[must_use] + pub fn with_expires_at(mut self, expires_at: u64) -> Self { + self.expires_at = Some(expires_at); + self + } +} + +/// Pluggable OIDC refresher. +/// +/// Implementations should be cheap to clone and safe to call from any tokio +/// task. They MUST NOT do their own single-flight coalescing — that's the +/// SDK's job (see [`TokenSource`]). +#[async_trait::async_trait] +pub trait Refresh: Send + Sync + 'static { + /// Mint a fresh access token. Called by the SDK when it determines the + /// current token is near expiry (or has been explicitly invalidated). + async fn refresh(&self) -> std::result::Result; +} + +/// Mutable token state shared between the auth interceptor and the +/// background refresh task. +#[derive(Debug)] +struct TokenState { + token: String, + expires_at: Option, +} + +/// A bearer-token source with single-flight refresh coalescing. +/// +/// Wraps a [`Refresh`] implementation and tracks the current token + its +/// advertised expiry. Phase 3 of the RFC plumbs this into the auth path; for +/// now language bindings hand it out directly so JS/Python code can drive +/// refreshes externally. +#[derive(Clone)] +pub struct TokenSource { + state: Arc>, + refresher: Arc, + in_flight: Arc>, + /// Refresh `skew` seconds before the advertised `expires_at`. Tokens + /// without `expires_at` are not auto-refreshed. + skew: Duration, +} + +impl TokenSource { + /// Construct a token source backed by `refresher`. Use this when wiring + /// an FFI callback or browser flow into the SDK. + pub fn new(initial: RefreshedToken, refresher: Arc) -> Self { + Self { + state: Arc::new(RwLock::new(TokenState { + token: initial.access_token, + expires_at: initial.expires_at, + })), + refresher, + in_flight: Arc::new(Mutex::new(())), + skew: Duration::from_secs(60), + } + } + + /// Current token without checking expiry. Used by the sync gRPC + /// interceptor, which can't await. + pub fn snapshot(&self) -> String { + self.state + .try_read() + .map(|s| s.token.clone()) + .unwrap_or_default() + } + + /// Async-fetch the current token, refreshing if it's within `skew` of + /// expiry. Single-flight: concurrent callers share one refresh. + pub async fn current(&self) -> Result { + if !self.needs_refresh().await { + return Ok(self.state.read().await.token.clone()); + } + self.refresh_now().await + } + + /// Force a refresh regardless of expiry. Used on `Unauthenticated` + /// responses from the gateway. + pub async fn refresh_now(&self) -> Result { + // Single-flight: only one refresh in flight at a time. Other waiters + // block here and then see the updated state on re-check. + let _guard = self.in_flight.lock().await; + + // Re-check inside the critical section: another caller may have just + // refreshed while we were waiting on the lock. + if !self.needs_refresh().await { + return Ok(self.state.read().await.token.clone()); + } + + let refreshed = self.refresher.refresh().await?; + let mut state = self.state.write().await; + state.token.clone_from(&refreshed.access_token); + state.expires_at = refreshed.expires_at; + Ok(refreshed.access_token) + } + + async fn needs_refresh(&self) -> bool { + let state = self.state.read().await; + let Some(expires_at) = state.expires_at else { + return false; + }; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + now + self.skew.as_secs() >= expires_at + } + + /// Replace the current token without invoking the refresher. + /// + /// Used by callers that manage refresh externally (e.g. the napi + /// binding's JS-side timer) or for testing. + pub async fn replace(&self, token: RefreshedToken) { + let mut state = self.state.write().await; + state.token = token.access_token; + state.expires_at = token.expires_at; + } +} + +impl fmt::Debug for TokenSource { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TokenSource") + .field("skew", &self.skew) + .finish_non_exhaustive() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + + struct CountingRefresher { + calls: Arc, + delay: Duration, + } + + #[async_trait::async_trait] + impl Refresh for CountingRefresher { + async fn refresh(&self) -> std::result::Result { + tokio::time::sleep(self.delay).await; + let n = self.calls.fetch_add(1, Ordering::SeqCst) + 1; + Ok(RefreshedToken::new(format!("token-{n}")).with_expires_at( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600, + )) + } + } + + #[tokio::test] + async fn refresh_now_coalesces_concurrent_callers() { + let calls = Arc::new(AtomicUsize::new(0)); + let refresher = Arc::new(CountingRefresher { + calls: Arc::clone(&calls), + delay: Duration::from_millis(50), + }); + let source = TokenSource::new(RefreshedToken::new("initial").with_expires_at(0), refresher); + + let tasks = (0..5).map(|_| { + let src = source.clone(); + tokio::spawn(async move { src.refresh_now().await }) + }); + for t in tasks { + t.await.unwrap().unwrap(); + } + + assert_eq!( + calls.load(Ordering::SeqCst), + 1, + "single-flight should have collapsed 5 concurrent calls into 1 refresh" + ); + } + + #[tokio::test] + async fn current_returns_cached_when_not_near_expiry() { + let calls = Arc::new(AtomicUsize::new(0)); + let refresher = Arc::new(CountingRefresher { + calls: Arc::clone(&calls), + delay: Duration::from_millis(0), + }); + let future = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600; + let source = TokenSource::new( + RefreshedToken::new("fresh").with_expires_at(future), + refresher, + ); + + let token = source.current().await.unwrap(); + assert_eq!(token, "fresh"); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn current_refreshes_when_within_skew() { + let calls = Arc::new(AtomicUsize::new(0)); + let refresher = Arc::new(CountingRefresher { + calls: Arc::clone(&calls), + delay: Duration::from_millis(0), + }); + let near = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 5; + let source = TokenSource::new( + RefreshedToken::new("stale").with_expires_at(near), + refresher, + ); + + let token = source.current().await.unwrap(); + assert_eq!(token, "token-1"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } +} diff --git a/crates/openshell-sdk/src/transport.rs b/crates/openshell-sdk/src/transport.rs new file mode 100644 index 000000000..d930aee5e --- /dev/null +++ b/crates/openshell-sdk/src/transport.rs @@ -0,0 +1,253 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! gRPC transport stack: channel construction and the insecure-TLS connector. +//! +//! mTLS is intentionally out of scope here. Gateways that require client +//! certificates are handled by `openshell-cli`'s legacy path until the auth +//! method is retired. + +use crate::config::{AuthConfig, ClientConfig}; +use crate::edge_tunnel; +use crate::error::{Result, SdkError}; +use rustls::{ + client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + pki_types::{CertificateDer, ServerName, UnixTime}, +}; +use std::collections::HashMap; +use std::future::Future; +use std::net::SocketAddr; +use std::sync::OnceLock; +use std::time::Duration; +use tokio::sync::Mutex; +use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint}; +use tracing::debug; + +/// Standard endpoint settings used by every dialed connection. +/// +/// Centralizes timeouts and HTTP/2 keepalive so behavior is consistent across +/// transport branches. Returns an `Endpoint` ready for `connect()` / +/// `connect_with_connector()`. +fn standard_endpoint(uri: String) -> Result { + Endpoint::from_shared(uri) + .map(|ep| { + ep.connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) + .http2_keep_alive_interval(Duration::from_secs(10)) + .keep_alive_while_idle(true) + }) + .map_err(|e| SdkError::invalid_config(format!("invalid gateway URL: {e}"))) +} + +// ── Edge tunnel registry ───────────────────────────────────────────── +// Each distinct edge-authenticated gateway gets its own local proxy +// instead of reusing the first gateway touched in the current process. +static EDGE_TUNNEL_ADDRS: OnceLock>> = OnceLock::new(); + +/// Look up (or start) the local tunnel proxy for an edge-authenticated +/// gateway. Subsequent calls with the same `(server, token)` reuse the +/// existing proxy. +async fn edge_tunnel_addr(server: &str, token: &str) -> Result { + let key = (server.to_string(), token.to_string()); + let registry = EDGE_TUNNEL_ADDRS.get_or_init(|| Mutex::new(HashMap::new())); + + { + let addrs = registry.lock().await; + if let Some(addr) = addrs.get(&key).copied() { + return Ok(addr); + } + } + + let proxy = edge_tunnel::start_tunnel_proxy(server, token).await?; + debug!( + local_addr = %proxy.local_addr, + server, + "edge tunnel proxy started, routing gRPC through local proxy" + ); + + let mut addrs = registry.lock().await; + Ok(*addrs.entry(key).or_insert(proxy.local_addr)) +} + +// ── Channel construction ───────────────────────────────────────────── + +/// Open a gRPC channel to the gateway described by `config`. +/// +/// Routing is determined by `gateway` scheme + `auth` variant + +/// `insecure_skip_verify`. Reference today's CLI implementation in +/// `openshell-cli/src/tls.rs::build_channel` (lines 219–308) for behavior +/// the SDK needs to preserve. +/// +/// **Branch table:** +/// +/// | `gateway` scheme | `auth` | `insecure_skip_verify` | TLS handling | +/// |------------------|--------|------------------------|-------------| +/// | `http://` | (any) | (any) | plaintext, ignore tls | +/// | `https://` | `Some(EdgeJwt)` | (any) | tunnel proxy + plaintext to local proxy | +/// | `https://` | (any) | `true` | `InsecureTlsConnector`, no verification | +/// | `https://` | `Some(Oidc)` or `None` | `false` | tonic TLS, pin `ca_cert` if set, system roots otherwise | +pub async fn build_channel(config: &ClientConfig) -> Result { + let gateway = &config.gateway; + + // Branch 1 — plaintext. + // Reference: cli/tls.rs:220-228 (http:// branch). + if gateway.starts_with("http://") { + return standard_endpoint(gateway.clone())? + .connect() + .await + .map_err(|e| SdkError::connect(format!("{e}"))); + } + + if !gateway.starts_with("https://") { + return Err(SdkError::invalid_config(format!( + "gateway URL must start with http:// or https://: {gateway}" + ))); + } + + // Branch 2 — Cloudflare Access edge JWT: tunnel proxy + plaintext-to-local. + // Reference: cli/tls.rs:233-249 (https:// + edge_token branch). Use + // `edge_tunnel_addr(gateway, token).await?` to get the local proxy + // address, then `standard_endpoint(format!("http://{local_addr}"))?.connect()`. + if let Some(AuthConfig::EdgeJwt(token)) = &config.auth { + let local_addr = edge_tunnel_addr(gateway, token).await?; + return standard_endpoint(format!("http://{local_addr}"))? + .connect() + .await + .map_err(|e| SdkError::connect(format!("{e}"))); + } + + // Branch 3 — insecure TLS (skip cert verification). + // Reference: cli/tls.rs:251-268 (gateway_insecure branch). Build the + // insecure rustls config, wrap it in `InsecureTlsConnector`, swap the + // gateway scheme to http:// (so tonic doesn't double-layer TLS), and + // call `endpoint.connect_with_connector(connector)`. + if config.insecure_skip_verify { + tracing::warn!("TLS certificate verification is disabled — do not use in production"); + let rustls_config = build_insecure_rustls_config()?; + let tls_connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(rustls_config)); + let connector = InsecureTlsConnector { tls_connector }; + let http_uri = gateway.replacen("https://", "http://", 1); + return standard_endpoint(http_uri)? + .connect_with_connector(connector) + .await + .map_err(|e| SdkError::connect(format!("{e}"))); + } + + // Branch 4 — anonymous TLS or OIDC bearer over HTTPS. + // Reference: cli/tls.rs:270-307 (the `oidc_token` and final mTLS + // branches collapsed). Build a `ClientTlsConfig`: + // - if `config.ca_cert` is `Some(pem)`, pin it via `.ca_certificate(...)` + // - else fall back to `.with_enabled_roots()` (system roots) + // Then `endpoint.tls_config(tls_config)?.connect()`. + // + // The OIDC bearer header is added by the gRPC interceptor at request + // time, not here — `build_channel` only owns the TLS layer. + + let tls_config = config.ca_cert.as_ref().map_or_else( + || ClientTlsConfig::new().with_enabled_roots(), + |ca_cert| ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca_cert)), + ); + standard_endpoint(gateway.clone())? + .tls_config(tls_config) + .map_err(|e| SdkError::tls(format!("{e}")))? + .connect() + .await + .map_err(|e| SdkError::connect(format!("{e}"))) +} + +/// rustls verifier that accepts any server certificate. +/// +/// Used only when the caller explicitly opts into +/// [`ClientConfig::insecure_skip_verify`]. Do not use in production. +/// +/// [`ClientConfig::insecure_skip_verify`]: crate::config::ClientConfig::insecure_skip_verify +#[derive(Debug)] +pub struct InsecureServerCertVerifier; + +impl ServerCertVerifier for InsecureServerCertVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> std::result::Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } +} + +/// rustls client config that disables server certificate verification. +/// +/// Pairs with [`InsecureTlsConnector`] for transports that need to skip +/// verification (development, debug). Returns `Result` for symmetry with +/// future verifying variants; the current implementation cannot fail. +pub fn build_insecure_rustls_config() -> Result { + Ok(rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(std::sync::Arc::new(InsecureServerCertVerifier)) + .with_no_client_auth()) +} + +/// `tower::Service` connector that performs TLS using the supplied rustls +/// connector, bypassing tonic's built-in TLS layering. +/// +/// Used to plumb [`InsecureServerCertVerifier`]-backed configs into a tonic +/// `Endpoint` via `connect_with_connector`. +#[derive(Clone)] +pub struct InsecureTlsConnector { + /// Inner rustls connector configured by the caller. + pub tls_connector: tokio_rustls::TlsConnector, +} + +impl tower::Service for InsecureTlsConnector { + type Response = hyper_util::rt::TokioIo>; + type Error = Box; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, uri: hyper::Uri) -> Self::Future { + let tls_connector = self.tls_connector.clone(); + Box::pin(async move { + let host = uri.host().unwrap_or("localhost").to_string(); + let port = uri.port_u16().unwrap_or(443); + let addr = format!("{host}:{port}"); + let tcp = tokio::net::TcpStream::connect(addr).await?; + let server_name = ServerName::try_from(host)?; + let tls_stream = tls_connector.connect(server_name, tcp).await?; + Ok(hyper_util::rt::TokioIo::new(tls_stream)) + }) + } +} diff --git a/crates/openshell-sdk/src/types.rs b/crates/openshell-sdk/src/types.rs new file mode 100644 index 000000000..5cfa99f48 --- /dev/null +++ b/crates/openshell-sdk/src/types.rs @@ -0,0 +1,169 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Curated public types for the high-level SDK surface. +//! +//! These types intentionally diverge from the raw protobuf shapes so future +//! language bindings (TypeScript via napi, Python via `PyO3`) can render them +//! idiomatically. In particular, enum-valued fields use Rust enums that map +//! to string literals in TypeScript rather than numeric proto enums; nested +//! `Option<...>` chains from proto are flattened where one of the wrappers +//! is structurally meaningless. +//! +//! The raw proto clients are still accessible via [`crate::raw`] as an +//! escape hatch for callers who need fields not exposed here. + +use openshell_core::proto; +use std::collections::HashMap; +use std::time::Duration; + +/// Gateway health snapshot. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct Health { + pub status: ServiceStatus, + pub version: String, +} + +/// Coarse gateway service status. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum ServiceStatus { + Unspecified, + Healthy, + Degraded, + Unhealthy, +} + +impl From for ServiceStatus { + fn from(value: proto::ServiceStatus) -> Self { + match value { + proto::ServiceStatus::Healthy => Self::Healthy, + proto::ServiceStatus::Degraded => Self::Degraded, + proto::ServiceStatus::Unhealthy => Self::Unhealthy, + proto::ServiceStatus::Unspecified => Self::Unspecified, + } + } +} + +impl From for ServiceStatus { + fn from(value: i32) -> Self { + proto::ServiceStatus::try_from(value).map_or(Self::Unspecified, Self::from) + } +} + +/// High-level sandbox lifecycle phase. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum SandboxPhase { + Unspecified, + Provisioning, + Ready, + Error, + Deleting, + Unknown, +} + +impl From for SandboxPhase { + fn from(value: proto::SandboxPhase) -> Self { + match value { + proto::SandboxPhase::Unspecified => Self::Unspecified, + proto::SandboxPhase::Provisioning => Self::Provisioning, + proto::SandboxPhase::Ready => Self::Ready, + proto::SandboxPhase::Error => Self::Error, + proto::SandboxPhase::Deleting => Self::Deleting, + proto::SandboxPhase::Unknown => Self::Unknown, + } + } +} + +impl From for SandboxPhase { + fn from(value: i32) -> Self { + proto::SandboxPhase::try_from(value).map_or(Self::Unspecified, Self::from) + } +} + +/// Caller intent for a new sandbox. +/// +/// Only the most commonly used fields are exposed. Callers that need the +/// full proto surface (volume claim templates, runtime classes, struct +/// resources, etc.) should drop down to [`crate::raw`]. +#[derive(Clone, Debug, Default)] +pub struct SandboxSpec { + /// Optional user-supplied sandbox name. When empty the server generates one. + pub name: Option, + /// Container image reference (e.g. `ghcr.io/nvidia/openshell-community/sandboxes/python:latest`). + pub image: Option, + /// Labels attached to the sandbox. + pub labels: HashMap, + /// Environment variables injected into the sandbox runtime. + pub environment: HashMap, + /// Provider names to attach. + pub providers: Vec, + /// Request a GPU. + pub gpu: bool, + /// Driver-specific GPU device selector (CDI ID for Docker/Podman; BDF or + /// index for VM). Only meaningful when `gpu` is true; empty defers to the + /// driver's default selection. + pub gpu_device: Option, +} + +/// Reference to a sandbox owned by the gateway. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct SandboxRef { + pub id: String, + pub name: String, + pub phase: SandboxPhase, + pub labels: HashMap, + pub resource_version: u64, +} + +impl SandboxRef { + pub(crate) fn from_proto(sandbox: proto::Sandbox) -> Self { + let meta = sandbox.metadata.unwrap_or_default(); + Self { + id: meta.id, + name: meta.name, + phase: sandbox.phase.into(), + labels: meta.labels, + resource_version: meta.resource_version, + } + } +} + +/// Options for listing sandboxes. +#[derive(Clone, Debug, Default)] +pub struct ListOptions { + /// Maximum sandboxes to return. `0` defers to the server default. + pub limit: u32, + /// Offset into the result list. + pub offset: u32, + /// Optional Kubernetes-style label selector (e.g. `env=prod,team=core`). + pub label_selector: Option, +} + +/// Options for [`crate::client::OpenShellClient::exec`]. +#[derive(Clone, Debug, Default)] +pub struct ExecOptions { + /// Working directory inside the sandbox. + pub workdir: Option, + /// Environment overrides for the exec. + pub environment: HashMap, + /// Optional command timeout. `None` lets the gateway choose. + pub timeout: Option, + /// Optional stdin payload. + pub stdin: Option>, +} + +/// Result of a non-streaming exec call. +/// +/// `stdout` and `stderr` are buffered to the end of the command. Use the +/// raw streaming RPC ([`crate::raw`]) for long-running output. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct ExecResult { + pub exit_code: i32, + pub stdout: Vec, + pub stderr: Vec, +} diff --git a/crates/openshell-sdk/tests/client_mock.rs b/crates/openshell-sdk/tests/client_mock.rs new file mode 100644 index 000000000..672ccc9ca --- /dev/null +++ b/crates/openshell-sdk/tests/client_mock.rs @@ -0,0 +1,764 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! High-level [`OpenShellClient`] tests against an in-process mock gateway. +//! +//! The mock binds an ephemeral plaintext TCP listener and serves the +//! `OpenShell` gRPC service. Tests dial it via `http://127.0.0.1:` so +//! TLS and auth code paths are skipped — those are exercised by the CLI's +//! `mtls_integration` and OIDC tests. + +use openshell_core::proto; +use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; +use openshell_sdk::{ + ClientConfig, ExecOptions, ListOptions, OpenShellClient, SandboxPhase, SandboxSpec, + ServiceStatus as SdkServiceStatus, +}; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicU32, Ordering}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::{Response, Status}; + +/// Captured fixture state — what the mock observed and the canned replies it +/// returned. One per test so assertions are scoped. +#[derive(Default)] +struct MockState { + last_get_name: Mutex>, + last_create: Mutex>, + last_delete_name: Mutex>, + last_list_request: Mutex>, + last_exec_request: Mutex>, + get_calls: AtomicU32, + phase_sequence: Vec, + get_returns_not_found: bool, + not_found_after: Option, +} + +#[derive(Clone)] +struct TestOpenShell { + state: Arc, +} + +fn sandbox_with_phase(name: &str, phase: proto::SandboxPhase) -> proto::Sandbox { + proto::Sandbox { + metadata: Some(proto::datamodel::v1::ObjectMeta { + id: format!("id-{name}"), + name: name.to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 1, + }), + spec: None, + status: None, + phase: phase.into(), + current_policy_version: 0, + } +} + +#[tonic::async_trait] +impl OpenShell for TestOpenShell { + async fn health( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::HealthResponse { + status: proto::ServiceStatus::Healthy.into(), + version: "test-1.2.3".to_string(), + })) + } + + async fn create_sandbox( + &self, + request: tonic::Request, + ) -> Result, Status> { + let req = request.into_inner(); + let name = if req.name.is_empty() { + "generated".to_string() + } else { + req.name.clone() + }; + *self.state.last_create.lock().await = Some(req); + Ok(Response::new(proto::SandboxResponse { + sandbox: Some(sandbox_with_phase(&name, proto::SandboxPhase::Provisioning)), + })) + } + + async fn get_sandbox( + &self, + request: tonic::Request, + ) -> Result, Status> { + let name = request.into_inner().name; + *self.state.last_get_name.lock().await = Some(name.clone()); + let count = self.state.get_calls.fetch_add(1, Ordering::SeqCst); + + if self.state.get_returns_not_found { + return Err(Status::not_found(format!("sandbox '{name}' not found"))); + } + if let Some(threshold) = self.state.not_found_after + && count >= threshold + { + return Err(Status::not_found(format!("sandbox '{name}' not found"))); + } + + let phase = self + .state + .phase_sequence + .get(count as usize) + .copied() + .or_else(|| self.state.phase_sequence.last().copied()) + .unwrap_or(proto::SandboxPhase::Ready); + + Ok(Response::new(proto::SandboxResponse { + sandbox: Some(sandbox_with_phase(&name, phase)), + })) + } + + async fn list_sandboxes( + &self, + request: tonic::Request, + ) -> Result, Status> { + *self.state.last_list_request.lock().await = Some(request.into_inner()); + Ok(Response::new(proto::ListSandboxesResponse { + sandboxes: vec![ + sandbox_with_phase("alpha", proto::SandboxPhase::Ready), + sandbox_with_phase("beta", proto::SandboxPhase::Provisioning), + ], + })) + } + + async fn list_sandbox_providers( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ListSandboxProvidersResponse::default())) + } + + async fn attach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + proto::AttachSandboxProviderResponse::default(), + )) + } + + async fn detach_sandbox_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + proto::DetachSandboxProviderResponse::default(), + )) + } + + async fn delete_sandbox( + &self, + request: tonic::Request, + ) -> Result, Status> { + let name = request.into_inner().name; + *self.state.last_delete_name.lock().await = Some(name); + Ok(Response::new(proto::DeleteSandboxResponse { + deleted: true, + })) + } + + async fn create_ssh_session( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::CreateSshSessionResponse::default())) + } + + async fn expose_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ServiceEndpointResponse::default())) + } + + async fn get_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn list_services( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_service( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn revoke_ssh_session( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::RevokeSshSessionResponse::default())) + } + + type ExecSandboxStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn exec_sandbox( + &self, + request: tonic::Request, + ) -> Result, Status> { + *self.state.last_exec_request.lock().await = Some(request.into_inner()); + let (tx, rx) = tokio::sync::mpsc::channel(8); + tokio::spawn(async move { + let _ = tx + .send(Ok(proto::ExecSandboxEvent { + payload: Some(proto::exec_sandbox_event::Payload::Stdout( + proto::ExecSandboxStdout { + data: b"hello\n".to_vec(), + }, + )), + })) + .await; + let _ = tx + .send(Ok(proto::ExecSandboxEvent { + payload: Some(proto::exec_sandbox_event::Payload::Stderr( + proto::ExecSandboxStderr { + data: b"warn\n".to_vec(), + }, + )), + })) + .await; + let _ = tx + .send(Ok(proto::ExecSandboxEvent { + payload: Some(proto::exec_sandbox_event::Payload::Exit( + proto::ExecSandboxExit { exit_code: 7 }, + )), + })) + .await; + }); + Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new( + rx, + ))) + } + + type ForwardTcpStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn forward_tcp( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + type ExecSandboxInteractiveStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn exec_sandbox_interactive( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn create_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ProviderResponse::default())) + } + + async fn get_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ProviderResponse::default())) + } + + async fn list_providers( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ListProvidersResponse::default())) + } + + async fn list_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_provider_profile( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn import_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn lint_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn update_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::ProviderResponse::default())) + } + + async fn get_provider_refresh_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn configure_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn rotate_provider_credential( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider_refresh( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::DeleteProviderResponse::default())) + } + + async fn delete_provider_profile( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_sandbox_config( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::GetSandboxConfigResponse::default())) + } + + async fn get_gateway_config( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(proto::GetGatewayConfigResponse::default())) + } + + async fn update_config( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_sandbox_policy_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + proto::GetSandboxPolicyStatusResponse::default(), + )) + } + + async fn list_sandbox_policies( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn report_policy_status( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_sandbox_provider_environment( + &self, + _: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + proto::GetSandboxProviderEnvironmentResponse::default(), + )) + } + + async fn get_sandbox_logs( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn push_sandbox_logs( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + type WatchSandboxStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn watch_sandbox( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn submit_policy_analysis( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_draft_policy( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn approve_draft_chunk( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn approve_all_draft_chunks( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn reject_draft_chunk( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn edit_draft_chunk( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn undo_draft_chunk( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn clear_draft_chunks( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn get_draft_history( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn connect_supervisor( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + type RelayStreamStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn relay_stream( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn issue_sandbox_token( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn refresh_sandbox_token( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } +} + +/// Spin up the mock gateway, return its endpoint URL. +async fn start_mock(state: Arc) -> String { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let endpoint = format!("http://{addr}"); + let stream = TcpListenerStream::new(listener); + let svc = OpenShellServer::new(TestOpenShell { state }); + tokio::spawn(async move { + let _ = tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming(stream) + .await; + }); + endpoint +} + +async fn connect(endpoint: &str) -> OpenShellClient { + OpenShellClient::connect(ClientConfig::new(endpoint)) + .await + .expect("connect should succeed against local mock") +} + +#[tokio::test] +async fn health_returns_curated_snapshot() { + let state = Arc::new(MockState::default()); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let h = client.health().await.unwrap(); + assert_eq!(h.status, SdkServiceStatus::Healthy); + assert_eq!(h.version, "test-1.2.3"); +} + +#[tokio::test] +async fn create_sandbox_passes_spec_through() { + let state = Arc::new(MockState::default()); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let mut labels = HashMap::new(); + labels.insert("team".to_string(), "core".to_string()); + + let spec = SandboxSpec { + name: Some("my-box".to_string()), + image: Some("ghcr.io/foo:bar".to_string()), + labels: labels.clone(), + gpu: true, + ..Default::default() + }; + + let result = client.create_sandbox(spec).await.unwrap(); + assert_eq!(result.name, "my-box"); + assert_eq!(result.phase, SandboxPhase::Provisioning); + + let observed = state.last_create.lock().await.clone().unwrap(); + assert_eq!(observed.name, "my-box"); + assert_eq!(observed.labels, labels); + let observed_spec = observed.spec.unwrap(); + assert!(observed_spec.gpu); + assert_eq!( + observed_spec.template.as_ref().unwrap().image, + "ghcr.io/foo:bar" + ); +} + +#[tokio::test] +async fn get_sandbox_sends_name_and_maps_phase() { + let state = Arc::new(MockState { + phase_sequence: vec![proto::SandboxPhase::Ready], + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let sandbox = client.get_sandbox("my-box").await.unwrap(); + assert_eq!(sandbox.name, "my-box"); + assert_eq!(sandbox.id, "id-my-box"); + assert_eq!(sandbox.phase, SandboxPhase::Ready); + + let observed = state.last_get_name.lock().await.clone(); + assert_eq!(observed.as_deref(), Some("my-box")); +} + +#[tokio::test] +async fn list_sandboxes_propagates_filters() { + let state = Arc::new(MockState::default()); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let opts = ListOptions { + limit: 25, + offset: 5, + label_selector: Some("team=core".to_string()), + }; + let items = client.list_sandboxes(opts).await.unwrap(); + assert_eq!(items.len(), 2); + assert_eq!(items[0].name, "alpha"); + assert_eq!(items[0].phase, SandboxPhase::Ready); + assert_eq!(items[1].phase, SandboxPhase::Provisioning); + + let observed = state.last_list_request.lock().await.clone().unwrap(); + assert_eq!(observed.limit, 25); + assert_eq!(observed.offset, 5); + assert_eq!(observed.label_selector, "team=core"); +} + +#[tokio::test] +async fn delete_sandbox_returns_server_ack() { + let state = Arc::new(MockState::default()); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let deleted = client.delete_sandbox("doomed").await.unwrap(); + assert!(deleted); + + let observed = state.last_delete_name.lock().await.clone(); + assert_eq!(observed.as_deref(), Some("doomed")); +} + +#[tokio::test] +async fn wait_ready_transitions_through_phases() { + let state = Arc::new(MockState { + phase_sequence: vec![ + proto::SandboxPhase::Provisioning, + proto::SandboxPhase::Provisioning, + proto::SandboxPhase::Ready, + ], + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let sandbox = client + .wait_ready("my-box", std::time::Duration::from_secs(5)) + .await + .unwrap(); + assert_eq!(sandbox.phase, SandboxPhase::Ready); + assert!(state.get_calls.load(Ordering::SeqCst) >= 3); +} + +#[tokio::test] +async fn wait_ready_surfaces_error_phase() { + let state = Arc::new(MockState { + phase_sequence: vec![proto::SandboxPhase::Error], + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let err = client + .wait_ready("my-box", std::time::Duration::from_secs(5)) + .await + .unwrap_err(); + assert_eq!(err.code(), "connect"); +} + +#[tokio::test] +async fn wait_deleted_returns_when_get_reports_not_found() { + let state = Arc::new(MockState { + phase_sequence: vec![proto::SandboxPhase::Deleting], + not_found_after: Some(2), + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + client + .wait_deleted("my-box", std::time::Duration::from_secs(5)) + .await + .unwrap(); + assert!(state.get_calls.load(Ordering::SeqCst) >= 3); +} + +#[tokio::test] +async fn get_sandbox_not_found_maps_to_typed_error() { + let state = Arc::new(MockState { + get_returns_not_found: true, + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let err = client.get_sandbox("missing").await.unwrap_err(); + assert_eq!(err.code(), "not_found"); +} + +#[tokio::test] +async fn exec_buffers_stdout_stderr_and_exit() { + let state = Arc::new(MockState { + phase_sequence: vec![proto::SandboxPhase::Ready], + ..Default::default() + }); + let endpoint = start_mock(state.clone()).await; + let client = connect(&endpoint).await; + + let result = client + .exec( + "my-box", + &["echo".to_string(), "hello".to_string()], + ExecOptions { + workdir: Some("/work".to_string()), + timeout: Some(std::time::Duration::from_secs(10)), + ..Default::default() + }, + ) + .await + .unwrap(); + + assert_eq!(result.exit_code, 7); + assert_eq!(result.stdout, b"hello\n"); + assert_eq!(result.stderr, b"warn\n"); + + let observed = state.last_exec_request.lock().await.clone().unwrap(); + assert_eq!(observed.sandbox_id, "id-my-box"); + assert_eq!( + observed.command, + vec!["echo".to_string(), "hello".to_string()] + ); + assert_eq!(observed.workdir, "/work"); + assert_eq!(observed.timeout_seconds, 10); +} diff --git a/crates/openshell-tui/Cargo.toml b/crates/openshell-tui/Cargo.toml index b0ac0c7ca..e7230c2f4 100644 --- a/crates/openshell-tui/Cargo.toml +++ b/crates/openshell-tui/Cargo.toml @@ -15,6 +15,7 @@ openshell-core = { path = "../openshell-core" } openshell-bootstrap = { path = "../openshell-bootstrap" } openshell-policy = { path = "../openshell-policy" } openshell-providers = { path = "../openshell-providers" } +openshell-sdk = { path = "../openshell-sdk" } base64 = { workspace = true } ratatui = { workspace = true } diff --git a/crates/openshell-tui/src/app.rs b/crates/openshell-tui/src/app.rs index ba817bcf8..0325742fe 100644 --- a/crates/openshell-tui/src/app.rs +++ b/crates/openshell-tui/src/app.rs @@ -5,10 +5,10 @@ use std::collections::HashMap; use std::time::{Duration, Instant}; use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; -use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::proto::open_shell_client::OpenShellClient; use openshell_core::proto::setting_value; use openshell_core::settings::{self, SettingValueKind}; +use openshell_sdk::EdgeAuthInterceptor; use tonic::service::interceptor::InterceptedService; use tonic::transport::Channel; diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index 1969715ce..7fa671355 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -18,9 +18,9 @@ use crossterm::terminal::{ EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode, }; use miette::{IntoDiagnostic, Result}; -use openshell_core::auth::EdgeAuthInterceptor; use openshell_core::metadata::{ObjectId, ObjectLabels, ObjectName}; use openshell_core::proto::open_shell_client::OpenShellClient; +use openshell_sdk::EdgeAuthInterceptor; use ratatui::Terminal; use ratatui::backend::CrosstermBackend; use tokio::sync::mpsc; From 99ea7632241c32397c39924763042f370c993755 Mon Sep 17 00:00:00 2001 From: Max Dubrinsky Date: Thu, 28 May 2026 11:01:16 -0400 Subject: [PATCH 3/6] feat(sdk-node): add napi-rs TypeScript binding (@openshell/sdk) Per RFC 0005, ship openshell-sdk-node as a napi-rs wrapper over the openshell-sdk crate, exposing the same surface to TypeScript/Node consumers as @openshell/sdk. - New crate openshell-sdk-node binds OpenShellClient, OidcRefresher, the edge-tunnel dialer, and the sandbox/exec API through napi-rs. - Generated index.d.ts and index.js are committed; the per-platform .node binary is gitignored and rebuilt with `napi build` per host. - lib.mjs provides a small ESM facade with named exports plus the errorCode() helper for typed error discrimination. Tests: 5-case smoke suite that exercises exports, typed connect errors, and OidcRefresher single-flight semantics. --- Cargo.lock | 106 +++- crates/openshell-sdk-node/.gitignore | 3 + crates/openshell-sdk-node/Cargo.toml | 37 ++ crates/openshell-sdk-node/build.rs | 6 + crates/openshell-sdk-node/index.d.ts | 138 ++++++ crates/openshell-sdk-node/index.js | 590 +++++++++++++++++++++++ crates/openshell-sdk-node/lib.d.ts | 24 + crates/openshell-sdk-node/lib.mjs | 32 ++ crates/openshell-sdk-node/package.json | 55 +++ crates/openshell-sdk-node/src/lib.rs | 406 ++++++++++++++++ crates/openshell-sdk-node/test/smoke.mjs | 105 ++++ 11 files changed, 1500 insertions(+), 2 deletions(-) create mode 100644 crates/openshell-sdk-node/.gitignore create mode 100644 crates/openshell-sdk-node/Cargo.toml create mode 100644 crates/openshell-sdk-node/build.rs create mode 100644 crates/openshell-sdk-node/index.d.ts create mode 100644 crates/openshell-sdk-node/index.js create mode 100644 crates/openshell-sdk-node/lib.d.ts create mode 100644 crates/openshell-sdk-node/lib.mjs create mode 100644 crates/openshell-sdk-node/package.json create mode 100644 crates/openshell-sdk-node/src/lib.rs create mode 100644 crates/openshell-sdk-node/test/smoke.mjs diff --git a/Cargo.lock b/Cargo.lock index 739f47568..cb3d85c08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -741,7 +741,7 @@ checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" dependencies = [ "glob", "libc", - "libloading", + "libloading 0.8.9", ] [[package]] @@ -901,6 +901,15 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" +[[package]] +name = "convert_case" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "affbf0190ed2caf063e3def54ff444b449371d55c58e513a95ab98eca50adb49" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "core-foundation" version = "0.10.1" @@ -1105,6 +1114,12 @@ dependencies = [ "rand_core 0.10.0-rc-3", ] +[[package]] +name = "ctor" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d765eb1c0bda10d31e0ea185f5ee15da532d60b0912d2bd1441783439e749c5" + [[package]] name = "ctr" version = "0.9.2" @@ -2852,6 +2867,16 @@ dependencies = [ "windows-link", ] +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.16" @@ -3128,6 +3153,65 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" +[[package]] +name = "napi" +version = "3.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1d395473824516f38dd1071a1a37bc57daa7be65b293ebba4ead5f7abb017a2" +dependencies = [ + "anyhow", + "bitflags", + "ctor", + "futures", + "napi-build", + "napi-sys", + "nohash-hasher", + "rustc-hash 2.1.2", + "tokio", +] + +[[package]] +name = "napi-build" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9c366d2c8c60b86fa632df75f745509b52f9128f91a6bad4c796e44abb505e1" + +[[package]] +name = "napi-derive" +version = "3.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b3f766e04667e6da0e181e2da4f85475d5a6513b7cf6a80bea184e224a5b42" +dependencies = [ + "convert_case", + "ctor", + "napi-derive-backend", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "napi-derive-backend" +version = "5.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d5af30503edf933ce7377cf6d4c877a62b0f1107ea05585f1b5e430e88d5baf" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "semver", + "syn 2.0.117", +] + +[[package]] +name = "napi-sys" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eb602b84d7c1edae45e50bbf1374696548f36ae179dfa667f577e384bb90c2b" +dependencies = [ + "libloading 0.9.0", +] + [[package]] name = "nix" version = "0.29.0" @@ -3140,6 +3224,12 @@ dependencies = [ "libc", ] +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "nom" version = "7.1.3" @@ -3543,7 +3633,7 @@ dependencies = [ "flate2", "futures", "libc", - "libloading", + "libloading 0.8.9", "miette", "nix", "oci-client", @@ -3707,6 +3797,18 @@ dependencies = [ "tracing", ] +[[package]] +name = "openshell-sdk-node" +version = "0.0.0" +dependencies = [ + "async-trait", + "napi", + "napi-build", + "napi-derive", + "openshell-sdk", + "tokio", +] + [[package]] name = "openshell-server" version = "0.0.0" diff --git a/crates/openshell-sdk-node/.gitignore b/crates/openshell-sdk-node/.gitignore new file mode 100644 index 000000000..fb4a29529 --- /dev/null +++ b/crates/openshell-sdk-node/.gitignore @@ -0,0 +1,3 @@ +node_modules/ +package-lock.json +*.node diff --git a/crates/openshell-sdk-node/Cargo.toml b/crates/openshell-sdk-node/Cargo.toml new file mode 100644 index 000000000..e798acc21 --- /dev/null +++ b/crates/openshell-sdk-node/Cargo.toml @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-sdk-node" +description = "napi-rs bindings for the OpenShell Rust SDK; published as @openshell/sdk" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +# Required for napi-rs: produces a dynamic library Node.js loads via N-API. +# `rlib` is kept so the crate can also be exercised by `cargo test --workspace`. +[lib] +crate-type = ["cdylib", "rlib"] + +[dependencies] +openshell-sdk = { path = "../openshell-sdk" } +async-trait = "0.1" +napi = { version = "3", default-features = false, features = [ + "async", + "napi6", + "tokio_rt", + "error_anyhow", +] } +napi-derive = "3" +tokio = { workspace = true } + +[build-dependencies] +napi-build = "2" + +# Lints are disabled for this crate because napi-derive expands into patterns +# (raw FFI signatures, Unsafe blocks for N-API, generated TypeScript shim +# functions) that trigger several pedantic lints. The wrapped Rust SDK already +# enforces the workspace lint policy. +[lints] diff --git a/crates/openshell-sdk-node/build.rs b/crates/openshell-sdk-node/build.rs new file mode 100644 index 000000000..010bd1add --- /dev/null +++ b/crates/openshell-sdk-node/build.rs @@ -0,0 +1,6 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +fn main() { + napi_build::setup(); +} diff --git a/crates/openshell-sdk-node/index.d.ts b/crates/openshell-sdk-node/index.d.ts new file mode 100644 index 000000000..7b1b4453c --- /dev/null +++ b/crates/openshell-sdk-node/index.d.ts @@ -0,0 +1,138 @@ +/* auto-generated by NAPI-RS */ +/* eslint-disable */ +/** + * A live token source backed by a JS callback. Hand off to + * [`OpenShellClient::set_oidc_refresher`] before any RPCs run; the SDK + * proactively refreshes when the token is within 60s of expiry, and + * coalesces concurrent refreshes into a single callback invocation. + */ +export declare class OidcRefresher { + /** + * Create a refresher with an initial token and a JS callback. + * + * The callback must return a Promise resolving to + * `{ accessToken, expiresAt? }`. `expiresAt` is Unix epoch seconds. + */ + constructor(initialToken: string, initialExpiresAt: number | undefined | null, callback: () => Promise<{ accessToken: string; expiresAt?: number }>) + /** + * Snapshot the current token (no refresh check). Mostly useful for + * tests; in steady-state the SDK calls this internally. + */ + currentToken(): string + /** + * Force a refresh now and return the new access token. Concurrent + * callers coalesce. + */ + refresh(): Promise +} + +/** + * The JS-facing client. Cheap to share between async tasks; do not call + * `connect` per request. + */ +export declare class OpenShellClient { + /** Open a connection to the gateway described by `options`. */ + static connect(options: ConnectOptions): Promise + /** Gateway health snapshot. */ + health(): Promise + /** Create a new sandbox. */ + createSandbox(spec: SandboxSpec): Promise + /** Fetch a sandbox by name. */ + getSandbox(name: string): Promise + /** List sandboxes. */ + listSandboxes(options?: ListOptions | undefined | null): Promise> + /** + * Delete a sandbox by name. Returns `true` when the gateway acknowledged + * the deletion, `false` when it was already absent. + */ + deleteSandbox(name: string): Promise + /** Poll until the sandbox reaches `ready` or `timeout_secs` elapses. */ + waitReady(name: string, timeoutSecs: number): Promise + /** Poll until the sandbox is gone or `timeout_secs` elapses. */ + waitDeleted(name: string, timeoutSecs: number): Promise + /** Run a command inside a sandbox; buffers stdout/stderr to the end. */ + exec(name: string, command: Array, options?: ExecOptions | undefined | null): Promise +} + +/** + * Connection options. Mirrors [`openshell_sdk::ClientConfig`] with + * JS-friendly field names. + */ +export interface ConnectOptions { + /** Gateway URL (`http://...` or `https://...`). */ + gateway: string + /** CA certificate (PEM-encoded). `None` falls back to system roots. */ + caCert?: Buffer + /** Bearer token for direct OIDC auth. Mutually exclusive with `edge_token`. */ + oidcToken?: string + /** Cloudflare Access bearer token. Routes through a local WebSocket tunnel. */ + edgeToken?: string + /** Disable TLS certificate verification (development/debug only). */ + insecureSkipVerify?: boolean +} + +/** Options for [`OpenShellClient::exec`]. */ +export interface ExecOptions { + workdir?: string + environment?: Record + /** Timeout in seconds. `None` lets the gateway choose. */ + timeoutSecs?: number + /** Optional stdin payload. */ + stdin?: Buffer +} + +/** Result of a non-streaming exec call. */ +export interface ExecResult { + exitCode: number + stdout: Buffer + stderr: Buffer +} + +/** Gateway health snapshot. */ +export interface Health { + /** Coarse status: `"healthy"`, `"degraded"`, `"unhealthy"`, `"unspecified"`. */ + status: string + version: string +} + +/** JS-side refresh callback returning a Promise<{ accessToken, expiresAt? }>. */ +export interface JsRefreshedToken { + accessToken: string + /** + * Expiry as Unix epoch seconds. Stored as `f64` because JS numbers + * can't hold `u64` exactly past 2^53; values are clamped to that range + * in practice (the year 287396 is fine). + */ + expiresAt?: number +} + +/** Options for [`OpenShellClient::list_sandboxes`]. */ +export interface ListOptions { + limit?: number + offset?: number + labelSelector?: string +} + +/** + * Lifecycle phase: `"unspecified"`, `"provisioning"`, `"ready"`, `"error"`, + * `"deleting"`, `"unknown"`. + */ +export interface SandboxRef { + id: string + name: string + phase: string + labels: Record + /** Resource version as a string — JS numbers can't safely hold u64. */ + resourceVersion: string +} + +/** Caller intent for a new sandbox. */ +export interface SandboxSpec { + name?: string + image?: string + labels?: Record + environment?: Record + providers?: Array + gpu?: boolean + gpuDevice?: string +} diff --git a/crates/openshell-sdk-node/index.js b/crates/openshell-sdk-node/index.js new file mode 100644 index 000000000..565c20ef5 --- /dev/null +++ b/crates/openshell-sdk-node/index.js @@ -0,0 +1,590 @@ +// prettier-ignore +/* eslint-disable */ +// @ts-nocheck +/* auto-generated by NAPI-RS */ + +const { readFileSync } = require('node:fs') +let nativeBinding = null +const loadErrors = [] + +const isMusl = () => { + let musl = false + if (process.platform === 'linux') { + musl = isMuslFromFilesystem() + if (musl === null) { + musl = isMuslFromReport() + } + if (musl === null) { + musl = isMuslFromChildProcess() + } + } + return musl +} + +const isFileMusl = (f) => f.includes('libc.musl-') || f.includes('ld-musl-') + +const isMuslFromFilesystem = () => { + try { + return readFileSync('/usr/bin/ldd', 'utf-8').includes('musl') + } catch { + return null + } +} + +const isMuslFromReport = () => { + let report = null + if (typeof process.report?.getReport === 'function') { + process.report.excludeNetwork = true + report = process.report.getReport() + } + if (!report) { + return null + } + if (report.header && report.header.glibcVersionRuntime) { + return false + } + if (Array.isArray(report.sharedObjects)) { + if (report.sharedObjects.some(isFileMusl)) { + return true + } + } + return false +} + +const isMuslFromChildProcess = () => { + try { + return require('child_process').execSync('ldd --version', { encoding: 'utf8' }).includes('musl') + } catch (e) { + // If we reach this case, we don't know if the system is musl or not, so is better to just fallback to false + return false + } +} + +function requireNative() { + if (process.env.NAPI_RS_NATIVE_LIBRARY_PATH) { + try { + return require(process.env.NAPI_RS_NATIVE_LIBRARY_PATH); + } catch (err) { + loadErrors.push(err) + } + } else if (process.platform === 'android') { + if (process.arch === 'arm64') { + try { + return require('./openshell-sdk.android-arm64.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-android-arm64') + const bindingPackageVersion = require('@openshell/sdk-android-arm64/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else if (process.arch === 'arm') { + try { + return require('./openshell-sdk.android-arm-eabi.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-android-arm-eabi') + const bindingPackageVersion = require('@openshell/sdk-android-arm-eabi/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + loadErrors.push(new Error(`Unsupported architecture on Android ${process.arch}`)) + } + } else if (process.platform === 'win32') { + if (process.arch === 'x64') { + if (process.config?.variables?.shlib_suffix === 'dll.a' || process.config?.variables?.node_target_type === 'shared_library') { + try { + return require('./openshell-sdk.win32-x64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-win32-x64-gnu') + const bindingPackageVersion = require('@openshell/sdk-win32-x64-gnu/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + try { + return require('./openshell-sdk.win32-x64-msvc.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-win32-x64-msvc') + const bindingPackageVersion = require('@openshell/sdk-win32-x64-msvc/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } + } else if (process.arch === 'ia32') { + try { + return require('./openshell-sdk.win32-ia32-msvc.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-win32-ia32-msvc') + const bindingPackageVersion = require('@openshell/sdk-win32-ia32-msvc/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else if (process.arch === 'arm64') { + try { + return require('./openshell-sdk.win32-arm64-msvc.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-win32-arm64-msvc') + const bindingPackageVersion = require('@openshell/sdk-win32-arm64-msvc/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + loadErrors.push(new Error(`Unsupported architecture on Windows: ${process.arch}`)) + } + } else if (process.platform === 'darwin') { + try { + return require('./openshell-sdk.darwin-universal.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-darwin-universal') + const bindingPackageVersion = require('@openshell/sdk-darwin-universal/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + if (process.arch === 'x64') { + try { + return require('./openshell-sdk.darwin-x64.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-darwin-x64') + const bindingPackageVersion = require('@openshell/sdk-darwin-x64/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else if (process.arch === 'arm64') { + try { + return require('./openshell-sdk.darwin-arm64.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-darwin-arm64') + const bindingPackageVersion = require('@openshell/sdk-darwin-arm64/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + loadErrors.push(new Error(`Unsupported architecture on macOS: ${process.arch}`)) + } + } else if (process.platform === 'freebsd') { + if (process.arch === 'x64') { + try { + return require('./openshell-sdk.freebsd-x64.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-freebsd-x64') + const bindingPackageVersion = require('@openshell/sdk-freebsd-x64/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else if (process.arch === 'arm64') { + try { + return require('./openshell-sdk.freebsd-arm64.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-freebsd-arm64') + const bindingPackageVersion = require('@openshell/sdk-freebsd-arm64/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + loadErrors.push(new Error(`Unsupported architecture on FreeBSD: ${process.arch}`)) + } + } else if (process.platform === 'linux') { + if (process.arch === 'x64') { + if (isMusl()) { + try { + return require('./openshell-sdk.linux-x64-musl.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-x64-musl') + const bindingPackageVersion = require('@openshell/sdk-linux-x64-musl/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + try { + return require('./openshell-sdk.linux-x64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-x64-gnu') + const bindingPackageVersion = require('@openshell/sdk-linux-x64-gnu/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } + } else if (process.arch === 'arm64') { + if (isMusl()) { + try { + return require('./openshell-sdk.linux-arm64-musl.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-arm64-musl') + const bindingPackageVersion = require('@openshell/sdk-linux-arm64-musl/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + try { + return require('./openshell-sdk.linux-arm64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-arm64-gnu') + const bindingPackageVersion = require('@openshell/sdk-linux-arm64-gnu/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } + } else if (process.arch === 'arm') { + if (isMusl()) { + try { + return require('./openshell-sdk.linux-arm-musleabihf.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-arm-musleabihf') + const bindingPackageVersion = require('@openshell/sdk-linux-arm-musleabihf/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + try { + return require('./openshell-sdk.linux-arm-gnueabihf.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-arm-gnueabihf') + const bindingPackageVersion = require('@openshell/sdk-linux-arm-gnueabihf/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } + } else if (process.arch === 'loong64') { + if (isMusl()) { + try { + return require('./openshell-sdk.linux-loong64-musl.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-loong64-musl') + const bindingPackageVersion = require('@openshell/sdk-linux-loong64-musl/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + try { + return require('./openshell-sdk.linux-loong64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-loong64-gnu') + const bindingPackageVersion = require('@openshell/sdk-linux-loong64-gnu/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } + } else if (process.arch === 'riscv64') { + if (isMusl()) { + try { + return require('./openshell-sdk.linux-riscv64-musl.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-riscv64-musl') + const bindingPackageVersion = require('@openshell/sdk-linux-riscv64-musl/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + try { + return require('./openshell-sdk.linux-riscv64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-riscv64-gnu') + const bindingPackageVersion = require('@openshell/sdk-linux-riscv64-gnu/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } + } else if (process.arch === 'ppc64') { + try { + return require('./openshell-sdk.linux-ppc64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-ppc64-gnu') + const bindingPackageVersion = require('@openshell/sdk-linux-ppc64-gnu/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else if (process.arch === 's390x') { + try { + return require('./openshell-sdk.linux-s390x-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-linux-s390x-gnu') + const bindingPackageVersion = require('@openshell/sdk-linux-s390x-gnu/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + loadErrors.push(new Error(`Unsupported architecture on Linux: ${process.arch}`)) + } + } else if (process.platform === 'openharmony') { + if (process.arch === 'arm64') { + try { + return require('./openshell-sdk.openharmony-arm64.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-openharmony-arm64') + const bindingPackageVersion = require('@openshell/sdk-openharmony-arm64/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else if (process.arch === 'x64') { + try { + return require('./openshell-sdk.openharmony-x64.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-openharmony-x64') + const bindingPackageVersion = require('@openshell/sdk-openharmony-x64/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else if (process.arch === 'arm') { + try { + return require('./openshell-sdk.openharmony-arm.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@openshell/sdk-openharmony-arm') + const bindingPackageVersion = require('@openshell/sdk-openharmony-arm/package.json').version + if (bindingPackageVersion !== '0.0.0-alpha.0' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 0.0.0-alpha.0 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + loadErrors.push(new Error(`Unsupported architecture on OpenHarmony: ${process.arch}`)) + } + } else { + loadErrors.push(new Error(`Unsupported OS: ${process.platform}, architecture: ${process.arch}`)) + } +} + +nativeBinding = requireNative() + +// NAPI_RS_FORCE_WASI is a tri-state flag: +// unset / any other value → native binding preferred, WASI is only a fallback +// 'true' → force WASI fallback even if native loaded +// 'error' → force WASI and throw if no WASI binding is found +// Treating any non-empty string as truthy (the historical behavior) meant +// NAPI_RS_FORCE_WASI=false, NAPI_RS_FORCE_WASI=0, etc. inadvertently triggered +// the WASI path, causing ENOENT for packages shipped without a .wasi.cjs file. +const forceWasi = + process.env.NAPI_RS_FORCE_WASI === 'true' || process.env.NAPI_RS_FORCE_WASI === 'error' + +if (!nativeBinding || forceWasi) { + let wasiBinding = null + let wasiBindingError = null + try { + wasiBinding = require('./openshell-sdk.wasi.cjs') + nativeBinding = wasiBinding + } catch (err) { + if (forceWasi) { + wasiBindingError = err + } + } + if (!nativeBinding || forceWasi) { + try { + wasiBinding = require('@openshell/sdk-wasm32-wasi') + nativeBinding = wasiBinding + } catch (err) { + if (forceWasi) { + if (!wasiBindingError) { + wasiBindingError = err + } else { + wasiBindingError.cause = err + } + loadErrors.push(err) + } + } + } + if (process.env.NAPI_RS_FORCE_WASI === 'error' && !wasiBinding) { + const error = new Error('WASI binding not found and NAPI_RS_FORCE_WASI is set to error') + error.cause = wasiBindingError + throw error + } +} + +if (!nativeBinding) { + if (loadErrors.length > 0) { + throw new Error( + `Cannot find native binding. ` + + `npm has a bug related to optional dependencies (https://github.com/npm/cli/issues/4828). ` + + 'Please try `npm i` again after removing both package-lock.json and node_modules directory.', + { + cause: loadErrors.reduce((err, cur) => { + cur.cause = err + return cur + }), + }, + ) + } + throw new Error(`Failed to load native binding`) +} + +module.exports = nativeBinding +module.exports.OidcRefresher = nativeBinding.OidcRefresher +module.exports.OpenShellClient = nativeBinding.OpenShellClient diff --git a/crates/openshell-sdk-node/lib.d.ts b/crates/openshell-sdk-node/lib.d.ts new file mode 100644 index 000000000..f6b74039d --- /dev/null +++ b/crates/openshell-sdk-node/lib.d.ts @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +export { OidcRefresher, OpenShellClient } from './index.js' +export type { + ConnectOptions, + ExecOptions, + ExecResult, + Health, + JsRefreshedToken, + ListOptions, + SandboxRef, + SandboxSpec, +} from './index.js' + +/** + * Extract the SDK error code from a thrown error. + * + * The native binding prefixes every error message with `[code] ` where + * `code` is one of: `invalid_config`, `tls`, `connect`, `auth`, `io`, + * `not_found`, `already_exists`, `rpc`. Returns `null` when the prefix is + * missing (the error wasn't thrown by this binding). + */ +export declare function errorCode(err: unknown): string | null diff --git a/crates/openshell-sdk-node/lib.mjs b/crates/openshell-sdk-node/lib.mjs new file mode 100644 index 000000000..a40d5eff5 --- /dev/null +++ b/crates/openshell-sdk-node/lib.mjs @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// ESM facade over the auto-generated CommonJS index. +// +// Re-exports the napi-generated classes and adds the `errorCode` helper, +// which parses the `[code] message` prefix the binding uses to surface the +// SDK's discriminable error kind. The auto-generated `index.js`/`index.d.ts` +// pair from napi-rs is left untouched so the build matrix stays cookie-cutter. + +import nativeBinding from './index.js' + +const { OpenShellClient, OidcRefresher } = nativeBinding + +export { OpenShellClient, OidcRefresher } + +/** + * Extract the SDK error code from a thrown error. + * + * The native binding prefixes every error message with `[code] ` where + * `code` is one of: `invalid_config`, `tls`, `connect`, `auth`, `io`, + * `not_found`, `already_exists`, `rpc`. Returns `null` when the prefix is + * missing (the error wasn't thrown by this binding). + * + * @param {unknown} err + * @returns {string | null} + */ +export function errorCode(err) { + if (!err || typeof err.message !== 'string') return null + const match = err.message.match(/^\[([^\]]+)\]/) + return match ? match[1] : null +} diff --git a/crates/openshell-sdk-node/package.json b/crates/openshell-sdk-node/package.json new file mode 100644 index 000000000..c31300b0e --- /dev/null +++ b/crates/openshell-sdk-node/package.json @@ -0,0 +1,55 @@ +{ + "name": "@openshell/sdk", + "version": "0.0.0-alpha.0", + "description": "TypeScript SDK for OpenShell gateways. Wraps the openshell-sdk Rust core via napi-rs.", + "main": "index.js", + "types": "index.d.ts", + "exports": { + ".": { + "import": { + "types": "./lib.d.ts", + "default": "./lib.mjs" + }, + "require": { + "types": "./index.d.ts", + "default": "./index.js" + } + }, + "./package.json": "./package.json" + }, + "license": "Apache-2.0", + "repository": { + "type": "git", + "url": "https://github.com/NVIDIA/OpenShell.git", + "directory": "crates/openshell-sdk-node" + }, + "engines": { + "node": ">= 18" + }, + "files": [ + "index.js", + "index.d.ts", + "lib.mjs", + "lib.d.ts" + ], + "napi": { + "binaryName": "openshell-sdk", + "targets": [ + "x86_64-apple-darwin", + "aarch64-apple-darwin", + "x86_64-unknown-linux-gnu", + "aarch64-unknown-linux-gnu", + "x86_64-pc-windows-msvc", + "aarch64-pc-windows-msvc" + ] + }, + "scripts": { + "build": "napi build --platform --release", + "build:debug": "napi build --platform", + "test": "node test/smoke.mjs", + "prepublishOnly": "napi prepublish -t npm" + }, + "devDependencies": { + "@napi-rs/cli": "^3" + } +} diff --git a/crates/openshell-sdk-node/src/lib.rs b/crates/openshell-sdk-node/src/lib.rs new file mode 100644 index 000000000..08d2d52d9 --- /dev/null +++ b/crates/openshell-sdk-node/src/lib.rs @@ -0,0 +1,406 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! napi-rs bindings over [`openshell_sdk`]. +//! +//! This crate is a thin adapter — it owns no business logic. It maps the +//! curated SDK surface to JS-shaped types (camelCase keys, string-literal +//! enums, JS `Error` with a discriminable `code` field), and bridges a +//! JS-side OIDC refresh callback to the SDK's [`openshell_sdk::Refresh`] +//! trait. +//! +//! Published as `@openshell/sdk` (alpha; no semver guarantee until 1.0). +//! +//! # Runtime ownership +//! +//! napi-rs v3 provides an ambient tokio runtime that's only available inside +//! `async fn` napi entry points. Every JS-facing function on [`OpenShellClient`] +//! is therefore `async`. Sync FFI entry points cannot call the SDK because +//! tonic requires a reactor; attempting `tokio::spawn` from a sync `#[napi]` +//! function panics with "no reactor running". + +#![allow(clippy::needless_pass_by_value, clippy::missing_errors_doc)] + +use napi::Status; +use napi::bindgen_prelude::*; +use napi::threadsafe_function::ThreadsafeFunction; +use napi_derive::napi; +use openshell_sdk as sdk; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +// ── Error mapping ───────────────────────────────────────────────────── + +fn to_napi_error(error: sdk::SdkError) -> Error { + let code = error.code(); + // Embed the SDK code as a `[code] message` prefix. N-API maps the + // status enum to JS `err.code`, which is too coarse for discrimination + // (`GenericFailure` covers most variants), so callers parse the prefix + // or use the `errorCode()` helper exported from the JS shim. + Error::new(Status::GenericFailure, format!("[{code}] {error}")) +} + +// ── Public input/output types (JS-shaped) ──────────────────────────── + +/// Connection options. Mirrors [`openshell_sdk::ClientConfig`] with +/// JS-friendly field names. +#[napi(object)] +#[derive(Default)] +pub struct ConnectOptions { + /// Gateway URL (`http://...` or `https://...`). + pub gateway: String, + /// CA certificate (PEM-encoded). `None` falls back to system roots. + pub ca_cert: Option, + /// Bearer token for direct OIDC auth. Mutually exclusive with `edge_token`. + pub oidc_token: Option, + /// Cloudflare Access bearer token. Routes through a local WebSocket tunnel. + pub edge_token: Option, + /// Disable TLS certificate verification (development/debug only). + pub insecure_skip_verify: Option, +} + +/// Gateway health snapshot. +#[napi(object)] +pub struct Health { + /// Coarse status: `"healthy"`, `"degraded"`, `"unhealthy"`, `"unspecified"`. + pub status: String, + pub version: String, +} + +/// Lifecycle phase: `"unspecified"`, `"provisioning"`, `"ready"`, `"error"`, +/// `"deleting"`, `"unknown"`. +#[napi(object)] +pub struct SandboxRef { + pub id: String, + pub name: String, + pub phase: String, + pub labels: HashMap, + /// Resource version as a string — JS numbers can't safely hold u64. + pub resource_version: String, +} + +/// Caller intent for a new sandbox. +#[napi(object)] +#[derive(Default)] +pub struct SandboxSpec { + pub name: Option, + pub image: Option, + pub labels: Option>, + pub environment: Option>, + pub providers: Option>, + pub gpu: Option, + pub gpu_device: Option, +} + +/// Options for [`OpenShellClient::list_sandboxes`]. +#[napi(object)] +#[derive(Default)] +pub struct ListOptions { + pub limit: Option, + pub offset: Option, + pub label_selector: Option, +} + +/// Options for [`OpenShellClient::exec`]. +#[napi(object)] +#[derive(Default)] +pub struct ExecOptions { + pub workdir: Option, + pub environment: Option>, + /// Timeout in seconds. `None` lets the gateway choose. + pub timeout_secs: Option, + /// Optional stdin payload. + pub stdin: Option, +} + +/// Result of a non-streaming exec call. +#[napi(object)] +pub struct ExecResult { + pub exit_code: i32, + pub stdout: Buffer, + pub stderr: Buffer, +} + +// ── Type conversions ───────────────────────────────────────────────── + +fn phase_to_str(phase: sdk::SandboxPhase) -> &'static str { + match phase { + sdk::SandboxPhase::Provisioning => "provisioning", + sdk::SandboxPhase::Ready => "ready", + sdk::SandboxPhase::Error => "error", + sdk::SandboxPhase::Deleting => "deleting", + sdk::SandboxPhase::Unknown => "unknown", + _ => "unspecified", + } +} + +fn status_to_str(status: sdk::ServiceStatus) -> &'static str { + match status { + sdk::ServiceStatus::Healthy => "healthy", + sdk::ServiceStatus::Degraded => "degraded", + sdk::ServiceStatus::Unhealthy => "unhealthy", + _ => "unspecified", + } +} + +impl From for SandboxRef { + fn from(r: sdk::SandboxRef) -> Self { + Self { + id: r.id, + name: r.name, + phase: phase_to_str(r.phase).to_string(), + labels: r.labels, + resource_version: r.resource_version.to_string(), + } + } +} + +fn sdk_spec_from_js(spec: SandboxSpec) -> sdk::SandboxSpec { + sdk::SandboxSpec { + name: spec.name, + image: spec.image, + labels: spec.labels.unwrap_or_default(), + environment: spec.environment.unwrap_or_default(), + providers: spec.providers.unwrap_or_default(), + gpu: spec.gpu.unwrap_or(false), + gpu_device: spec.gpu_device, + } +} + +fn sdk_list_opts_from_js(opts: ListOptions) -> sdk::ListOptions { + sdk::ListOptions { + limit: opts.limit.unwrap_or(0), + offset: opts.offset.unwrap_or(0), + label_selector: opts.label_selector, + } +} + +fn sdk_exec_opts_from_js(opts: ExecOptions) -> sdk::ExecOptions { + sdk::ExecOptions { + workdir: opts.workdir, + environment: opts.environment.unwrap_or_default(), + timeout: opts.timeout_secs.map(|s| Duration::from_secs(u64::from(s))), + stdin: opts.stdin.map(|b| b.to_vec()), + } +} + +fn build_client_config(opts: ConnectOptions) -> sdk::ClientConfig { + let auth = match (opts.oidc_token, opts.edge_token) { + (Some(token), _) => Some(sdk::AuthConfig::Oidc(token)), + (None, Some(token)) => Some(sdk::AuthConfig::EdgeJwt(token)), + (None, None) => None, + }; + let mut cfg = sdk::ClientConfig::new(opts.gateway); + cfg.ca_cert = opts.ca_cert.map(|b| b.to_vec()); + cfg.auth = auth; + cfg.insecure_skip_verify = opts.insecure_skip_verify.unwrap_or(false); + cfg +} + +// ── OIDC refresh callback bridge ───────────────────────────────────── + +/// JS-side refresh callback returning a Promise<{ accessToken, expiresAt? }>. +#[napi(object)] +pub struct JsRefreshedToken { + pub access_token: String, + /// Expiry as Unix epoch seconds. Stored as `f64` because JS numbers + /// can't hold `u64` exactly past 2^53; values are clamped to that range + /// in practice (the year 287396 is fine). + pub expires_at: Option, +} + +/// Bridge between a JS refresh callback and the SDK's [`sdk::Refresh`] trait. +struct JsRefresher { + callback: ThreadsafeFunction<(), Promise, (), Status, false>, +} + +#[async_trait::async_trait] +impl sdk::Refresh for JsRefresher { + async fn refresh(&self) -> std::result::Result { + // Invoke the JS callback; it returns a Promise. + let promise = + self.callback.call_async(()).await.map_err(|e| { + sdk::RefreshError::Transient(format!("refresh callback failed: {e}")) + })?; + let result = promise + .await + .map_err(|e| sdk::RefreshError::Transient(format!("refresh promise rejected: {e}")))?; + let token = sdk::RefreshedToken::new(result.access_token); + Ok(match result.expires_at { + Some(expires_at) if expires_at.is_finite() && expires_at > 0.0 => { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + { + token.with_expires_at(expires_at as u64) + } + } + _ => token, + }) + } +} + +/// A live token source backed by a JS callback. Hand off to +/// [`OpenShellClient::set_oidc_refresher`] before any RPCs run; the SDK +/// proactively refreshes when the token is within 60s of expiry, and +/// coalesces concurrent refreshes into a single callback invocation. +#[napi] +pub struct OidcRefresher { + inner: Arc, +} + +#[napi] +impl OidcRefresher { + /// Create a refresher with an initial token and a JS callback. + /// + /// The callback must return a Promise resolving to + /// `{ accessToken, expiresAt? }`. `expiresAt` is Unix epoch seconds. + #[napi(constructor)] + pub fn new( + initial_token: String, + initial_expires_at: Option, + #[napi(ts_arg_type = "() => Promise<{ accessToken: string; expiresAt?: number }>")] + callback: ThreadsafeFunction<(), Promise, (), Status, false>, + ) -> Self { + let initial = match initial_expires_at { + Some(exp) if exp.is_finite() && exp > 0.0 => { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + { + sdk::RefreshedToken::new(initial_token).with_expires_at(exp as u64) + } + } + _ => sdk::RefreshedToken::new(initial_token), + }; + let refresher = Arc::new(JsRefresher { callback }); + Self { + inner: Arc::new(sdk::TokenSource::new(initial, refresher)), + } + } + + /// Snapshot the current token (no refresh check). Mostly useful for + /// tests; in steady-state the SDK calls this internally. + #[napi] + pub fn current_token(&self) -> String { + self.inner.snapshot() + } + + /// Force a refresh now and return the new access token. Concurrent + /// callers coalesce. + #[napi] + pub async fn refresh(&self) -> Result { + self.inner.refresh_now().await.map_err(to_napi_error) + } +} + +// ── Main client class ──────────────────────────────────────────────── + +/// The JS-facing client. Cheap to share between async tasks; do not call +/// `connect` per request. +#[napi] +pub struct OpenShellClient { + inner: sdk::OpenShellClient, +} + +#[napi] +impl OpenShellClient { + /// Open a connection to the gateway described by `options`. + #[napi(factory)] + pub async fn connect(options: ConnectOptions) -> Result { + let cfg = build_client_config(options); + let inner = sdk::OpenShellClient::connect(cfg) + .await + .map_err(to_napi_error)?; + Ok(Self { inner }) + } + + /// Gateway health snapshot. + #[napi] + pub async fn health(&self) -> Result { + let h = self.inner.health().await.map_err(to_napi_error)?; + Ok(Health { + status: status_to_str(h.status).to_string(), + version: h.version, + }) + } + + /// Create a new sandbox. + #[napi] + pub async fn create_sandbox(&self, spec: SandboxSpec) -> Result { + self.inner + .create_sandbox(sdk_spec_from_js(spec)) + .await + .map(Into::into) + .map_err(to_napi_error) + } + + /// Fetch a sandbox by name. + #[napi] + pub async fn get_sandbox(&self, name: String) -> Result { + self.inner + .get_sandbox(&name) + .await + .map(Into::into) + .map_err(to_napi_error) + } + + /// List sandboxes. + #[napi] + pub async fn list_sandboxes(&self, options: Option) -> Result> { + let opts = sdk_list_opts_from_js(options.unwrap_or_default()); + let items = self + .inner + .list_sandboxes(opts) + .await + .map_err(to_napi_error)?; + Ok(items.into_iter().map(Into::into).collect()) + } + + /// Delete a sandbox by name. Returns `true` when the gateway acknowledged + /// the deletion, `false` when it was already absent. + #[napi] + pub async fn delete_sandbox(&self, name: String) -> Result { + self.inner + .delete_sandbox(&name) + .await + .map_err(to_napi_error) + } + + /// Poll until the sandbox reaches `ready` or `timeout_secs` elapses. + #[napi] + pub async fn wait_ready(&self, name: String, timeout_secs: u32) -> Result { + self.inner + .wait_ready(&name, Duration::from_secs(u64::from(timeout_secs))) + .await + .map(Into::into) + .map_err(to_napi_error) + } + + /// Poll until the sandbox is gone or `timeout_secs` elapses. + #[napi] + pub async fn wait_deleted(&self, name: String, timeout_secs: u32) -> Result<()> { + self.inner + .wait_deleted(&name, Duration::from_secs(u64::from(timeout_secs))) + .await + .map_err(to_napi_error) + } + + /// Run a command inside a sandbox; buffers stdout/stderr to the end. + #[napi] + pub async fn exec( + &self, + name: String, + command: Vec, + options: Option, + ) -> Result { + let opts = sdk_exec_opts_from_js(options.unwrap_or_default()); + let res = self + .inner + .exec(&name, &command, opts) + .await + .map_err(to_napi_error)?; + Ok(ExecResult { + exit_code: res.exit_code, + stdout: res.stdout.into(), + stderr: res.stderr.into(), + }) + } +} diff --git a/crates/openshell-sdk-node/test/smoke.mjs b/crates/openshell-sdk-node/test/smoke.mjs new file mode 100644 index 000000000..cf7dd6f7c --- /dev/null +++ b/crates/openshell-sdk-node/test/smoke.mjs @@ -0,0 +1,105 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Smoke test for @openshell/sdk. +// +// Verifies the binding's surface: exports, error mapping, and the +// `OidcRefresher` single-flight contract. End-to-end RPC verification +// requires a running mock gateway and lives in the Rust crate's mock tests +// (see `crates/openshell-sdk/tests/client_mock.rs`). +// +// Implemented as a plain async script rather than `node --test` because +// the napi-rs `tokio_rt` feature holds the libuv event loop open and the +// test runner never exits cleanly. We explicitly `process.exit(0)` here. + +import { strict as assert } from 'node:assert' +import { errorCode, OidcRefresher, OpenShellClient } from '../lib.mjs' + +const cases = [] +function test(name, fn) { + cases.push({ name, fn }) +} + +test('module exports the documented classes', () => { + assert.equal(typeof OpenShellClient, 'function') + assert.equal(typeof OpenShellClient.connect, 'function') + assert.equal(typeof OidcRefresher, 'function') + assert.equal(typeof errorCode, 'function') +}) + +test('connect rejects against a closed port with a typed error', async () => { + let caught + try { + await OpenShellClient.connect({ gateway: 'http://127.0.0.1:1' }) + } catch (err) { + caught = err + } + assert.ok(caught, 'expected connect to fail against 127.0.0.1:1') + assert.equal(errorCode(caught), 'connect', `unexpected error: ${caught?.message}`) +}) + +test('connect rejects with invalid_config for a malformed gateway URL', async () => { + let caught + try { + await OpenShellClient.connect({ gateway: 'not-a-url' }) + } catch (err) { + caught = err + } + assert.ok(caught, 'expected connect to fail on malformed URL') + assert.equal(errorCode(caught), 'invalid_config', `unexpected error: ${caught?.message}`) +}) + +test('OidcRefresher coalesces concurrent refresh calls', async () => { + let calls = 0 + // expiresAt = 1 (Unix epoch 1970) is in the past, so the SDK's reactive + // path treats the token as expired. See refresh.rs `needs_refresh`. + const refresher = new OidcRefresher('initial', 1, async () => { + calls += 1 + await new Promise((resolve) => setTimeout(resolve, 25)) + return { + accessToken: `token-${calls}`, + expiresAt: Math.floor(Date.now() / 1000) + 3600, + } + }) + + const results = await Promise.all([ + refresher.refresh(), + refresher.refresh(), + refresher.refresh(), + refresher.refresh(), + ]) + + assert.equal(calls, 1, 'callback should have been invoked once for coalesced calls') + assert.equal(new Set(results).size, 1, 'all waiters should observe the same token') + assert.equal(results[0], 'token-1') + assert.equal(refresher.currentToken(), 'token-1') +}) + +test('OidcRefresher surfaces callback rejections as auth errors', async () => { + const refresher = new OidcRefresher('stale', 1, async () => { + throw new Error('IdP unreachable') + }) + + let caught + try { + await refresher.refresh() + } catch (err) { + caught = err + } + assert.ok(caught, 'expected refresh to reject when callback throws') + assert.equal(errorCode(caught), 'auth', `unexpected error: ${caught?.message}`) +}) + +let failed = 0 +for (const { name, fn } of cases) { + try { + await fn() + console.log(`ok ${name}`) + } catch (err) { + failed += 1 + console.error(`fail ${name}`) + console.error(err) + } +} +console.log(`\n${cases.length - failed}/${cases.length} passed`) +process.exit(failed === 0 ? 0 : 1) From 279e1c29ec390c096aa4368f710eee1193f851d0 Mon Sep 17 00:00:00 2001 From: Max Dubrinsky Date: Thu, 28 May 2026 11:01:25 -0400 Subject: [PATCH 4/6] feat(examples): Pi extension wrapping @openshell/sdk Adds a Pi (pi.dev) coding-agent extension under examples/ that treats OpenShell sandboxes as disposable sub-agents. The extension wraps @openshell/sdk and registers five tools (openshell_run_task, openshell_spawn_sandbox, openshell_exec, openshell_list_sandboxes, openshell_destroy_sandbox) plus a /openshell-health slash command. openshell_run_task is the primary dispatch surface: one call creates a sandbox, waits for ready, runs a command, streams stdout/stderr back as the tool result, and deletes the sandbox unless keep_sandbox=true. --- .../pi-extension-sandbox-agents/.gitignore | 3 + .../pi-extension-sandbox-agents/README.md | 162 ++++++++ .../pi-extension-sandbox-agents/extension.ts | 346 ++++++++++++++++++ .../pi-extension-sandbox-agents/package.json | 27 ++ .../pi-extension-sandbox-agents/tsconfig.json | 14 + 5 files changed, 552 insertions(+) create mode 100644 examples/pi-extension-sandbox-agents/.gitignore create mode 100644 examples/pi-extension-sandbox-agents/README.md create mode 100644 examples/pi-extension-sandbox-agents/extension.ts create mode 100644 examples/pi-extension-sandbox-agents/package.json create mode 100644 examples/pi-extension-sandbox-agents/tsconfig.json diff --git a/examples/pi-extension-sandbox-agents/.gitignore b/examples/pi-extension-sandbox-agents/.gitignore new file mode 100644 index 000000000..79bd70dbb --- /dev/null +++ b/examples/pi-extension-sandbox-agents/.gitignore @@ -0,0 +1,3 @@ +node_modules/ +.npm-cache/ +package-lock.json diff --git a/examples/pi-extension-sandbox-agents/README.md b/examples/pi-extension-sandbox-agents/README.md new file mode 100644 index 000000000..f81ef47a8 --- /dev/null +++ b/examples/pi-extension-sandbox-agents/README.md @@ -0,0 +1,162 @@ + + + +# Pi extension: OpenShell sandboxes as sub-agents + +A [Pi coding agent](https://pi.dev/) extension that wraps the `@openshell/sdk` +TypeScript binding so a controlling Pi agent can spin up fresh OpenShell +sandboxes, dispatch tasks into them, and tear them down — turning each +sandbox into a disposable sub-agent. + +## What this demo shows + +The Pi agent stays in your terminal. Long-running, isolated, or +untrusted-input work is dispatched to OpenShell sandboxes, which give you: + +- A fresh filesystem and process namespace per task +- Policy-enforced egress (no exfil from a runaway tool) +- Credential placeholders instead of real tokens for upstream calls +- Logs you can audit after the fact + +The extension exposes five tools to the Pi model: + +| Tool | Purpose | +| --- | --- | +| `openshell_run_task` | One-shot: create → exec → return → delete. The killer feature. | +| `openshell_spawn_sandbox` | Create a long-lived sandbox and wait for ready. | +| `openshell_exec` | Run a follow-up command inside an existing sandbox. | +| `openshell_list_sandboxes` | Observe active sub-agents, optionally label-filtered. | +| `openshell_destroy_sandbox` | Release a long-lived sandbox. | + +And one slash command: + +| Command | Purpose | +| --- | --- | +| `/openshell-health` | Probe the gateway and print its status. | + +## How sub-agent semantics work + +`openshell_run_task` is the primary dispatch surface. Each call: + +1. Creates a sandbox with `pi.openshell/role=sub-agent` plus a + `pi.openshell/task=