Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
419 changes: 410 additions & 9 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ zstd-framed = { version = "0.1.1", features = ["tokio"] }
# potentially very old system version.
openssl = { version = "0.10", features = ["vendored"] }

tract-onnx = "0.21"

[workspace.dependencies.wasmtime]
version = "39"
default-features = false
Expand Down
1 change: 1 addition & 0 deletions crates/bindings-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ bench = false

[features]
unstable = []
onnx = []

[dependencies]
spacetimedb-primitives.workspace = true
Expand Down
131 changes: 131 additions & 0 deletions crates/bindings-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,75 @@ pub mod raw {
) -> u16;
}

#[cfg(feature = "onnx")]
#[link(wasm_import_module = "spacetime_10.5")]
unsafe extern "C" {
/// Runs ONNX inference on a model identified by name.
///
/// `name_ptr[..name_len]` is a UTF-8 model name (e.g. `"bot_brain"`).
/// The host resolves this to a `.onnx` file on its filesystem,
/// loads and caches the model on first use. Model bytes never enter WASM memory.
///
/// `input_ptr[..input_len]` should contain a BSATN-encoded `Vec<spacetimedb_lib::onnx::Tensor>`.
///
/// On success, a [`BytesSource`] is written to `out[0]` containing a BSATN-encoded
/// `Vec<spacetimedb_lib::onnx::Tensor>` with the inference output, and this function returns 0.
///
/// # Traps
///
/// Traps if:
/// - `name_ptr` is NULL or `name_ptr[..name_len]` is not in bounds of WASM memory.
/// - `name_ptr[..name_len]` is not valid UTF-8.
/// - `input_ptr` is NULL or `input_ptr[..input_len]` is not in bounds of WASM memory.
/// - `out` is NULL or `out[..size_of::<u32>()]` is not in bounds of WASM memory.
///
/// # Errors
///
/// Returns an error:
///
/// - `ONNX_ERROR` if the model could not be found, loaded, or inference failed.
/// In this case, a [`BytesSource`] containing a BSATN-encoded error message `String`
/// is written to `out[0]`.
pub fn onnx_run(
name_ptr: *const u8,
name_len: u32,
input_ptr: *const u8,
input_len: u32,
out: *mut u32,
) -> u16;

/// Runs ONNX inference on multiple batches of inputs for a single model.
///
/// `name_ptr[..name_len]` is a UTF-8 model name.
/// `input_ptr[..input_len]` should contain a BSATN-encoded `Vec<Vec<spacetimedb_lib::onnx::Tensor>>`.
///
/// On success, a [`BytesSource`] is written to `out[0]` containing a BSATN-encoded
/// `Vec<Vec<spacetimedb_lib::onnx::Tensor>>` with the inference outputs, and this function returns 0.
///
/// # Traps
///
/// Traps if:
/// - `name_ptr` is NULL or `name_ptr[..name_len]` is not in bounds of WASM memory.
/// - `name_ptr[..name_len]` is not valid UTF-8.
/// - `input_ptr` is NULL or `input_ptr[..input_len]` is not in bounds of WASM memory.
/// - `out` is NULL or `out[..size_of::<u32>()]` is not in bounds of WASM memory.
///
/// # Errors
///
/// Returns an error:
///
/// - `ONNX_ERROR` if the model could not be found, loaded, or inference failed.
/// In this case, a [`BytesSource`] containing a BSATN-encoded error message `String`
/// is written to `out[0]`.
pub fn onnx_run_multi(
name_ptr: *const u8,
name_len: u32,
input_ptr: *const u8,
input_len: u32,
out: *mut u32,
) -> u16;
}

/// What strategy does the database index use?
///
/// See also: <https://www.postgresql.org/docs/current/sql-createindex.html>
Expand Down Expand Up @@ -1626,3 +1695,65 @@ pub mod procedure {
}
}
}

