From d181f6c8cca6932acb319b9a5ec75968b9c5ee96 Mon Sep 17 00:00:00 2001 From: Randolf Jung Date: Mon, 18 May 2026 02:02:03 -0700 Subject: [PATCH] feat: support borrowed parameter types via borrowed_rs_type override Adds a `borrowed_rs_type` field to type overrides so callers can take parameters by reference without forcing rows to also be borrowed. When set, the override participates in borrowed mode for scalar parameter positions; rows, array contents, and copyfrom items continue to use the owned form. Lifetimes are threaded via elision wherever possible: single-param functions take the borrowed type directly; params structs gain `<'a>` with the fn referencing it as `<'_>`; batch streams reuse the existing `'a`; copyfrom introduces `'a` because its where clause cannot elide. Vec wrappers around borrowed inners become `&[T_owned]` so a caller holding a `Vec` can pass `&v` without re-collecting. `rs_type` becomes optional when `borrowed_rs_type` is present; `borrowed_rs_type` is parsed at config load and rejected if it does not contain a reference, so a value like `Option` fails fast. --- Cargo.lock | 8 + Cargo.toml | 2 +- README.md | 51 ++- examples/advanced-types/src/queries.rs | 2 +- examples/basic/src/queries.rs | 2 +- examples/batch/src/queries.rs | 2 +- examples/borrowed/Cargo.toml | 8 + examples/borrowed/queries.sql | 14 + examples/borrowed/schema.sql | 5 + examples/borrowed/src/main.rs | 51 +++ examples/borrowed/src/queries.rs | 113 ++++++ examples/enums/src/queries.rs | 2 +- sqlc.yaml | 14 +- src/codegen/batch.rs | 39 ++- src/codegen/copyfrom.rs | 35 +- src/codegen/lifetimes.rs | 83 +++++ src/codegen/mod.rs | 1 + src/codegen/query.rs | 83 ++++- src/config.rs | 104 +++++- src/types.rs | 312 +++++++++++++---- tests/codegen.rs | 325 +++++++++++++++++- .../codegen__borrowed_array_param.snap | 53 +++ .../codegen__borrowed_batchexec.snap | 69 ++++ .../codegen__borrowed_batchexec_struct.snap | 77 +++++ .../snapshots/codegen__borrowed_copyfrom.snap | 80 +++++ .../codegen__borrowed_one_row_unaffected.snap | 61 ++++ .../codegen__borrowed_params_struct.snap | 62 ++++ tests/snapshots/codegen__borrowed_scalar.snap | 53 +++ .../codegen__borrowed_slice_param.snap | 53 +++ .../codegen__borrowed_with_custom_owned.snap | 61 ++++ 30 files changed, 1713 insertions(+), 112 deletions(-) create mode 100644 examples/borrowed/Cargo.toml create mode 100644 examples/borrowed/queries.sql create mode 100644 examples/borrowed/schema.sql create mode 100644 examples/borrowed/src/main.rs create mode 100644 examples/borrowed/src/queries.rs create mode 100644 src/codegen/lifetimes.rs create mode 100644 tests/snapshots/codegen__borrowed_array_param.snap create mode 100644 tests/snapshots/codegen__borrowed_batchexec.snap create mode 100644 tests/snapshots/codegen__borrowed_batchexec_struct.snap create mode 100644 tests/snapshots/codegen__borrowed_copyfrom.snap create mode 100644 tests/snapshots/codegen__borrowed_one_row_unaffected.snap create mode 100644 tests/snapshots/codegen__borrowed_params_struct.snap create mode 100644 tests/snapshots/codegen__borrowed_scalar.snap create mode 100644 tests/snapshots/codegen__borrowed_slice_param.snap create mode 100644 tests/snapshots/codegen__borrowed_with_custom_owned.snap diff --git a/Cargo.lock b/Cargo.lock index a761a0f..e3b1cce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -304,6 +304,14 @@ dependencies = [ "time", ] +[[package]] +name = "borrowed" +version = "0.1.0" +dependencies = [ + "sqlx", + "tokio", +] + [[package]] name = "buffa" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 8440234..eaafa5a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ proc-macro2 = "1" quote = "1" serde = { version = "1", features = ["derive"] } serde_json = { workspace = true } -syn = { version = "2", features = ["full"] } +syn = { version = "2", features = ["full", "visit-mut"] } [dev-dependencies] insta = { version = "1", features = ["filters"] } diff --git a/README.md b/README.md index e65f476..2ce2fb5 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ All options are passed in `codegen[*].options`: | Key | Type | Default | Description | |---|---|---|---| | `output` | string | `queries.rs` | Output filename | -| `overrides` | array | `[]` | Type overrides (see below) | +| `overrides` | array | `[]` | Type overrides (`rs_type`, optional `borrowed_rs_type`; see below) | | `row_derives` | array | `[]` | Extra derives for row and params structs | | `enum_derives` | array | `[]` | Extra derives for generated enum types | | `composite_derives` | array | `[]` | Extra derives for generated composite types | @@ -77,6 +77,55 @@ options: copy_cheap: false ``` +### Borrowed parameters + +Add `borrowed_rs_type` to a type or column override to take that type by +reference in parameter positions. Row struct fields, array contents, and the +`Item` of `:copyfrom` chunks continue to use the owned form: + +```yaml +options: + overrides: + - db_type: "text" + borrowed_rs_type: "&str" +``` + +With that override, generated signatures borrow scalar `text` parameters and +the codegen threads lifetimes only where needed: + +```rust +// Scalar — lifetime elided +pub async fn get_author_by_name( + mut db: E, name: &str, +) -> Result { ... } + +// Multiple params — struct carries `'a`, fn uses `'_` +pub struct CreateAuthorParams<'a> { + pub name: &'a str, + pub bio: Option<&'a str>, +} +pub async fn create_author( + mut db: E, arg: CreateAuthorParams<'_>, +) -> Result { ... } + +// Row struct stays owned — results are returned by value +pub struct GetAuthorByNameRow { pub name: String, /* ... */ } +``` + +`rs_type` is optional alongside `borrowed_rs_type`. Omit it to keep the +built-in owned default; set both to fully customize: + +```yaml +overrides: + - db_type: "text" + rs_type: "MyStr" # used for row fields & array contents + borrowed_rs_type: "&MyStr" # used for scalar params +``` + +For `text[]` and `sqlc.slice(text)` the wrapper becomes a borrowed slice while +the inner item stays owned (`&[String]`), so callers can pass `&my_vec` +directly without re-collecting. + ## Supported PostgreSQL types | PostgreSQL | Rust | diff --git a/examples/advanced-types/src/queries.rs b/examples/advanced-types/src/queries.rs index 036bd1d..ea9bef8 100644 --- a/examples/advanced-types/src/queries.rs +++ b/examples/advanced-types/src/queries.rs @@ -1,4 +1,4 @@ -// Code generated by sqlc-gen-sqlx v0.1.7. DO NOT EDIT. +// Code generated by sqlc-gen-sqlx v0.2.0. DO NOT EDIT. // sqlc version: v1.30.0 #![allow( diff --git a/examples/basic/src/queries.rs b/examples/basic/src/queries.rs index d31ddc2..f8a3d0f 100644 --- a/examples/basic/src/queries.rs +++ b/examples/basic/src/queries.rs @@ -1,4 +1,4 @@ -// Code generated by sqlc-gen-sqlx v0.1.7. DO NOT EDIT. +// Code generated by sqlc-gen-sqlx v0.2.0. DO NOT EDIT. // sqlc version: v1.30.0 #![allow( diff --git a/examples/batch/src/queries.rs b/examples/batch/src/queries.rs index fa9625e..88db2f6 100644 --- a/examples/batch/src/queries.rs +++ b/examples/batch/src/queries.rs @@ -1,4 +1,4 @@ -// Code generated by sqlc-gen-sqlx v0.1.7. DO NOT EDIT. +// Code generated by sqlc-gen-sqlx v0.2.0. DO NOT EDIT. // sqlc version: v1.30.0 #![allow( diff --git a/examples/borrowed/Cargo.toml b/examples/borrowed/Cargo.toml new file mode 100644 index 0000000..e64915d --- /dev/null +++ b/examples/borrowed/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "borrowed" +version = "0.1.0" +edition = "2024" + +[dependencies] +sqlx = { workspace = true } +tokio = { workspace = true } diff --git a/examples/borrowed/queries.sql b/examples/borrowed/queries.sql new file mode 100644 index 0000000..5960ce7 --- /dev/null +++ b/examples/borrowed/queries.sql @@ -0,0 +1,14 @@ +-- name: GetAuthor :one +SELECT id, name, bio FROM authors WHERE id = $1; + +-- name: GetAuthorByName :one +SELECT id, name, bio FROM authors WHERE name = $1; + +-- name: ListAuthors :many +SELECT id, name, bio FROM authors ORDER BY name; + +-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES ($1, $2) RETURNING id, name, bio; + +-- name: DeleteAuthor :exec +DELETE FROM authors WHERE id = $1; diff --git a/examples/borrowed/schema.sql b/examples/borrowed/schema.sql new file mode 100644 index 0000000..ab00390 --- /dev/null +++ b/examples/borrowed/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name TEXT NOT NULL, + bio TEXT +); diff --git a/examples/borrowed/src/main.rs b/examples/borrowed/src/main.rs new file mode 100644 index 0000000..8254807 --- /dev/null +++ b/examples/borrowed/src/main.rs @@ -0,0 +1,51 @@ +#[cfg(test)] +use sqlx::{Connection as _, PgConnection}; + +#[path = "queries.rs"] +#[cfg(test)] +mod queries; +#[cfg(test)] +use queries::CreateAuthorParams; + +#[cfg(test)] +#[tokio::test] +async fn test_borrowed_author_roundtrip() { + let db_url = std::env::var("DATABASE_URL") + .unwrap_or_else(|_| "postgres://sqlc:sqlc@localhost:5432/sqlc_test".to_string()); + let mut conn = PgConnection::connect(&db_url).await.expect("connect"); + + sqlx::query("CREATE TABLE IF NOT EXISTS authors (id BIGSERIAL PRIMARY KEY, name TEXT NOT NULL, bio TEXT)") + .execute(&mut conn) + .await + .unwrap(); + sqlx::query("TRUNCATE authors RESTART IDENTITY CASCADE") + .execute(&mut conn) + .await + .unwrap(); + + // Borrow string literals directly — no allocation needed. + let author = queries::create_author( + &mut conn, + CreateAuthorParams { + name: "Alice", + bio: Some("Loves Rust"), + }, + ) + .await + .expect("create"); + assert_eq!(author.name, "Alice"); + assert_eq!(author.bio.as_deref(), Some("Loves Rust")); + + // Borrow an owned String the caller already has. + let stored_name: String = author.name.clone(); + let fetched = queries::get_author_by_name(&mut conn, &stored_name) + .await + .expect("get_by_name"); + assert_eq!(fetched.id, author.id); + + queries::delete_author(&mut conn, author.id) + .await + .expect("delete"); +} + +fn main() {} diff --git a/examples/borrowed/src/queries.rs b/examples/borrowed/src/queries.rs new file mode 100644 index 0000000..0e5fbc7 --- /dev/null +++ b/examples/borrowed/src/queries.rs @@ -0,0 +1,113 @@ +// Code generated by sqlc-gen-sqlx v0.2.0. DO NOT EDIT. +// sqlc version: v1.30.0 + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +const GET_AUTHOR: &str = "SELECT id, name, bio FROM authors WHERE id = $1"; +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct GetAuthorRow { + pub id: i64, + pub name: String, + pub bio: Option, +} +pub async fn get_author(mut db: E, id: i64) -> Result { + sqlx::query_as::<_, GetAuthorRow>(GET_AUTHOR) + .bind(id) + .fetch_one(db.as_executor()) + .await +} +const GET_AUTHOR_BY_NAME: &str = "SELECT id, name, bio FROM authors WHERE name = $1"; +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct GetAuthorByNameRow { + pub id: i64, + pub name: String, + pub bio: Option, +} +pub async fn get_author_by_name( + mut db: E, + name: &str, +) -> Result { + sqlx::query_as::<_, GetAuthorByNameRow>(GET_AUTHOR_BY_NAME) + .bind(name) + .fetch_one(db.as_executor()) + .await +} +const LIST_AUTHORS: &str = "SELECT id, name, bio FROM authors ORDER BY name"; +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct ListAuthorsRow { + pub id: i64, + pub name: String, + pub bio: Option, +} +pub async fn list_authors(mut db: E) -> Result, sqlx::Error> { + sqlx::query_as::<_, ListAuthorsRow>(LIST_AUTHORS) + .fetch_all(db.as_executor()) + .await +} +#[derive(Debug, Clone)] +pub struct CreateAuthorParams<'a> { + pub name: &'a str, + pub bio: Option<&'a str>, +} +const CREATE_AUTHOR: &str = + "INSERT INTO authors (name, bio) VALUES ($1, $2) RETURNING id, name, bio"; +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct CreateAuthorRow { + pub id: i64, + pub name: String, + pub bio: Option, +} +pub async fn create_author( + mut db: E, + arg: CreateAuthorParams<'_>, +) -> Result { + sqlx::query_as::<_, CreateAuthorRow>(CREATE_AUTHOR) + .bind(arg.name) + .bind(arg.bio) + .fetch_one(db.as_executor()) + .await +} +const DELETE_AUTHOR: &str = "DELETE FROM authors WHERE id = $1"; +pub async fn delete_author(mut db: E, id: i64) -> Result<(), sqlx::Error> { + sqlx::query(DELETE_AUTHOR) + .bind(id) + .execute(db.as_executor()) + .await?; + Ok(()) +} diff --git a/examples/enums/src/queries.rs b/examples/enums/src/queries.rs index 35be388..c63ec81 100644 --- a/examples/enums/src/queries.rs +++ b/examples/enums/src/queries.rs @@ -1,4 +1,4 @@ -// Code generated by sqlc-gen-sqlx v0.1.7. DO NOT EDIT. +// Code generated by sqlc-gen-sqlx v0.2.0. DO NOT EDIT. // sqlc version: v1.30.0 #![allow( diff --git a/sqlc.yaml b/sqlc.yaml index 44e9123..289c900 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: sqlc-gen-sqlx wasm: url: file://target/wasm32-wasip1/debug/sqlc-gen-sqlx.wasm - sha256: "87780a1e4ebe3f7bfb3fdd4041e8f7127b5c35d540c40b456558d87f1ef62e5b" + sha256: "0405bc13685568ff34c6ff8fb43fc4fe474916b1175007c31b2485c99a6ee2a7" sql: - schema: examples/basic/schema.sql @@ -41,3 +41,15 @@ sql: out: examples/enums/src options: output: queries.rs + + - schema: examples/borrowed/schema.sql + queries: examples/borrowed/queries.sql + engine: postgresql + codegen: + - plugin: sqlc-gen-sqlx + out: examples/borrowed/src + options: + output: queries.rs + overrides: + - db_type: "text" + borrowed_rs_type: "&str" diff --git a/src/codegen/batch.rs b/src/codegen/batch.rs index 430069d..46923f2 100644 --- a/src/codegen/batch.rs +++ b/src/codegen/batch.rs @@ -3,37 +3,52 @@ use quote::{format_ident, quote}; use syn::parse_str; use crate::{ + codegen::lifetimes::inject_lifetime, config::Config, error::Error, ident::to_snake_case, plugin::QueryView, - types::{ResolvedType, TypeMap}, + types::{ColumnOverride, TypeMap}, }; use super::query::{ - bind_calls, dynamic_bind_statements, dynamic_sql_setup, has_dynamic_slice, maybe_params_struct, - resolve_columns, resolve_params, row_struct, sql_const, + Param, any_borrowed, bind_calls, dynamic_bind_statements, dynamic_sql_setup, has_dynamic_slice, + maybe_params_struct, resolve_columns, resolve_params, row_struct, sql_const, }; fn batch_items_type( query_name: &str, - params: &[super::query::Param], + params: &[Param], derives: &[String], ) -> Result<(Option, TokenStream), Error> { if params.len() >= 2 { let (struct_tokens, struct_ident) = maybe_params_struct(query_name, params, derives)? .expect("guarded by params.len() >= 2"); - Ok((Some(struct_tokens), quote! { #struct_ident })) + let item_ty = if any_borrowed(params) { + // The struct carries `<'a>`; the stream's existing `'a` is reused + // by referencing the struct as `StructName<'a>`. + quote! { #struct_ident<'a> } + } else { + quote! { #struct_ident } + }; + Ok((Some(struct_tokens), item_ty)) } else { let p = ¶ms[0]; - let ty: syn::Type = parse_str(&p.resolved.rust_type).map_err(|e| { - Error::Codegen(format!("invalid Rust type '{}': {e}", p.resolved.rust_type)) - })?; + // Single-param batch: if the param is borrowed we need to inject the + // stream's `'a` because `where Item = ...` cannot use elided + // lifetimes. + let ty_str = if let Some(borrowed) = &p.resolved.borrowed_rust_type { + inject_lifetime(borrowed, "'a")? + } else { + p.resolved.rust_type.clone() + }; + let ty: syn::Type = parse_str(&ty_str) + .map_err(|e| Error::Codegen(format!("invalid Rust type '{ty_str}': {e}")))?; Ok((None, quote! { #ty })) } } -fn single_item_alias(params: &[super::query::Param]) -> Option { +fn single_item_alias(params: &[Param]) -> Option { (params.len() == 1).then(|| { let ident = params[0].ident.clone(); quote! { let #ident = item; } @@ -75,7 +90,7 @@ pub fn gen_batchexec( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; if params.is_empty() { @@ -132,7 +147,7 @@ pub fn gen_batchone( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; if params.is_empty() { @@ -192,7 +207,7 @@ pub fn gen_batchmany( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; if params.is_empty() { diff --git a/src/codegen/copyfrom.rs b/src/codegen/copyfrom.rs index cc52e99..46b04cf 100644 --- a/src/codegen/copyfrom.rs +++ b/src/codegen/copyfrom.rs @@ -3,20 +3,21 @@ use quote::{format_ident, quote}; use syn::parse_str; use crate::{ + codegen::lifetimes::inject_lifetime, config::Config, error::Error, ident::to_snake_case, plugin::QueryView, - types::{ResolvedType, TypeMap}, + types::{ColumnOverride, TypeMap}, }; -use super::query::{Param, maybe_params_struct, resolve_params, sql_const}; +use super::query::{Param, any_borrowed, maybe_params_struct, resolve_params, sql_const}; pub fn gen_copyfrom( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; if params.is_empty() { @@ -31,26 +32,44 @@ pub fn gen_copyfrom( let insert_prefix = insert_prefix(query.text)?; let (const_tokens, const_name) = sql_const(query.name, &insert_prefix); let batch_size = std::cmp::max(1usize, 65_535usize / params.len()); + let has_borrowed = any_borrowed(¶ms); let (params_struct, items_ty, builder_binds) = if params.len() >= 2 { let (struct_tokens, struct_ident) = maybe_params_struct(query.name, ¶ms, &config.row_derives)? .expect("guarded by params.len() >= 2"); + let item_ty = if has_borrowed { + quote! { #struct_ident<'a> } + } else { + quote! { #struct_ident } + }; ( Some(struct_tokens), - quote! { #struct_ident }, + item_ty, push_bind_calls(¶ms, Some(&format_ident!("item"))), ) } else { let p = ¶ms[0]; - let ty: syn::Type = parse_str(&p.resolved.rust_type).map_err(|e| { - Error::Codegen(format!("invalid Rust type '{}': {e}", p.resolved.rust_type)) - })?; + let ty_str = if let Some(borrowed) = &p.resolved.borrowed_rust_type { + inject_lifetime(borrowed, "'a")? + } else { + p.resolved.rust_type.clone() + }; + let ty: syn::Type = parse_str(&ty_str) + .map_err(|e| Error::Codegen(format!("invalid Rust type '{ty_str}': {e}")))?; (None, quote! { #ty }, push_bind_calls(¶ms, None)) }; + // The `where Item = ...` clause cannot use anonymous lifetimes, so we + // introduce `'a` to the fn signature whenever any param is borrowed. + let generics = if has_borrowed { + quote! { <'a, E: AsExecutor, I> } + } else { + quote! { } + }; + let fn_tokens = quote! { - pub async fn #fn_name(mut db: E, items: I) -> Result + pub async fn #fn_name #generics (mut db: E, items: I) -> Result where I: IntoIterator, { diff --git a/src/codegen/lifetimes.rs b/src/codegen/lifetimes.rs new file mode 100644 index 0000000..77492f2 --- /dev/null +++ b/src/codegen/lifetimes.rs @@ -0,0 +1,83 @@ +use crate::error::Error; +use syn::visit_mut::VisitMut; + +/// Parse a borrowed Rust type string (e.g. `&str`, `Option<&str>`, +/// `&[String]`, `Option<&[i64]>`) and replace every anonymous or `'_` +/// reference lifetime with the named lifetime `lifetime`. +/// +/// Used when emitting borrowed types into positions that cannot rely on +/// elision: struct field types and `where Item = ...` clauses. +pub fn inject_lifetime(rust_type: &str, lifetime: &str) -> Result { + let mut ty: syn::Type = syn::parse_str(rust_type) + .map_err(|e| Error::Codegen(format!("invalid Rust type '{rust_type}': {e}")))?; + let mut visitor = LifetimeInjector { + lifetime: syn::Lifetime::new(lifetime, proc_macro2::Span::call_site()), + }; + visitor.visit_type_mut(&mut ty); + Ok(quote::quote!(#ty).to_string()) +} + +struct LifetimeInjector { + lifetime: syn::Lifetime, +} + +impl VisitMut for LifetimeInjector { + fn visit_type_reference_mut(&mut self, node: &mut syn::TypeReference) { + let needs_inject = match &node.lifetime { + None => true, + Some(lt) => lt.ident == "_", + }; + if needs_inject { + node.lifetime = Some(self.lifetime.clone()); + } + syn::visit_mut::visit_type_reference_mut(self, node); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn inject(s: &str) -> String { + inject_lifetime(s, "'a") + .unwrap() + .split_whitespace() + .collect::>() + .join(" ") + } + + #[test] + fn injects_into_bare_reference() { + assert_eq!(inject("&str"), "& 'a str"); + } + + #[test] + fn injects_into_anonymous_lifetime() { + assert_eq!(inject("&'_ str"), "& 'a str"); + } + + #[test] + fn injects_into_slice() { + assert_eq!(inject("&[String]"), "& 'a [String]"); + } + + #[test] + fn injects_inside_option() { + assert_eq!(inject("Option<&str>"), "Option < & 'a str >"); + } + + #[test] + fn injects_inside_option_slice() { + assert_eq!(inject("Option<&[i64]>"), "Option < & 'a [i64] >"); + } + + #[test] + fn leaves_named_lifetime_alone() { + assert_eq!(inject("&'b str"), "& 'b str"); + } + + #[test] + fn injects_inside_nested_borrowed_slice() { + assert_eq!(inject("&[&str]"), "& 'a [& 'a str]"); + } +} diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 0ff98d2..1c4cb1f 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -7,6 +7,7 @@ mod batch; mod composites; mod copyfrom; mod enums; +pub(crate) mod lifetimes; mod query; pub fn generate(request: &GenerateRequestView<'_>, config: &Config) -> Result { diff --git a/src/codegen/query.rs b/src/codegen/query.rs index 3d95ecb..15e90c2 100644 --- a/src/codegen/query.rs +++ b/src/codegen/query.rs @@ -3,11 +3,12 @@ use quote::{format_ident, quote}; use syn::parse_str; use crate::{ + codegen::lifetimes::inject_lifetime, config::Config, error::Error, ident::{field_ident, query_params_name, to_pascal_case, to_snake_case, type_ident}, plugin::{ColumnView, ParameterView, QueryView}, - types::{ResolvedType, TypeMap}, + types::{ColumnOverride, ResolvedType, TypeMap}, }; /// Resolved parameter: Rust identifier + type. @@ -19,6 +20,21 @@ pub(crate) struct Param { pub(crate) resolved: ResolvedType, } +impl Param { + /// Parameter-position type: borrowed form if the resolver produced one, + /// otherwise the owned form. Lifetimes are anonymous (`&str`). + pub(crate) fn param_type(&self) -> &str { + self.resolved + .borrowed_rust_type + .as_deref() + .unwrap_or(&self.resolved.rust_type) + } + + pub(crate) fn is_borrowed(&self) -> bool { + self.resolved.borrowed_rust_type.is_some() + } +} + /// Columns from an embedded table (`sqlc.embed(table)`), grouped together. pub(crate) struct EmbeddedGroup { /// snake_case field name in the parent struct, e.g. `author` @@ -40,7 +56,7 @@ pub(crate) struct ResolvedColumnSet { pub(crate) fn resolve_params<'a>( params: impl Iterator>, type_map: &TypeMap, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result, Error> { let mut out = Vec::new(); for p in params { @@ -86,7 +102,14 @@ pub(crate) fn resolve_params<'a>( Ok(out) } -/// Emit a params struct when the query has ≥2 parameters. +/// Whether any param in the set carries a borrowed type. +pub(crate) fn any_borrowed(params: &[Param]) -> bool { + params.iter().any(Param::is_borrowed) +} + +/// Emit a params struct when the query has ≥2 parameters. When any field is +/// borrowed, the struct gains a `<'a>` lifetime parameter and each borrowed +/// field references it. pub(crate) fn maybe_params_struct( query_name: &str, params: &[Param], @@ -96,12 +119,17 @@ pub(crate) fn maybe_params_struct( return Ok(None); } let struct_name = type_ident(&query_params_name(query_name)); + let has_borrowed = any_borrowed(params); let mut field_tokens = Vec::new(); for p in params { let ident = &p.ident; - let ty: syn::Type = parse_str(&p.resolved.rust_type).map_err(|e| { - Error::Codegen(format!("invalid Rust type '{}': {e}", p.resolved.rust_type)) - })?; + let ty_str = if let Some(borrowed) = &p.resolved.borrowed_rust_type { + inject_lifetime(borrowed, "'a")? + } else { + p.resolved.rust_type.clone() + }; + let ty: syn::Type = parse_str(&ty_str) + .map_err(|e| Error::Codegen(format!("invalid Rust type '{ty_str}': {e}")))?; field_tokens.push(quote! { pub #ident: #ty, }); } let mut derive_paths = Vec::new(); @@ -110,9 +138,14 @@ pub(crate) fn maybe_params_struct( parse_str(d).map_err(|e| Error::Codegen(format!("invalid derive path '{d}': {e}")))?; derive_paths.push(quote! { #path }); } + let generics = if has_borrowed { + quote! { <'a> } + } else { + quote! {} + }; let tokens = quote! { #[derive(Debug, Clone, #(#derive_paths),*)] - pub struct #struct_name { + pub struct #struct_name #generics { #(#field_tokens)* } }; @@ -292,11 +325,13 @@ pub(crate) fn sql_const(query_name: &str, sql: &str) -> (TokenStream, proc_macro (tokens, const_name) } -/// Resolve result columns into flat fields and embedded groups. +/// Resolve result columns into flat fields and embedded groups. Row positions +/// always use the owned form, so any `borrowed_rust_type` on the resolution +/// is intentionally ignored downstream. pub(crate) fn resolve_columns<'a>( cols: impl Iterator>, type_map: &TypeMap, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let mut flat: Vec<(proc_macro2::Ident, ResolvedType)> = Vec::new(); let mut embedded_groups: Vec<(String, Vec<(proc_macro2::Ident, ResolvedType)>)> = Vec::new(); @@ -408,6 +443,11 @@ pub(crate) fn row_struct( /// Build the parameters portion of a query function signature. /// Returns `(params_struct_tokens, arg_ident, fn_params_tokens)`. +/// +/// When the query has a params struct and any field is borrowed, the +/// fn-signature reference uses `<'_>` to elide the struct's `'a` (Rust 2024 +/// anonymous lifetime in type paths). Single-param functions rely on +/// classical lifetime elision and take the borrowed type directly. pub(crate) fn build_fn_params( query_name: &str, params: &[Param], @@ -418,17 +458,22 @@ pub(crate) fn build_fn_params( let (struct_tokens, struct_ident) = maybe_params_struct(query_name, params, derives)? .expect("guarded by params.len() >= 2 check above"); let arg = format_ident!("arg"); + let arg_ty = if any_borrowed(params) { + quote! { #struct_ident<'_> } + } else { + quote! { #struct_ident } + }; Ok(( Some(struct_tokens), Some(arg.clone()), - quote! { #arg: #struct_ident }, + quote! { #arg: #arg_ty }, )) } else if params.len() == 1 { let p = ¶ms[0]; let ident = &p.ident; - let ty: syn::Type = parse_str(&p.resolved.rust_type).map_err(|e| { - Error::Codegen(format!("invalid Rust type '{}': {e}", p.resolved.rust_type)) - })?; + let ty_str = p.param_type(); + let ty: syn::Type = parse_str(ty_str) + .map_err(|e| Error::Codegen(format!("invalid Rust type '{ty_str}': {e}")))?; Ok((None, None, quote! { #ident: #ty })) } else { Ok((None, None, quote! {})) @@ -440,7 +485,7 @@ pub fn gen_one( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; let columns = resolve_columns(query.columns.iter(), type_map, col_overrides)?; @@ -492,7 +537,7 @@ pub fn gen_many( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; let columns = resolve_columns(query.columns.iter(), type_map, col_overrides)?; @@ -542,7 +587,7 @@ pub fn gen_execrows( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; let fn_name = format_ident!("{}", to_snake_case(query.name)); @@ -582,7 +627,7 @@ pub fn gen_execresult( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; let fn_name = format_ident!("{}", to_snake_case(query.name)); @@ -620,7 +665,7 @@ pub fn gen_exec( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; let fn_name = format_ident!("{}", to_snake_case(query.name)); @@ -670,7 +715,7 @@ pub fn gen_execlastid( query: &QueryView<'_>, type_map: &TypeMap, config: &Config, - col_overrides: &std::collections::HashMap, + col_overrides: &std::collections::HashMap, ) -> Result { let params = resolve_params(query.params.iter(), type_map, col_overrides)?; let fn_name = format_ident!("{}", to_snake_case(query.name)); diff --git a/src/config.rs b/src/config.rs index d0b4e02..b3cccda 100644 --- a/src/config.rs +++ b/src/config.rs @@ -26,15 +26,62 @@ impl Default for Config { impl Config { pub fn from_bytes(bytes: &[u8]) -> Result { - Ok(serde_json::from_slice(bytes)?) + let cfg: Config = serde_json::from_slice(bytes)?; + for o in &cfg.overrides { + let target = || { + o.db_type + .as_deref() + .or(o.column.as_deref()) + .unwrap_or("") + .to_string() + }; + if o.rs_type.is_none() && o.borrowed_rs_type.is_none() { + return Err(Error::Codegen(format!( + "override for '{}' must set at least one of 'rs_type' or 'borrowed_rs_type'", + target() + ))); + } + if let Some(borrowed) = &o.borrowed_rs_type { + validate_borrowed_type(borrowed, &target())?; + } + } + Ok(cfg) + } +} + +/// `borrowed_rs_type` must parse as a Rust type AND contain at least one +/// reference (`&T`). Without a reference there is nothing for the codegen's +/// lifetime injector to do — the field name is then misleading and the user +/// has likely written the wrong thing. +fn validate_borrowed_type(rs_type: &str, target: &str) -> Result<(), Error> { + syn::parse_str::(rs_type).map_err(|e| { + Error::Codegen(format!( + "override for '{target}': borrowed_rs_type '{rs_type}' is not a valid Rust type: {e}" + )) + })?; + if !rs_type.contains('&') { + return Err(Error::Codegen(format!( + "override for '{target}': borrowed_rs_type '{rs_type}' must contain a reference \ + (e.g. '&str' or 'Option<&str>'); use 'rs_type' for owned types" + ))); } + Ok(()) } #[derive(Debug, serde::Deserialize)] pub struct TypeOverride { pub db_type: Option, pub column: Option, - pub rs_type: String, + /// Owned Rust type for rows and array contents. Optional when + /// `borrowed_rs_type` is set; missing values fall back to the built-in + /// default for the matched PG type. + pub rs_type: Option, + /// Borrowed Rust type for scalar parameter positions. When present, the + /// override participates in borrowed mode: parameter signatures use this + /// type (with lifetime injection where the position requires a named + /// lifetime), while rows and array contents continue to use the owned + /// form. + pub borrowed_rs_type: Option, #[serde(default)] pub copy_cheap: bool, } @@ -63,7 +110,11 @@ mod tests { let c = Config::from_bytes(json).unwrap(); assert_eq!(c.overrides.len(), 1); assert_eq!(c.overrides[0].db_type, Some("timestamptz".to_string())); - assert_eq!(c.overrides[0].rs_type, "chrono::DateTime"); + assert_eq!( + c.overrides[0].rs_type.as_deref(), + Some("chrono::DateTime") + ); + assert!(c.overrides[0].borrowed_rs_type.is_none()); } #[test] @@ -81,4 +132,51 @@ mod tests { assert_eq!(c.row_derives, ["serde::Serialize"]); assert_eq!(c.enum_derives.len(), 2); } + + #[test] + fn parses_borrowed_only_override() { + let json = br#"{"overrides":[{"db_type":"text","borrowed_rs_type":"&str"}]}"#; + let c = Config::from_bytes(json).unwrap(); + assert_eq!(c.overrides.len(), 1); + assert!(c.overrides[0].rs_type.is_none()); + assert_eq!(c.overrides[0].borrowed_rs_type.as_deref(), Some("&str")); + } + + #[test] + fn parses_owned_and_borrowed_override() { + let json = + br#"{"overrides":[{"db_type":"text","rs_type":"MyStr","borrowed_rs_type":"&MyStr"}]}"#; + let c = Config::from_bytes(json).unwrap(); + assert_eq!(c.overrides[0].rs_type.as_deref(), Some("MyStr")); + assert_eq!(c.overrides[0].borrowed_rs_type.as_deref(), Some("&MyStr")); + } + + #[test] + fn rejects_non_borrowed_in_borrowed_rs_type() { + let json = br#"{"overrides":[{"db_type":"text","borrowed_rs_type":"Option"}]}"#; + let err = Config::from_bytes(json).unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("must contain a reference"), + "expected reference-required error, got: {msg}" + ); + assert!(msg.contains("text"), "expected target name in error: {msg}"); + } + + #[test] + fn rejects_invalid_rust_in_borrowed_rs_type() { + let json = br#"{"overrides":[{"db_type":"text","borrowed_rs_type":"¬ a type!!"}]}"#; + let err = Config::from_bytes(json).unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("not a valid Rust type"), + "expected parse error, got: {msg}" + ); + } + + #[test] + fn accepts_option_of_reference_in_borrowed_rs_type() { + let json = br#"{"overrides":[{"db_type":"text","borrowed_rs_type":"Option<&str>"}]}"#; + Config::from_bytes(json).expect("Option<&str> should be accepted as borrowed"); + } } diff --git a/src/types.rs b/src/types.rs index 403bd47..bb9269f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,23 +3,36 @@ use std::collections::HashMap; #[derive(Debug, Clone)] pub struct ResolvedType { + /// Owned Rust type used for row struct fields, array contents, and any + /// position that takes the value by value. Always present. pub rust_type: String, + /// Borrowed Rust type used for scalar parameter positions when the + /// override opted into borrowed mode. `None` means the position should + /// use `rust_type`. + /// + /// Lifetimes in this string are anonymous (`&str`, not `&'a str`); the + /// codegen runs a lifetime-injection pass when emitting into positions + /// that require a named lifetime (struct field, where clause). + pub borrowed_rust_type: Option, pub copy_cheap: bool, } -impl ResolvedType { - fn new(rust_type: impl Into, copy_cheap: bool) -> Self { - Self { - rust_type: rust_type.into(), - copy_cheap, - } - } +/// Stored representation of a user override: both forms tracked independently +/// so the resolver can pick the right one for each context. +#[derive(Debug, Clone)] +struct OverrideEntry { + /// Owned form: `None` means "use the built-in default for the PG type". + owned: Option, + /// Borrowed form: `None` means the override did not opt into borrowed + /// mode. + borrowed: Option, + copy_cheap: bool, } pub struct TypeMap { defaults: HashMap<&'static str, (&'static str, bool)>, - type_overrides: HashMap, - custom_types: HashMap, + type_overrides: HashMap, + custom_types: HashMap, } impl TypeMap { @@ -192,12 +205,16 @@ impl TypeMap { defaults.insert(n, ("bit_vec::BitVec", false)); } - let mut type_overrides = HashMap::new(); + let mut type_overrides: HashMap = HashMap::new(); for o in overrides { if let Some(db_type) = &o.db_type { type_overrides.insert( db_type.to_lowercase(), - ResolvedType::new(o.rs_type.clone(), o.copy_cheap), + OverrideEntry { + owned: o.rs_type.clone(), + borrowed: o.borrowed_rs_type.clone(), + copy_cheap: o.copy_cheap, + }, ); } } @@ -206,8 +223,15 @@ impl TypeMap { let key = name.to_lowercase(); if let Some(ovr) = type_overrides.get_mut(&key) { ovr.copy_cheap = true; - } else if let Some(&(ty, _)) = defaults.get(key.as_str()) { - type_overrides.insert(key, ResolvedType::new(ty.to_string(), true)); + } else if defaults.contains_key(key.as_str()) { + type_overrides.insert( + key, + OverrideEntry { + owned: None, + borrowed: None, + copy_cheap: true, + }, + ); } } @@ -222,10 +246,8 @@ impl TypeMap { /// Registered types are checked after both `type_overrides` and `defaults`, /// so user-level overrides and built-in defaults always take precedence. pub fn register(&mut self, pg_name: &str, rust_name: &str, copy_cheap: bool) { - self.custom_types.insert( - pg_name.to_lowercase(), - ResolvedType::new(rust_name.to_string(), copy_cheap), - ); + self.custom_types + .insert(pg_name.to_lowercase(), (rust_name.to_string(), copy_cheap)); } pub fn resolve_pg_type( @@ -244,20 +266,38 @@ impl TypeMap { array_dims: usize, ) -> Option { let key = pg_type.to_lowercase(); - let (inner, copy_cheap) = if let Some(ovr) = self.type_overrides.get(&key) { - (ovr.rust_type.clone(), ovr.copy_cheap) - } else if let Some(&(ty, cc)) = self.defaults.get(key.as_str()) { - (ty.to_string(), cc) - } else if let Some(custom) = self.custom_types.get(&key) { - (custom.rust_type.clone(), custom.copy_cheap) - } else { - return None; - }; - - let rust_type = wrap_type(&inner, nullable, array_dims); + let (owned_inner, borrowed_inner, copy_cheap) = + if let Some(ovr) = self.type_overrides.get(&key) { + let default = self.defaults.get(key.as_str()).map(|&(t, _)| t.to_string()); + let owned = ovr + .owned + .clone() + .or(default) + .or_else(|| self.custom_types.get(&key).map(|(name, _)| name.clone())); + (owned, ovr.borrowed.clone(), ovr.copy_cheap) + } else if let Some(&(ty, cc)) = self.defaults.get(key.as_str()) { + (Some(ty.to_string()), None, cc) + } else if let Some((name, cc)) = self.custom_types.get(&key) { + (Some(name.clone()), None, *cc) + } else { + return None; + }; + + let owned = wrap_owned(owned_inner.as_deref()?, nullable, array_dims); + let borrowed = borrowed_inner.map(|inner| { + // Array wraps revert inner to owned default, and the outermost + // wrapper becomes a borrowed slice. + wrap_borrowed( + &inner, + owned_inner.as_deref().unwrap_or(""), + nullable, + array_dims, + ) + }); let effective_copy_cheap = copy_cheap && !nullable && array_dims == 0; Some(ResolvedType { - rust_type, + rust_type: owned, + borrowed_rust_type: borrowed, copy_cheap: effective_copy_cheap, }) } @@ -268,7 +308,7 @@ impl TypeMap { nullable: bool, is_array: bool, column_key: Option<&str>, - column_overrides: &HashMap, + column_overrides: &HashMap, ) -> Option { self.resolve_column_dims( pg_type, @@ -285,15 +325,30 @@ impl TypeMap { nullable: bool, array_dims: usize, column_key: Option<&str>, - column_overrides: &HashMap, + column_overrides: &HashMap, ) -> Option { if let Some(key) = column_key && let Some(ovr) = column_overrides.get(key) { - let rust_type = wrap_type(&ovr.rust_type, nullable, array_dims); + // Owned form for the column: explicit `rs_type` wins; otherwise + // fall back to the type-level resolution (which itself may have + // an override or fall back to a default). + let owned_inner = if let Some(owned) = &ovr.owned { + owned.clone() + } else if let Some(resolved) = self.resolve_pg_type_dims(pg_type, false, 0) { + resolved.rust_type + } else { + return None; + }; + let owned = wrap_owned(&owned_inner, nullable, array_dims); + let borrowed = ovr + .borrowed + .as_ref() + .map(|b| wrap_borrowed(b, &owned_inner, nullable, array_dims)); let cc = ovr.copy_cheap && !nullable && array_dims == 0; return Some(ResolvedType { - rust_type, + rust_type: owned, + borrowed_rust_type: borrowed, copy_cheap: cc, }); } @@ -301,7 +356,26 @@ impl TypeMap { } } -fn wrap_type(inner: &str, nullable: bool, array_dims: usize) -> String { +/// Per-column override stored separately from the type-level map. +#[derive(Debug, Clone)] +pub struct ColumnOverride { + pub owned: Option, + pub borrowed: Option, + pub copy_cheap: bool, +} + +impl ColumnOverride { + #[cfg(test)] + pub fn owned_form(rust_type: impl Into, copy_cheap: bool) -> Self { + Self { + owned: Some(rust_type.into()), + borrowed: None, + copy_cheap, + } + } +} + +fn wrap_owned(inner: &str, nullable: bool, array_dims: usize) -> String { let mut t = inner.to_string(); for _ in 0..array_dims { t = format!("Vec<{t}>"); @@ -309,14 +383,42 @@ fn wrap_type(inner: &str, nullable: bool, array_dims: usize) -> String { if nullable { format!("Option<{t}>") } else { t } } -pub fn build_column_overrides(overrides: &[TypeOverride]) -> HashMap { +fn wrap_borrowed( + borrowed_inner: &str, + owned_inner: &str, + nullable: bool, + array_dims: usize, +) -> String { + let body = if array_dims == 0 { + borrowed_inner.to_string() + } else { + let mut t = owned_inner.to_string(); + // Inner array dimensions stay as owned `Vec<...>`. + for _ in 0..(array_dims - 1) { + t = format!("Vec<{t}>"); + } + // The outermost array becomes a borrowed slice. + format!("&[{t}]") + }; + if nullable { + format!("Option<{body}>") + } else { + body + } +} + +pub fn build_column_overrides(overrides: &[TypeOverride]) -> HashMap { overrides .iter() .filter_map(|o| { o.column.as_ref().map(|col| { ( col.clone(), - ResolvedType::new(o.rs_type.clone(), o.copy_cheap), + ColumnOverride { + owned: o.rs_type.clone(), + borrowed: o.borrowed_rs_type.clone(), + copy_cheap: o.copy_cheap, + }, ) }) }) @@ -331,10 +433,21 @@ mod tests { TypeMap::new(&[], &[]) } + fn owned_override(db_type: &str, rs_type: &str) -> TypeOverride { + TypeOverride { + db_type: Some(db_type.to_string()), + column: None, + rs_type: Some(rs_type.to_string()), + borrowed_rs_type: None, + copy_cheap: false, + } + } + #[test] fn maps_text() { let t = map().resolve_pg_type("text", false, false).unwrap(); assert_eq!(t.rust_type, "String"); + assert!(t.borrowed_rust_type.is_none()); assert!(!t.copy_cheap); } #[test] @@ -394,32 +507,24 @@ mod tests { } #[test] fn type_override_replaces_default() { - use crate::config::TypeOverride; - let ovr = TypeOverride { - db_type: Some("timestamptz".to_string()), - column: None, - rs_type: "time::OffsetDateTime".to_string(), - copy_cheap: false, - }; - let t = TypeMap::new(&[ovr], &[]) - .resolve_pg_type("timestamptz", false, false) - .unwrap(); + let t = TypeMap::new( + &[owned_override("timestamptz", "time::OffsetDateTime")], + &[], + ) + .resolve_pg_type("timestamptz", false, false) + .unwrap(); assert_eq!(t.rust_type, "time::OffsetDateTime"); + assert!(t.borrowed_rust_type.is_none()); } #[test] fn column_override_beats_type_override() { - use crate::config::TypeOverride; let overrides = vec![ - TypeOverride { - db_type: Some("text".to_string()), - column: None, - rs_type: "TypeLevel".to_string(), - copy_cheap: false, - }, + owned_override("text", "TypeLevel"), TypeOverride { db_type: None, column: Some("users.name".to_string()), - rs_type: "ColumnLevel".to_string(), + rs_type: Some("ColumnLevel".to_string()), + borrowed_rs_type: None, copy_cheap: false, }, ]; @@ -474,14 +579,7 @@ mod tests { } #[test] fn type_override_beats_registered_custom() { - use crate::config::TypeOverride; - let ovr = TypeOverride { - db_type: Some("my_enum".to_string()), - column: None, - rs_type: "Override".to_string(), - copy_cheap: false, - }; - let mut map = TypeMap::new(&[ovr], &[]); + let mut map = TypeMap::new(&[owned_override("my_enum", "Override")], &[]); map.register("my_enum", "MyEnum", false); let t = map.resolve_pg_type("my_enum", false, false).unwrap(); // type_overrides must win over custom_types @@ -542,4 +640,94 @@ mod tests { let t = map.resolve_pg_type("uuid", false, false).unwrap(); assert!(t.copy_cheap); } + + // --- Borrowed mode ----------------------------------------------------- + + fn borrowed_text() -> TypeOverride { + TypeOverride { + db_type: Some("text".to_string()), + column: None, + rs_type: None, + borrowed_rs_type: Some("&str".to_string()), + copy_cheap: false, + } + } + + #[test] + fn borrowed_only_override_keeps_default_for_owned_form() { + let map = TypeMap::new(&[borrowed_text()], &[]); + let t = map.resolve_pg_type("text", false, false).unwrap(); + assert_eq!(t.rust_type, "String"); + assert_eq!(t.borrowed_rust_type.as_deref(), Some("&str")); + } + + #[test] + fn borrowed_nullable_wraps_inside_option() { + let map = TypeMap::new(&[borrowed_text()], &[]); + let t = map.resolve_pg_type("text", true, false).unwrap(); + assert_eq!(t.rust_type, "Option"); + assert_eq!(t.borrowed_rust_type.as_deref(), Some("Option<&str>")); + } + + #[test] + fn borrowed_array_uses_owned_inner_in_slice() { + let map = TypeMap::new(&[borrowed_text()], &[]); + let t = map.resolve_pg_type("text", false, true).unwrap(); + assert_eq!(t.rust_type, "Vec"); + assert_eq!(t.borrowed_rust_type.as_deref(), Some("&[String]")); + } + + #[test] + fn borrowed_nullable_array() { + let map = TypeMap::new(&[borrowed_text()], &[]); + let t = map.resolve_pg_type("text", true, true).unwrap(); + assert_eq!(t.rust_type, "Option>"); + assert_eq!(t.borrowed_rust_type.as_deref(), Some("Option<&[String]>")); + } + + #[test] + fn borrowed_multidim_array_inner_stays_vec() { + let map = TypeMap::new(&[borrowed_text()], &[]); + let t = map.resolve_pg_type_dims("text", false, 2).unwrap(); + assert_eq!(t.rust_type, "Vec>"); + assert_eq!(t.borrowed_rust_type.as_deref(), Some("&[Vec]")); + } + + #[test] + fn owned_and_borrowed_pair() { + let ovr = TypeOverride { + db_type: Some("text".to_string()), + column: None, + rs_type: Some("MyStr".to_string()), + borrowed_rs_type: Some("&MyStr".to_string()), + copy_cheap: false, + }; + let map = TypeMap::new(&[ovr], &[]); + let t = map.resolve_pg_type("text", false, false).unwrap(); + assert_eq!(t.rust_type, "MyStr"); + assert_eq!(t.borrowed_rust_type.as_deref(), Some("&MyStr")); + + let t = map.resolve_pg_type("text", false, true).unwrap(); + // Array uses owned inner from the override. + assert_eq!(t.rust_type, "Vec"); + assert_eq!(t.borrowed_rust_type.as_deref(), Some("&[MyStr]")); + } + + #[test] + fn borrowed_column_override() { + let overrides = vec![TypeOverride { + db_type: None, + column: Some("users.name".to_string()), + rs_type: None, + borrowed_rs_type: Some("&str".to_string()), + copy_cheap: false, + }]; + let col_ovrs = build_column_overrides(&overrides); + let map = TypeMap::new(&overrides, &[]); + let t = map + .resolve_column("text", false, false, Some("users.name"), &col_ovrs) + .unwrap(); + assert_eq!(t.rust_type, "String"); + assert_eq!(t.borrowed_rust_type.as_deref(), Some("&str")); + } } diff --git a/tests/codegen.rs b/tests/codegen.rs index c4df9ac..bb7a2e9 100644 --- a/tests/codegen.rs +++ b/tests/codegen.rs @@ -34,6 +34,10 @@ fn authors_table() -> Table { } fn make_request(queries: Vec) -> Vec { + make_request_with_options(queries, b"{}") +} + +fn make_request_with_options(queries: Vec, options: &[u8]) -> Vec { let req = GenerateRequest { catalog: Some(Catalog { default_schema: "public".to_string(), @@ -46,7 +50,7 @@ fn make_request(queries: Vec) -> Vec { }) .into(), queries, - plugin_options: b"{}".to_vec(), + plugin_options: options.to_vec(), sqlc_version: "1.0.0-test".to_string(), ..Default::default() }; @@ -1021,3 +1025,322 @@ fn snapshot_batch_dynamic_slice_param() { "expected per-element binding inside batch stream in:\n{code}" ); } + +// ---- Borrowed mode -------------------------------------------------------- + +const BORROWED_TEXT_OPTIONS: &[u8] = + br#"{"overrides":[{"db_type":"text","borrowed_rs_type":"&str"}]}"#; + +#[test] +fn snapshot_borrowed_scalar() { + let query = Query { + name: "DeleteByName".to_string(), + cmd: ":exec".to_string(), + text: "DELETE FROM authors WHERE name = $1".to_string(), + params: vec![Parameter { + number: 1, + column: Some(make_author_column("name", "text", true)).into(), + ..Default::default() + }], + ..Default::default() + }; + let bytes = make_request_with_options(vec![query], BORROWED_TEXT_OPTIONS); + let out = run_with_bytes(&bytes).expect("generate failed"); + let resp = sqlc_gen_sqlx::plugin::GenerateResponse::decode_from_slice(&out).unwrap(); + let code = String::from_utf8(resp.files[0].contents.clone()).unwrap(); + assert_codegen_snapshot("borrowed_scalar", &code); + assert!( + code.contains("name: &str"), + "expected name: &str (elided) in:\n{code}" + ); + assert!( + !code.contains("DeleteByNameParams"), + "single param must not emit a params struct in:\n{code}" + ); +} + +#[test] +fn snapshot_borrowed_params_struct() { + let query = Query { + name: "UpdateAuthor".to_string(), + cmd: ":exec".to_string(), + text: "UPDATE authors SET name = $1, bio = $2".to_string(), + params: vec![ + Parameter { + number: 1, + column: Some(make_author_column("name", "text", true)).into(), + ..Default::default() + }, + Parameter { + number: 2, + column: Some(make_author_column("bio", "text", false)).into(), + ..Default::default() + }, + ], + ..Default::default() + }; + let bytes = make_request_with_options(vec![query], BORROWED_TEXT_OPTIONS); + let out = run_with_bytes(&bytes).expect("generate failed"); + let resp = sqlc_gen_sqlx::plugin::GenerateResponse::decode_from_slice(&out).unwrap(); + let code = String::from_utf8(resp.files[0].contents.clone()).unwrap(); + assert_codegen_snapshot("borrowed_params_struct", &code); + assert!( + code.contains("pub struct UpdateAuthorParams<'a>"), + "expected struct with <'a> in:\n{code}" + ); + assert!( + code.contains("pub name: &'a str"), + "expected &'a str field in:\n{code}" + ); + assert!( + code.contains("pub bio: Option<&'a str>"), + "expected Option<&'a str> field in:\n{code}" + ); + assert!( + code.contains("arg: UpdateAuthorParams<'_>"), + "expected arg: UpdateAuthorParams<'_> (elided) in:\n{code}" + ); +} + +#[test] +fn snapshot_borrowed_array_param() { + let array_col = Column { + name: "tags".to_string(), + r#type: Some(Identifier { + name: "text".to_string(), + ..Default::default() + }) + .into(), + not_null: true, + is_array: true, + ..Default::default() + }; + let query = Query { + name: "ListByTags".to_string(), + cmd: ":exec".to_string(), + text: "DELETE FROM authors WHERE name = ANY($1)".to_string(), + params: vec![Parameter { + number: 1, + column: Some(array_col).into(), + ..Default::default() + }], + ..Default::default() + }; + let bytes = make_request_with_options(vec![query], BORROWED_TEXT_OPTIONS); + let out = run_with_bytes(&bytes).expect("generate failed"); + let resp = sqlc_gen_sqlx::plugin::GenerateResponse::decode_from_slice(&out).unwrap(); + let code = String::from_utf8(resp.files[0].contents.clone()).unwrap(); + assert_codegen_snapshot("borrowed_array_param", &code); + assert!( + code.contains("tags: &[String]"), + "expected &[String] for borrowed array of text in:\n{code}" + ); +} + +#[test] +fn snapshot_borrowed_slice_param() { + let slice_col = Column { + name: "names".to_string(), + r#type: Some(Identifier { + name: "text".to_string(), + ..Default::default() + }) + .into(), + not_null: true, + is_sqlc_slice: true, + ..Default::default() + }; + let query = Query { + name: "DeleteByNames".to_string(), + cmd: ":exec".to_string(), + text: "DELETE FROM authors WHERE name = ANY($1)".to_string(), + params: vec![Parameter { + number: 1, + column: Some(slice_col).into(), + ..Default::default() + }], + ..Default::default() + }; + let bytes = make_request_with_options(vec![query], BORROWED_TEXT_OPTIONS); + let out = run_with_bytes(&bytes).expect("generate failed"); + let resp = sqlc_gen_sqlx::plugin::GenerateResponse::decode_from_slice(&out).unwrap(); + let code = String::from_utf8(resp.files[0].contents.clone()).unwrap(); + assert_codegen_snapshot("borrowed_slice_param", &code); + assert!( + code.contains("names: &[String]"), + "expected &[String] for sqlc.slice(text) in:\n{code}" + ); +} + +#[test] +fn snapshot_borrowed_one_row_unaffected() { + let query = Query { + name: "GetAuthorByName".to_string(), + cmd: ":one".to_string(), + text: "SELECT id, name, bio FROM authors WHERE name = $1".to_string(), + columns: vec![ + make_author_column("id", "int8", true), + make_author_column("name", "text", true), + make_author_column("bio", "text", false), + ], + params: vec![Parameter { + number: 1, + column: Some(make_author_column("name", "text", true)).into(), + ..Default::default() + }], + ..Default::default() + }; + let bytes = make_request_with_options(vec![query], BORROWED_TEXT_OPTIONS); + let out = run_with_bytes(&bytes).expect("generate failed"); + let resp = sqlc_gen_sqlx::plugin::GenerateResponse::decode_from_slice(&out).unwrap(); + let code = String::from_utf8(resp.files[0].contents.clone()).unwrap(); + assert_codegen_snapshot("borrowed_one_row_unaffected", &code); + assert!(code.contains("name: &str"), "param uses &str in:\n{code}"); + assert!( + code.contains("pub name: String"), + "row field stays owned String in:\n{code}" + ); + assert!( + code.contains("pub bio: Option"), + "nullable row field stays Option in:\n{code}" + ); +} + +#[test] +fn snapshot_borrowed_batchexec() { + let query = Query { + name: "BatchDeleteByName".to_string(), + cmd: ":batchexec".to_string(), + text: "DELETE FROM authors WHERE name = $1".to_string(), + params: vec![Parameter { + number: 1, + column: Some(make_author_column("name", "text", true)).into(), + ..Default::default() + }], + ..Default::default() + }; + let bytes = make_request_with_options(vec![query], BORROWED_TEXT_OPTIONS); + let out = run_with_bytes(&bytes).expect("generate failed"); + let resp = sqlc_gen_sqlx::plugin::GenerateResponse::decode_from_slice(&out).unwrap(); + let code = String::from_utf8(resp.files[0].contents.clone()).unwrap(); + assert_codegen_snapshot("borrowed_batchexec", &code); + assert!( + code.contains("IntoIterator"), + "expected Item = &'a str using the stream's lifetime in:\n{code}" + ); +} + +#[test] +fn snapshot_borrowed_batchexec_struct() { + let query = Query { + name: "BatchUpdateAuthor".to_string(), + cmd: ":batchexec".to_string(), + text: "UPDATE authors SET name = $1, bio = $2 WHERE id = $3".to_string(), + params: vec![ + Parameter { + number: 1, + column: Some(make_author_column("name", "text", true)).into(), + ..Default::default() + }, + Parameter { + number: 2, + column: Some(make_author_column("bio", "text", false)).into(), + ..Default::default() + }, + Parameter { + number: 3, + column: Some(make_author_column("id", "int8", true)).into(), + ..Default::default() + }, + ], + ..Default::default() + }; + let bytes = make_request_with_options(vec![query], BORROWED_TEXT_OPTIONS); + let out = run_with_bytes(&bytes).expect("generate failed"); + let resp = sqlc_gen_sqlx::plugin::GenerateResponse::decode_from_slice(&out).unwrap(); + let code = String::from_utf8(resp.files[0].contents.clone()).unwrap(); + assert_codegen_snapshot("borrowed_batchexec_struct", &code); + assert!( + code.contains("pub struct BatchUpdateAuthorParams<'a>"), + "expected params struct with <'a> in:\n{code}" + ); + assert!( + code.contains("IntoIterator>"), + "expected Item = BatchUpdateAuthorParams<'a> in:\n{code}" + ); +} + +#[test] +fn snapshot_borrowed_copyfrom() { + let query = Query { + name: "CopyAuthors".to_string(), + cmd: ":copyfrom".to_string(), + text: "INSERT INTO authors (name, bio) VALUES ($1, $2)".to_string(), + insert_into_table: Some(Identifier { + name: "authors".to_string(), + ..Default::default() + }) + .into(), + params: vec![ + Parameter { + number: 1, + column: Some(make_author_column("name", "text", true)).into(), + ..Default::default() + }, + Parameter { + number: 2, + column: Some(make_author_column("bio", "text", false)).into(), + ..Default::default() + }, + ], + ..Default::default() + }; + let bytes = make_request_with_options(vec![query], BORROWED_TEXT_OPTIONS); + let out = run_with_bytes(&bytes).expect("generate failed"); + let resp = sqlc_gen_sqlx::plugin::GenerateResponse::decode_from_slice(&out).unwrap(); + let code = String::from_utf8(resp.files[0].contents.clone()).unwrap(); + assert_codegen_snapshot("borrowed_copyfrom", &code); + assert!( + code.contains("pub async fn copy_authors<'a, E: AsExecutor, I>"), + "expected 'a in fn signature in:\n{code}" + ); + assert!( + code.contains("I: IntoIterator>"), + "expected Item with <'a> in where clause in:\n{code}" + ); +} + +#[test] +fn snapshot_borrowed_with_custom_owned() { + let options = + br#"{"overrides":[{"db_type":"text","rs_type":"MyStr","borrowed_rs_type":"&MyStr"}]}"#; + let query = Query { + name: "GetByName".to_string(), + cmd: ":one".to_string(), + text: "SELECT id, name, bio FROM authors WHERE name = $1".to_string(), + columns: vec![ + make_author_column("id", "int8", true), + make_author_column("name", "text", true), + make_author_column("bio", "text", false), + ], + params: vec![Parameter { + number: 1, + column: Some(make_author_column("name", "text", true)).into(), + ..Default::default() + }], + ..Default::default() + }; + let bytes = make_request_with_options(vec![query], options); + let out = run_with_bytes(&bytes).expect("generate failed"); + let resp = sqlc_gen_sqlx::plugin::GenerateResponse::decode_from_slice(&out).unwrap(); + let code = String::from_utf8(resp.files[0].contents.clone()).unwrap(); + assert_codegen_snapshot("borrowed_with_custom_owned", &code); + assert!( + code.contains("name: &MyStr"), + "param uses borrowed override in:\n{code}" + ); + assert!( + code.contains("pub name: MyStr"), + "row uses custom owned rs_type in:\n{code}" + ); +} diff --git a/tests/snapshots/codegen__borrowed_array_param.snap b/tests/snapshots/codegen__borrowed_array_param.snap new file mode 100644 index 0000000..aeea079 --- /dev/null +++ b/tests/snapshots/codegen__borrowed_array_param.snap @@ -0,0 +1,53 @@ +--- +source: tests/codegen.rs +expression: code +--- +// Code generated by sqlc-gen-sqlx v[VERSION]. DO NOT EDIT. +// sqlc version: [VERSION] + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +const LIST_BY_TAGS: &str = "DELETE FROM authors WHERE name = ANY($1)"; +pub async fn list_by_tags( + mut db: E, + tags: &[String], +) -> Result<(), sqlx::Error> { + sqlx::query(LIST_BY_TAGS).bind(tags).execute(db.as_executor()).await?; + Ok(()) +} diff --git a/tests/snapshots/codegen__borrowed_batchexec.snap b/tests/snapshots/codegen__borrowed_batchexec.snap new file mode 100644 index 0000000..e674322 --- /dev/null +++ b/tests/snapshots/codegen__borrowed_batchexec.snap @@ -0,0 +1,69 @@ +--- +source: tests/codegen.rs +expression: code +--- +// Code generated by sqlc-gen-sqlx v[VERSION]. DO NOT EDIT. +// sqlc version: [VERSION] + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +const BATCH_DELETE_BY_NAME: &str = "DELETE FROM authors WHERE name = $1"; +pub fn batch_delete_by_name<'a, E, I>( + db: E, + items: I, +) -> impl futures_core::stream::Stream> + 'a +where + E: AsExecutor + 'a, + I: IntoIterator + 'a, + I::IntoIter: 'a, +{ + futures_util::stream::try_unfold( + (db, items.into_iter()), + |(mut db, mut items)| async move { + let Some(item) = items.next() else { + return Ok(None); + }; + sqlx::query(BATCH_DELETE_BY_NAME) + .bind(item) + .execute(db.as_executor()) + .await?; + Ok(Some(((), (db, items)))) + }, + ) +} diff --git a/tests/snapshots/codegen__borrowed_batchexec_struct.snap b/tests/snapshots/codegen__borrowed_batchexec_struct.snap new file mode 100644 index 0000000..bcf4a0f --- /dev/null +++ b/tests/snapshots/codegen__borrowed_batchexec_struct.snap @@ -0,0 +1,77 @@ +--- +source: tests/codegen.rs +expression: code +--- +// Code generated by sqlc-gen-sqlx v[VERSION]. DO NOT EDIT. +// sqlc version: [VERSION] + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +#[derive(Debug, Clone)] +pub struct BatchUpdateAuthorParams<'a> { + pub name: &'a str, + pub bio: Option<&'a str>, + pub id: i64, +} +const BATCH_UPDATE_AUTHOR: &str = "UPDATE authors SET name = $1, bio = $2 WHERE id = $3"; +pub fn batch_update_author<'a, E, I>( + db: E, + items: I, +) -> impl futures_core::stream::Stream> + 'a +where + E: AsExecutor + 'a, + I: IntoIterator> + 'a, + I::IntoIter: 'a, +{ + futures_util::stream::try_unfold( + (db, items.into_iter()), + |(mut db, mut items)| async move { + let Some(item) = items.next() else { + return Ok(None); + }; + sqlx::query(BATCH_UPDATE_AUTHOR) + .bind(item.name) + .bind(item.bio) + .bind(item.id) + .execute(db.as_executor()) + .await?; + Ok(Some(((), (db, items)))) + }, + ) +} diff --git a/tests/snapshots/codegen__borrowed_copyfrom.snap b/tests/snapshots/codegen__borrowed_copyfrom.snap new file mode 100644 index 0000000..2cb73bc --- /dev/null +++ b/tests/snapshots/codegen__borrowed_copyfrom.snap @@ -0,0 +1,80 @@ +--- +source: tests/codegen.rs +expression: code +--- +// Code generated by sqlc-gen-sqlx v[VERSION]. DO NOT EDIT. +// sqlc version: [VERSION] + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +#[derive(Debug, Clone)] +pub struct CopyAuthorsParams<'a> { + pub name: &'a str, + pub bio: Option<&'a str>, +} +const COPY_AUTHORS: &str = "INSERT INTO authors (name, bio) "; +const COPY_AUTHORS_BATCH_SIZE: usize = 32767usize; +pub async fn copy_authors<'a, E: AsExecutor, I>( + mut db: E, + items: I, +) -> Result +where + I: IntoIterator>, +{ + let mut rows_affected = 0u64; + let mut items = items.into_iter(); + loop { + let chunk = items.by_ref().take(COPY_AUTHORS_BATCH_SIZE).collect::>(); + if chunk.is_empty() { + break; + } + let mut query_builder = sqlx::QueryBuilder::::new(COPY_AUTHORS); + query_builder + .push_values( + chunk, + |mut b, item| { + b.push_bind(item.name); + b.push_bind(item.bio); + }, + ); + rows_affected + += query_builder.build().execute(db.as_executor()).await?.rows_affected(); + } + Ok(rows_affected) +} diff --git a/tests/snapshots/codegen__borrowed_one_row_unaffected.snap b/tests/snapshots/codegen__borrowed_one_row_unaffected.snap new file mode 100644 index 0000000..ea3741f --- /dev/null +++ b/tests/snapshots/codegen__borrowed_one_row_unaffected.snap @@ -0,0 +1,61 @@ +--- +source: tests/codegen.rs +expression: code +--- +// Code generated by sqlc-gen-sqlx v[VERSION]. DO NOT EDIT. +// sqlc version: [VERSION] + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +const GET_AUTHOR_BY_NAME: &str = "SELECT id, name, bio FROM authors WHERE name = $1"; +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct GetAuthorByNameRow { + pub id: i64, + pub name: String, + pub bio: Option, +} +pub async fn get_author_by_name( + mut db: E, + name: &str, +) -> Result { + sqlx::query_as::<_, GetAuthorByNameRow>(GET_AUTHOR_BY_NAME) + .bind(name) + .fetch_one(db.as_executor()) + .await +} diff --git a/tests/snapshots/codegen__borrowed_params_struct.snap b/tests/snapshots/codegen__borrowed_params_struct.snap new file mode 100644 index 0000000..572f403 --- /dev/null +++ b/tests/snapshots/codegen__borrowed_params_struct.snap @@ -0,0 +1,62 @@ +--- +source: tests/codegen.rs +expression: code +--- +// Code generated by sqlc-gen-sqlx v[VERSION]. DO NOT EDIT. +// sqlc version: [VERSION] + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +#[derive(Debug, Clone)] +pub struct UpdateAuthorParams<'a> { + pub name: &'a str, + pub bio: Option<&'a str>, +} +const UPDATE_AUTHOR: &str = "UPDATE authors SET name = $1, bio = $2"; +pub async fn update_author( + mut db: E, + arg: UpdateAuthorParams<'_>, +) -> Result<(), sqlx::Error> { + sqlx::query(UPDATE_AUTHOR) + .bind(arg.name) + .bind(arg.bio) + .execute(db.as_executor()) + .await?; + Ok(()) +} diff --git a/tests/snapshots/codegen__borrowed_scalar.snap b/tests/snapshots/codegen__borrowed_scalar.snap new file mode 100644 index 0000000..bcb9cc6 --- /dev/null +++ b/tests/snapshots/codegen__borrowed_scalar.snap @@ -0,0 +1,53 @@ +--- +source: tests/codegen.rs +expression: code +--- +// Code generated by sqlc-gen-sqlx v[VERSION]. DO NOT EDIT. +// sqlc version: [VERSION] + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +const DELETE_BY_NAME: &str = "DELETE FROM authors WHERE name = $1"; +pub async fn delete_by_name( + mut db: E, + name: &str, +) -> Result<(), sqlx::Error> { + sqlx::query(DELETE_BY_NAME).bind(name).execute(db.as_executor()).await?; + Ok(()) +} diff --git a/tests/snapshots/codegen__borrowed_slice_param.snap b/tests/snapshots/codegen__borrowed_slice_param.snap new file mode 100644 index 0000000..9fef906 --- /dev/null +++ b/tests/snapshots/codegen__borrowed_slice_param.snap @@ -0,0 +1,53 @@ +--- +source: tests/codegen.rs +expression: code +--- +// Code generated by sqlc-gen-sqlx v[VERSION]. DO NOT EDIT. +// sqlc version: [VERSION] + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +const DELETE_BY_NAMES: &str = "DELETE FROM authors WHERE name = ANY($1)"; +pub async fn delete_by_names( + mut db: E, + names: &[String], +) -> Result<(), sqlx::Error> { + sqlx::query(DELETE_BY_NAMES).bind(names).execute(db.as_executor()).await?; + Ok(()) +} diff --git a/tests/snapshots/codegen__borrowed_with_custom_owned.snap b/tests/snapshots/codegen__borrowed_with_custom_owned.snap new file mode 100644 index 0000000..2615fec --- /dev/null +++ b/tests/snapshots/codegen__borrowed_with_custom_owned.snap @@ -0,0 +1,61 @@ +--- +source: tests/codegen.rs +expression: code +--- +// Code generated by sqlc-gen-sqlx v[VERSION]. DO NOT EDIT. +// sqlc version: [VERSION] + +#![allow( + dead_code, + reason = "generated queries may expose items a caller does not use" +)] + +pub trait AsExecutor { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres>; +} +impl AsExecutor for sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &*self + } +} +impl AsExecutor for &sqlx::PgPool { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + *self + } +} +impl AsExecutor for sqlx::PgConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut *self + } +} +impl AsExecutor for sqlx::Transaction<'_, sqlx::Postgres> { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for sqlx::pool::PoolConnection { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + &mut **self + } +} +impl AsExecutor for &mut T { + fn as_executor(&mut self) -> impl sqlx::Executor<'_, Database = sqlx::Postgres> { + (**self).as_executor() + } +} +const GET_BY_NAME: &str = "SELECT id, name, bio FROM authors WHERE name = $1"; +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct GetByNameRow { + pub id: i64, + pub name: MyStr, + pub bio: Option, +} +pub async fn get_by_name( + mut db: E, + name: &MyStr, +) -> Result { + sqlx::query_as::<_, GetByNameRow>(GET_BY_NAME) + .bind(name) + .fetch_one(db.as_executor()) + .await +}