diff --git a/Cargo.lock b/Cargo.lock index 091a88c9..87581d29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1348,10 +1348,12 @@ dependencies = [ "futures", "futures-util", "getrandom 0.3.4", + "hyper-util", "js-sys", "libloading", "nemo-relay-plugin", "nemo-relay-types", + "nemo-relay-worker-proto", "object_store", "openinference-semantic-conventions", "opentelemetry", @@ -1373,6 +1375,7 @@ dependencies = [ "tokio-tungstenite", "toml", "tonic", + "tower", "typed-builder", "uuid", "wasm-bindgen", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index d3ef689d..07e1087b 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -25,7 +25,7 @@ default = ["atof-streaming"] atof-streaming = ["nemo-relay/atof-streaming"] [dependencies] -nemo-relay = { workspace = true, features = ["guardrails-remote", "object-store", "openinference"] } +nemo-relay = { workspace = true, features = ["guardrails-remote", "object-store", "openinference", "worker-grpc"] } nemo-relay-adaptive = { workspace = true, features = ["redis-backend"] } nemo-relay-pii-redaction.workspace = true async-stream = "0.3" diff --git a/crates/cli/src/server.rs b/crates/cli/src/server.rs index 9765cddd..dc697791 100644 --- a/crates/cli/src/server.rs +++ b/crates/cli/src/server.rs @@ -10,7 +10,8 @@ use axum::http::HeaderMap; use axum::routing::{get, post}; use axum::{Json, Router}; use nemo_relay::plugin::dynamic::{ - DynamicPluginKind, NativePluginActivation, NativePluginLoadSpec, load_native_plugins, + DynamicPluginKind, NativePluginActivation, NativePluginLoadSpec, WorkerPluginActivation, + WorkerPluginLoadSpec, load_native_plugins, load_worker_plugins, }; use nemo_relay::plugin::{ PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins_exact, @@ -216,6 +217,7 @@ async fn idle_shutdown_future( struct PluginActivation { active: bool, native: Option, + worker: Option, } impl PluginActivation { @@ -227,6 +229,7 @@ impl PluginActivation { return Ok(Self { active: false, native: None, + worker: None, }); }; register_adaptive_component().map_err(|error| { @@ -251,6 +254,23 @@ impl PluginActivation { }) }) .collect::, CliError>>()?; + let worker_specs = dynamic_plugins + .iter() + .filter(|plugin| plugin.kind == DynamicPluginKind::Worker) + .map(|plugin| { + let manifest_ref = plugin.manifest_ref.clone().ok_or_else(|| { + CliError::Config(format!( + "worker dynamic plugin '{}' has no manifest_ref in lifecycle state", + plugin.plugin_id + )) + })?; + Ok(WorkerPluginLoadSpec { + plugin_id: plugin.plugin_id.clone(), + manifest_ref, + config: plugin.config.clone(), + }) + }) + .collect::, CliError>>()?; let native = if native_specs.is_empty() { None @@ -259,6 +279,14 @@ impl PluginActivation { CliError::Config(format!("native plugin load failed: {error}")) })?) }; + let worker = + if worker_specs.is_empty() { + None + } else { + Some(load_worker_plugins(worker_specs).map_err(|error| { + CliError::Config(format!("worker plugin load failed: {error}")) + })?) + }; // Gateway already resolved its config; activate exactly (no re-discovery). let mut plugin_config: PluginConfig = match config { Some(config) => serde_json::from_value(config) @@ -282,6 +310,7 @@ impl PluginActivation { Ok(Self { active: true, native, + worker, }) } @@ -295,6 +324,7 @@ impl PluginActivation { Ok(()) }; self.native.take(); + self.worker.take(); result } } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 97f66869..9082eb91 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -73,9 +73,24 @@ openinference = [ "dep:wasm-bindgen", "dep:wasm-bindgen-futures", ] +worker-grpc = [ + "dep:nemo-relay-worker-proto", + "dep:hyper-util", + "dep:tower", + "dep:tonic", + "tokio-stream/net", + "tonic/codegen", + "tonic/router", + "tonic/transport", + "tokio/io-util", + "tokio/net", + "tokio/process", + "tokio/rt-multi-thread", +] [dependencies] nemo-relay-types.workspace = true +nemo-relay-worker-proto = { workspace = true, optional = true } uuid = { workspace = true, features = ["v7", "serde"] } serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" @@ -120,6 +135,11 @@ path = "tests/integration/context_isolation_tests.rs" name = "native_plugin_integration" path = "tests/integration/native_plugin_tests.rs" +[[test]] +name = "worker_plugin_integration" +path = "tests/integration/worker_plugin_tests.rs" +required-features = ["worker-grpc"] + [[test]] name = "middleware_integration" path = "tests/integration/middleware_tests.rs" @@ -163,3 +183,5 @@ rustls = { version = "0.23", default-features = false, features = ["ring", "std" tonic = { version = "0.14.1", default-features = false, optional = true } object_store = { version = "0.13", default-features = false, features = ["aws"], optional = true } tokio-tungstenite = { version = "0.27", default-features = false, features = ["connect", "rustls-tls-native-roots"], optional = true } +tower = { version = "0.5", features = ["util"], optional = true } +hyper-util = { version = "0.1", features = ["tokio"], optional = true } diff --git a/crates/core/src/plugin/dynamic.rs b/crates/core/src/plugin/dynamic.rs index a15af408..05182c7f 100644 --- a/crates/core/src/plugin/dynamic.rs +++ b/crates/core/src/plugin/dynamic.rs @@ -22,11 +22,15 @@ mod manifest; #[cfg(not(target_arch = "wasm32"))] mod native; mod registry; +#[cfg(all(feature = "worker-grpc", not(target_arch = "wasm32")))] +mod worker; pub use manifest::*; #[cfg(not(target_arch = "wasm32"))] pub use native::*; pub use registry::*; +#[cfg(all(feature = "worker-grpc", not(target_arch = "wasm32")))] +pub use worker::*; /// Plugin execution lane. #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Display)] diff --git a/crates/core/src/plugin/dynamic/worker.rs b/crates/core/src/plugin/dynamic/worker.rs new file mode 100644 index 00000000..a4e46e56 --- /dev/null +++ b/crates/core/src/plugin/dynamic/worker.rs @@ -0,0 +1,2038 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! gRPC worker dynamic plugin loader and host-side proxy adapter. + +use std::collections::HashMap; +use std::future::Future; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::process::{Child, Command, Stdio}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use nemo_relay_worker_proto::v1::plugin_worker_client::PluginWorkerClient; +use nemo_relay_worker_proto::v1::relay_host_runtime_server::{ + RelayHostRuntime, RelayHostRuntimeServer, +}; +use nemo_relay_worker_proto::v1::{ + CreateScopeStackRequest, CreateScopeStackResponse, DropScopeStackRequest, EmitMarkRequest, + GuardrailResult, HandshakeRequest, HealthRequest, HostAck, InvokeRequest, InvokeResponse, + JsonEnvelope, JsonResult, LlmInvocation, LlmNextRequest, LlmStreamNextRequest, PopScopeRequest, + PushScopeRequest, PushScopeResponse, RegisterRequest, RegisterResponse, Registration, + RegistrationSurface, ScopeContext, ShutdownRequest, StreamChunk, ToolInvocation, + ToolNextRequest, ValidateRequest, WorkerError, +}; +use nemo_relay_worker_proto::{WORKER_PROTOCOL_GRPC_V1, decode_json_envelope, json_envelope}; +use semver::{Version, VersionReq}; +use serde_json::{Map, Value as Json}; +use tokio::runtime::{Builder as RuntimeBuilder, Runtime}; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::StreamExt; +use tonic::transport::{Channel, Endpoint, Server}; +use tonic::{Request, Response, Status}; +use uuid::Uuid; + +#[cfg(unix)] +use hyper_util::rt::TokioIo; +#[cfg(not(unix))] +use std::net::{SocketAddr, TcpListener}; +#[cfg(unix)] +use std::os::unix::net::UnixListener as StdUnixListener; +#[cfg(not(unix))] +use tokio::net::TcpListener as TokioTcpListener; +#[cfg(unix)] +use tokio::net::{UnixListener, UnixStream}; +#[cfg(not(unix))] +use tokio_stream::wrappers::TcpListenerStream; +#[cfg(unix)] +use tokio_stream::wrappers::UnixListenerStream; +#[cfg(unix)] +use tower::service_fn; + +use crate::api::event::Event; +use crate::api::llm::LlmRequest; +use crate::api::runtime::{ + LlmExecutionNextFn, LlmJsonStream, LlmStreamExecutionNextFn, ToolExecutionNextFn, + current_scope_stack, with_scope_stack, +}; +use crate::api::scope::{ + EmitMarkEventParams, PopScopeParams, PushScopeParams, ScopeAttributes, ScopeHandle, ScopeType, + event as emit_scope_mark, pop_scope, push_scope, +}; +use crate::codec::request::AnnotatedLlmRequest; +use crate::error::{FlowError, Result as FlowResult}; +use crate::plugin::{ + ConfigDiagnostic, DiagnosticLevel, Plugin, PluginError, PluginRegistrationContext, + deregister_plugin, register_plugin, +}; + +use super::{DynamicPluginKind, DynamicPluginManifest, DynamicPluginManifestLoad, WorkerRuntime}; + +const JSON_SCHEMA: &str = "nemo.relay.Json@1"; +const EVENT_SCHEMA: &str = "nemo.relay.Event@1"; +const LLM_REQUEST_SCHEMA: &str = "nemo.relay.LlmRequest@1"; +const ANNOTATED_LLM_REQUEST_SCHEMA: &str = "nemo.relay.AnnotatedLlmRequest@1"; +const WORKER_STARTUP_TIMEOUT: Duration = Duration::from_secs(10); +const WORKER_RPC_TIMEOUT: Duration = Duration::from_secs(30); +const WORKER_CONNECT_RETRY: Duration = Duration::from_millis(25); +const PYTHON_WORKER_BOOTSTRAP: &str = r#" +import asyncio +import importlib +import inspect +import sys + +target = sys.argv[1] +module_name, separator, function_name = target.partition(":") +if not separator or not module_name or not function_name: + raise SystemExit("Python worker entrypoint must be 'module:function'") + +entrypoint = getattr(importlib.import_module(module_name), function_name) +result = entrypoint() +if inspect.isawaitable(result): + asyncio.run(result) +"#; + +/// Worker plugin load request derived from host dynamic-plugin state. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WorkerPluginLoadSpec { + /// Expected plugin id. + pub plugin_id: String, + /// Path to the authored `relay-plugin.toml`. + pub manifest_ref: String, + /// Resolved dynamic plugin config passed to the worker. + pub config: Map, +} + +/// Owns gRPC worker processes registered into the plugin registry. +/// +/// Dropping this value deregisters worker plugin kinds and shuts down worker +/// processes. Clear active plugin configuration before dropping it so runtime +/// callbacks cannot outlive the worker activation. +pub struct WorkerPluginActivation { + plugins: Vec>, + plugin_kinds: Vec, +} + +impl WorkerPluginActivation { + /// Returns `true` when no worker plugins were loaded. + pub fn is_empty(&self) -> bool { + self.plugins.is_empty() + } + + /// Consumes the activation; deregistration runs from `Drop`. + pub fn clear(self) {} +} + +impl Drop for WorkerPluginActivation { + fn drop(&mut self) { + for plugin_kind in self.plugin_kinds.iter().rev() { + let _ = deregister_plugin(plugin_kind); + } + } +} + +/// Loads gRPC worker plugins and registers their plugin kinds. +/// +/// The returned activation must be kept alive until after active plugin +/// configuration has been cleared. +pub fn load_worker_plugins(specs: I) -> crate::plugin::Result +where + I: IntoIterator, +{ + let mut activation = WorkerPluginActivation { + plugins: Vec::new(), + plugin_kinds: Vec::new(), + }; + for spec in specs { + let instance = load_one_worker_plugin(&spec)?; + let plugin_kind = instance.plugin_kind.clone(); + register_plugin(Arc::new(WorkerPluginAdapter { + plugin_kind: plugin_kind.clone(), + allows_multiple_components: instance.allows_multiple_components, + instance: instance.clone(), + }))?; + activation.plugins.push(instance); + activation.plugin_kinds.push(plugin_kind); + } + Ok(activation) +} + +struct WorkerPluginAdapter { + plugin_kind: String, + allows_multiple_components: bool, + instance: Arc, +} + +impl Plugin for WorkerPluginAdapter { + fn plugin_kind(&self) -> &str { + &self.plugin_kind + } + + fn allows_multiple_components(&self) -> bool { + self.allows_multiple_components + } + + fn validate(&self, plugin_config: &Map) -> Vec { + if plugin_config != &self.instance.config { + return vec![worker_error_diagnostic( + &self.plugin_kind, + "plugin.worker_config_mismatch", + "worker plugin config changed after dynamic activation; reload the worker activation", + )]; + } + self.instance.validation_diagnostics.clone() + } + + fn register<'a>( + &'a self, + plugin_config: &Map, + ctx: &'a mut PluginRegistrationContext, + ) -> Pin> + Send + 'a>> { + let config_matches = plugin_config == &self.instance.config; + Box::pin(async move { + if !config_matches { + return Err(PluginError::RegistrationFailed( + "worker plugin config changed after dynamic activation; reload the worker activation" + .into(), + )); + } + self.instance.install_registrations(ctx) + }) + } +} + +struct WorkerPluginInstance { + plugin_kind: String, + allows_multiple_components: bool, + config: Map, + validation_diagnostics: Vec, + registrations: Vec, + runtime: OwnedWorkerRuntime, + client: PluginWorkerClient, + host_state: Arc, + shutdown: Mutex>>, + process: Mutex>, + activation_dir: PathBuf, +} + +impl Drop for WorkerPluginInstance { + fn drop(&mut self) { + let mut client = self.client.clone(); + let request = ShutdownRequest { + activation_id: self.host_state.activation_id.clone(), + auth_token: self.host_state.auth_token.clone(), + reason: "plugin activation dropped".into(), + }; + let _ = block_on_runtime(self.runtime.runtime(), async move { + worker_rpc(client.shutdown(worker_rpc_request(request))).await + }); + if let Ok(mut shutdown) = self.shutdown.lock() + && let Some(sender) = shutdown.take() + { + let _ = sender.send(()); + } + if let Ok(mut process) = self.process.lock() + && let Some(mut child) = process.take() + { + let _ = child.kill(); + let _ = child.wait(); + } + let _ = std::fs::remove_dir_all(&self.activation_dir); + } +} + +fn load_one_worker_plugin( + spec: &WorkerPluginLoadSpec, +) -> crate::plugin::Result> { + let (manifest, manifest_ref) = DynamicPluginManifest::load_from_path(&spec.manifest_ref)?; + if manifest.plugin.id.trim() != spec.plugin_id { + return Err(PluginError::InvalidConfig(format!( + "dynamic plugin manifest id '{}' does not match expected id '{}'", + manifest.plugin.id, spec.plugin_id + ))); + } + if manifest.plugin.kind != DynamicPluginKind::Worker { + return Err(PluginError::InvalidConfig(format!( + "dynamic plugin '{}' is kind {}; worker loader only supports worker", + spec.plugin_id, manifest.plugin.kind + ))); + } + validate_relay_compatibility(manifest.compat.relay.as_deref())?; + let DynamicPluginManifestLoad::Worker(load) = &manifest.load else { + unreachable!("validated worker manifest must carry worker load contract"); + }; + let runtime = load + .runtime + .ok_or_else(|| PluginError::InvalidConfig("load.runtime is required".into()))?; + let entrypoint = load + .entrypoint + .as_deref() + .ok_or_else(|| PluginError::InvalidConfig("load.entrypoint is required".into()))?; + + let activation_uuid = Uuid::now_v7(); + let activation_id = activation_uuid.to_string(); + let auth_token = Uuid::now_v7().to_string(); + let activation_dir = std::env::temp_dir().join(format!("nmrw-{}", activation_uuid.simple())); + std::fs::create_dir_all(&activation_dir) + .map_err(|err| PluginError::Internal(format!("worker activation directory: {err}")))?; + let mut activation_dir_guard = ActivationDirGuard::new(activation_dir.clone()); + let runtime_handle = OwnedWorkerRuntime::new( + RuntimeBuilder::new_multi_thread() + .enable_all() + .thread_name("nemo-relay-worker-host") + .build() + .map_err(|err| PluginError::Internal(format!("worker runtime: {err}")))?, + ); + let WorkerEndpoints { + host_server, + host_advertise, + worker_advertise, + worker_connect, + worker_endpoint_file, + } = WorkerEndpoints::new(&activation_dir)?; + let host_state = Arc::new(WorkerHostRuntimeState::new( + activation_id.clone(), + auth_token.clone(), + )); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + runtime_handle.runtime().spawn(serve_host_runtime( + host_server, + host_state.clone(), + shutdown_rx, + )); + + let manifest_path = PathBuf::from(&manifest_ref); + let mut child = ChildGuard::new(spawn_worker_process(WorkerProcessLaunch { + runtime, + manifest_path: &manifest_path, + plugin_id: &spec.plugin_id, + entrypoint, + activation_id: &activation_id, + auth_token: &auth_token, + host_endpoint: &host_advertise, + worker_endpoint: &worker_advertise, + worker_endpoint_file: worker_endpoint_file.as_deref(), + })?); + let mut client = block_on_runtime( + runtime_handle.runtime(), + connect_worker_with_retry(&worker_connect), + )?; + + let health = block_on_runtime( + runtime_handle.runtime(), + worker_rpc(client.health(worker_rpc_request(HealthRequest { + activation_id: activation_id.clone(), + auth_token: auth_token.clone(), + }))), + ) + .map_err(|err| PluginError::RegistrationFailed(format!("worker health check failed: {err}")))?; + let health = health.into_inner(); + if !health.ok { + let message = format!("worker plugin health check failed: {}", health.message); + return Err(PluginError::RegistrationFailed(message)); + } + + let handshake = block_on_runtime( + runtime_handle.runtime(), + worker_rpc(client.handshake(worker_rpc_request(HandshakeRequest { + activation_id: activation_id.clone(), + plugin_id: spec.plugin_id.clone(), + relay_version: env!("CARGO_PKG_VERSION").into(), + worker_protocol: WORKER_PROTOCOL_GRPC_V1.into(), + auth_token: auth_token.clone(), + host_endpoint: host_advertise.clone(), + }))), + ) + .map_err(|err| PluginError::RegistrationFailed(format!("worker handshake failed: {err}")))?; + let handshake = handshake.into_inner(); + if handshake.plugin_id != spec.plugin_id || handshake.plugin_kind != spec.plugin_id { + return Err(PluginError::InvalidConfig(format!( + "worker plugin returned id '{}' kind '{}' but manifest id is '{}'", + handshake.plugin_id, handshake.plugin_kind, spec.plugin_id + ))); + } + if handshake.worker_protocol != WORKER_PROTOCOL_GRPC_V1 { + let message = format!( + "unsupported worker_protocol '{}'", + handshake.worker_protocol + ); + return Err(PluginError::InvalidConfig(message)); + } + + let config = Json::Object(spec.config.clone()); + let validate = block_on_runtime( + runtime_handle.runtime(), + worker_rpc(client.validate(worker_rpc_request(ValidateRequest { + activation_id: activation_id.clone(), + plugin_id: spec.plugin_id.clone(), + auth_token: auth_token.clone(), + config: Some(json_envelope(JSON_SCHEMA, &config)?), + }))), + ) + .map_err(|err| { + PluginError::RegistrationFailed(format!("worker validation RPC failed: {err}")) + })?; + let validate = validate.into_inner(); + if let Some(error) = validate.error { + return Err(worker_error_to_plugin(error, "worker validation failed")); + } + let validation_diagnostics = match validate.diagnostics { + Some(diagnostics) => decode_json_envelope::>(&diagnostics) + .map_err(PluginError::Serialization)?, + None => Vec::new(), + }; + + let registrations = if diagnostics_have_errors(&validation_diagnostics) { + Vec::new() + } else { + let register = block_on_runtime( + runtime_handle.runtime(), + worker_rpc(client.register(worker_rpc_request(RegisterRequest { + activation_id: activation_id.clone(), + plugin_id: spec.plugin_id.clone(), + auth_token: auth_token.clone(), + config: Some(json_envelope(JSON_SCHEMA, &config)?), + }))), + ) + .map_err(|err| { + PluginError::RegistrationFailed(format!("worker registration RPC failed: {err}")) + })?; + let register = register.into_inner(); + if let Some(error) = register.error { + return Err(worker_error_to_plugin(error, "worker registration failed")); + } + validate_registration_plan(&spec.plugin_id, ®ister)?; + register.registrations + }; + + Ok(Arc::new(WorkerPluginInstance { + plugin_kind: spec.plugin_id.clone(), + allows_multiple_components: handshake.allows_multiple_components, + config: spec.config.clone(), + validation_diagnostics, + registrations, + runtime: runtime_handle, + client, + host_state, + shutdown: Mutex::new(Some(shutdown_tx)), + process: Mutex::new(Some(child.take())), + activation_dir: activation_dir_guard.keep(), + })) +} + +enum HostRuntimeServer { + #[cfg(unix)] + Unix(StdUnixListener), + #[cfg(not(unix))] + Tcp(TcpListener), +} + +#[derive(Clone)] +enum WorkerConnectEndpoint { + #[cfg(unix)] + Unix(PathBuf), + #[cfg(not(unix))] + Tcp(String), + #[cfg(not(unix))] + Announced(PathBuf), +} + +struct WorkerEndpoints { + host_server: HostRuntimeServer, + host_advertise: String, + worker_advertise: String, + worker_connect: WorkerConnectEndpoint, + worker_endpoint_file: Option, +} + +impl WorkerEndpoints { + fn new(activation_dir: &Path) -> crate::plugin::Result { + #[cfg(not(unix))] + let _ = activation_dir; + + #[cfg(unix)] + { + let host_socket = activation_dir.join("host.sock"); + let worker_socket = activation_dir.join("worker.sock"); + let _ = std::fs::remove_file(&host_socket); + let host_listener = StdUnixListener::bind(&host_socket).map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to bind worker host runtime socket '{}': {err}", + host_socket.display() + )) + })?; + host_listener.set_nonblocking(true).map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to configure worker host runtime socket '{}': {err}", + host_socket.display() + )) + })?; + Ok(Self { + host_server: HostRuntimeServer::Unix(host_listener), + host_advertise: unix_endpoint_display(&host_socket), + worker_advertise: unix_endpoint_display(&worker_socket), + worker_connect: WorkerConnectEndpoint::Unix(worker_socket), + worker_endpoint_file: None, + }) + } + + #[cfg(not(unix))] + { + let (host_listener, host_addr) = bind_loopback_listener()?; + let worker_endpoint_file = activation_dir.join("worker-endpoint"); + Ok(Self { + host_server: HostRuntimeServer::Tcp(host_listener), + host_advertise: format!("http://{host_addr}"), + worker_advertise: "tcp://127.0.0.1:0".into(), + worker_connect: WorkerConnectEndpoint::Announced(worker_endpoint_file.clone()), + worker_endpoint_file: Some(worker_endpoint_file), + }) + } + } +} + +async fn serve_host_runtime( + endpoint: HostRuntimeServer, + state: Arc, + shutdown: oneshot::Receiver<()>, +) { + let service = RelayHostRuntimeServer::new(WorkerHostRuntimeService { state }); + let result = match endpoint { + #[cfg(unix)] + HostRuntimeServer::Unix(listener) => { + let listener = match UnixListener::from_std(listener) { + Ok(listener) => listener, + Err(err) => { + eprintln!("failed to attach worker host runtime socket: {err}"); + return; + } + }; + Server::builder() + .add_service(service) + .serve_with_incoming_shutdown(UnixListenerStream::new(listener), async { + let _ = shutdown.await; + }) + .await + } + #[cfg(not(unix))] + HostRuntimeServer::Tcp(listener) => { + let listener = match TokioTcpListener::from_std(listener) { + Ok(listener) => listener, + Err(err) => { + eprintln!("failed to attach worker host runtime endpoint: {err}"); + return; + } + }; + Server::builder() + .add_service(service) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + let _ = shutdown.await; + }) + .await + } + }; + if let Err(err) = result { + eprintln!("worker host runtime server failed: {err}"); + } +} + +async fn connect_worker_with_retry( + endpoint: &WorkerConnectEndpoint, +) -> crate::plugin::Result> { + let start = std::time::Instant::now(); + loop { + let connect_endpoint = match resolve_worker_connect_endpoint(endpoint) { + Ok(Some(endpoint)) => endpoint, + Ok(None) if start.elapsed() < WORKER_STARTUP_TIMEOUT => { + tokio::time::sleep(WORKER_CONNECT_RETRY).await; + continue; + } + Ok(None) => { + let message = format!( + "worker did not announce endpoint within {}s", + WORKER_STARTUP_TIMEOUT.as_secs() + ); + return Err(PluginError::RegistrationFailed(message)); + } + Err(err) => return Err(err), + }; + match connect_worker(&connect_endpoint).await { + Ok(client) => return Ok(client), + Err(err) if start.elapsed() < WORKER_STARTUP_TIMEOUT => { + let _ = err; + tokio::time::sleep(WORKER_CONNECT_RETRY).await; + } + Err(err) => { + let message = format!( + "worker did not start within {}s: {err}", + WORKER_STARTUP_TIMEOUT.as_secs() + ); + return Err(PluginError::RegistrationFailed(message)); + } + } + } +} + +#[cfg(not(unix))] +fn normalize_worker_tcp_endpoint(endpoint: &str) -> crate::plugin::Result { + let endpoint = endpoint.trim(); + if let Some(authority) = endpoint.strip_prefix("tcp://") { + if authority.is_empty() { + return Err(PluginError::RegistrationFailed( + "worker announced an empty TCP endpoint".into(), + )); + } + return Ok(format!("http://{authority}")); + } + if endpoint.starts_with("http://") { + return Ok(endpoint.to_owned()); + } + Err(PluginError::RegistrationFailed(format!( + "worker announced unsupported endpoint '{endpoint}'" + ))) +} + +fn resolve_worker_connect_endpoint( + endpoint: &WorkerConnectEndpoint, +) -> crate::plugin::Result> { + match endpoint { + #[cfg(unix)] + WorkerConnectEndpoint::Unix(path) => Ok(Some(WorkerConnectEndpoint::Unix(path.clone()))), + #[cfg(not(unix))] + WorkerConnectEndpoint::Tcp(endpoint) => Ok(Some(WorkerConnectEndpoint::Tcp( + normalize_worker_tcp_endpoint(endpoint)?, + ))), + #[cfg(not(unix))] + WorkerConnectEndpoint::Announced(path) => { + let endpoint = match std::fs::read_to_string(path) { + Ok(endpoint) => endpoint, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None), + Err(err) => { + return Err(PluginError::RegistrationFailed(format!( + "failed to read worker endpoint file '{}': {err}", + path.display() + ))); + } + }; + Ok(Some(WorkerConnectEndpoint::Tcp( + normalize_worker_tcp_endpoint(endpoint.trim())?, + ))) + } + } +} + +async fn connect_worker( + endpoint: &WorkerConnectEndpoint, +) -> crate::plugin::Result> { + match endpoint { + #[cfg(unix)] + WorkerConnectEndpoint::Unix(socket) => { + let path = Arc::new(socket.to_path_buf()); + let endpoint = Endpoint::try_from("http://[::]:50051") + .map_err(|err| PluginError::Internal(format!("invalid worker endpoint: {err}")))?; + let channel = endpoint + .connect_with_connector(service_fn(move |_| { + let path = path.clone(); + async move { UnixStream::connect(&*path).await.map(TokioIo::new) } + })) + .await + .map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to connect to worker socket '{}': {err}", + socket.display() + )) + })?; + Ok(PluginWorkerClient::new(channel)) + } + #[cfg(not(unix))] + WorkerConnectEndpoint::Tcp(endpoint) => { + let channel = Endpoint::from_shared(endpoint.clone()) + .map_err(|err| PluginError::Internal(format!("invalid worker endpoint: {err}")))? + .connect() + .await + .map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to connect to worker endpoint '{endpoint}': {err}" + )) + })?; + Ok(PluginWorkerClient::new(channel)) + } + #[cfg(not(unix))] + WorkerConnectEndpoint::Announced(path) => Err(PluginError::Internal(format!( + "worker endpoint file '{}' was not resolved before connect", + path.display() + ))), + } +} + +struct WorkerProcessLaunch<'a> { + runtime: WorkerRuntime, + manifest_path: &'a Path, + plugin_id: &'a str, + entrypoint: &'a str, + activation_id: &'a str, + auth_token: &'a str, + host_endpoint: &'a str, + worker_endpoint: &'a str, + worker_endpoint_file: Option<&'a Path>, +} + +fn spawn_worker_process(spec: WorkerProcessLaunch<'_>) -> crate::plugin::Result { + let manifest_dir = spec + .manifest_path + .parent() + .unwrap_or_else(|| Path::new(".")); + let (mut command, command_display) = match spec.runtime { + WorkerRuntime::Python => { + let python = std::env::var("NEMO_RELAY_PYTHON").unwrap_or_else(|_| "python3".into()); + let mut command = Command::new(python); + command + .arg("-c") + .arg(PYTHON_WORKER_BOOTSTRAP) + .arg(spec.entrypoint); + (command, spec.entrypoint.to_string()) + } + WorkerRuntime::Rust | WorkerRuntime::Command => { + let entrypoint = resolve_manifest_relative_path(spec.manifest_path, spec.entrypoint); + let command_display = entrypoint.display().to_string(); + (Command::new(entrypoint), command_display) + } + }; + command + .current_dir(manifest_dir) + .env("NEMO_RELAY_WORKER_ID", spec.activation_id) + .env("NEMO_RELAY_PLUGIN_ID", spec.plugin_id) + .env("NEMO_RELAY_WORKER_SOCKET", spec.worker_endpoint) + .env("NEMO_RELAY_HOST_SOCKET", spec.host_endpoint) + .env("NEMO_RELAY_WORKER_TOKEN", spec.auth_token) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::inherit()); + if let Some(path) = spec.worker_endpoint_file { + command.env("NEMO_RELAY_WORKER_ENDPOINT_FILE", path); + } + command.spawn().map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to spawn {} worker '{}': {err}", + spec.runtime, command_display + )) + }) +} + +impl WorkerPluginInstance { + fn install_registrations( + &self, + ctx: &mut PluginRegistrationContext, + ) -> crate::plugin::Result<()> { + for registration in &self.registrations { + let surface = RegistrationSurface::try_from(registration.surface).map_err(|_| { + PluginError::RegistrationFailed(format!( + "worker plugin '{}' returned unsupported registration surface {}", + self.plugin_kind, registration.surface + )) + })?; + let name = registration.local_name.clone(); + let priority = registration.priority; + let break_chain = registration.break_chain; + match surface { + RegistrationSurface::Subscriber => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_subscriber( + &name, + Arc::new(move |event| { + let _ = instance.invoke_subscriber(&callback_name, event); + }), + )?; + } + RegistrationSurface::ToolSanitizeRequestGuardrail => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_tool_sanitize_request_guardrail( + &name, + priority, + Arc::new(move |tool_name, value| { + instance + .invoke_tool_json( + &callback_name, + RegistrationSurface::ToolSanitizeRequestGuardrail, + tool_name, + value.clone(), + None, + ) + .unwrap_or(value) + }), + )?; + } + RegistrationSurface::ToolSanitizeResponseGuardrail => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_tool_sanitize_response_guardrail( + &name, + priority, + Arc::new(move |tool_name, value| { + instance + .invoke_tool_json( + &callback_name, + RegistrationSurface::ToolSanitizeResponseGuardrail, + tool_name, + value.clone(), + None, + ) + .unwrap_or(value) + }), + )?; + } + RegistrationSurface::ToolConditionalExecutionGuardrail => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_tool_conditional_execution_guardrail( + &name, + priority, + Arc::new(move |tool_name, value| { + instance.invoke_tool_guardrail(&callback_name, tool_name, value.clone()) + }), + )?; + } + RegistrationSurface::ToolRequestIntercept => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_tool_request_intercept( + &name, + priority, + break_chain, + Arc::new(move |tool_name, value| { + instance.invoke_tool_json( + &callback_name, + RegistrationSurface::ToolRequestIntercept, + tool_name, + value, + None, + ) + }), + )?; + } + RegistrationSurface::ToolExecutionIntercept => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_tool_execution_intercept( + &name, + priority, + Arc::new(move |tool_name, value, next| { + let instance = instance.clone(); + let name = callback_name.clone(); + let tool_name = tool_name.to_string(); + Box::pin(async move { + instance + .invoke_tool_execution(&name, &tool_name, value, next) + .await + }) + }), + )?; + } + RegistrationSurface::LlmSanitizeRequestGuardrail => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_llm_sanitize_request_guardrail( + &name, + priority, + Arc::new(move |request| { + instance + .invoke_llm_request_json( + &callback_name, + RegistrationSurface::LlmSanitizeRequestGuardrail, + "", + request.clone(), + None, + None, + ) + .unwrap_or(request) + }), + )?; + } + RegistrationSurface::LlmSanitizeResponseGuardrail => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_llm_sanitize_response_guardrail( + &name, + priority, + Arc::new(move |value| { + instance + .invoke_llm_response_json( + &callback_name, + RegistrationSurface::LlmSanitizeResponseGuardrail, + "", + value.clone(), + ) + .unwrap_or(value) + }), + )?; + } + RegistrationSurface::LlmConditionalExecutionGuardrail => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_llm_conditional_execution_guardrail( + &name, + priority, + Arc::new(move |request| { + instance.invoke_llm_guardrail(&callback_name, request.clone()) + }), + )?; + } + RegistrationSurface::LlmRequestIntercept => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_llm_request_intercept( + &name, + priority, + break_chain, + Arc::new(move |model_name, request, annotated| { + instance.invoke_llm_request_intercept( + &callback_name, + model_name, + request, + annotated, + ) + }), + )?; + } + RegistrationSurface::LlmExecutionIntercept => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_llm_execution_intercept( + &name, + priority, + Arc::new(move |model_name, request, next| { + let instance = instance.clone(); + let name = callback_name.clone(); + let model_name = model_name.to_string(); + Box::pin(async move { + instance + .invoke_llm_execution(&name, &model_name, request, next) + .await + }) + }), + )?; + } + RegistrationSurface::LlmStreamExecutionIntercept => { + let instance = Arc::new(self.clone_for_callback()); + let callback_name = name.clone(); + ctx.register_llm_stream_execution_intercept( + &name, + priority, + Arc::new(move |model_name, request, next| { + let instance = instance.clone(); + let name = callback_name.clone(); + let model_name = model_name.to_string(); + Box::pin(async move { + instance + .invoke_llm_stream_execution(&name, &model_name, request, next) + .await + }) + }), + )?; + } + RegistrationSurface::Unspecified => { + return Err(PluginError::RegistrationFailed(format!( + "worker plugin '{}' returned unspecified registration surface", + self.plugin_kind + ))); + } + } + } + Ok(()) + } + + fn clone_for_callback(&self) -> WorkerPluginCallback { + WorkerPluginCallback { + activation_id: self.host_state.activation_id.clone(), + runtime: self.runtime.handle(), + client: self.client.clone(), + host_state: self.host_state.clone(), + } + } +} + +#[cfg(not(unix))] +fn bind_loopback_listener() -> crate::plugin::Result<(TcpListener, SocketAddr)> { + let listener = TcpListener::bind(("127.0.0.1", 0)).map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to bind worker host runtime endpoint: {err}" + )) + })?; + listener.set_nonblocking(true).map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to configure worker host runtime endpoint: {err}" + )) + })?; + let addr = listener.local_addr().map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to inspect worker host runtime endpoint: {err}" + )) + })?; + Ok((listener, addr)) +} + +#[derive(Clone)] +struct WorkerPluginCallback { + activation_id: String, + runtime: tokio::runtime::Handle, + client: PluginWorkerClient, + host_state: Arc, +} + +impl WorkerPluginCallback { + fn invoke_subscriber(&self, registration_name: &str, event: &Event) -> FlowResult<()> { + let request = self.base_request( + registration_name, + RegistrationSurface::Subscriber, + None, + Some(invoke_request_payload_event(event)), + ); + let response = self.invoke_blocking(request)?; + match response.result { + Some(invoke_response_result::Result::Empty(_)) | None => Ok(()), + Some(invoke_response_result::Result::Error(error)) => Err(worker_error_to_flow(error)), + _ => Err(FlowError::Internal( + "worker subscriber returned unexpected result".into(), + )), + } + } + + fn invoke_tool_json( + &self, + registration_name: &str, + surface: RegistrationSurface, + tool_name: &str, + value: Json, + continuation_id: Option, + ) -> FlowResult { + let request = self.base_request( + registration_name, + surface, + continuation_id, + Some(invoke_request_payload_tool(tool_name, value)), + ); + json_from_invoke_response(self.invoke_blocking(request)?) + } + + fn invoke_tool_guardrail( + &self, + registration_name: &str, + tool_name: &str, + value: Json, + ) -> FlowResult> { + let request = self.base_request( + registration_name, + RegistrationSurface::ToolConditionalExecutionGuardrail, + None, + Some(invoke_request_payload_tool(tool_name, value)), + ); + guardrail_from_invoke_response(self.invoke_blocking(request)?) + } + + async fn invoke_tool_execution( + &self, + registration_name: &str, + tool_name: &str, + value: Json, + next: ToolExecutionNextFn, + ) -> FlowResult { + let continuation_id = self + .host_state + .insert_continuation(Continuation::Tool(next))?; + let callback = self.clone(); + let registration_name = registration_name.to_string(); + let tool_name = tool_name.to_string(); + tokio::task::spawn_blocking(move || { + callback.invoke_tool_json( + ®istration_name, + RegistrationSurface::ToolExecutionIntercept, + &tool_name, + value, + Some(continuation_id.clone()), + ) + }) + .await + .map_err(|err| FlowError::Internal(format!("worker task join failed: {err}")))? + } + + fn invoke_llm_request_json( + &self, + registration_name: &str, + surface: RegistrationSurface, + model_name: &str, + request: LlmRequest, + annotated: Option, + continuation_id: Option, + ) -> FlowResult { + let invoke = self.base_request( + registration_name, + surface, + continuation_id, + Some(invoke_request_payload_llm( + model_name, + Some(request), + annotated, + None, + )), + ); + let value = json_from_invoke_response(self.invoke_blocking(invoke)?)?; + serde_json::from_value(value).map_err(|err| { + FlowError::Internal(format!("worker returned invalid LLM request: {err}")) + }) + } + + fn invoke_llm_response_json( + &self, + registration_name: &str, + surface: RegistrationSurface, + model_name: &str, + response: Json, + ) -> FlowResult { + let invoke = self.base_request( + registration_name, + surface, + None, + Some(invoke_request_payload_llm( + model_name, + None, + None, + Some(response), + )), + ); + json_from_invoke_response(self.invoke_blocking(invoke)?) + } + + fn invoke_llm_guardrail( + &self, + registration_name: &str, + request: LlmRequest, + ) -> FlowResult> { + let invoke = self.base_request( + registration_name, + RegistrationSurface::LlmConditionalExecutionGuardrail, + None, + Some(invoke_request_payload_llm("", Some(request), None, None)), + ); + guardrail_from_invoke_response(self.invoke_blocking(invoke)?) + } + + fn invoke_llm_request_intercept( + &self, + registration_name: &str, + model_name: &str, + request: LlmRequest, + annotated: Option, + ) -> FlowResult<(LlmRequest, Option)> { + let invoke = self.base_request( + registration_name, + RegistrationSurface::LlmRequestIntercept, + None, + Some(invoke_request_payload_llm( + model_name, + Some(request), + annotated, + None, + )), + ); + let response = self.invoke_blocking(invoke)?; + match response.result { + Some(invoke_response_result::Result::LlmRequest(result)) => { + let request = required_envelope(result.request, "llm request intercept request")?; + let request = decode_json_envelope::(&request).map_err(|err| { + FlowError::Internal(format!("worker returned invalid LLM request: {err}")) + })?; + let annotated = if result.has_annotated_request { + let envelope = required_envelope( + result.annotated_request, + "llm request intercept annotated request", + )?; + Some( + decode_json_envelope::(&envelope).map_err(|err| { + FlowError::Internal(format!( + "worker returned invalid annotated LLM request: {err}" + )) + })?, + ) + } else { + None + }; + Ok((request, annotated)) + } + Some(invoke_response_result::Result::Error(error)) => Err(worker_error_to_flow(error)), + _ => Err(FlowError::Internal( + "worker LLM request intercept returned unexpected result".into(), + )), + } + } + + async fn invoke_llm_execution( + &self, + registration_name: &str, + model_name: &str, + request: LlmRequest, + next: LlmExecutionNextFn, + ) -> FlowResult { + let continuation_id = self + .host_state + .insert_continuation(Continuation::Llm(next))?; + let callback = self.clone(); + let registration_name = registration_name.to_string(); + let model_name = model_name.to_string(); + tokio::task::spawn_blocking(move || { + let invoke = callback.base_request( + ®istration_name, + RegistrationSurface::LlmExecutionIntercept, + Some(continuation_id), + Some(invoke_request_payload_llm( + &model_name, + Some(request), + None, + None, + )), + ); + json_from_invoke_response(callback.invoke_blocking(invoke)?) + }) + .await + .map_err(|err| FlowError::Internal(format!("worker task join failed: {err}")))? + } + + async fn invoke_llm_stream_execution( + &self, + registration_name: &str, + model_name: &str, + request: LlmRequest, + next: LlmStreamExecutionNextFn, + ) -> FlowResult { + let continuation_id = self + .host_state + .insert_continuation(Continuation::LlmStream(next))?; + let invoke = self.base_request( + registration_name, + RegistrationSurface::LlmStreamExecutionIntercept, + Some(continuation_id.clone()), + Some(invoke_request_payload_llm( + model_name, + Some(request), + None, + None, + )), + ); + let scope_stack_id = invoke + .scope + .as_ref() + .map(|scope| scope.scope_stack_id.clone()) + .unwrap_or_default(); + let mut client = self.client.clone(); + let host_state = self.host_state.clone(); + let (tx, rx) = mpsc::channel(16); + self.runtime.spawn(async move { + let result = worker_rpc(client.invoke_stream(worker_rpc_request(invoke))).await; + match result { + Ok(response) => { + let mut stream = response.into_inner(); + while let Some(item) = stream.next().await { + let result = match item { + Ok(chunk) => json_from_stream_chunk(chunk), + Err(err) => Err(FlowError::Internal(format!( + "worker stream transport failed: {err}" + ))), + }; + if tx.send(result).await.is_err() { + break; + } + } + } + Err(err) => { + let _ = tx + .send(Err(FlowError::Internal(format!( + "worker stream invoke failed: {err}" + )))) + .await; + } + } + host_state.remove_continuation(&continuation_id); + if !scope_stack_id.is_empty() { + host_state.remove_invocation_scope_stack(&scope_stack_id); + } + }); + Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx))) + } + + fn base_request( + &self, + registration_name: &str, + surface: RegistrationSurface, + continuation_id: Option, + payload: Option, + ) -> InvokeRequest { + let scope_stack_id = self + .host_state + .insert_invocation_scope_stack(current_scope_stack()); + InvokeRequest { + activation_id: self.activation_id.clone(), + auth_token: self.host_state.auth_token.clone(), + invocation_id: Uuid::now_v7().to_string(), + registration_name: registration_name.into(), + surface: surface as i32, + continuation_id: continuation_id.unwrap_or_default(), + scope: Some(ScopeContext { + scope_stack_id, + parent_scope_id: String::new(), + }), + payload, + } + } + + fn invoke_blocking(&self, request: InvokeRequest) -> FlowResult { + let scope_stack_id = request + .scope + .as_ref() + .map(|scope| scope.scope_stack_id.clone()) + .unwrap_or_default(); + let continuation_id = request.continuation_id.clone(); + let mut client = self.client.clone(); + let result = block_on_handle(&self.runtime, async move { + worker_rpc(client.invoke(worker_rpc_request(request))).await + }) + .map(|response| response.into_inner()) + .map_err(|err| FlowError::Internal(format!("worker invoke failed: {err}"))); + if !continuation_id.is_empty() { + self.host_state.remove_continuation(&continuation_id); + } + if !scope_stack_id.is_empty() { + self.host_state + .remove_invocation_scope_stack(&scope_stack_id); + } + result + } +} + +struct OwnedWorkerRuntime { + runtime: Option, +} + +impl OwnedWorkerRuntime { + fn new(runtime: Runtime) -> Self { + Self { + runtime: Some(runtime), + } + } + + fn runtime(&self) -> &Runtime { + self.runtime + .as_ref() + .expect("worker runtime accessed after drop") + } + + fn handle(&self) -> tokio::runtime::Handle { + self.runtime().handle().clone() + } +} + +impl Drop for OwnedWorkerRuntime { + fn drop(&mut self) { + let Some(runtime) = self.runtime.take() else { + return; + }; + if tokio::runtime::Handle::try_current().is_ok() { + std::thread::scope(|scope| { + scope + .spawn(move || drop(runtime)) + .join() + .expect("worker runtime drop thread panicked"); + }); + } else { + drop(runtime); + } + } +} + +struct ChildGuard { + child: Option, +} + +impl ChildGuard { + fn new(child: Child) -> Self { + Self { child: Some(child) } + } + + fn take(&mut self) -> Child { + self.child.take().expect("worker child already taken") + } +} + +impl Drop for ChildGuard { + fn drop(&mut self) { + if let Some(mut child) = self.child.take() { + let _ = child.kill(); + let _ = child.wait(); + } + } +} + +struct ActivationDirGuard { + path: Option, +} + +impl ActivationDirGuard { + fn new(path: PathBuf) -> Self { + Self { path: Some(path) } + } + + fn keep(&mut self) -> PathBuf { + self.path + .take() + .expect("worker activation directory already taken") + } +} + +impl Drop for ActivationDirGuard { + fn drop(&mut self) { + if let Some(path) = self.path.take() { + let _ = std::fs::remove_dir_all(path); + } + } +} + +fn worker_rpc_request(message: T) -> Request { + Request::new(message) +} + +async fn worker_rpc(future: F) -> Result, Status> +where + F: Future, Status>>, +{ + match tokio::time::timeout(WORKER_RPC_TIMEOUT, future).await { + Ok(result) => result, + Err(_) => Err(Status::deadline_exceeded(format!( + "worker RPC timed out after {}s", + WORKER_RPC_TIMEOUT.as_secs() + ))), + } +} + +fn block_on_runtime(runtime: &Runtime, future: F) -> F::Output +where + F: Future + Send, + F::Output: Send, +{ + if tokio::runtime::Handle::try_current().is_ok() { + std::thread::scope(|scope| { + scope + .spawn(|| runtime.block_on(future)) + .join() + .expect("worker runtime blocking thread panicked") + }) + } else { + runtime.block_on(future) + } +} + +fn block_on_handle(handle: &tokio::runtime::Handle, future: F) -> F::Output +where + F: Future + Send, + F::Output: Send, +{ + if tokio::runtime::Handle::try_current().is_ok() { + let handle = handle.clone(); + std::thread::scope(|scope| { + scope + .spawn(move || handle.block_on(future)) + .join() + .expect("worker callback blocking thread panicked") + }) + } else { + handle.block_on(future) + } +} + +struct WorkerHostRuntimeState { + activation_id: String, + auth_token: String, + scope_stacks: Mutex>, + scope_handles: Mutex>, + continuations: Mutex>, +} + +struct StoredScopeHandle { + handle: ScopeHandle, + scope_stack_id: String, +} + +impl WorkerHostRuntimeState { + fn new(activation_id: String, auth_token: String) -> Self { + Self { + activation_id, + auth_token, + scope_stacks: Mutex::new(HashMap::new()), + scope_handles: Mutex::new(HashMap::new()), + continuations: Mutex::new(HashMap::new()), + } + } + + fn authorize(&self, activation_id: &str, token: &str) -> Result<(), Status> { + if activation_id != self.activation_id || token != self.auth_token { + return Err(Status::permission_denied("invalid worker host token")); + } + Ok(()) + } + + fn insert_invocation_scope_stack( + &self, + stack: crate::api::runtime::ScopeStackHandle, + ) -> String { + let id = format!("invoke-{}", Uuid::now_v7()); + if let Ok(mut stacks) = self.scope_stacks.lock() { + stacks.insert(id.clone(), stack); + } + id + } + + fn remove_invocation_scope_stack(&self, id: &str) { + if let Ok(mut stacks) = self.scope_stacks.lock() { + stacks.remove(id); + } + } + + fn insert_continuation(&self, continuation: Continuation) -> FlowResult { + let id = format!("next-{}", Uuid::now_v7()); + let mut continuations = self + .continuations + .lock() + .map_err(|err| FlowError::Internal(format!("continuation lock poisoned: {err}")))?; + continuations.insert(id.clone(), continuation); + Ok(id) + } + + fn remove_continuation(&self, id: &str) { + if let Ok(mut continuations) = self.continuations.lock() { + continuations.remove(id); + } + } + + fn continuation(&self, id: &str) -> Result { + self.continuations + .lock() + .map_err(|err| Status::internal(format!("continuation lock poisoned: {err}")))? + .get(id) + .cloned() + .ok_or_else(|| Status::not_found("continuation not found")) + } + + fn stack(&self, id: &str) -> Result, Status> { + if id.is_empty() { + return Ok(None); + } + self.scope_stacks + .lock() + .map_err(|err| Status::internal(format!("scope stack lock poisoned: {err}")))? + .get(id) + .cloned() + .map(Some) + .ok_or_else(|| Status::not_found("scope stack not found")) + } +} + +#[derive(Clone)] +enum Continuation { + Tool(ToolExecutionNextFn), + Llm(LlmExecutionNextFn), + LlmStream(LlmStreamExecutionNextFn), +} + +struct WorkerHostRuntimeService { + state: Arc, +} + +#[tonic::async_trait] +impl RelayHostRuntime for WorkerHostRuntimeService { + async fn emit_mark( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .authorize(&request.activation_id, &request.auth_token)?; + let result = self.with_stack(request.scope.as_ref(), || { + emit_scope_mark( + EmitMarkEventParams::builder() + .name(&request.name) + .data_opt(optional_envelope_to_json(request.data)?) + .metadata_opt(optional_envelope_to_json(request.metadata)?) + .build(), + ) + }); + Ok(Response::new(host_ack(result))) + } + + async fn push_scope( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .authorize(&request.activation_id, &request.auth_token)?; + let result = self.with_stack(request.scope.as_ref(), || { + push_scope( + PushScopeParams::builder() + .name(&request.name) + .scope_type(proto_scope_type(request.scope_type)) + .attributes(ScopeAttributes::empty()) + .data_opt(optional_envelope_to_json(request.data)?) + .metadata_opt(optional_envelope_to_json(request.metadata)?) + .input_opt(optional_envelope_to_json(request.input)?) + .build(), + ) + }); + match result { + Ok(handle) => { + let id = format!("scope-{}", handle.uuid); + let scope_stack_id = request + .scope + .as_ref() + .map(|scope| scope.scope_stack_id.clone()) + .unwrap_or_default(); + self.state + .scope_handles + .lock() + .map_err(|err| Status::internal(format!("scope handle lock poisoned: {err}")))? + .insert( + id.clone(), + StoredScopeHandle { + handle, + scope_stack_id, + }, + ); + Ok(Response::new(PushScopeResponse { + scope_handle_id: id, + error: None, + })) + } + Err(err) => Ok(Response::new(PushScopeResponse { + scope_handle_id: String::new(), + error: Some(flow_error_to_worker(err)), + })), + } + } + + async fn pop_scope( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .authorize(&request.activation_id, &request.auth_token)?; + let handle = self + .state + .scope_handles + .lock() + .map_err(|err| Status::internal(format!("scope handle lock poisoned: {err}")))? + .remove(&request.scope_handle_id) + .ok_or_else(|| Status::not_found("scope handle not found"))?; + let output = optional_envelope_to_json(request.output).map_err(status_from_flow)?; + let metadata = optional_envelope_to_json(request.metadata).map_err(status_from_flow)?; + let pop = || { + pop_scope( + PopScopeParams::builder() + .handle_uuid(&handle.handle.uuid) + .output_opt(output) + .metadata_opt(metadata) + .build(), + ) + }; + let result = if handle.scope_stack_id.is_empty() { + pop() + } else if let Some(stack) = self.state.stack(&handle.scope_stack_id)? { + with_scope_stack(stack, pop) + } else { + pop() + }; + Ok(Response::new(host_ack(result))) + } + + async fn create_scope_stack( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .authorize(&request.activation_id, &request.auth_token)?; + let id = format!("stack-{}", Uuid::now_v7()); + self.state + .scope_stacks + .lock() + .map_err(|err| Status::internal(format!("scope stack lock poisoned: {err}")))? + .insert(id.clone(), crate::api::runtime::create_scope_stack()); + Ok(Response::new(CreateScopeStackResponse { + scope_stack_id: id, + error: None, + })) + } + + async fn drop_scope_stack( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .authorize(&request.activation_id, &request.auth_token)?; + self.state + .scope_stacks + .lock() + .map_err(|err| Status::internal(format!("scope stack lock poisoned: {err}")))? + .remove(&request.scope_stack_id); + Ok(Response::new(HostAck { + ok: true, + error: None, + })) + } + + async fn tool_next( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .authorize(&request.activation_id, &request.auth_token)?; + let continuation = self.state.continuation(&request.continuation_id)?; + let Continuation::Tool(next) = continuation else { + return Err(Status::invalid_argument( + "continuation is not a tool continuation", + )); + }; + let value = + required_envelope(request.value, "tool next value").map_err(status_from_flow)?; + let value = decode_json_envelope::(&value) + .map_err(|err| Status::invalid_argument(format!("invalid tool next JSON: {err}")))?; + let result = next(value).await; + Ok(Response::new(json_result(result))) + } + + async fn llm_next( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .authorize(&request.activation_id, &request.auth_token)?; + let continuation = self.state.continuation(&request.continuation_id)?; + let Continuation::Llm(next) = continuation else { + return Err(Status::invalid_argument( + "continuation is not an LLM continuation", + )); + }; + let request = + required_envelope(request.request, "llm next request").map_err(status_from_flow)?; + let request = decode_json_envelope::(&request) + .map_err(|err| Status::invalid_argument(format!("invalid LLM next request: {err}")))?; + let result = next(request).await; + Ok(Response::new(json_result(result))) + } + + type LlmStreamNextStream = + Pin> + Send>>; + + async fn llm_stream_next( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.state + .authorize(&request.activation_id, &request.auth_token)?; + let continuation = self.state.continuation(&request.continuation_id)?; + let Continuation::LlmStream(next) = continuation else { + return Err(Status::invalid_argument( + "continuation is not an LLM stream continuation", + )); + }; + let request = required_envelope(request.request, "llm stream next request") + .map_err(status_from_flow)?; + let request = decode_json_envelope::(&request).map_err(|err| { + Status::invalid_argument(format!("invalid LLM stream next request: {err}")) + })?; + let stream = next(request).await.map_err(status_from_flow)?; + let mapped = stream.map(|item| match item { + Ok(value) => Ok(StreamChunk { + item: Some(stream_chunk_item::Item::Value(json_envelope_infallible( + JSON_SCHEMA, + &value, + ))), + }), + Err(err) => Ok(StreamChunk { + item: Some(stream_chunk_item::Item::Error(flow_error_to_worker(err))), + }), + }); + Ok(Response::new(Box::pin(mapped))) + } +} + +impl WorkerHostRuntimeService { + fn with_stack( + &self, + scope: Option<&ScopeContext>, + f: impl FnOnce() -> FlowResult, + ) -> FlowResult { + let Some(stack_id) = scope.map(|scope| scope.scope_stack_id.as_str()) else { + return f(); + }; + let Some(stack) = self + .state + .stack(stack_id) + .map_err(|err| FlowError::Internal(err.to_string()))? + else { + return f(); + }; + with_scope_stack(stack, f) + } +} + +mod invoke_request_payload { + pub(crate) use nemo_relay_worker_proto::v1::invoke_request::Payload; +} + +mod invoke_response_result { + pub(crate) use nemo_relay_worker_proto::v1::invoke_response::Result; +} + +mod stream_chunk_item { + pub(crate) use nemo_relay_worker_proto::v1::stream_chunk::Item; +} + +fn invoke_request_payload_event(event: &Event) -> invoke_request_payload::Payload { + invoke_request_payload::Payload::Event(json_envelope_infallible(EVENT_SCHEMA, event)) +} + +fn invoke_request_payload_tool(tool_name: &str, value: Json) -> invoke_request_payload::Payload { + invoke_request_payload::Payload::Tool(ToolInvocation { + tool_name: tool_name.into(), + value: Some(json_envelope_infallible(JSON_SCHEMA, &value)), + }) +} + +fn invoke_request_payload_llm( + model_name: &str, + request: Option, + annotated_request: Option, + response: Option, +) -> invoke_request_payload::Payload { + invoke_request_payload::Payload::Llm(LlmInvocation { + model_name: model_name.into(), + request: request + .as_ref() + .map(|request| json_envelope_infallible(LLM_REQUEST_SCHEMA, request)), + annotated_request: annotated_request + .as_ref() + .map(|request| json_envelope_infallible(ANNOTATED_LLM_REQUEST_SCHEMA, request)), + response: response + .as_ref() + .map(|response| json_envelope_infallible(JSON_SCHEMA, response)), + }) +} + +fn json_envelope_infallible(schema: &str, value: &T) -> JsonEnvelope { + json_envelope(schema, value).expect("Relay DTO JSON serialization should be infallible") +} + +fn json_from_invoke_response(response: InvokeResponse) -> FlowResult { + match response.result { + Some(invoke_response_result::Result::Json(result)) => { + if let Some(error) = result.error { + return Err(worker_error_to_flow(error)); + } + let envelope = required_envelope(result.value, "worker JSON result")?; + decode_json_envelope::(&envelope).map_err(|err| { + FlowError::Internal(format!("worker returned invalid JSON result: {err}")) + }) + } + Some(invoke_response_result::Result::Error(error)) => Err(worker_error_to_flow(error)), + _ => Err(FlowError::Internal( + "worker returned unexpected invoke result".into(), + )), + } +} + +fn guardrail_from_invoke_response(response: InvokeResponse) -> FlowResult> { + match response.result { + Some(invoke_response_result::Result::Guardrail(GuardrailResult { block_reason })) => { + Ok((!block_reason.is_empty()).then_some(block_reason)) + } + Some(invoke_response_result::Result::Error(error)) => Err(worker_error_to_flow(error)), + _ => Err(FlowError::Internal( + "worker guardrail returned unexpected invoke result".into(), + )), + } +} + +fn json_from_stream_chunk(chunk: StreamChunk) -> FlowResult { + match chunk.item { + Some(stream_chunk_item::Item::Value(value)) => decode_json_envelope::(&value) + .map_err(|err| FlowError::Internal(format!("invalid worker stream chunk: {err}"))), + Some(stream_chunk_item::Item::Error(error)) => Err(worker_error_to_flow(error)), + None => Err(FlowError::Internal("worker stream chunk was empty".into())), + } +} + +fn required_envelope(value: Option, field: &str) -> FlowResult { + value.ok_or_else(|| FlowError::Internal(format!("{field} is missing"))) +} + +fn optional_envelope_to_json(value: Option) -> FlowResult> { + value + .map(|value| { + decode_json_envelope::(&value) + .map_err(|err| FlowError::Internal(format!("invalid JSON envelope: {err}"))) + }) + .transpose() +} + +fn host_ack(result: FlowResult<()>) -> HostAck { + match result { + Ok(()) => HostAck { + ok: true, + error: None, + }, + Err(err) => HostAck { + ok: false, + error: Some(flow_error_to_worker(err)), + }, + } +} + +fn json_result(result: FlowResult) -> JsonResult { + match result { + Ok(value) => JsonResult { + value: Some(json_envelope_infallible(JSON_SCHEMA, &value)), + error: None, + }, + Err(err) => JsonResult { + value: None, + error: Some(flow_error_to_worker(err)), + }, + } +} + +fn flow_error_to_worker(err: FlowError) -> WorkerError { + WorkerError { + code: "host.runtime_error".into(), + message: err.to_string(), + retryable: false, + } +} + +fn worker_error_to_flow(error: WorkerError) -> FlowError { + FlowError::Internal(format!("{}: {}", error.code, error.message)) +} + +fn worker_error_to_plugin(error: WorkerError, fallback: &str) -> PluginError { + let message = if error.message.is_empty() { + fallback.to_string() + } else { + format!("{}: {}", error.code, error.message) + }; + PluginError::RegistrationFailed(message) +} + +fn status_from_flow(err: FlowError) -> Status { + Status::internal(err.to_string()) +} + +fn proto_scope_type(scope_type: i32) -> ScopeType { + match nemo_relay_worker_proto::v1::ScopeType::try_from(scope_type) { + Ok(nemo_relay_worker_proto::v1::ScopeType::Agent) => ScopeType::Agent, + Ok(nemo_relay_worker_proto::v1::ScopeType::Function) => ScopeType::Function, + Ok(nemo_relay_worker_proto::v1::ScopeType::Tool) => ScopeType::Tool, + Ok(nemo_relay_worker_proto::v1::ScopeType::Llm) => ScopeType::Llm, + Ok(nemo_relay_worker_proto::v1::ScopeType::Retriever) => ScopeType::Retriever, + Ok(nemo_relay_worker_proto::v1::ScopeType::Embedder) => ScopeType::Embedder, + Ok(nemo_relay_worker_proto::v1::ScopeType::Reranker) => ScopeType::Reranker, + Ok(nemo_relay_worker_proto::v1::ScopeType::Guardrail) => ScopeType::Guardrail, + Ok(nemo_relay_worker_proto::v1::ScopeType::Evaluator) => ScopeType::Evaluator, + Ok(nemo_relay_worker_proto::v1::ScopeType::Custom) => ScopeType::Custom, + Ok(nemo_relay_worker_proto::v1::ScopeType::Unknown) => ScopeType::Unknown, + _ => ScopeType::Custom, + } +} + +fn validate_registration_plan( + plugin_id: &str, + response: &RegisterResponse, +) -> crate::plugin::Result<()> { + for registration in &response.registrations { + if registration.local_name.trim().is_empty() { + return Err(PluginError::RegistrationFailed(format!( + "worker plugin '{plugin_id}' returned a registration with empty local_name" + ))); + } + let surface = RegistrationSurface::try_from(registration.surface).map_err(|_| { + PluginError::RegistrationFailed(format!( + "worker plugin '{plugin_id}' returned unsupported registration surface {}", + registration.surface + )) + })?; + if surface == RegistrationSurface::Unspecified { + return Err(PluginError::RegistrationFailed(format!( + "worker plugin '{plugin_id}' returned unspecified registration surface" + ))); + } + } + Ok(()) +} + +fn diagnostics_have_errors(diagnostics: &[ConfigDiagnostic]) -> bool { + diagnostics + .iter() + .any(|diagnostic| diagnostic.level == DiagnosticLevel::Error) +} + +fn worker_error_diagnostic(plugin_kind: &str, code: &str, message: &str) -> ConfigDiagnostic { + ConfigDiagnostic { + level: DiagnosticLevel::Error, + code: code.into(), + component: Some(plugin_kind.into()), + field: None, + message: message.into(), + } +} + +fn validate_relay_compatibility(relay: Option<&str>) -> crate::plugin::Result<()> { + let relay = relay + .map(str::trim) + .filter(|value| !value.is_empty()) + .ok_or_else(|| PluginError::InvalidConfig("compat.relay is required".into()))?; + let req = VersionReq::parse(relay).map_err(|err| { + PluginError::InvalidConfig(format!("invalid compat.relay version requirement: {err}")) + })?; + let version = Version::parse(env!("CARGO_PKG_VERSION")) + .map_err(|err| PluginError::Internal(format!("failed to parse host version: {err}")))?; + if req.matches(&version) { + Ok(()) + } else { + Err(PluginError::InvalidConfig(format!( + "worker plugin requires relay '{relay}' but host version is {version}" + ))) + } +} + +fn resolve_manifest_relative_path(manifest_path: &Path, value: &str) -> PathBuf { + let path = PathBuf::from(value); + if path.is_absolute() { + path + } else { + manifest_path + .parent() + .map(|parent| parent.join(&path)) + .unwrap_or(path) + } +} + +#[cfg(unix)] +fn unix_endpoint_display(path: &Path) -> String { + format!("unix://{}", path.display()) +} + +#[cfg(test)] +#[path = "../../../tests/unit/dynamic_worker_tests.rs"] +mod tests; diff --git a/crates/core/tests/fixtures/worker_plugin/Cargo.lock b/crates/core/tests/fixtures/worker_plugin/Cargo.lock new file mode 100644 index 00000000..fe31ce56 --- /dev/null +++ b/crates/core/tests/fixtures/worker_plugin/Cargo.lock @@ -0,0 +1,1372 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anyhow" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a4385e2e34eb35d6b3efe798b9eb88096925d87726c0798709bf56d9ed84af3" + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" + +[[package]] +name = "axum" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" +dependencies = [ + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" +dependencies = [ + "serde_core", +] + +[[package]] +name = "bumpalo" +version = "3.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" + +[[package]] +name = "bytes" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ae3f5d315924270530207e2a68396c3cc547f6dca3fbdca317cfb1a51edb593" + +[[package]] +name = "cc" +version = "1.2.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e228eec9be7c17ccb640b59b36a5cd805ea2a564a4c5e162c2f659fea30d3b96" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chrono" +version = "0.4.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "either" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-macro", + "futures-task", + "pin-project-lite", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "300e883d756b2e4ec94e02791f39b04b522276138852cfc41d9fb7e904106099" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", +] + +[[package]] +name = "h2" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6cb093c84e8bd9b188d4c4a8cb6579fc016968d14c99882163cd3ff402a4f155" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "http" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6970f50e31d6fc17d3fa27329444bfa74e196cf62e95052a3f6fee181dba6425" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55281c53a1894c864990125767da440a4e630446785086f52523b20033b74498" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "libc", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "indexmap" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" +dependencies = [ + "equivalent", + "hashbrown 0.17.1", +] + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "js-sys" +version = "0.3.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53b44bfcdb3f8d5837a46dae1ca9660a837176eee74a28b229bc626816589102" +dependencies = [ + "cfg-if", + "futures-util", + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "log" +version = "0.4.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ceec5bc11778974d1bcb055b18002eba7f4b3518b6a0081b3af5f21666da9ad" + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + +[[package]] +name = "memchr" +version = "2.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88904434abc2901f197fe8cc55f0445e7ded921dba5911dad2e2b39b48e663c4" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mio" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02bd0af71c67b473010cbbc60715ee815645a4dc942899111f494b4b737d6fda" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + +[[package]] +name = "multimap" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" + +[[package]] +name = "nemo-relay-types" +version = "0.5.0" +dependencies = [ + "bitflags", + "chrono", + "serde", + "serde_json", + "typed-builder", + "uuid", +] + +[[package]] +name = "nemo-relay-worker" +version = "0.5.0" +dependencies = [ + "futures-util", + "hyper-util", + "nemo-relay-types", + "nemo-relay-worker-proto", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tonic", + "tower", +] + +[[package]] +name = "nemo-relay-worker-plugin-fixture" +version = "0.0.0" +dependencies = [ + "nemo-relay-worker", + "serde_json", + "tokio", + "tokio-stream", +] + +[[package]] +name = "nemo-relay-worker-proto" +version = "0.5.0" +dependencies = [ + "prost", + "prost-build", + "protoc-bin-vendored", + "serde", + "serde_json", + "tonic", + "tonic-prost", + "tonic-prost-build", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "petgraph" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.5", + "indexmap", +] + +[[package]] +name = "pin-project" +version = "1.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2466b2336ed02bcdca6b294417127b90ec92038d1d5c4fbeac971a922e0e0924" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c96395f0a926bc13b1c17622aaddda1ecb55d49c8f1bf9777e4d877800a43f8b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "prost" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "528ac67416ff8646872a3c02cad9cc4ee5dc9f9540c9b10771855c95cb2e5ae1" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03da047801ff44bb6a4d407d4860c05fd70bb81714e6b2f3812603d5b145b042" +dependencies = [ + "heck", + "itertools", + "log", + "multimap", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "pulldown-cmark", + "pulldown-cmark-to-cmark", + "regex", + "syn", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b570b25f7617e43d59005d0990ccb79e950a423952cea19671b7a876da390adf" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-types" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f94967dc7688f3054c7fac87473ffae4cc4c3904800e2d9f5b857246d8963b0a" +dependencies = [ + "prost", +] + +[[package]] +name = "protoc-bin-vendored" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1c381df33c98266b5f08186583660090a4ffa0889e76c7e9a5e175f645a67fa" +dependencies = [ + "protoc-bin-vendored-linux-aarch_64", + "protoc-bin-vendored-linux-ppcle_64", + "protoc-bin-vendored-linux-s390_64", + "protoc-bin-vendored-linux-x86_32", + "protoc-bin-vendored-linux-x86_64", + "protoc-bin-vendored-macos-aarch_64", + "protoc-bin-vendored-macos-x86_64", + "protoc-bin-vendored-win32", +] + +[[package]] +name = "protoc-bin-vendored-linux-aarch_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c350df4d49b5b9e3ca79f7e646fde2377b199e13cfa87320308397e1f37e1a4c" + +[[package]] +name = "protoc-bin-vendored-linux-ppcle_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55a63e6c7244f19b5c6393f025017eb5d793fd5467823a099740a7a4222440c" + +[[package]] +name = "protoc-bin-vendored-linux-s390_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dba5565db4288e935d5330a07c264a4ee8e4a5b4a4e6f4e83fad824cc32f3b0" + +[[package]] +name = "protoc-bin-vendored-linux-x86_32" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8854774b24ee28b7868cd71dccaae8e02a2365e67a4a87a6cd11ee6cdbdf9cf5" + +[[package]] +name = "protoc-bin-vendored-linux-x86_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b38b07546580df720fa464ce124c4b03630a6fb83e05c336fea2a241df7e5d78" + +[[package]] +name = "protoc-bin-vendored-macos-aarch_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89278a9926ce312e51f1d999fee8825d324d603213344a9a706daa009f1d8092" + +[[package]] +name = "protoc-bin-vendored-macos-x86_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81745feda7ccfb9471d7a4de888f0652e806d5795b61480605d4943176299756" + +[[package]] +name = "protoc-bin-vendored-win32" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95067976aca6421a523e491fce939a3e65249bac4b977adee0ee9771568e8aa3" + +[[package]] +name = "pulldown-cmark" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f068eba8e7071c5f9511831b44f32c740d5adf574e990f946ddb53db2f314e" +dependencies = [ + "bitflags", + "memchr", + "unicase", +] + +[[package]] +name = "pulldown-cmark-to-cmark" +version = "22.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50793def1b900256624a709439404384204a5dc3a6ec580281bfaac35e882e90" +dependencies = [ + "pulldown-cmark", +] + +[[package]] +name = "quote" +version = "1.0.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbc457d0c7a0759a614551b11a6409e5951f6c7537be1f1b7682b9ae9230368" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "regex" +version = "1.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1292b7759ae1cb9ec195452d1390a074f0cd8541ab7a5a8c31cd6db45d4a6ba" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6f6ff9a378485b298a5286656da665ba74413d36db0979633275d2e708145d4" + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.150" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "shlex" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ed6a63f02c8539c91a8685a86f4099661ba3da017932f6ebbea6de3f0fa7c90" + +[[package]] +name = "socket2" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "syn" +version = "2.0.118" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9ae57f904213ebb649ce6895b8a66c66f0203b9319718f69a5612a065b1422" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.3", + "once_cell", + "rustix", + "windows-sys", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio" +version = "1.52.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", + "tokio-util", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tonic" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac2a5518c70fa84342385732db33fb3f44bc4cc748936eb5833d2df34d6445ef" +dependencies = [ + "async-trait", + "axum", + "base64", + "bytes", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "socket2", + "sync_wrapper", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-build" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c68f61875ac5293cf72e6c8cf0158086428c82c37229e98c840878f1706b0322" +dependencies = [ + "prettyplease", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tonic-prost" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50849f68853be452acf590cde0b146665b8d507b3b8af17261df47e02c209ea0" +dependencies = [ + "bytes", + "prost", + "tonic", +] + +[[package]] +name = "tonic-prost-build" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "654e5643eff75d7f8c99197ce1440ed19a3474eada74c12bbac488b2cafdae27" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build", + "prost-types", + "quote", + "syn", + "tempfile", + "tonic-build", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "indexmap", + "pin-project-lite", + "slab", + "sync_wrapper", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "typed-builder" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31aa81521b70f94402501d848ccc0ecaa8f93c8eb6999eb9747e72287757ffda" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076a02dc54dd46795c2e9c8282ed40bcfb1e22747e955de9389a1de28190fb26" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "uuid" +version = "1.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +dependencies = [ + "getrandom 0.3.4", + "js-sys", + "serde", + "wasm-bindgen", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.4+wasi-0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67efb37e106e55ce722a510d6b5f9c17f083e5fc79afc2badeb12cc313d9487" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.126" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b067c0c11094aef6b7a801c1e34a26affafdf3d051dba08456b868789aaf9a4" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.126" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "167ce5e579f6bcf889c4f7175a8a5a585de84e8ff93976ce393efa5f2837aab1" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.126" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3997c7839262f4ef12cf90b818d6340c18e80f263f1a94bf157d0ec4420380e" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.126" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1b4cb0cc549fcf58d7dfc081778139b3d283a081644e833e84682ad71cea24" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/crates/core/tests/fixtures/worker_plugin/Cargo.toml b/crates/core/tests/fixtures/worker_plugin/Cargo.toml new file mode 100644 index 00000000..4b9c4a90 --- /dev/null +++ b/crates/core/tests/fixtures/worker_plugin/Cargo.toml @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[workspace] + +[package] +name = "nemo-relay-worker-plugin-fixture" +version = "0.0.0" +edition = "2024" +license = "Apache-2.0" +publish = false + +[[bin]] +name = "nemo-relay-worker-plugin-fixture" +path = "src/main.rs" + +[dependencies] +nemo-relay-worker = { path = "../../../../worker" } +serde_json = "1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tokio-stream = "0.1" diff --git a/crates/core/tests/fixtures/worker_plugin/src/main.rs b/crates/core/tests/fixtures/worker_plugin/src/main.rs new file mode 100644 index 00000000..c6e1e250 --- /dev/null +++ b/crates/core/tests/fixtures/worker_plugin/src/main.rs @@ -0,0 +1,294 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use nemo_relay_worker::{ + JsonStream, LlmNext, LlmStreamNext, PluginContext, ScopeType, ToolNext, WorkerPlugin, + WorkerSdkError, serve_plugin, +}; +use nemo_relay_worker::{ConfigDiagnostic, DiagnosticLevel, Json, LlmRequest}; +use serde_json::json; + +struct FixtureWorkerPlugin; + +impl WorkerPlugin for FixtureWorkerPlugin { + fn plugin_id(&self) -> &str { + if std::env::var("FIXTURE_WORKER_PLUGIN_ID").as_deref() == Ok("other_worker") { + return "other_worker"; + } + "fixture_worker" + } + + fn validate(&self, config: &Json) -> Vec { + if config + .get("exit_in_validate") + .and_then(Json::as_bool) + .unwrap_or(false) + { + std::process::exit(42); + } + if config + .get("reject") + .and_then(Json::as_bool) + .unwrap_or(false) + { + return vec![ConfigDiagnostic { + level: DiagnosticLevel::Error, + code: "fixture.rejected".into(), + component: Some("fixture_worker".into()), + field: Some("reject".into()), + message: "fixture rejection requested".into(), + }]; + } + Vec::new() + } + + fn register(&self, ctx: &mut PluginContext, config: &Json) -> nemo_relay_worker::Result<()> { + let register_error = config + .get("register_error") + .and_then(Json::as_bool) + .unwrap_or(false); + let exit_in_register = config + .get("exit_in_register") + .and_then(Json::as_bool) + .unwrap_or(false); + if exit_in_register { + std::process::exit(43); + } + if register_error { + return Err(WorkerSdkError::Callback( + "fixture registration error requested".into(), + )); + } + + let empty_registration_name = config + .get("empty_registration_name") + .and_then(Json::as_bool) + .unwrap_or(false); + if empty_registration_name { + ctx.register_subscriber("", |_| {}); + return Ok(()); + } + + let block_tool = config + .get("block_tool") + .and_then(Json::as_bool) + .unwrap_or(false); + let tool_request_error = config + .get("tool_request_error") + .and_then(Json::as_bool) + .unwrap_or(false); + let llm_request_error = config + .get("llm_request_error") + .and_then(Json::as_bool) + .unwrap_or(false); + let llm_stream_open_error = config + .get("llm_stream_open_error") + .and_then(Json::as_bool) + .unwrap_or(false); + + let runtime = ctx + .runtime() + .ok_or_else(|| WorkerSdkError::Callback("runtime handle missing".into()))?; + + ctx.register_subscriber("fixture_subscriber", { + let runtime = runtime.clone(); + move |event| { + if event.name() == "worker-plugin-test-outer" { + let runtime = runtime.clone(); + let _ = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async move { + runtime + .emit_mark( + "fixture.worker.subscriber.mark", + Some(json!("subscriber")), + None, + ) + .await + }) + }); + } + } + }); + + ctx.register_tool_sanitize_request_guardrail( + "fixture_tool_sanitize_request", + 0, + |_name, args| mark_json(args, "worker_plugin_tool_sanitize_request"), + ); + ctx.register_tool_sanitize_response_guardrail( + "fixture_tool_sanitize_response", + 0, + |_name, result| mark_json(result, "worker_plugin_tool_sanitize_response"), + ); + ctx.register_tool_conditional_execution_guardrail( + "fixture_tool_conditional", + 0, + move |_name, _args| { + if block_tool { + Ok(Some("fixture tool blocked".into())) + } else { + Ok(None) + } + }, + ); + ctx.register_tool_request_intercept("fixture_rewrite_args", 0, false, { + let runtime = runtime.clone(); + move |_name, args| { + if tool_request_error { + return Err(WorkerSdkError::Callback( + "fixture tool request error requested".into(), + )); + } + let runtime = runtime.clone(); + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(emit_runtime_events(runtime)) + })?; + Ok(mark_json(args, "worker_plugin")) + } + }); + ctx.register_tool_execution_intercept( + "fixture_tool_execution", + 0, + |_name, args, next: ToolNext| async move { + let result = next + .call(mark_json(args, "worker_plugin_tool_execution_request")) + .await?; + Ok(mark_json(result, "worker_plugin_tool_execution")) + }, + ); + + ctx.register_llm_sanitize_request_guardrail( + "fixture_llm_sanitize_request", + 0, + |request| mark_llm_request(request, "worker_plugin_llm_sanitize_request"), + ); + ctx.register_llm_sanitize_response_guardrail( + "fixture_llm_sanitize_response", + 0, + |response| mark_json(response, "worker_plugin_llm_sanitize_response"), + ); + ctx.register_llm_conditional_execution_guardrail( + "fixture_llm_conditional", + 0, + |_request| Ok(None), + ); + ctx.register_llm_request_intercept( + "fixture_llm_request_intercept", + 0, + false, + move |_name, request, annotated| { + if llm_request_error { + return Err(WorkerSdkError::Callback( + "fixture LLM request error requested".into(), + )); + } + let annotated = annotated.map(|mut annotated| { + annotated + .extra + .insert("worker_plugin_annotated_request".into(), json!(true)); + annotated + }); + Ok((mark_llm_request(request, "worker_plugin_llm_request_intercept"), annotated)) + }, + ); + ctx.register_llm_execution_intercept( + "fixture_llm_execution", + 0, + |_name, request, next: LlmNext| async move { + let response = next + .call(mark_llm_request( + request, + "worker_plugin_llm_execution_request", + )) + .await?; + Ok(mark_json(response, "worker_plugin_llm_execution")) + }, + ); + ctx.register_llm_stream_execution_intercept( + "fixture_llm_stream_execution", + 0, + move |_name, request, next: LlmStreamNext| async move { + if llm_stream_open_error { + return Err(WorkerSdkError::Callback( + "fixture LLM stream open error requested".into(), + )); + } + let stream = next + .call(mark_llm_request( + request, + "worker_plugin_llm_stream_execution_request", + )) + .await?; + let mapped: JsonStream = Box::pin(tokio_stream::StreamExt::map(stream, |chunk| { + chunk.map(|value| mark_json(value, "worker_plugin_llm_stream_execution")) + })); + Ok(mapped) + }, + ); + + Ok(()) + } +} + +async fn emit_runtime_events(runtime: nemo_relay_worker::PluginRuntime) -> nemo_relay_worker::Result<()> { + runtime + .emit_mark("fixture.worker.mark", Some(json!("current")), None) + .await?; + let scope = runtime + .push_scope( + None, + "fixture.worker.scope", + ScopeType::Custom, + None, + None, + Some(json!("current-scope-input")), + ) + .await?; + runtime + .pop_scope(&scope, Some(json!("current-scope-output")), None) + .await?; + + let isolated = runtime.create_scope_stack().await?; + let isolated_scope = runtime + .push_scope( + Some(&isolated), + "fixture.worker.isolated.scope", + ScopeType::Custom, + None, + None, + Some(json!("isolated-input")), + ) + .await?; + let isolated_runtime = runtime.clone(); + runtime + .with_scope_stack(&isolated, || async move { + isolated_runtime + .emit_mark("fixture.worker.isolated.mark", Some(json!("isolated")), None) + .await + }) + .await?; + runtime + .pop_scope(&isolated_scope, Some(json!("isolated-output")), None) + .await?; + runtime.drop_scope_stack(&isolated).await +} + +fn mark_llm_request(mut request: LlmRequest, key: &str) -> LlmRequest { + request.content = mark_json(request.content, key); + request +} + +fn mark_json(mut value: Json, key: &str) -> Json { + if let Json::Object(object) = &mut value { + object.insert(key.into(), json!(true)); + } + value +} + +#[tokio::main] +async fn main() { + if let Err(error) = serve_plugin(FixtureWorkerPlugin).await { + eprintln!("fixture worker failed: {error}"); + std::process::exit(1); + } +} diff --git a/crates/core/tests/integration/worker_plugin_tests.rs b/crates/core/tests/integration/worker_plugin_tests.rs new file mode 100644 index 00000000..6b0fa173 --- /dev/null +++ b/crates/core/tests/integration/worker_plugin_tests.rs @@ -0,0 +1,1041 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Integration coverage for gRPC worker dynamic plugins. + +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::sync::{Arc, Mutex, OnceLock}; + +use futures::StreamExt; +use nemo_relay::api::event::{Event, ScopeCategory}; +use nemo_relay::api::llm::{ + LlmCallExecuteParams, LlmRequest, LlmStreamCallExecuteParams, llm_call_execute, + llm_stream_call_execute, +}; +use nemo_relay::api::runtime::{TASK_SCOPE_STACK, create_scope_stack}; +use nemo_relay::api::scope::{PopScopeParams, PushScopeParams, ScopeType, pop_scope, push_scope}; +use nemo_relay::api::subscriber::{flush_subscribers, register_subscriber}; +use nemo_relay::api::tool::{ToolCallExecuteParams, tool_call_execute, tool_request_intercepts}; +use nemo_relay::codec::request::AnnotatedLlmRequest; +use nemo_relay::codec::traits::LlmCodec; +use nemo_relay::error::Result as FlowResult; +use nemo_relay::plugin::dynamic::{ + WorkerPluginActivation, WorkerPluginLoadSpec, load_worker_plugins, +}; +use nemo_relay::plugin::{ + PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins_exact, +}; +use serde_json::{Map, Value as Json, json}; +use tempfile::TempDir; +use uuid::Uuid; + +static WORKER_PLUGIN_TEST_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(()); + +#[test] +fn worker_activation_with_no_specs_is_empty() { + let activation = load_worker_plugins(Vec::::new()) + .expect("empty worker activation should succeed"); + assert!(activation.is_empty()); + activation.clear(); +} + +#[tokio::test] +async fn rust_worker_registers_and_invokes_all_current_surfaces() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let loaded = load_and_initialize_fixture(Map::new()).await; + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured = events.clone(); + register_subscriber( + "worker_plugin_fixture_events", + Arc::new(move |event| { + captured.lock().unwrap().push(event.clone()); + }), + ) + .expect("test subscriber should register"); + + let stack = create_scope_stack(); + let (outer_uuid, rewritten, tool_result) = TASK_SCOPE_STACK + .scope(stack, async { + let outer = push_scope( + PushScopeParams::builder() + .name("worker-plugin-test-outer") + .scope_type(ScopeType::Agent) + .build(), + ) + .expect("outer scope should push"); + let outer_uuid = outer.uuid; + let rewritten = tool_request_intercepts("demo_tool", json!({ "input": "value" })) + .expect("worker request intercept should run"); + let tool_result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("worker-fixture-tool") + .args(json!({ "input": "execute" })) + .func(Arc::new(|args| { + Box::pin(async move { Ok(json!({ "tool_callback": true, "args": args })) }) + })) + .build(), + ) + .await + .expect("worker tool middleware should run"); + pop_scope(PopScopeParams::builder().handle_uuid(&outer.uuid).build()) + .expect("outer scope should pop"); + (outer_uuid, rewritten, tool_result) + }) + .await; + + assert_eq!(rewritten["worker_plugin"], true); + assert_eq!(tool_result["tool_callback"], true); + assert_eq!(tool_result["worker_plugin_tool_execution"], true); + assert_eq!( + tool_result["args"]["worker_plugin_tool_execution_request"], + true + ); + + flush_subscribers().expect("worker fixture events should flush"); + let captured_events = events.lock().unwrap().clone(); + assert_parent( + &captured_events, + "fixture.worker.mark", + None, + Some(outer_uuid), + ); + assert_parent( + &captured_events, + "fixture.worker.scope", + Some(ScopeCategory::Start), + Some(outer_uuid), + ); + assert_not_parent( + &captured_events, + "fixture.worker.isolated.scope", + Some(ScopeCategory::Start), + outer_uuid, + ); + let isolated_scope = find_event( + &captured_events, + "fixture.worker.isolated.scope", + Some(ScopeCategory::Start), + ); + let isolated_mark = find_event(&captured_events, "fixture.worker.isolated.mark", None); + assert_eq!(isolated_mark.parent_uuid(), Some(isolated_scope.uuid())); + assert_ne!( + isolated_mark.parent_uuid(), + Some(outer_uuid), + "worker isolated mark should use the plugin-selected isolated stack" + ); + let tool_start = find_event( + &captured_events, + "worker-fixture-tool", + Some(ScopeCategory::Start), + ); + assert_eq!( + tool_start.input().unwrap()["worker_plugin_tool_sanitize_request"], + true + ); + let tool_end = find_event( + &captured_events, + "worker-fixture-tool", + Some(ScopeCategory::End), + ); + assert_eq!( + tool_end.output().unwrap()["worker_plugin_tool_sanitize_response"], + true + ); + + let llm_execute_response = llm_call_execute( + LlmCallExecuteParams::builder() + .name("worker-fixture-llm-execute") + .request(LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "managed" }), + }) + .func(Arc::new(|request| { + Box::pin(async move { + Ok(json!({ + "id": "managed-response", + "request": request.content, + "llm_callback": true + })) + }) + })) + .build(), + ) + .await + .expect("worker LLM middleware should run"); + assert_eq!(llm_execute_response["llm_callback"], true); + assert_eq!(llm_execute_response["worker_plugin_llm_execution"], true); + assert_eq!( + llm_execute_response["request"]["worker_plugin_llm_execution_request"], + true + ); + flush_subscribers().expect("worker fixture LLM events should flush"); + let captured_events = events.lock().unwrap().clone(); + find_event(&captured_events, "fixture.worker.subscriber.mark", None); + let llm_start = find_event( + &captured_events, + "worker-fixture-llm-execute", + Some(ScopeCategory::Start), + ); + assert_eq!( + llm_start.input().unwrap()["content"]["worker_plugin_llm_sanitize_request"], + true + ); + let llm_end = find_event( + &captured_events, + "worker-fixture-llm-execute", + Some(ScopeCategory::End), + ); + assert_eq!( + llm_end.output().unwrap()["worker_plugin_llm_sanitize_response"], + true + ); + + let stream_values = llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("worker-fixture-llm-stream") + .request(LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "stream" }), + }) + .func(Arc::new(|request| { + Box::pin(async move { + let first = json!({ + "chunk": 1, + "request": request.content, + }); + Ok(Box::pin(tokio_stream::iter(vec![Ok(first)])) as _) + }) + })) + .collector(Box::new(|_chunk| Ok(()))) + .finalizer(Box::new(|| json!({ "done": true }))) + .build(), + ) + .await + .expect("worker stream middleware should start") + .collect::>() + .await; + let stream_value = stream_values + .into_iter() + .next() + .expect("one stream chunk should be returned") + .expect("stream chunk should succeed"); + assert_eq!(stream_value["worker_plugin_llm_stream_execution"], true); + assert_eq!( + stream_value["request"]["worker_plugin_llm_stream_execution_request"], + true + ); + + loaded.clear(); +} + +#[tokio::test] +async fn worker_request_intercept_callback_error_surfaces_to_host() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let loaded = + load_and_initialize_fixture(Map::from_iter([("tool_request_error".into(), json!(true))])) + .await; + + let error = tool_request_intercepts("demo_tool", json!({ "input": "value" })) + .expect_err("worker callback error should surface"); + assert!( + error + .to_string() + .contains("fixture tool request error requested"), + "{error}" + ); + + loaded.clear(); +} + +#[tokio::test] +async fn worker_conditional_guardrail_blocks_tool_execution() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let loaded = + load_and_initialize_fixture(Map::from_iter([("block_tool".into(), json!(true))])).await; + + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("worker-fixture-blocked-tool") + .args(json!({ "input": "blocked" })) + .func(Arc::new(|_| { + Box::pin(async move { Ok(json!({ "should_not_run": true })) }) + })) + .build(), + ) + .await + .expect_err("worker guardrail should block tool execution"); + assert!( + error.to_string().contains("fixture tool blocked"), + "{error}" + ); + + loaded.clear(); +} + +#[tokio::test] +async fn worker_llm_request_intercept_round_trips_annotations() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let loaded = load_and_initialize_fixture(Map::new()).await; + + let response = llm_call_execute( + LlmCallExecuteParams::builder() + .name("worker-fixture-llm-annotated") + .request(LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "annotated" }), + }) + .codec(Arc::new(FixtureCodec)) + .func(Arc::new(|request| { + Box::pin(async move { + Ok(json!({ + "request": request.content, + "llm_callback": true + })) + }) + })) + .build(), + ) + .await + .expect("worker LLM request intercept should preserve annotations"); + assert_eq!(response["llm_callback"], true); + assert_eq!(response["request"]["worker_plugin_annotated_request"], true); + + loaded.clear(); +} + +#[tokio::test] +async fn worker_llm_request_intercept_callback_error_surfaces_to_host() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let loaded = + load_and_initialize_fixture(Map::from_iter([("llm_request_error".into(), json!(true))])) + .await; + + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("worker-fixture-llm-error") + .request(LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "error" }), + }) + .func(Arc::new(|request| { + Box::pin(async move { + Ok(json!({ + "request": request.content, + "should_not_complete": true + })) + }) + })) + .build(), + ) + .await + .expect_err("worker LLM request intercept error should surface"); + assert!( + error + .to_string() + .contains("fixture LLM request error requested"), + "{error}" + ); + + loaded.clear(); +} + +#[tokio::test] +async fn worker_llm_stream_open_error_surfaces_to_host() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let loaded = load_and_initialize_fixture(Map::from_iter([( + "llm_stream_open_error".into(), + json!(true), + )])) + .await; + + let mut stream = llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("worker-fixture-llm-stream-error") + .request(LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "stream-error" }), + }) + .func(Arc::new(|request| { + Box::pin(async move { + let chunk = json!({ "request": request.content }); + Ok(Box::pin(tokio_stream::iter(vec![Ok(chunk)])) as _) + }) + })) + .collector(Box::new(|_chunk| Ok(()))) + .finalizer(Box::new(|| json!({ "done": true }))) + .build(), + ) + .await + .expect("worker stream invoke should return a host stream"); + let error = stream + .next() + .await + .expect("stream should yield the worker error") + .expect_err("worker stream callback error should surface"); + assert!( + error + .to_string() + .contains("fixture LLM stream open error requested"), + "{error}" + ); + + loaded.clear(); +} + +#[tokio::test] +async fn worker_validation_diagnostics_prevent_initialization() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let fixture = build_fixture_worker(); + let (_manifest_dir, manifest_ref) = write_manifest(fixture.binary_path()); + let config = Map::from_iter([("reject".into(), json!(true))]); + + let activation = load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: config.clone(), + }]) + .expect("worker plugin should load with validation diagnostics"); + + let mut plugin_config = PluginConfig::default(); + plugin_config.components.push(PluginComponentSpec { + kind: "fixture_worker".into(), + enabled: true, + config, + }); + let error = initialize_plugins_exact(plugin_config) + .await + .expect_err("validation diagnostics should prevent initialization") + .to_string(); + assert!(error.contains("fixture rejection requested"), "{error}"); + + clear_plugin_configuration().expect("worker plugin config should clear"); + activation.clear(); +} + +#[tokio::test] +async fn worker_duplicate_component_rejected_for_single_instance_plugin() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let fixture = build_fixture_worker(); + let (_manifest_dir, manifest_ref) = write_manifest(fixture.binary_path()); + + let activation = load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) + .expect("worker plugin should load"); + + let mut plugin_config = PluginConfig::default(); + plugin_config.components.push(PluginComponentSpec { + kind: "fixture_worker".into(), + enabled: true, + config: Map::new(), + }); + plugin_config.components.push(PluginComponentSpec { + kind: "fixture_worker".into(), + enabled: true, + config: Map::new(), + }); + let error = initialize_plugins_exact(plugin_config) + .await + .expect_err("single-instance worker plugin should reject duplicate components") + .to_string(); + assert!(error.contains("may only appear once"), "{error}"); + + clear_plugin_configuration().expect("worker plugin config should clear"); + activation.clear(); +} + +#[tokio::test] +async fn worker_config_mismatch_prevents_initialization() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let fixture = build_fixture_worker(); + let (_manifest_dir, manifest_ref) = write_manifest(fixture.binary_path()); + + let activation = load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) + .expect("worker plugin should load"); + + let mut plugin_config = PluginConfig::default(); + plugin_config.components.push(PluginComponentSpec { + kind: "fixture_worker".into(), + enabled: true, + config: Map::from_iter([("changed".into(), json!(true))]), + }); + let error = initialize_plugins_exact(plugin_config) + .await + .expect_err("config drift should prevent initialization") + .to_string(); + assert!(error.contains("config changed"), "{error}"); + + clear_plugin_configuration().expect("worker plugin config should clear"); + activation.clear(); +} + +#[tokio::test] +async fn worker_registration_error_fails_activation() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let fixture = build_fixture_worker(); + let (_manifest_dir, manifest_ref) = write_manifest(fixture.binary_path()); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::from_iter([("register_error".into(), json!(true))]), + }]) { + Ok(activation) => { + activation.clear(); + panic!("worker registration error should fail activation"); + } + Err(error) => error.to_string(), + }; + assert!( + error.contains("fixture registration error requested"), + "{error}" + ); +} + +#[tokio::test] +async fn worker_invalid_registration_plan_fails_activation() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let fixture = build_fixture_worker(); + let (_manifest_dir, manifest_ref) = write_manifest(fixture.binary_path()); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::from_iter([("empty_registration_name".into(), json!(true))]), + }]) { + Ok(activation) => { + activation.clear(); + panic!("empty registration name should fail activation"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("empty local_name"), "{error}"); +} + +#[tokio::test] +async fn worker_handshake_plugin_id_mismatch_reports_config_error() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let _env = EnvVarGuard::set("FIXTURE_WORKER_PLUGIN_ID", "other_worker"); + let fixture = build_fixture_worker(); + let (_manifest_dir, manifest_ref) = write_manifest(fixture.binary_path()); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) { + Ok(activation) => { + activation.clear(); + panic!("worker handshake id mismatch should fail activation"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("returned id 'other_worker'"), "{error}"); +} + +#[tokio::test] +async fn worker_validation_rpc_failure_reports_activation_error() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let fixture = build_fixture_worker(); + let (_manifest_dir, manifest_ref) = write_manifest(fixture.binary_path()); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::from_iter([("exit_in_validate".into(), json!(true))]), + }]) { + Ok(activation) => { + activation.clear(); + panic!("worker validation process exit should fail activation"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("worker validation RPC failed"), "{error}"); +} + +#[tokio::test] +async fn worker_registration_rpc_failure_reports_activation_error() { + let _guard = WORKER_PLUGIN_TEST_LOCK.lock().await; + let fixture = build_fixture_worker(); + let (_manifest_dir, manifest_ref) = write_manifest(fixture.binary_path()); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::from_iter([("exit_in_register".into(), json!(true))]), + }]) { + Ok(activation) => { + activation.clear(); + panic!("worker registration process exit should fail activation"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("worker registration RPC failed"), "{error}"); +} + +#[test] +fn missing_worker_executable_reports_startup_error() { + let _guard = WORKER_PLUGIN_TEST_LOCK.blocking_lock(); + let missing_binary = std::env::temp_dir().join(format!("missing-worker-{}", Uuid::now_v7())); + let (_manifest_dir, manifest_ref) = write_manifest(&missing_binary); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) { + Ok(activation) => { + activation.clear(); + panic!("missing worker executable should fail activation"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("failed to spawn"), "{error}"); +} + +#[test] +fn worker_manifest_id_mismatch_reports_config_error() { + let _guard = WORKER_PLUGIN_TEST_LOCK.blocking_lock(); + let missing_binary = std::env::temp_dir().join(format!("unused-worker-{}", Uuid::now_v7())); + let (_manifest_dir, manifest_ref) = write_manifest(&missing_binary); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "different_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) { + Ok(activation) => { + activation.clear(); + panic!("manifest id mismatch should fail"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("does not match expected id"), "{error}"); +} + +#[test] +fn worker_manifest_kind_mismatch_reports_config_error() { + let _guard = WORKER_PLUGIN_TEST_LOCK.blocking_lock(); + let relay = supported_relay_requirement(); + let (_manifest_dir, manifest_ref) = write_manifest_text(&format!( + r#" +manifest_version = 1 + +[plugin] +id = "fixture_worker" +kind = "rust_dynamic" + +[compat] +relay = {relay} +native_api = "1" + +[defaults] +enabled = false + +[capabilities] +items = ["plugin_native"] + +[load] +library = "missing" +symbol = "nemo_relay_plugin_entry" +"#, + relay = toml_string(&relay) + )); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) { + Ok(activation) => { + activation.clear(); + panic!("manifest kind mismatch should fail"); + } + Err(error) => error.to_string(), + }; + assert!( + error.contains("worker loader only supports worker"), + "{error}" + ); +} + +#[test] +fn unsupported_worker_relay_requirement_reports_compatibility_error() { + let _guard = WORKER_PLUGIN_TEST_LOCK.blocking_lock(); + let missing_binary = std::env::temp_dir().join(format!("unused-worker-{}", Uuid::now_v7())); + let (_manifest_dir, manifest_ref) = + write_manifest_with_relay(&missing_binary, ">=9999.0,<10000.0"); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) { + Ok(activation) => { + activation.clear(); + panic!("unsupported relay requirement should fail"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("requires relay"), "{error}"); +} + +#[test] +fn invalid_worker_relay_requirement_reports_parse_error() { + let _guard = WORKER_PLUGIN_TEST_LOCK.blocking_lock(); + let missing_binary = std::env::temp_dir().join(format!("unused-worker-{}", Uuid::now_v7())); + let (_manifest_dir, manifest_ref) = write_manifest_with_relay(&missing_binary, "not semver"); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) { + Ok(activation) => { + activation.clear(); + panic!("invalid relay requirement should fail"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("invalid compat.relay"), "{error}"); +} + +#[test] +fn command_worker_entrypoint_is_resolved_relative_to_manifest() { + let _guard = WORKER_PLUGIN_TEST_LOCK.blocking_lock(); + let relay = supported_relay_requirement(); + let (manifest_dir, manifest_ref) = + write_worker_manifest("fixture_worker", &relay, "command", "missing-worker"); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) { + Ok(activation) => { + activation.clear(); + panic!("missing relative command worker should fail"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("failed to spawn command worker"), "{error}"); + assert_error_mentions_manifest_relative_entrypoint( + &error, + manifest_dir.path(), + "missing-worker", + ); +} + +#[test] +fn python_worker_uses_configured_interpreter() { + let _guard = WORKER_PLUGIN_TEST_LOCK.blocking_lock(); + let missing_python = std::env::temp_dir().join(format!("missing-python-{}", Uuid::now_v7())); + let _env = EnvVarGuard::set( + "NEMO_RELAY_PYTHON", + missing_python.to_string_lossy().as_ref(), + ); + let relay = supported_relay_requirement(); + let (_manifest_dir, manifest_ref) = write_worker_manifest( + "fixture_worker", + &relay, + "python", + "fixture_worker:create_plugin", + ); + + let error = match load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: Map::new(), + }]) { + Ok(activation) => { + activation.clear(); + panic!("missing configured Python interpreter should fail"); + } + Err(error) => error.to_string(), + }; + assert!(error.contains("failed to spawn python worker"), "{error}"); +} + +struct FixtureCodec; + +impl LlmCodec for FixtureCodec { + fn decode(&self, request: &LlmRequest) -> FlowResult { + Ok(AnnotatedLlmRequest { + messages: Vec::new(), + model: Some("fixture-model".into()), + params: None, + tools: None, + tool_choice: None, + store: None, + previous_response_id: None, + truncation: None, + reasoning: None, + include: None, + user: None, + metadata: None, + service_tier: None, + parallel_tool_calls: None, + max_output_tokens: None, + max_tool_calls: None, + top_logprobs: None, + stream: None, + extra: request.content.as_object().cloned().unwrap_or_default(), + }) + } + + fn encode( + &self, + annotated: &AnnotatedLlmRequest, + original: &LlmRequest, + ) -> FlowResult { + Ok(LlmRequest { + headers: original.headers.clone(), + content: Json::Object(annotated.extra.clone()), + }) + } +} + +struct LoadedWorker { + activation: Option, + _manifest_dir: TempDir, +} + +impl LoadedWorker { + fn clear(mut self) { + clear_plugin_configuration().expect("worker plugin config should clear"); + if let Some(activation) = self.activation.take() { + activation.clear(); + } + } +} + +impl Drop for LoadedWorker { + fn drop(&mut self) { + let _ = clear_plugin_configuration(); + if let Some(activation) = self.activation.take() { + activation.clear(); + } + } +} + +async fn load_and_initialize_fixture(config: Map) -> LoadedWorker { + let fixture = build_fixture_worker(); + let (manifest_dir, manifest_ref) = write_manifest(fixture.binary_path()); + + let activation = load_worker_plugins([WorkerPluginLoadSpec { + plugin_id: "fixture_worker".into(), + manifest_ref: manifest_ref.to_string_lossy().into_owned(), + config: config.clone(), + }]) + .expect("worker plugin should load"); + + let mut plugin_config = PluginConfig::default(); + plugin_config.components.push(PluginComponentSpec { + kind: "fixture_worker".into(), + enabled: true, + config, + }); + initialize_plugins_exact(plugin_config) + .await + .expect("worker plugin should initialize"); + + LoadedWorker { + activation: Some(activation), + _manifest_dir: manifest_dir, + } +} + +struct BuiltWorkerFixture { + binary_path: PathBuf, +} + +impl BuiltWorkerFixture { + fn binary_path(&self) -> &Path { + &self.binary_path + } +} + +fn build_fixture_worker() -> BuiltWorkerFixture { + static FIXTURE_BINARY: OnceLock = OnceLock::new(); + let binary_path = FIXTURE_BINARY.get_or_init(|| { + let fixture_dir = fixture_root(); + let target_root = + Path::new(env!("CARGO_MANIFEST_DIR")).join("../../target/worker-plugin-fixture"); + let target_dir = target_root.join("target"); + let manifest = fixture_dir.join("Cargo.toml"); + let status = Command::new("cargo") + .arg("build") + .arg("--quiet") + .arg("--locked") + .arg("--manifest-path") + .arg(&manifest) + .arg("--target-dir") + .arg(&target_dir) + .status() + .expect("fixture worker build should start"); + assert!(status.success(), "fixture worker build should succeed"); + let binary_path = target_dir.join("debug").join(format!( + "nemo-relay-worker-plugin-fixture{}", + std::env::consts::EXE_SUFFIX + )); + assert!(binary_path.exists(), "fixture worker binary should exist"); + binary_path + }); + BuiltWorkerFixture { + binary_path: binary_path.clone(), + } +} + +fn write_manifest(binary: &Path) -> (TempDir, PathBuf) { + let relay = supported_relay_requirement(); + write_manifest_with_relay(binary, &relay) +} + +fn write_manifest_with_relay(binary: &Path, relay: &str) -> (TempDir, PathBuf) { + write_worker_manifest( + "fixture_worker", + relay, + "rust", + binary.to_string_lossy().as_ref(), + ) +} + +fn write_worker_manifest( + plugin_id: &str, + relay: &str, + runtime: &str, + entrypoint: &str, +) -> (TempDir, PathBuf) { + write_manifest_text(&format!( + r#" +manifest_version = 1 + +[plugin] +id = {plugin_id} +kind = "worker" + +[compat] +relay = {relay} +worker_protocol = "grpc-v1" + +[defaults] +enabled = false + +[capabilities] +items = ["plugin_worker"] + +[load] +runtime = {runtime} +entrypoint = {entrypoint} +"#, + plugin_id = toml_string(plugin_id), + relay = toml_string(relay), + runtime = toml_string(runtime), + entrypoint = toml_string(entrypoint) + )) +} + +fn write_manifest_text(contents: &str) -> (TempDir, PathBuf) { + let temp = TempDir::new().expect("manifest tempdir should be created"); + let manifest = temp.path().join("relay-plugin.toml"); + std::fs::write(&manifest, contents).expect("manifest should be written"); + (temp, manifest) +} + +fn toml_string(value: &str) -> String { + format!("{value:?}") +} + +fn supported_relay_requirement() -> String { + format!("={}", env!("CARGO_PKG_VERSION")) +} + +fn assert_error_mentions_manifest_relative_entrypoint( + error: &str, + manifest_dir: &Path, + entrypoint: &str, +) { + let manifest_dir_name = manifest_dir + .file_name() + .expect("manifest dir should have a leaf name") + .to_string_lossy(); + let manifest_dir_pos = error.find(manifest_dir_name.as_ref()).unwrap_or_else(|| { + panic!("error did not mention manifest dir '{manifest_dir_name}': {error}") + }); + assert!( + error[manifest_dir_pos + manifest_dir_name.len()..].contains(entrypoint), + "error did not mention entrypoint '{entrypoint}' after manifest dir '{manifest_dir_name}': {error}" + ); +} + +struct EnvVarGuard { + key: &'static str, + previous: Option, +} + +impl EnvVarGuard { + fn set(key: &'static str, value: &str) -> Self { + let previous = std::env::var(key).ok(); + // SAFETY: this module serializes worker tests with WORKER_PLUGIN_TEST_LOCK. + unsafe { + std::env::set_var(key, value); + } + Self { key, previous } + } +} + +impl Drop for EnvVarGuard { + fn drop(&mut self) { + // SAFETY: this module serializes worker tests with WORKER_PLUGIN_TEST_LOCK. + unsafe { + if let Some(previous) = &self.previous { + std::env::set_var(self.key, previous); + } else { + std::env::remove_var(self.key); + } + } + } +} + +fn fixture_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/worker_plugin") +} + +fn find_event<'a>( + events: &'a [Event], + name: &str, + scope_category: Option, +) -> &'a Event { + events + .iter() + .find(|event| event.name() == name && event.scope_category() == scope_category) + .unwrap_or_else(|| panic!("event {name:?} with category {scope_category:?} not found")) +} + +fn assert_parent( + events: &[Event], + name: &str, + scope_category: Option, + expected_parent: Option, +) { + let event = find_event(events, name, scope_category); + assert_eq!(event.parent_uuid(), expected_parent); +} + +fn assert_not_parent( + events: &[Event], + name: &str, + scope_category: Option, + excluded_parent: Uuid, +) { + let event = find_event(events, name, scope_category); + assert_ne!(event.parent_uuid(), Some(excluded_parent)); +} diff --git a/crates/core/tests/unit/dynamic_worker_tests.rs b/crates/core/tests/unit/dynamic_worker_tests.rs new file mode 100644 index 00000000..aa77c043 --- /dev/null +++ b/crates/core/tests/unit/dynamic_worker_tests.rs @@ -0,0 +1,1235 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::sync::{Arc, Mutex}; + +use crate::api::event::{BaseEvent, MarkEvent}; +use nemo_relay_worker_proto::json_envelope; +use nemo_relay_worker_proto::v1::invoke_response::Result as InvokeResult; +use nemo_relay_worker_proto::v1::plugin_worker_server::{PluginWorker, PluginWorkerServer}; +use nemo_relay_worker_proto::v1::stream_chunk::Item as StreamItem; +use nemo_relay_worker_proto::v1::{ + CancelInvocationRequest, CreateScopeStackRequest, DropScopeStackRequest, EmitMarkRequest, + EmptyResult, GuardrailResult, HandshakeRequest, HandshakeResponse, HealthRequest, + HealthResponse, JsonEnvelope, JsonResult, LlmNextRequest, LlmRequestInterceptResult, + LlmStreamNextRequest, PopScopeRequest, PushScopeRequest, Registration, ScopeContext, + ScopeType as ProtoScopeType, ShutdownRequest, StreamChunk, ToolNextRequest, ValidateRequest, + ValidateResponse, WorkerAck, +}; +use serde_json::json; +use tokio_stream::StreamExt; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::Request; +use tonic::transport::Server; + +use super::*; + +const ACTIVATION_ID: &str = "activation-test"; +const AUTH_TOKEN: &str = "auth-test"; + +#[test] +fn response_helpers_cover_error_and_unexpected_shapes() { + let worker_error = WorkerError { + code: "worker.failed".into(), + message: "boom".into(), + retryable: false, + }; + + let error = json_from_invoke_response(InvokeResponse { + result: Some(InvokeResult::Json(JsonResult { + value: None, + error: Some(worker_error.clone()), + })), + }) + .expect_err("json result worker error should surface"); + assert!(error.to_string().contains("worker.failed: boom")); + + let error = json_from_invoke_response(InvokeResponse { + result: Some(InvokeResult::Error(worker_error.clone())), + }) + .expect_err("top-level worker error should surface"); + assert!(error.to_string().contains("worker.failed: boom")); + + let error = json_from_invoke_response(InvokeResponse { + result: Some(InvokeResult::Empty(EmptyResult {})), + }) + .expect_err("unexpected JSON result shape should fail"); + assert!(error.to_string().contains("unexpected invoke result")); + + let error = json_from_invoke_response(InvokeResponse { + result: Some(InvokeResult::Json(JsonResult { + value: Some(JsonEnvelope { + schema: JSON_SCHEMA.into(), + json: b"{".to_vec(), + }), + error: None, + })), + }) + .expect_err("invalid JSON envelope should fail"); + assert!(error.to_string().contains("invalid JSON result")); + + assert_eq!( + guardrail_from_invoke_response(InvokeResponse { + result: Some(InvokeResult::Guardrail(GuardrailResult { + block_reason: String::new(), + })), + }) + .expect("empty block reason is allowed"), + None + ); + assert_eq!( + guardrail_from_invoke_response(InvokeResponse { + result: Some(InvokeResult::Guardrail(GuardrailResult { + block_reason: "blocked".into(), + })), + }) + .expect("block reason should parse"), + Some("blocked".into()) + ); + assert!( + guardrail_from_invoke_response(InvokeResponse { + result: Some(InvokeResult::Error(worker_error.clone())), + }) + .expect_err("guardrail worker error should surface") + .to_string() + .contains("worker.failed") + ); + assert!( + guardrail_from_invoke_response(InvokeResponse { + result: Some(InvokeResult::Empty(EmptyResult {})), + }) + .expect_err("unexpected guardrail shape should fail") + .to_string() + .contains("guardrail returned unexpected") + ); + + assert!( + json_from_stream_chunk(StreamChunk { + item: Some(StreamItem::Error(worker_error.clone())), + }) + .expect_err("stream worker error should surface") + .to_string() + .contains("worker.failed") + ); + assert!( + json_from_stream_chunk(StreamChunk { + item: Some(StreamItem::Value(JsonEnvelope { + schema: JSON_SCHEMA.into(), + json: b"{".to_vec(), + })), + }) + .expect_err("invalid stream JSON envelope should fail") + .to_string() + .contains("invalid worker stream chunk") + ); + assert!( + json_from_stream_chunk(StreamChunk { item: None }) + .expect_err("empty stream chunk should fail") + .to_string() + .contains("stream chunk was empty") + ); +} + +#[test] +fn envelope_and_error_helpers_cover_failure_paths() { + assert!( + required_envelope(None, "required test") + .expect_err("missing envelope should fail") + .to_string() + .contains("required test is missing") + ); + assert!( + optional_envelope_to_json(Some(JsonEnvelope { + schema: JSON_SCHEMA.into(), + json: b"not-json".to_vec(), + })) + .expect_err("invalid optional envelope should fail") + .to_string() + .contains("invalid JSON envelope") + ); + + let ack = host_ack(Err(FlowError::Internal("host failed".into()))); + assert!(!ack.ok); + assert_eq!(ack.error.expect("host error").code, "host.runtime_error"); + + let result = json_result(Err(FlowError::Internal("json failed".into()))); + assert!(result.value.is_none()); + assert_eq!(result.error.expect("json error").code, "host.runtime_error"); + + let fallback = worker_error_to_plugin( + WorkerError { + code: "worker.empty".into(), + message: String::new(), + retryable: false, + }, + "fallback message", + ); + assert!(fallback.to_string().contains("fallback message")); + + let status = status_from_flow(FlowError::Internal("status failed".into())); + assert_eq!(status.code(), tonic::Code::Internal); + assert!(status.message().contains("status failed")); +} + +#[test] +fn registration_plan_and_scope_type_helpers_validate_edges() { + let empty_name = validate_registration_plan( + "fixture_worker", + &RegisterResponse { + registrations: vec![Registration { + local_name: " ".into(), + surface: RegistrationSurface::Subscriber as i32, + priority: 0, + break_chain: false, + }], + error: None, + }, + ) + .expect_err("empty registration names should fail"); + assert!(empty_name.to_string().contains("empty local_name")); + + let unsupported = validate_registration_plan( + "fixture_worker", + &RegisterResponse { + registrations: vec![Registration { + local_name: "bad".into(), + surface: 999, + priority: 0, + break_chain: false, + }], + error: None, + }, + ) + .expect_err("unsupported registration surfaces should fail"); + assert!( + unsupported + .to_string() + .contains("unsupported registration surface") + ); + + let unspecified = validate_registration_plan( + "fixture_worker", + &RegisterResponse { + registrations: vec![Registration { + local_name: "bad".into(), + surface: RegistrationSurface::Unspecified as i32, + priority: 0, + break_chain: false, + }], + error: None, + }, + ) + .expect_err("unspecified registration surfaces should fail"); + assert!( + unspecified + .to_string() + .contains("unspecified registration surface") + ); + + let cases = [ + (ProtoScopeType::Agent, crate::api::scope::ScopeType::Agent), + ( + ProtoScopeType::Function, + crate::api::scope::ScopeType::Function, + ), + (ProtoScopeType::Tool, crate::api::scope::ScopeType::Tool), + (ProtoScopeType::Llm, crate::api::scope::ScopeType::Llm), + ( + ProtoScopeType::Retriever, + crate::api::scope::ScopeType::Retriever, + ), + ( + ProtoScopeType::Embedder, + crate::api::scope::ScopeType::Embedder, + ), + ( + ProtoScopeType::Reranker, + crate::api::scope::ScopeType::Reranker, + ), + ( + ProtoScopeType::Guardrail, + crate::api::scope::ScopeType::Guardrail, + ), + ( + ProtoScopeType::Evaluator, + crate::api::scope::ScopeType::Evaluator, + ), + (ProtoScopeType::Custom, crate::api::scope::ScopeType::Custom), + ( + ProtoScopeType::Unknown, + crate::api::scope::ScopeType::Unknown, + ), + ]; + for (proto, expected) in cases { + assert_eq!(proto_scope_type(proto as i32), expected); + } + assert_eq!(proto_scope_type(999), crate::api::scope::ScopeType::Custom); +} + +#[test] +fn relay_compatibility_and_blocking_helpers_cover_local_edges() { + assert!( + validate_relay_compatibility(None) + .expect_err("missing relay compatibility should fail") + .to_string() + .contains("compat.relay is required") + ); + assert!( + validate_relay_compatibility(Some("not semver")) + .expect_err("invalid relay compatibility should fail") + .to_string() + .contains("invalid compat.relay") + ); + + let runtime = RuntimeBuilder::new_current_thread() + .enable_all() + .build() + .expect("runtime should build"); + assert_eq!(block_on_runtime(&runtime, async { 42 }), 42); +} + +#[test] +#[cfg(unix)] +fn worker_endpoints_fail_when_host_socket_cannot_bind() { + let activation_dir = std::env::temp_dir().join(format!("nmrw-unit-{}", Uuid::now_v7())); + let host_socket = activation_dir.join("host.sock"); + std::fs::create_dir_all(&host_socket).expect("host socket directory should be created"); + + let error = match WorkerEndpoints::new(&activation_dir) { + Ok(_) => panic!("endpoint creation should fail when host socket path is a directory"), + Err(error) => error, + }; + assert!( + error + .to_string() + .contains("failed to bind worker host runtime socket") + ); + + let _ = std::fs::remove_dir_all(&activation_dir); +} + +#[tokio::test(flavor = "multi_thread")] +async fn callback_helpers_cover_worker_response_edges() { + let worker_error = WorkerError { + code: "worker.failed".into(), + message: "boom".into(), + retryable: false, + }; + let (callback, _shutdown) = fake_callback_service({ + let worker_error = worker_error.clone(); + move |request| match request.registration_name.as_str() { + "subscriber_error" => InvokeResponse { + result: Some(InvokeResult::Error(worker_error.clone())), + }, + "subscriber_unexpected" | "llm_intercept_unexpected" => InvokeResponse { + result: Some(InvokeResult::Json(JsonResult { + value: Some(json_envelope(JSON_SCHEMA, &json!({})).expect("json envelope")), + error: None, + })), + }, + "llm_json_invalid" => InvokeResponse { + result: Some(InvokeResult::Json(JsonResult { + value: Some(json_envelope(JSON_SCHEMA, &json!(null)).expect("json envelope")), + error: None, + })), + }, + "llm_intercept_invalid_request" => InvokeResponse { + result: Some(InvokeResult::LlmRequest(LlmRequestInterceptResult { + request: Some(JsonEnvelope { + schema: LLM_REQUEST_SCHEMA.into(), + json: b"null".to_vec(), + }), + annotated_request: None, + has_annotated_request: false, + })), + }, + "llm_intercept_missing_annotated" => InvokeResponse { + result: Some(InvokeResult::LlmRequest(LlmRequestInterceptResult { + request: Some(valid_llm_request_envelope()), + annotated_request: None, + has_annotated_request: true, + })), + }, + "llm_intercept_invalid_annotated" => InvokeResponse { + result: Some(InvokeResult::LlmRequest(LlmRequestInterceptResult { + request: Some(valid_llm_request_envelope()), + annotated_request: Some(JsonEnvelope { + schema: ANNOTATED_LLM_REQUEST_SCHEMA.into(), + json: b"null".to_vec(), + }), + has_annotated_request: true, + })), + }, + "llm_intercept_error" => InvokeResponse { + result: Some(InvokeResult::Error(worker_error.clone())), + }, + _ => InvokeResponse { + result: Some(InvokeResult::Empty(EmptyResult {})), + }, + } + }) + .await; + let event = Event::Mark(MarkEvent::new( + BaseEvent::builder().name("callback-edge").build(), + None, + None, + )); + + let error = callback + .invoke_subscriber("subscriber_error", &event) + .expect_err("subscriber worker error should surface"); + assert!(error.to_string().contains("worker.failed: boom")); + + let error = callback + .invoke_subscriber("subscriber_unexpected", &event) + .expect_err("unexpected subscriber result should fail"); + assert!(error.to_string().contains("subscriber returned unexpected")); + + let error = callback + .invoke_llm_request_json( + "llm_json_invalid", + RegistrationSurface::LlmSanitizeRequestGuardrail, + "model", + valid_llm_request(), + None, + None, + ) + .expect_err("invalid LLM JSON result should fail"); + assert!(error.to_string().contains("invalid type")); + + let error = callback + .invoke_llm_request_intercept( + "llm_intercept_invalid_request", + "model", + valid_llm_request(), + None, + ) + .expect_err("invalid LLM intercept request should fail"); + assert!(error.to_string().contains("invalid LLM request")); + + let error = callback + .invoke_llm_request_intercept( + "llm_intercept_missing_annotated", + "model", + valid_llm_request(), + None, + ) + .expect_err("missing annotated request should fail when flagged present"); + assert!( + error + .to_string() + .contains("llm request intercept annotated request is missing") + ); + + let error = callback + .invoke_llm_request_intercept( + "llm_intercept_invalid_annotated", + "model", + valid_llm_request(), + None, + ) + .expect_err("invalid annotated request should fail"); + assert!(error.to_string().contains("invalid annotated LLM request")); + + let error = callback + .invoke_llm_request_intercept("llm_intercept_error", "model", valid_llm_request(), None) + .expect_err("LLM intercept worker error should surface"); + assert!(error.to_string().contains("worker.failed: boom")); + + let error = callback + .invoke_llm_request_intercept( + "llm_intercept_unexpected", + "model", + valid_llm_request(), + None, + ) + .expect_err("unexpected LLM intercept result should fail"); + assert!( + error + .to_string() + .contains("LLM request intercept returned unexpected") + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn callback_stream_transport_error_surfaces_to_host_stream() { + let (callback, _shutdown) = fake_callback_service(|_| InvokeResponse { + result: Some(InvokeResult::Empty(EmptyResult {})), + }) + .await; + + let mut stream = callback + .invoke_llm_stream_execution( + "stream_transport_error", + "model", + valid_llm_request(), + Arc::new(|_request| { + Box::pin(async { Ok(Box::pin(tokio_stream::empty()) as LlmJsonStream) }) + }), + ) + .await + .expect("host stream should be returned"); + + let error = stream + .next() + .await + .expect("transport error should be yielded") + .expect_err("stream transport error should surface"); + assert!(error.to_string().contains("worker stream transport failed")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn callback_stream_stops_when_host_receiver_is_dropped() { + let (yield_tx, yield_rx) = oneshot::channel(); + let (stream_dropped_tx, stream_dropped_rx) = oneshot::channel(); + let stream_dropped_tx = Arc::new(Mutex::new(Some(stream_dropped_tx))); + let yield_rx = Arc::new(Mutex::new(Some(yield_rx))); + let (callback, _shutdown) = fake_callback_service_with_stream( + |_| InvokeResponse { + result: Some(InvokeResult::Empty(EmptyResult {})), + }, + { + let stream_dropped_tx = stream_dropped_tx.clone(); + let yield_rx = yield_rx.clone(); + move |_| { + let dropped = stream_dropped_tx + .lock() + .expect("stream drop signal lock should not be poisoned") + .take() + .expect("test stream should be created once"); + let yield_rx = yield_rx + .lock() + .expect("stream yield signal lock should not be poisoned") + .take() + .expect("test stream should be created once"); + Box::pin(SignalChunkThenPendingStream { + yield_rx, + dropped: Some(dropped), + yielded: false, + }) as FakeInvokeStream + } + }, + ) + .await; + + let stream = callback + .invoke_llm_stream_execution( + "stream_receiver_drop", + "model", + valid_llm_request(), + Arc::new(|_request| { + Box::pin(async { Ok(Box::pin(tokio_stream::empty()) as LlmJsonStream) }) + }), + ) + .await + .expect("host stream should be returned"); + drop(stream); + yield_tx + .send(()) + .expect("worker stream yield signal should be delivered"); + tokio::time::timeout(std::time::Duration::from_secs(1), stream_dropped_rx) + .await + .expect("worker stream should be dropped after host receiver is dropped") + .expect("worker stream drop signal should be delivered"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn install_registrations_covers_registry_error_edges() { + for surface in [ + RegistrationSurface::Subscriber, + RegistrationSurface::ToolSanitizeRequestGuardrail, + RegistrationSurface::ToolSanitizeResponseGuardrail, + RegistrationSurface::ToolConditionalExecutionGuardrail, + RegistrationSurface::ToolRequestIntercept, + RegistrationSurface::ToolExecutionIntercept, + RegistrationSurface::LlmSanitizeRequestGuardrail, + RegistrationSurface::LlmSanitizeResponseGuardrail, + RegistrationSurface::LlmConditionalExecutionGuardrail, + RegistrationSurface::LlmRequestIntercept, + RegistrationSurface::LlmExecutionIntercept, + RegistrationSurface::LlmStreamExecutionIntercept, + ] { + let (instance, _shutdown) = fake_worker_instance(vec![ + registration(surface, "duplicate"), + registration(surface, "duplicate"), + ]) + .await; + let mut ctx = PluginRegistrationContext::new(); + let error = instance + .install_registrations(&mut ctx) + .expect_err("duplicate worker registration should fail"); + assert!( + error.to_string().contains("duplicate") + || error.to_string().contains("already registered"), + "{surface:?}: {error}" + ); + let mut registrations = ctx.into_registrations(); + crate::plugin::rollback_registrations(&mut registrations); + } + + let (instance, _shutdown) = fake_worker_instance(vec![Registration { + surface: 999, + ..registration(RegistrationSurface::Subscriber, "bad") + }]) + .await; + let mut ctx = PluginRegistrationContext::new(); + assert!( + instance + .install_registrations(&mut ctx) + .expect_err("unsupported registration surface should fail") + .to_string() + .contains("unsupported registration surface") + ); + + let (instance, _shutdown) = + fake_worker_instance(vec![registration(RegistrationSurface::Unspecified, "bad")]).await; + let mut ctx = PluginRegistrationContext::new(); + assert!( + instance + .install_registrations(&mut ctx) + .expect_err("unspecified registration surface should fail") + .to_string() + .contains("unspecified registration surface") + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn adapter_register_rejects_config_drift_even_without_validation_call() { + let (instance, _shutdown) = fake_worker_instance(Vec::new()).await; + let adapter = WorkerPluginAdapter { + plugin_kind: "fixture_worker".into(), + allows_multiple_components: false, + instance: Arc::new(instance), + }; + let mut ctx = PluginRegistrationContext::new(); + let changed = serde_json::Map::from_iter([("changed".into(), json!(true))]); + + let error = adapter + .register(&changed, &mut ctx) + .await + .expect_err("config drift should fail registration"); + assert!(error.to_string().contains("config changed"), "{error}"); +} + +#[tokio::test] +async fn host_runtime_service_covers_auth_scope_and_ack_errors() { + let state = Arc::new(WorkerHostRuntimeState::new( + ACTIVATION_ID.into(), + AUTH_TOKEN.into(), + )); + let service = WorkerHostRuntimeService { + state: state.clone(), + }; + + let auth_error = service + .emit_mark(Request::new(EmitMarkRequest { + activation_id: "wrong".into(), + auth_token: AUTH_TOKEN.into(), + name: "auth-failure".into(), + scope: None, + data: None, + metadata: None, + })) + .await + .expect_err("bad activation id should fail auth"); + assert_eq!(auth_error.code(), tonic::Code::PermissionDenied); + + let ack = service + .emit_mark(Request::new(EmitMarkRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + name: "missing-stack".into(), + scope: Some(ScopeContext { + scope_stack_id: "missing-stack".into(), + parent_scope_id: String::new(), + }), + data: None, + metadata: None, + })) + .await + .expect("missing stack should return host ack") + .into_inner(); + assert!(!ack.ok); + assert!( + ack.error + .expect("missing stack error") + .message + .contains("not found") + ); + + let ack = service + .emit_mark(Request::new(EmitMarkRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + name: "no-scope".into(), + scope: None, + data: None, + metadata: None, + })) + .await + .expect("no-scope mark should succeed") + .into_inner(); + assert!(ack.ok); + + let push = service + .push_scope(Request::new(PushScopeRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + scope: None, + name: "invalid-json-scope".into(), + scope_type: ProtoScopeType::Custom as i32, + data: Some(JsonEnvelope { + schema: JSON_SCHEMA.into(), + json: b"not-json".to_vec(), + }), + metadata: None, + input: None, + })) + .await + .expect("invalid JSON should be structured") + .into_inner(); + assert!( + push.error + .expect("push error") + .message + .contains("invalid JSON") + ); + + let pop_error = service + .pop_scope(Request::new(PopScopeRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + scope_handle_id: "missing-scope".into(), + output: None, + metadata: None, + })) + .await + .expect_err("missing scope handle should fail"); + assert_eq!(pop_error.code(), tonic::Code::NotFound); + + let created = service + .create_scope_stack(Request::new(CreateScopeStackRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + })) + .await + .expect("scope stack should be created") + .into_inner(); + let scope_stack_id = created.scope_stack_id.clone(); + assert!( + state + .stack("") + .expect("empty stack id should be valid") + .is_none() + ); + let dropped = service + .drop_scope_stack(Request::new(DropScopeStackRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + scope_stack_id: scope_stack_id.clone(), + })) + .await + .expect("scope stack should be dropped") + .into_inner(); + assert!(dropped.ok); + assert_eq!( + state + .stack(&scope_stack_id) + .expect_err("dropped stack should be removed") + .code(), + tonic::Code::NotFound + ); + + assert_eq!( + service + .with_stack( + Some(&ScopeContext { + scope_stack_id: String::new(), + parent_scope_id: String::new(), + }), + || Ok(7), + ) + .expect("empty explicit stack id should run without binding"), + 7 + ); +} + +#[tokio::test] +async fn host_runtime_service_reports_poisoned_internal_locks() { + let state = Arc::new(WorkerHostRuntimeState::new( + ACTIVATION_ID.into(), + AUTH_TOKEN.into(), + )); + poison_mutex({ + let state = state.clone(); + move || { + let _guard = state.scope_handles.lock().expect("scope handles lock"); + panic!("poison scope handles"); + } + }); + let service = WorkerHostRuntimeService { + state: state.clone(), + }; + let push_error = service + .push_scope(Request::new(PushScopeRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + scope: None, + name: "poisoned".into(), + scope_type: ProtoScopeType::Custom as i32, + data: None, + metadata: None, + input: None, + })) + .await + .expect_err("poisoned scope handle lock should fail"); + assert_eq!(push_error.code(), tonic::Code::Internal); + + let pop_error = service + .pop_scope(Request::new(PopScopeRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + scope_handle_id: "missing".into(), + output: None, + metadata: None, + })) + .await + .expect_err("poisoned scope handle lock should fail"); + assert_eq!(pop_error.code(), tonic::Code::Internal); + + let state = Arc::new(WorkerHostRuntimeState::new( + ACTIVATION_ID.into(), + AUTH_TOKEN.into(), + )); + poison_mutex({ + let state = state.clone(); + move || { + let _guard = state.scope_stacks.lock().expect("scope stacks lock"); + panic!("poison scope stacks"); + } + }); + let service = WorkerHostRuntimeService { state }; + let create_error = service + .create_scope_stack(Request::new(CreateScopeStackRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + })) + .await + .expect_err("poisoned scope stack lock should fail"); + assert_eq!(create_error.code(), tonic::Code::Internal); + + let drop_error = service + .drop_scope_stack(Request::new(DropScopeStackRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + scope_stack_id: "stack".into(), + })) + .await + .expect_err("poisoned scope stack lock should fail"); + assert_eq!(drop_error.code(), tonic::Code::Internal); +} + +#[test] +fn owned_worker_runtime_drop_is_idempotent_when_runtime_already_taken() { + drop(OwnedWorkerRuntime { runtime: None }); +} + +#[tokio::test] +async fn host_runtime_service_covers_continuation_errors_and_stream_items() { + let state = Arc::new(WorkerHostRuntimeState::new( + ACTIVATION_ID.into(), + AUTH_TOKEN.into(), + )); + let service = WorkerHostRuntimeService { + state: state.clone(), + }; + + let llm_continuation = state + .insert_continuation(Continuation::Llm(Arc::new(|request| { + Box::pin(async move { Ok(request.content) }) + }))) + .expect("llm continuation should insert"); + let wrong_type = service + .tool_next(Request::new(ToolNextRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + continuation_id: llm_continuation, + value: Some(json_envelope(JSON_SCHEMA, &json!({})).expect("json envelope")), + })) + .await + .expect_err("wrong continuation type should fail"); + assert_eq!(wrong_type.code(), tonic::Code::InvalidArgument); + + let tool_continuation = state + .insert_continuation(Continuation::Tool(Arc::new(|value| { + Box::pin(async move { Ok(value) }) + }))) + .expect("tool continuation should insert"); + let invalid_tool_json = service + .tool_next(Request::new(ToolNextRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + continuation_id: tool_continuation, + value: Some(JsonEnvelope { + schema: JSON_SCHEMA.into(), + json: b"not-json".to_vec(), + }), + })) + .await + .expect_err("invalid tool next JSON should fail"); + assert_eq!(invalid_tool_json.code(), tonic::Code::InvalidArgument); + + let llm_continuation = state + .insert_continuation(Continuation::Llm(Arc::new(|request| { + Box::pin(async move { Ok(request.content) }) + }))) + .expect("llm continuation should insert"); + let invalid_llm_json = service + .llm_next(Request::new(LlmNextRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + continuation_id: llm_continuation, + request: Some(JsonEnvelope { + schema: LLM_REQUEST_SCHEMA.into(), + json: b"not-json".to_vec(), + }), + })) + .await + .expect_err("invalid LLM next request should fail"); + assert_eq!(invalid_llm_json.code(), tonic::Code::InvalidArgument); + + let stream_continuation = state + .insert_continuation(Continuation::LlmStream(Arc::new(|_request| { + Box::pin(async move { + Ok(Box::pin(tokio_stream::iter(vec![Err(FlowError::Internal( + "stream item failed".into(), + ))])) as LlmJsonStream) + }) + }))) + .expect("stream continuation should insert"); + let stream_response = service + .llm_stream_next(Request::new(LlmStreamNextRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + continuation_id: stream_continuation, + request: Some( + json_envelope( + LLM_REQUEST_SCHEMA, + &LlmRequest { + headers: serde_json::Map::new(), + content: json!({ "prompt": "stream" }), + }, + ) + .expect("llm request envelope"), + ), + })) + .await + .expect("stream next should return stream"); + let mut stream = stream_response.into_inner(); + let chunk = stream + .next() + .await + .expect("stream should yield one item") + .expect("transport should be ok"); + match chunk.item { + Some(StreamItem::Error(error)) => { + assert!(error.message.contains("stream item failed")); + } + other => panic!("expected worker stream error, got {other:?}"), + } + + let stream_continuation = state + .insert_continuation(Continuation::LlmStream(Arc::new(|_request| { + Box::pin(async move { Ok(Box::pin(tokio_stream::empty()) as LlmJsonStream) }) + }))) + .expect("stream continuation should insert"); + let invalid_stream_request = match service + .llm_stream_next(Request::new(LlmStreamNextRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + continuation_id: stream_continuation, + request: Some(JsonEnvelope { + schema: LLM_REQUEST_SCHEMA.into(), + json: b"not-json".to_vec(), + }), + })) + .await + { + Ok(_) => panic!("invalid LLM stream request should fail"), + Err(error) => error, + }; + assert_eq!(invalid_stream_request.code(), tonic::Code::InvalidArgument); +} + +fn valid_llm_request() -> LlmRequest { + LlmRequest { + headers: serde_json::Map::new(), + content: json!({ "prompt": "unit" }), + } +} + +fn valid_llm_request_envelope() -> JsonEnvelope { + json_envelope(LLM_REQUEST_SCHEMA, &valid_llm_request()).expect("llm request envelope") +} + +async fn fake_callback_service( + invoke: impl Fn(InvokeRequest) -> InvokeResponse + Send + Sync + 'static, +) -> (WorkerPluginCallback, oneshot::Sender<()>) { + let (client, shutdown_tx) = fake_worker_client(invoke).await; + callback_for_client(client, shutdown_tx) +} + +async fn fake_callback_service_with_stream( + invoke: impl Fn(InvokeRequest) -> InvokeResponse + Send + Sync + 'static, + invoke_stream: impl Fn(InvokeRequest) -> FakeInvokeStream + Send + Sync + 'static, +) -> (WorkerPluginCallback, oneshot::Sender<()>) { + let (client, shutdown_tx) = fake_worker_client_with_stream(invoke, invoke_stream).await; + callback_for_client(client, shutdown_tx) +} + +fn callback_for_client( + client: PluginWorkerClient, + shutdown_tx: oneshot::Sender<()>, +) -> (WorkerPluginCallback, oneshot::Sender<()>) { + let state = Arc::new(WorkerHostRuntimeState::new( + ACTIVATION_ID.into(), + AUTH_TOKEN.into(), + )); + ( + WorkerPluginCallback { + activation_id: ACTIVATION_ID.into(), + runtime: tokio::runtime::Handle::current(), + client, + host_state: state, + }, + shutdown_tx, + ) +} + +async fn fake_worker_instance( + registrations: Vec, +) -> (WorkerPluginInstance, oneshot::Sender<()>) { + let (client, shutdown_tx) = fake_worker_client(|_| InvokeResponse { + result: Some(InvokeResult::Empty(EmptyResult {})), + }) + .await; + let activation_dir = std::env::temp_dir().join(format!("nmrw-unit-{}", Uuid::now_v7())); + std::fs::create_dir_all(&activation_dir).expect("unit activation dir should be created"); + ( + WorkerPluginInstance { + plugin_kind: "fixture_worker".into(), + allows_multiple_components: false, + config: serde_json::Map::new(), + validation_diagnostics: Vec::new(), + registrations, + runtime: OwnedWorkerRuntime::new( + RuntimeBuilder::new_multi_thread() + .enable_all() + .build() + .expect("worker runtime should build"), + ), + client, + host_state: Arc::new(WorkerHostRuntimeState::new( + ACTIVATION_ID.into(), + AUTH_TOKEN.into(), + )), + shutdown: Mutex::new(None), + process: Mutex::new(None), + activation_dir, + }, + shutdown_tx, + ) +} + +async fn fake_worker_client( + invoke: impl Fn(InvokeRequest) -> InvokeResponse + Send + Sync + 'static, +) -> (PluginWorkerClient, oneshot::Sender<()>) { + fake_worker_client_with_stream(invoke, |_| { + Box::pin(tokio_stream::iter(vec![Err(Status::unavailable( + "stream transport down", + ))])) as FakeInvokeStream + }) + .await +} + +async fn fake_worker_client_with_stream( + invoke: impl Fn(InvokeRequest) -> InvokeResponse + Send + Sync + 'static, + invoke_stream: impl Fn(InvokeRequest) -> FakeInvokeStream + Send + Sync + 'static, +) -> (PluginWorkerClient, oneshot::Sender<()>) { + let listener = tokio::net::TcpListener::bind(("127.0.0.1", 0)) + .await + .expect("fake worker listener should bind"); + let addr = listener + .local_addr() + .expect("fake worker listener address should be available"); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + tokio::spawn( + Server::builder() + .add_service(PluginWorkerServer::new(FakePluginWorker { + invoke: Arc::new(invoke), + invoke_stream: Arc::new(invoke_stream), + })) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }), + ); + let client = PluginWorkerClient::connect(format!("http://{addr}")) + .await + .expect("fake worker client should connect"); + (client, shutdown_tx) +} + +fn registration(surface: RegistrationSurface, local_name: &str) -> Registration { + Registration { + local_name: local_name.into(), + surface: surface as i32, + priority: 0, + break_chain: false, + } +} + +fn poison_mutex(f: impl FnOnce() + std::panic::UnwindSafe) { + let _ = std::panic::catch_unwind(f); +} + +struct FakePluginWorker { + invoke: Arc InvokeResponse + Send + Sync>, + invoke_stream: Arc FakeInvokeStream + Send + Sync>, +} + +type FakeInvokeStream = + Pin> + Send>>; + +struct SignalChunkThenPendingStream { + yield_rx: oneshot::Receiver<()>, + dropped: Option>, + yielded: bool, +} + +impl tokio_stream::Stream for SignalChunkThenPendingStream { + type Item = std::result::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if self.yielded { + return std::task::Poll::Pending; + } + match Pin::new(&mut self.yield_rx).poll(cx) { + std::task::Poll::Ready(_) => { + self.yielded = true; + std::task::Poll::Ready(Some(Ok(StreamChunk { + item: Some(StreamItem::Value( + json_envelope(JSON_SCHEMA, &json!({ "after_receiver_drop": true })) + .expect("test stream chunk should encode"), + )), + }))) + } + std::task::Poll::Pending => std::task::Poll::Pending, + } + } +} + +impl Drop for SignalChunkThenPendingStream { + fn drop(&mut self) { + if let Some(dropped) = self.dropped.take() { + let _ = dropped.send(()); + } + } +} + +#[tonic::async_trait] +impl PluginWorker for FakePluginWorker { + async fn handshake( + &self, + _request: Request, + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new(HandshakeResponse { + plugin_id: "fixture_worker".into(), + plugin_kind: "fixture_worker".into(), + allows_multiple_components: false, + worker_protocol: WORKER_PROTOCOL_GRPC_V1.into(), + sdk_name: "unit".into(), + sdk_version: "0".into(), + runtime_name: "unit".into(), + runtime_version: "0".into(), + supported_surfaces: Vec::new(), + })) + } + + async fn health( + &self, + _request: Request, + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new(HealthResponse { + ok: true, + message: String::new(), + plugin_id: "fixture_worker".into(), + worker_protocol: WORKER_PROTOCOL_GRPC_V1.into(), + sdk_name: "unit".into(), + sdk_version: "0".into(), + runtime_name: "unit".into(), + runtime_version: "0".into(), + })) + } + + async fn validate( + &self, + _request: Request, + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new(ValidateResponse { + diagnostics: None, + error: None, + })) + } + + async fn register( + &self, + _request: Request, + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new(RegisterResponse { + registrations: Vec::new(), + error: None, + })) + } + + async fn invoke( + &self, + request: Request, + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new((self.invoke)(request.into_inner()))) + } + + type InvokeStreamStream = + Pin> + Send>>; + + async fn invoke_stream( + &self, + request: Request, + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new((self.invoke_stream)( + request.into_inner(), + ))) + } + + async fn cancel_invocation( + &self, + _request: Request, + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new(WorkerAck { + accepted: false, + message: "not implemented".into(), + })) + } + + async fn shutdown( + &self, + _request: Request, + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new(WorkerAck { + accepted: false, + message: "not implemented".into(), + })) + } +} diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index 357f4265..4bb394c0 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -11,7 +11,6 @@ use std::future::Future; use std::net::{SocketAddr, ToSocketAddrs}; #[cfg(unix)] use std::os::unix::fs::FileTypeExt; -#[cfg(unix)] use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::{Arc, Mutex}; @@ -38,9 +37,11 @@ use nemo_relay_worker_proto::v1::{ WorkerError, }; use nemo_relay_worker_proto::{WORKER_PROTOCOL_GRPC_V1, decode_json_envelope, json_envelope}; +use tokio::net::TcpListener; #[cfg(unix)] use tokio::net::{UnixListener, UnixStream}; use tokio::sync::OnceCell; +use tokio_stream::wrappers::TcpListenerStream; #[cfg(unix)] use tokio_stream::wrappers::UnixListenerStream; use tonic::transport::{Channel, Endpoint, Server}; @@ -501,6 +502,24 @@ impl PluginRuntime { ack_to_result(response.ok, response.error) } + /// Runs an async operation with runtime calls bound to a specific host-owned scope stack. + /// + /// This is useful for isolated stacks created with [`Self::create_scope_stack`]. The previous + /// worker invocation scope is restored after the future completes. + pub async fn with_scope_stack(&self, scope_stack_id: &str, f: F) -> T + where + F: FnOnce() -> Fut, + Fut: Future, + { + let scope = Some(scope_context(scope_stack_id)); + TASK_SCOPE_CONTEXT + .scope(scope.clone(), async move { + let future = with_thread_scope(&scope, f); + future.await + }) + .await + } + /// Pushes a scope through the host runtime. pub async fn push_scope( &self, @@ -683,7 +702,12 @@ pub async fn serve_plugin_arc(plugin: Arc) -> Result<()> { activation_id: required_env("NEMO_RELAY_WORKER_ID")?, auth_token: required_env("NEMO_RELAY_WORKER_TOKEN")?, }; - serve_plugin_arc_with_config(plugin, config).await + serve_plugin_arc_with_endpoint_file( + plugin, + config, + optional_env("NEMO_RELAY_WORKER_ENDPOINT_FILE").map(PathBuf::from), + ) + .await } /// Serves a shared worker plugin using explicit endpoint and authentication configuration. @@ -697,6 +721,14 @@ pub async fn serve_plugin_arc(plugin: Arc) -> Result<()> { pub async fn serve_plugin_arc_with_config( plugin: Arc, config: WorkerServerConfig, +) -> Result<()> { + serve_plugin_arc_with_endpoint_file(plugin, config, None).await +} + +async fn serve_plugin_arc_with_endpoint_file( + plugin: Arc, + config: WorkerServerConfig, + endpoint_file: Option, ) -> Result<()> { let runtime = PluginRuntime { activation_id: config.activation_id, @@ -709,11 +741,15 @@ pub async fn serve_plugin_arc_with_config( runtime, handlers: Arc::new(Mutex::new(WorkerHandlers::default())), }; - serve_worker_service(service, &config.worker_endpoint).await + serve_worker_service(service, &config.worker_endpoint, endpoint_file.as_deref()).await } #[cfg(unix)] -async fn serve_worker_service(service: WorkerService, endpoint: &str) -> Result<()> { +async fn serve_worker_service( + service: WorkerService, + endpoint: &str, + endpoint_file: Option<&Path>, +) -> Result<()> { if endpoint.starts_with("unix://") { let path = parse_unix_endpoint(endpoint)?; remove_stale_socket(&path)?; @@ -726,24 +762,41 @@ async fn serve_worker_service(service: WorkerService, endpoint: &str) -> Result< .await .map_err(|err| WorkerSdkError::Transport(err.to_string())); } - serve_tcp_worker_service(service, endpoint).await + serve_tcp_worker_service(service, endpoint, endpoint_file).await } #[cfg(not(unix))] -async fn serve_worker_service(service: WorkerService, endpoint: &str) -> Result<()> { +async fn serve_worker_service( + service: WorkerService, + endpoint: &str, + endpoint_file: Option<&Path>, +) -> Result<()> { if endpoint.starts_with("unix://") { return Err(WorkerSdkError::InvalidInput( "unix endpoints are not supported on this platform".into(), )); } - serve_tcp_worker_service(service, endpoint).await + serve_tcp_worker_service(service, endpoint, endpoint_file).await } -async fn serve_tcp_worker_service(service: WorkerService, endpoint: &str) -> Result<()> { +async fn serve_tcp_worker_service( + service: WorkerService, + endpoint: &str, + endpoint_file: Option<&Path>, +) -> Result<()> { let addr = parse_tcp_endpoint(endpoint)?; + let listener = TcpListener::bind(addr) + .await + .map_err(|err| WorkerSdkError::Transport(format!("failed to bind worker socket: {err}")))?; + if let Some(path) = endpoint_file { + let local_addr = listener.local_addr().map_err(|err| { + WorkerSdkError::Transport(format!("failed to inspect worker socket: {err}")) + })?; + write_endpoint_file(path, &format!("http://{local_addr}"))?; + } Server::builder() .add_service(PluginWorkerServer::new(service)) - .serve(addr) + .serve_with_incoming(TcpListenerStream::new(listener)) .await .map_err(|err| WorkerSdkError::Transport(err.to_string())) } @@ -1605,6 +1658,29 @@ fn required_env(name: &str) -> Result { }) } +fn optional_env(name: &str) -> Option { + std::env::var(name) + .ok() + .filter(|value| !value.trim().is_empty()) +} + +fn write_endpoint_file(path: &Path, endpoint: &str) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|err| { + WorkerSdkError::Transport(format!( + "failed to create worker endpoint file directory '{}': {err}", + parent.display() + )) + })?; + } + std::fs::write(path, endpoint).map_err(|err| { + WorkerSdkError::Transport(format!( + "failed to write worker endpoint file '{}': {err}", + path.display() + )) + }) +} + fn rustc_version_runtime() -> String { option_env!("RUSTC_VERSION") .unwrap_or("unknown") diff --git a/crates/worker/tests/worker_sdk_tests.rs b/crates/worker/tests/worker_sdk_tests.rs index 6941b71f..30ae616a 100644 --- a/crates/worker/tests/worker_sdk_tests.rs +++ b/crates/worker/tests/worker_sdk_tests.rs @@ -5,14 +5,13 @@ use std::future::Future; use std::net::{SocketAddr, TcpListener}; +use std::path::Path; #[cfg(unix)] -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use std::time::Duration; -#[cfg(unix)] -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use futures_util::{Stream, StreamExt}; #[cfg(unix)] @@ -453,6 +452,8 @@ async fn worker_service_invokes_every_registration_surface() { let calls = host.calls(); assert!(calls.contains(&"mark:tool-exec:stack-1:parent-1".into())); assert!(calls.contains(&"create_scope_stack".into())); + assert!(calls.contains(&"mark:tool-exec-isolated:isolated-stack:".into())); + assert!(calls.contains(&"mark:tool-exec-restored:stack-1:parent-1".into())); assert!(calls.contains(&"push:worker-scope:stack-1:parent-1".into())); assert!(calls.contains(&"pop:scope-handle-1".into())); assert!(calls.contains(&"drop:isolated-stack".into())); @@ -689,6 +690,46 @@ async fn worker_service_validates_env_and_endpoints() { } } +#[tokio::test(flavor = "multi_thread")] +async fn worker_service_announces_ephemeral_tcp_endpoint_file() { + const ENVS: &[&str] = &[ + "NEMO_RELAY_WORKER_SOCKET", + "NEMO_RELAY_HOST_SOCKET", + "NEMO_RELAY_WORKER_ID", + "NEMO_RELAY_WORKER_TOKEN", + "NEMO_RELAY_WORKER_ENDPOINT_FILE", + ]; + let _env_guard = ENV_LOCK.lock().await; + let snapshot = EnvSnapshot::capture(ENVS); + let endpoint_file = unique_temp_file("nrw-endpoint"); + let _ = std::fs::remove_file(&endpoint_file); + + set_required_envs(); + set_env("NEMO_RELAY_WORKER_SOCKET", "tcp://127.0.0.1:0"); + set_env( + "NEMO_RELAY_WORKER_ENDPOINT_FILE", + endpoint_file.to_str().expect("endpoint path utf-8"), + ); + let handle = tokio::spawn(serve_plugin_arc(Arc::new(MinimalPlugin))); + let endpoint = wait_for_endpoint_file(&endpoint_file).await; + assert!(endpoint.starts_with("http://127.0.0.1:")); + + let mut client = connect_worker(&endpoint).await; + let health = client + .health(Request::new(HealthRequest { + activation_id: ACTIVATION_ID.into(), + auth_token: AUTH_TOKEN.into(), + })) + .await + .expect("announced endpoint should accept connections") + .into_inner(); + assert!(health.ok); + + handle.abort(); + let _ = std::fs::remove_file(endpoint_file); + snapshot.restore(); +} + #[cfg(unix)] #[tokio::test(flavor = "multi_thread")] async fn worker_service_supports_unix_socket_worker_and_host_endpoints() { @@ -1202,6 +1243,15 @@ impl WorkerPlugin for SurfacePlugin { async move { runtime.emit_mark("tool-exec", None, None).await?; let stack_id = runtime.create_scope_stack().await?; + let isolated_runtime = runtime.clone(); + runtime + .with_scope_stack(&stack_id, || async move { + isolated_runtime + .emit_mark("tool-exec-isolated", None, None) + .await + }) + .await?; + runtime.emit_mark("tool-exec-restored", None, None).await?; let handle = runtime .push_scope(None, "worker-scope", ScopeType::Function, None, None, None) .await?; @@ -1685,6 +1735,16 @@ async fn wait_for_port(endpoint: &str) { panic!("server did not start at {endpoint}"); } +async fn wait_for_endpoint_file(path: &Path) -> String { + for _ in 0..50 { + match std::fs::read_to_string(path) { + Ok(endpoint) if !endpoint.trim().is_empty() => return endpoint, + Ok(_) | Err(_) => std::thread::sleep(Duration::from_millis(20)), + } + } + panic!("endpoint file was not written at {}", path.display()); +} + #[cfg(unix)] async fn wait_for_unix_socket(path: &Path) { for _ in 0..50 { @@ -2072,6 +2132,14 @@ fn remove_env(name: &str) { } } +fn unique_temp_file(prefix: &str) -> std::path::PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("{prefix}-{}-{nanos}", std::process::id())) +} + #[cfg(unix)] fn unique_temp_path(prefix: &str) -> std::path::PathBuf { let nanos = SystemTime::now()