/// ONNX inference operations, available from both reducers and procedures.
#[cfg(feature = "onnx")]
pub mod onnx {
use super::raw;

/// Run ONNX inference on a named model with BSATN-encoded input tensors.
///
/// The host loads and caches the model on first use.
/// `input_bsatn` should be a BSATN-encoded `Vec<spacetimedb_lib::onnx::Tensor>`.
///
/// On success, returns `Ok(bytes_source)` containing BSATN-encoded output tensors.
/// On failure, returns `Err(bytes_source)` containing a BSATN-encoded error message.
#[inline]
pub fn run(name: &str, input_bsatn: &[u8]) -> Result<raw::BytesSource, raw::BytesSource> {
let mut out = [raw::BytesSource::INVALID; 1];

let res = unsafe {
super::raw::onnx_run(
name.as_ptr(),
name.len() as u32,
input_bsatn.as_ptr(),
input_bsatn.len() as u32,
out.as_mut_ptr().cast(),
)
};

match super::Errno::from_code(res) {
None => Ok(out[0]),
Some(errno) if errno == super::Errno::ONNX_ERROR => Err(out[0]),
Some(errno) => panic!("{errno}"),
}
}

/// Run ONNX inference on multiple batches of inputs for a named model.
///
/// The host loads and caches the model on first use.
/// `input_bsatn` should be a BSATN-encoded `Vec<Vec<spacetimedb_lib::onnx::Tensor>>`.
///
/// On success, returns `Ok(bytes_source)` containing BSATN-encoded `Vec<Vec<Tensor>>`.
/// On failure, returns `Err(bytes_source)` containing a BSATN-encoded error message.
#[inline]
pub fn run_multi(name: &str, input_bsatn: &[u8]) -> Result<raw::BytesSource, raw::BytesSource> {
let mut out = [raw::BytesSource::INVALID; 1];

let res = unsafe {
super::raw::onnx_run_multi(
name.as_ptr(),
name.len() as u32,
input_bsatn.as_ptr(),
input_bsatn.len() as u32,
out.as_mut_ptr().cast(),
)
};

match super::Errno::from_code(res) {
None => Ok(out[0]),
Some(errno) if errno == super::Errno::ONNX_ERROR => Err(out[0]),
Some(errno) => panic!("{errno}"),
}
}
}
1 change: 1 addition & 0 deletions crates/bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ default = ["rand"]
rand = ["rand08"]
rand08 = ["dep:rand08", "dep:getrandom02"]
unstable = ["spacetimedb-bindings-sys/unstable"]
onnx = ["spacetimedb-bindings-sys/onnx"]

