diff --git a/crates/duckdb/src/params.rs b/crates/duckdb/src/params.rs index efb4d9d4..80b29eef 100644 --- a/crates/duckdb/src/params.rs +++ b/crates/duckdb/src/params.rs @@ -1,4 +1,5 @@ -use crate::{Result, Statement, ToSql}; +use crate::{Error, Result, Statement, ToSql}; +use std::collections::HashMap; mod sealed { /// This trait exists just to ensure that the only impls of `trait Params` @@ -86,6 +87,36 @@ use sealed::Sealed; /// } /// ``` /// +/// ## Named parameters +/// +/// If you need named parameters, they can be passed in one of two ways: +/// +/// - As a `&HashMap` +/// - If the `serde_json` feature is enabled, use `serde_json::Value`. +/// +/// In both cases the keys should _not_ include the `$` prefix. +/// +/// ### Example (named parameters) +/// +/// ```rust,no_run +/// #[cfg(feature = "serde_json")] +/// { +/// use fallible_iterator::FallibleIterator; +/// use duckdb::{Connection, Result}; +/// use serde_json::json; +/// fn execute_query(conn: Connection) -> Result> { +/// let params = json!({ +/// "min": 23, +/// "max": 42, +/// }); +/// conn +/// .prepare("SELECT name FROM people WHERE age BETWEEN $min AND $max")? +/// .query(params)?.map(|row| row.get(0)) +/// .collect() +/// } +/// } +/// ``` +/// /// ## No parameters /// /// You can just use an empty array literal for no params. The @@ -303,3 +334,65 @@ where stmt.bind_parameters(self.0) } } + +impl Sealed for &HashMap {} + +impl Params for &HashMap { + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + let n = stmt.parameter_count(); + let params: Vec<_> = (1..=n) + .map(|i| { + let name = stmt.parameter_name(i)?; + let val = *self.get(&name).ok_or(Error::InvalidParameterName(name))?; + Ok(val) + }) + .collect::, Error>>()?; + stmt.bind_parameters(¶ms)?; + Ok(()) + } +} + +impl Sealed for HashMap {} + +impl Params for HashMap { + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + (&self).__bind_in(stmt) + } +} + +#[cfg(feature = "serde_json")] +mod serde_json_support { + use super::*; + use crate::{types::ToSql, Params, Statement}; + use serde_json::Value; + + impl Sealed for &Value {} + + impl Params for &Value { + fn __bind_in(self, stmt: &mut Statement<'_>) -> crate::Result<()> { + match self { + Value::Object(ref map) => { + let n = stmt.parameter_count(); + let params: Vec<&dyn ToSql> = (1..=n) + .map(|i| { + let name = stmt.parameter_name(i)?; + let val = map.get(&name).ok_or(Error::InvalidParameterName(name))?; + Ok(val as &dyn ToSql) + }) + .collect::, Error>>()?; + stmt.bind_parameters(¶ms) + } + Value::Array(ref array) => stmt.bind_parameters(&array[..]), + ref other => stmt.bind_parameters(&[other as &dyn ToSql]), + } + } + } + + impl Sealed for Value {} + + impl Params for Value { + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + (&self).__bind_in(stmt) + } + } +} diff --git a/crates/duckdb/src/statement.rs b/crates/duckdb/src/statement.rs index af8dc7b4..eee9706a 100644 --- a/crates/duckdb/src/statement.rs +++ b/crates/duckdb/src/statement.rs @@ -1120,6 +1120,30 @@ mod test { Ok(()) } + #[test] + fn test_named_parameters() -> Result<()> { + #[cfg(not(feature = "serde_json"))] + use std::collections::HashMap; + + #[cfg(not(feature = "serde_json"))] + let named_params: &HashMap = &HashMap::from([ + ("foo".to_string(), &42 as &dyn ToSql), + ("bar".to_string(), &23 as &dyn ToSql), + ]); + + #[cfg(feature = "serde_json")] + let named_params: serde_json::Value = serde_json::json!({ + "foo": 42, + "bar": 23 + }); + + let db = Connection::open_in_memory()?; + let sql = r#"SELECT $foo > $bar"#; + let result: bool = db.query_row(sql, named_params, |row| row.get(0))?; + assert_eq!(result, true); + Ok(()) + } + #[test] fn test_empty_stmt() -> Result<()> { let conn = Connection::open_in_memory()?;