[dependencies]
spacetimedb-bindings-sys.workspace = true
Expand Down
16 changes: 16 additions & 0 deletions crates/bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ mod client_visibility_filter;
#[cfg(feature = "unstable")]
pub mod http;
pub mod log_stopwatch;
#[cfg(feature = "onnx")]
pub mod onnx;
mod logger;
#[cfg(feature = "rand08")]
mod rng;
Expand Down Expand Up @@ -1005,6 +1007,10 @@ pub struct ReducerContext {
/// See the [`#[table]`](macro@crate::table) macro for more information.
pub db: Local,

/// Methods for performing ONNX inference.
#[cfg(feature = "onnx")]
pub onnx: crate::onnx::OnnxClient,

#[cfg(feature = "rand08")]
rng: std::cell::OnceCell<StdbRng>,
/// A counter used for generating UUIDv7 values.
Expand All @@ -1018,6 +1024,8 @@ impl ReducerContext {
pub fn __dummy() -> Self {
Self {
db: Local {},
#[cfg(feature = "onnx")]
onnx: crate::onnx::OnnxClient {},
sender: Identity::__dummy(),
timestamp: Timestamp::UNIX_EPOCH,
connection_id: None,
Expand All @@ -1033,6 +1041,8 @@ impl ReducerContext {
fn new(db: Local, sender: Identity, connection_id: Option<ConnectionId>, timestamp: Timestamp) -> Self {
Self {
db,
#[cfg(feature = "onnx")]
onnx: crate::onnx::OnnxClient {},
sender,
timestamp,
connection_id,
Expand Down Expand Up @@ -1179,6 +1189,10 @@ pub struct ProcedureContext {

/// Methods for performing HTTP requests.
pub http: crate::http::HttpClient,

/// Methods for performing ONNX inference.
#[cfg(feature = "onnx")]
pub onnx: crate::onnx::OnnxClient,
// TODO: Change rng?
// Complex and requires design because we may want procedure RNG to behave differently from reducer RNG,
// as it could actually be seeded by OS randomness rather than a deterministic source.
Expand All @@ -1199,6 +1213,8 @@ impl ProcedureContext {
timestamp,
connection_id,
http: http::HttpClient {},
#[cfg(feature = "onnx")]
onnx: crate::onnx::OnnxClient {},
#[cfg(feature = "rand08")]
rng: std::cell::OnceCell::new(),
#[cfg(feature = "rand")]
Expand Down
94 changes: 94 additions & 0 deletions crates/bindings/src/onnx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//! ONNX inference support for SpacetimeDB modules.
//!
//! Run ONNX model inference from within reducers or procedures.
//! Models are stored on the host filesystem — the model bytes never enter WASM memory.
//! Models are cached on the host after first load.
//!
//! # Example
//!
//! ```no_run
//! # use spacetimedb::{reducer, ReducerContext, onnx::{OnnxClient, Tensor}};
//! // In a reducer:
//! # #[reducer]
//! # fn my_reducer(ctx: &ReducerContext) {
//! let input = vec![Tensor {
//! shape: vec![1, 10],
//! data: vec![0.0; 10],
//! }];
//! let output = ctx.onnx.run("bot_brain", &input).expect("Inference failed");
//! log::info!("Output: {:?}", output[0].data);
//! # }
//! ```

use crate::rt::read_bytes_source_as;
use spacetimedb_lib::bsatn;

pub use spacetimedb_lib::onnx::Tensor;

/// Client for performing ONNX inference.
///
/// Access from within reducers via [`ReducerContext::onnx`](crate::ReducerContext)
/// or from procedures via [`ProcedureContext::onnx`](crate::ProcedureContext).
#[non_exhaustive]
pub struct OnnxClient {}

impl OnnxClient {
/// Run inference on a named ONNX model.
///
/// The host resolves `model_name` to a `.onnx` file on its filesystem,
/// loads and caches it on first use, then runs inference with the given inputs.
/// Model bytes never enter WASM memory — only tensor data crosses the boundary.
///
/// `inputs` are the input tensors for the model, in the order expected by the model's input nodes.
/// Returns the output tensors from the model.
pub fn run(&self, model_name: &str, inputs: &[Tensor]) -> Result<Vec<Tensor>, Error> {
let input_bsatn = bsatn::to_vec(inputs).expect("Failed to BSATN-serialize input tensors");

match spacetimedb_bindings_sys::onnx::run(model_name, &input_bsatn) {
Ok(output_source) => {
let output = read_bytes_source_as::<Vec<Tensor>>(output_source);
Ok(output)
}
Err(err_source) => {
let message = read_bytes_source_as::<String>(err_source);
Err(Error { message })
}
}
}

/// Run inference on multiple batches of inputs in a single host call.
///
/// Each element of `batches` is one set of input tensors (one inference invocation).
/// Returns one `Vec<Tensor>` of outputs per batch, in the same order.
///
/// This is more efficient than calling [`run`](Self::run) in a loop because it
/// crosses the WASM boundary only once for all batches.
pub fn run_multi(&self, model_name: &str, batches: &[Vec<Tensor>]) -> Result<Vec<Vec<Tensor>>, Error> {
let input_bsatn = bsatn::to_vec(batches).expect("Failed to BSATN-serialize input tensor batches");

match spacetimedb_bindings_sys::onnx::run_multi(model_name, &input_bsatn) {
Ok(output_source) => {
let output = read_bytes_source_as::<Vec<Vec<Tensor>>>(output_source);
Ok(output)
}
Err(err_source) => {
let message = read_bytes_source_as::<String>(err_source);
Err(Error { message })
}
}
}
}

/// An error from ONNX model loading or inference.
#[derive(Clone, Debug)]
pub struct Error {
message: String,
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str(&self.message)
}
}

impl std::error::Error for Error {}
3 changes: 3 additions & 0 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ url.workspace = true
urlencoding.workspace = true
uuid.workspace = true
v8.workspace = true
tract-onnx = { workspace = true, optional = true }
wasmtime.workspace = true
wasmtime-internal-fiber.workspace = true
jwks.workspace = true
Expand Down Expand Up @@ -150,6 +151,7 @@ perfmap = []
# Disables core pinning
no-core-pinning = []
no-job-core-pinning = []
onnx = ["dep:tract-onnx"]

[dev-dependencies]
spacetimedb-lib = { path = "../lib", features = ["proptest", "test"] }
Expand All @@ -168,6 +170,7 @@ pretty_assertions.workspace = true
jsonwebtoken.workspace = true
axum.workspace = true
fs_extra.workspace = true
prost = "0.11"

[lints]
workspace = true
3 changes: 3 additions & 0 deletions crates/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ pub enum NodesError {
ScheduleError(#[source] ScheduleError),
#[error("HTTP request failed: {0}")]
HttpError(String),
#[cfg(feature = "onnx")]
#[error("ONNX inference failed: {0}")]
OnnxError(String),
}

impl From<DBError> for NodesError {
Expand Down
Loading