From 88f926fa7c4e631e708ea3acd8b08bc2e24425e7 Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Mon, 1 Jun 2026 09:54:10 -0600 Subject: [PATCH 1/5] Cache queries on whether they are direct-to-shard + SQL This change allows the query rewriter to safely modify its behavior for direct-to-shard queries, and skip rewrite steps that are only needed for cross-shard queries. Though this commit does not take advantage of this, future changes will require this behavior and changing the cache key required enough code to justify being pulled into its own commit. The signature of Map::get and Borrow means this needs a bit more boilerplate than I would have liked. There is no way to implement Borrow<(str, _)> for (String, _) (or any other owned/ref types). Instead we have to erase the types to a trait object so we can get individual references to the fields when we need them for comparison and hashing. This trait is implemented for any combination of values/refs, and a special case for (Arc, B). If in the future this trait is needed with additional pointer types, or Arc in the right position, additional manual impls will need to be added. --- .../router/parser/cache/cache_impl.rs | 39 ++++--- .../frontend/router/parser/cache/key_pair.rs | 106 ++++++++++++++++++ pgdog/src/frontend/router/parser/cache/mod.rs | 1 + pgdog/src/frontend/router/parser/mod.rs | 1 + pgdog/src/frontend/router/parser/route.rs | 10 ++ 5 files changed, 144 insertions(+), 13 deletions(-) create mode 100644 pgdog/src/frontend/router/parser/cache/key_pair.rs diff --git a/pgdog/src/frontend/router/parser/cache/cache_impl.rs b/pgdog/src/frontend/router/parser/cache/cache_impl.rs index 5d66d3430..b49580e56 100644 --- a/pgdog/src/frontend/router/parser/cache/cache_impl.rs +++ b/pgdog/src/frontend/router/parser/cache/cache_impl.rs @@ -10,8 +10,8 @@ use parking_lot::{Mutex, RawMutex}; use std::sync::Arc; use tracing::debug; -use super::super::{Error, Route}; -use super::{super::parse_edge_comment, Ast, AstContext, AstQuery}; +use super::super::{Error, Route, ShardOptionExt}; +use super::{super::parse_edge_comment, key_pair::KeyPair, Ast, AstContext, AstQuery}; use crate::frontend::{BufferedQuery, PreparedStatements}; static CACHE: Lazy = Lazy::new(Cache::new); @@ -47,7 +47,7 @@ impl Stats { #[derive(Debug)] pub(super) struct Inner { /// Least-recently-used cache. - queries: LruCache, Ast>, + queries: LruCache<(Arc, bool), Ast>, /// Cache global stats. pub(super) stats: Stats, } @@ -119,10 +119,16 @@ impl Cache { { let mut guard = self.inner.lock(); - let ast = guard.queries.get_mut(query_and_comment.query).map(|entry| { - entry.stats.lock().hits += 1; // No contention on this. - entry.clone() - }); + let ast = guard + .queries + .get_mut( + &(query_and_comment.query, query_and_comment.shard.is_direct()) + as &dyn KeyPair, + ) + .map(|entry| { + entry.stats.lock().hits += 1; // No contention on this. + entry.clone() + }); if let Some(mut ast) = ast { guard.stats.hits += 1; ast.comment_role = query_and_comment.role; @@ -153,9 +159,13 @@ impl Cache { // (direct-shard) variant. let cacheable = entry.comment_shard.is_none() || entry.rewrite_plan.is_empty(); if cacheable { - guard - .queries - .put(entry.query_without_comment.clone(), entry.clone()); + guard.queries.put( + ( + entry.query_without_comment.clone(), + entry.comment_shard.is_direct(), + ), + entry.clone(), + ); } guard.stats.misses += 1; guard.stats.parse_time += parse_time; @@ -208,7 +218,10 @@ impl Cache { { let mut guard = self.inner.lock(); - if let Some(entry) = guard.queries.get(normalized.as_str()) { + if let Some(entry) = guard + .queries + .get(&(normalized.as_str(), false) as &dyn KeyPair) + { entry.update_stats(route); guard.stats.hits += 1; return Ok(()); @@ -219,7 +232,7 @@ impl Cache { entry.update_stats(route); let mut guard = self.inner.lock(); - guard.queries.put(normalized.into(), entry); + guard.queries.put((normalized.into(), false), entry); guard.stats.misses += 1; Ok(()) @@ -259,7 +272,7 @@ impl Cache { .lock() .queries .iter() - .map(|i| (i.0.clone(), i.1.clone())) + .map(|i| (i.0 .0.clone(), i.1.clone())) .collect() } diff --git a/pgdog/src/frontend/router/parser/cache/key_pair.rs b/pgdog/src/frontend/router/parser/cache/key_pair.rs new file mode 100644 index 000000000..6278375f9 --- /dev/null +++ b/pgdog/src/frontend/router/parser/cache/key_pair.rs @@ -0,0 +1,106 @@ +//! A trait object used to index into a map-ish container using a tuple key. +//! This trait is needed when we want to use borrowed types to look up owned +//! data. There is no easy way to have a type which is semantically (String, +//! String) implement Borrow<(&str, &str)>. The signature of Borrow would +//! require us to manifest a tuple reference out of thin air. +//! +//! Instead we use a trait object which lets us erase the type entirely, +//! creating a reference to each element of the tuple when needed for comparison +//! and hashing. +use std::borrow::Borrow; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +pub(crate) trait KeyPair { + fn left(&self) -> &A; + fn right(&self) -> &B; +} + +impl KeyPair for (A, B) { + fn left(&self) -> &A { + &self.0 + } + + fn right(&self) -> &B { + &self.1 + } +} + +impl KeyPair for (&A, &B) { + fn left(&self) -> &A { + self.0 + } + + fn right(&self) -> &B { + self.1 + } +} + +impl KeyPair for (&A, B) { + fn left(&self) -> &A { + self.0 + } + + fn right(&self) -> &B { + &self.1 + } +} + +impl KeyPair for (A, &B) { + fn left(&self) -> &A { + &self.0 + } + + fn right(&self) -> &B { + self.1 + } +} + +impl KeyPair for (Arc, B) { + fn left(&self) -> &A { + &self.0 + } + + fn right(&self) -> &B { + &self.1 + } +} + +impl<'a, A, B, X, Y> Borrow + 'a> for (X, Y) +where + A: ?Sized, + B: ?Sized, + (X, Y): KeyPair + 'a, +{ + fn borrow(&self) -> &(dyn KeyPair + 'a) { + self + } +} + +impl<'a, A, B> PartialEq for dyn KeyPair + 'a +where + A: PartialEq + ?Sized + 'a, + B: PartialEq + ?Sized + 'a, +{ + fn eq(&self, other: &Self) -> bool { + self.left() == other.left() && self.right() == other.right() + } +} + +impl<'a, A, B> Eq for dyn KeyPair + 'a +where + A: Eq + ?Sized + 'a, + B: Eq + ?Sized + 'a, +{ +} + +impl<'a, A, B> Hash for dyn KeyPair + 'a +where + A: ?Sized + 'a, + B: ?Sized + 'a, + for<'b, 'c> (&'b A, &'c B): Hash, +{ + fn hash(&self, state: &mut H) { + (self.left(), self.right()).hash(state) + } +} diff --git a/pgdog/src/frontend/router/parser/cache/mod.rs b/pgdog/src/frontend/router/parser/cache/mod.rs index 0a3437813..5eee7c1bc 100644 --- a/pgdog/src/frontend/router/parser/cache/mod.rs +++ b/pgdog/src/frontend/router/parser/cache/mod.rs @@ -6,6 +6,7 @@ pub mod ast; pub mod cache_impl; pub mod context; pub mod fingerprint; +pub mod key_pair; pub use ast::*; pub use cache_impl::*; diff --git a/pgdog/src/frontend/router/parser/mod.rs b/pgdog/src/frontend/router/parser/mod.rs index ccd86044a..f8973b5f9 100644 --- a/pgdog/src/frontend/router/parser/mod.rs +++ b/pgdog/src/frontend/router/parser/mod.rs @@ -55,6 +55,7 @@ pub use query::QueryParser; pub use rewrite::{ Assignment, AssignmentValue, ShardKeyRewritePlan, StatementRewrite, StatementRewriteContext, }; +pub(crate) use route::ShardOptionExt; pub use route::{Route, Shard, ShardWithPriority, ShardsWithPriority}; pub use schema::Schema; pub use sequence::{OwnedSequence, Sequence}; diff --git a/pgdog/src/frontend/router/parser/route.rs b/pgdog/src/frontend/router/parser/route.rs index 06ad11d35..b63a1866e 100644 --- a/pgdog/src/frontend/router/parser/route.rs +++ b/pgdog/src/frontend/router/parser/route.rs @@ -50,6 +50,16 @@ impl Shard { } } +pub(crate) trait ShardOptionExt { + fn is_direct(&self) -> bool; +} + +impl ShardOptionExt for Option { + fn is_direct(&self) -> bool { + self.as_ref().map(Shard::is_direct).unwrap_or(false) + } +} + impl From> for Shard { fn from(value: Option) -> Self { if let Some(value) = value { From 272ad0a028d535f0d193b7a6c45dbc6ee31a3575 Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Thu, 28 May 2026 12:28:33 -0600 Subject: [PATCH 2/5] Error when encountering an unsupported aggregate function When handling aggregate functions in cross-shard queries, we can only return a correct result for functions that we have explicit support for. No matter what the aggregate function is doing, we need to know how to combine the results from the separate shards. Prior to this change, we would just silently do the wrong thing, treating this as any other non-aggregate expression and just returning a union of the rows from each query. We now explicitly check if an aggregate function with that name exists and error if we don't recognize it. This is future proofed against new functions in later postgres versions, as well as user defined aggregates (which we can likely never support). This will produce a false positive if a function exists and is defined as aggregate for some argument types but not others. But frankly anyone doing that is asking for trouble. This logic is objectively in the wrong place. What we are doing here is validation, not parsing. However as I'm familiarizing myself with these code paths and preparing a larger rearchitecture, I'm slowly hammering things into a shape that's easier to move around. I do not intend on leaving this logic here long term. --- Cargo.lock | 1 + pgdog-stats/src/schema.rs | 1 + pgdog/Cargo.toml | 1 + .../src/backend/pool/connection/aggregate.rs | 2 +- pgdog/src/backend/schema/mod.rs | 37 ++++++++++++++++ pgdog/src/frontend/router/parser/aggregate.rs | 42 ++++++++++++++++--- .../frontend/router/parser/query/select.rs | 2 +- .../rewrite/statement/aggregate/engine.rs | 8 ++-- .../parser/rewrite/statement/aggregate/mod.rs | 9 +++- .../router/parser/rewrite/statement/error.rs | 3 ++ .../router/parser/rewrite/statement/mod.rs | 2 +- 11 files changed, 94 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6f1855f5b..4c3271544 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3243,6 +3243,7 @@ dependencies = [ "hyper", "hyper-util", "indexmap", + "itertools 0.14.0", "lazy_static", "libc", "lru", diff --git a/pgdog-stats/src/schema.rs b/pgdog-stats/src/schema.rs index 1ceb88278..cf62fddc1 100644 --- a/pgdog-stats/src/schema.rs +++ b/pgdog-stats/src/schema.rs @@ -134,6 +134,7 @@ impl Relation { pub struct SchemaInner { pub search_path: Vec, pub relations: Relations, + pub aggregate_functions: Vec, } impl Hash for SchemaInner { diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index f2c4c875c..b5d6b3ef1 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -80,6 +80,7 @@ bit-vec = "0.8" reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-webpki-roots-no-provider"] } hex = "0.4" x509-parser = "0.18" +itertools = "0.14.0" [target.'cfg(unix)'.dependencies] libc = "0.2" diff --git a/pgdog/src/backend/pool/connection/aggregate.rs b/pgdog/src/backend/pool/connection/aggregate.rs index e7c5ef814..03e830055 100644 --- a/pgdog/src/backend/pool/connection/aggregate.rs +++ b/pgdog/src/backend/pool/connection/aggregate.rs @@ -803,7 +803,7 @@ mod test { } fn parse(stmt: &str) -> Aggregate { - Aggregate::parse(&select(stmt)) + Aggregate::parse(&select(stmt), &Default::default()).unwrap() } #[test] diff --git a/pgdog/src/backend/schema/mod.rs b/pgdog/src/backend/schema/mod.rs index b836fcdf8..6a120bb7d 100644 --- a/pgdog/src/backend/schema/mod.rs +++ b/pgdog/src/backend/schema/mod.rs @@ -60,9 +60,16 @@ impl Schema { .map(|p| p.trim().replace("\"", "")) .collect(); + let aggregate_functions = server + .fetch_all::( + "SELECT DISTINCT proname FROM pg_proc INNER JOIN pg_aggregate ON oid = aggfnoid", + ) + .await?; + let inner = SchemaInner { search_path, relations, + aggregate_functions, }; Ok(Self { @@ -79,6 +86,15 @@ impl Schema { pub(crate) fn from_parts( search_path: Vec, relations: HashMap<(String, String), Relation>, + ) -> Self { + Self::from_parts_with_agg(search_path, relations, Vec::new()) + } + + #[cfg(test)] + pub(crate) fn from_parts_with_agg( + search_path: Vec, + relations: HashMap<(String, String), Relation>, + aggregate_functions: Vec, ) -> Self { let mut nested: StatsRelations = HashMap::new(); for ((schema, name), relation) in relations { @@ -91,6 +107,7 @@ impl Schema { inner: StatsSchema::new(SchemaInner { search_path, relations: nested, + aggregate_functions, }), } } @@ -555,4 +572,24 @@ mod test { ); } } + + #[tokio::test] + async fn test_loading_aggregate_functions() { + let mut server = test_server().await; + server.execute_checked("BEGIN").await.unwrap(); + + let schema = Schema::load(&mut server).await.unwrap(); + assert!(!schema + .aggregate_functions + .contains(&String::from("pgdog_sum"))); + + server + .execute_checked("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)") + .await + .unwrap(); + let schema = Schema::load(&mut server).await.unwrap(); + assert!(schema + .aggregate_functions + .contains(&String::from("pgdog_sum"))); + } } diff --git a/pgdog/src/frontend/router/parser/aggregate.rs b/pgdog/src/frontend/router/parser/aggregate.rs index fd88a4410..feb736930 100644 --- a/pgdog/src/frontend/router/parser/aggregate.rs +++ b/pgdog/src/frontend/router/parser/aggregate.rs @@ -1,8 +1,10 @@ +use itertools::Itertools; use pg_query::protobuf::Integer; use pg_query::protobuf::{a_const::Val, Node, SelectStmt, String as PgQueryString}; use pg_query::NodeEnum; -use crate::frontend::router::parser::{ExpressionRegistry, Function}; +use super::{rewrite::statement::Error, ExpressionRegistry, Function}; +use crate::backend::schema::Schema; #[derive(Debug, Clone, PartialEq)] pub struct AggregateTarget { @@ -121,7 +123,7 @@ fn columns_match(group_by_names: &[&String], select_names: &[&String]) -> bool { impl Aggregate { /// Figure out what aggregates are present and which ones PgDog supports. - pub fn parse(stmt: &SelectStmt) -> Self { + pub fn parse(stmt: &SelectStmt, schema: &Schema) -> Result { let mut targets = vec![]; let mut registry = ExpressionRegistry::new(); let group_by = stmt @@ -168,7 +170,18 @@ impl Aggregate { "stddev_pop" => Some(AggregateFunction::StddevPop), "variance" | "var_samp" => Some(AggregateFunction::VarSamp), "var_pop" => Some(AggregateFunction::VarPop), - _ => None, + fname => { + if schema + .aggregate_functions + .iter() + .map(|s| s.as_ref()) + .contains(fname) + { + return Err(Error::UnsupportedAggregate(String::from(fname))); + } else { + None + } + } }; if let Some(function) = function { @@ -196,7 +209,7 @@ impl Aggregate { } } - Self { targets, group_by } + Ok(Self { targets, group_by }) } pub fn targets(&self) -> &[AggregateTarget] { @@ -259,7 +272,7 @@ mod test { } fn parse(stmt: &str) -> Aggregate { - Aggregate::parse(&select(stmt)) + Aggregate::parse(&select(stmt), &Default::default()).unwrap() } #[test] @@ -455,4 +468,23 @@ mod test { assert_eq!(aggr.group_by(), &[0]); assert_eq!(aggr.targets().len(), 1); } + + #[test] + fn test_unrecognized_aggregate_function_errors() { + let schema_with_agg = Schema::from_parts_with_agg( + Vec::new(), + Default::default(), + vec![String::from("mysum")], + ); + let schema_without_agg = Default::default(); + let query = select("SELECT mysum(lol) FROM example"); + + // A random function that isn't listed as aggregate in the schema + // doesn't require special support on our end, so we should be fine. + assert!(Aggregate::parse(&query, &schema_without_agg).is_ok()); + // If we see an aggregate function we don't recognize, we can't + // process the query correctly, since we need to combine the + // results from each shard. + assert!(Aggregate::parse(&query, &schema_with_agg).is_err()); + } } diff --git a/pgdog/src/frontend/router/parser/query/select.rs b/pgdog/src/frontend/router/parser/query/select.rs index bcfc38370..583ff0cd8 100644 --- a/pgdog/src/frontend/router/parser/query/select.rs +++ b/pgdog/src/frontend/router/parser/query/select.rs @@ -138,7 +138,7 @@ impl QueryParser { } let shard = Self::converge(&shards, ConvergeAlgorithm::default()); - let aggregates = Aggregate::parse(stmt); + let aggregates = Aggregate::parse(stmt, &context.router_context.schema)?; let limit = LimitClause::new(stmt, context.router_context.bind).limit_offset()?; let distinct = Distinct::new(stmt).distinct()?; diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs index c7af4978e..631400860 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs @@ -301,7 +301,7 @@ mod tests { fn rewrite(sql: &str) -> (ParseResult, RewriteOutput) { let mut parsed = pg_query::parse(sql).unwrap().protobuf; let stmt = select(&mut parsed); - let aggregate = Aggregate::parse(stmt); + let aggregate = Aggregate::parse(stmt, &Default::default()).unwrap(); let output = AggregatesRewrite.rewrite_select(stmt, &aggregate); (parsed, output) } @@ -326,7 +326,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()).unwrap(); assert_eq!(aggregate.targets().len(), 2); assert!(aggregate .targets() @@ -353,7 +353,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()).unwrap(); assert_eq!(aggregate.targets().len(), 3); assert!( aggregate @@ -381,7 +381,7 @@ mod tests { assert_eq!(helper_discount.helper_column, 3); assert!(matches!(helper_discount.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()).unwrap(); assert_eq!(aggregate.targets().len(), 4); assert_eq!( aggregate diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs index 0dca2c6ee..84146a2ae 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs @@ -2,6 +2,7 @@ pub mod engine; pub mod plan; use super::{Error, RewritePlan, StatementRewrite}; +use crate::backend::schema::Schema; use crate::frontend::router::parser::aggregate::Aggregate; use pg_query::NodeEnum; @@ -10,7 +11,11 @@ pub use plan::{AggregateRewritePlan, HelperKind, HelperMapping, RewriteOutput}; impl StatementRewrite<'_> { /// Add missing COUNT(*) and other helps when using aggregates. - pub(super) fn rewrite_aggregates(&mut self, plan: &mut RewritePlan) -> Result<(), Error> { + pub(super) fn rewrite_aggregates( + &mut self, + plan: &mut RewritePlan, + schema: &Schema, + ) -> Result<(), Error> { if self.schema.shards == 1 { return Ok(()); } @@ -27,7 +32,7 @@ impl StatementRewrite<'_> { return Ok(()); }; - let aggregate = Aggregate::parse(&select); + let aggregate = Aggregate::parse(&select, schema)?; if aggregate.is_empty() { return Ok(()); } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/error.rs b/pgdog/src/frontend/router/parser/rewrite/statement/error.rs index 1e0858159..7ea21e229 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/error.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/error.rs @@ -37,4 +37,7 @@ pub enum Error { #[error("missing AST on request")] MissingAst, + + #[error("aggregate function {0} is not yet supported")] + UnsupportedAggregate(String), } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs index b03fab315..45eab440a 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs @@ -140,7 +140,7 @@ impl<'a> StatementRewrite<'a> { } })?; - self.rewrite_aggregates(&mut plan)?; + self.rewrite_aggregates(&mut plan, self.db_schema)?; self.limit_offset(&mut plan)?; if self.rewritten { From 888714a4e331068275d2f26228a93a3f541c0d43 Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Thu, 28 May 2026 13:23:14 -0600 Subject: [PATCH 3/5] Use HashSet instead of Vec --- Cargo.lock | 1 - pgdog-stats/src/schema.rs | 9 +++++++-- pgdog/Cargo.toml | 1 - pgdog/src/backend/schema/mod.rs | 6 ++++-- pgdog/src/frontend/router/parser/aggregate.rs | 8 +------- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4c3271544..6f1855f5b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3243,7 +3243,6 @@ dependencies = [ "hyper", "hyper-util", "indexmap", - "itertools 0.14.0", "lazy_static", "libc", "lru", diff --git a/pgdog-stats/src/schema.rs b/pgdog-stats/src/schema.rs index cf62fddc1..ee1844919 100644 --- a/pgdog-stats/src/schema.rs +++ b/pgdog-stats/src/schema.rs @@ -1,6 +1,11 @@ use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, hash::Hash, ops::Deref, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, + ops::Deref, + sync::Arc, +}; /// Schema name -> Table name -> Relation pub type Relations = HashMap>; @@ -134,7 +139,7 @@ impl Relation { pub struct SchemaInner { pub search_path: Vec, pub relations: Relations, - pub aggregate_functions: Vec, + pub aggregate_functions: HashSet, } impl Hash for SchemaInner { diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index b5d6b3ef1..f2c4c875c 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -80,7 +80,6 @@ bit-vec = "0.8" reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-webpki-roots-no-provider"] } hex = "0.4" x509-parser = "0.18" -itertools = "0.14.0" [target.'cfg(unix)'.dependencies] libc = "0.2" diff --git a/pgdog/src/backend/schema/mod.rs b/pgdog/src/backend/schema/mod.rs index 6a120bb7d..cda7da25c 100644 --- a/pgdog/src/backend/schema/mod.rs +++ b/pgdog/src/backend/schema/mod.rs @@ -64,7 +64,9 @@ impl Schema { .fetch_all::( "SELECT DISTINCT proname FROM pg_proc INNER JOIN pg_aggregate ON oid = aggfnoid", ) - .await?; + .await? + .into_iter() + .collect(); let inner = SchemaInner { search_path, @@ -107,7 +109,7 @@ impl Schema { inner: StatsSchema::new(SchemaInner { search_path, relations: nested, - aggregate_functions, + aggregate_functions: aggregate_functions.into_iter().collect(), }), } } diff --git a/pgdog/src/frontend/router/parser/aggregate.rs b/pgdog/src/frontend/router/parser/aggregate.rs index feb736930..cfa322340 100644 --- a/pgdog/src/frontend/router/parser/aggregate.rs +++ b/pgdog/src/frontend/router/parser/aggregate.rs @@ -1,4 +1,3 @@ -use itertools::Itertools; use pg_query::protobuf::Integer; use pg_query::protobuf::{a_const::Val, Node, SelectStmt, String as PgQueryString}; use pg_query::NodeEnum; @@ -171,12 +170,7 @@ impl Aggregate { "variance" | "var_samp" => Some(AggregateFunction::VarSamp), "var_pop" => Some(AggregateFunction::VarPop), fname => { - if schema - .aggregate_functions - .iter() - .map(|s| s.as_ref()) - .contains(fname) - { + if schema.aggregate_functions.contains(fname) { return Err(Error::UnsupportedAggregate(String::from(fname))); } else { None From 48759bb23fae9739b8bc1f120e86c3c4fc62722d Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Fri, 29 May 2026 10:53:43 -0600 Subject: [PATCH 4/5] Add integration test --- integration/rust/tests/integration/mod.rs | 1 + .../integration/unrecognized_aggregate.rs | 46 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 integration/rust/tests/integration/unrecognized_aggregate.rs diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index 51bc0cb72..0edbba462 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -35,3 +35,4 @@ pub mod timestamp_sorting; pub mod tls_enforced; pub mod tls_reload; pub mod transaction_state; +pub mod unrecognized_aggregate; diff --git a/integration/rust/tests/integration/unrecognized_aggregate.rs b/integration/rust/tests/integration/unrecognized_aggregate.rs new file mode 100644 index 000000000..b8f1b5431 --- /dev/null +++ b/integration/rust/tests/integration/unrecognized_aggregate.rs @@ -0,0 +1,46 @@ +use rust::setup::{admin_sqlx, connections_sqlx}; +use sqlx::Executor; +use std::assert_matches; + +async fn define_custom_aggregate_fn() { + for connection in connections_sqlx().await { + connection + .execute("DROP AGGREGATE IF EXISTS pgdog_sum (int4)") + .await + .unwrap(); + connection + .execute("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)") + .await + .unwrap(); + connection + .execute("DROP TABLE IF EXISTS unrecognized_agg_test") + .await + .unwrap(); + connection + .execute("CREATE TABLE unrecognized_agg_test (lol int4)") + .await + .unwrap(); + } + admin_sqlx().await.execute("RELOAD").await.unwrap(); +} + +#[tokio::test] +async fn unrecognized_aggregate_function_errors_only_on_cross_shard_queries() { + define_custom_aggregate_fn().await; + let mut connections = connections_sqlx().await; + let sharded = connections.pop().unwrap(); + let unsharded = connections.pop().unwrap(); + + let unsharded_query = unsharded + .fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + assert_matches!(unsharded_query, Ok(_)); + + let sharded_query = sharded + .fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + let err = sharded_query + .err() + .expect("unrecognized aggregate executed successfully"); + assert!(err.to_string().contains("pgdog_sum is not yet supported")); +} From ed257bc59436b636369fbb78c87d374ad171f0f1 Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Mon, 1 Jun 2026 10:57:27 -0600 Subject: [PATCH 5/5] Don't perform aggregate rewrites on direct-to-shard queries While we can't completely treat direct-to-shard queries the same as unsharded queries, we can at least skip certain parts of the rewriter. In particular, the aggregate rewrite does some extra work that will never be needed on direct-to-shard, and skipping it allows us to error on aggregate functions we don't support without breaking the very niche case of an unsupported aggregate function being used by a user only in direct-to-shard queries. Actually implementing this was a bit of a PITA. Since we cache the AST after the rewriter mutates it, as opposed to just the results of pg_query::parse, we needed to add whether the query is direct-to-shard to the cache key (#1027) The natural place to exit early is the same place we do the "skip this if the schema only has one shard", but that information was previously lost all the way up at the cache impl, so we needed to pass this information through quite a few levels of indirection. Ultimately I think we need to separate out parsing from rewriting more concretely, and structure rewriting much differently with a better structured context. However, as I work towards being ready to do that larger restructuring, putting this where it's most natural to make the dependency graph clear seems like the right short-term path forward. --- .../rust/tests/integration/unrecognized_aggregate.rs | 5 +++++ pgdog/src/frontend/router/parser/cache/ast.rs | 1 + pgdog/src/frontend/router/parser/cache/cache_impl.rs | 2 ++ pgdog/src/frontend/router/parser/cache/context.rs | 4 ++++ .../router/parser/rewrite/statement/aggregate/mod.rs | 3 ++- pgdog/src/frontend/router/parser/rewrite/statement/mod.rs | 7 ++++++- 6 files changed, 20 insertions(+), 2 deletions(-) diff --git a/integration/rust/tests/integration/unrecognized_aggregate.rs b/integration/rust/tests/integration/unrecognized_aggregate.rs index b8f1b5431..feeb62599 100644 --- a/integration/rust/tests/integration/unrecognized_aggregate.rs +++ b/integration/rust/tests/integration/unrecognized_aggregate.rs @@ -43,4 +43,9 @@ async fn unrecognized_aggregate_function_errors_only_on_cross_shard_queries() { .err() .expect("unrecognized aggregate executed successfully"); assert!(err.to_string().contains("pgdog_sum is not yet supported")); + + let direct_to_shard_query = sharded + .fetch_one("/* pgdog_shard: 0 */ SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + assert_matches!(direct_to_shard_query, Ok(_)); } diff --git a/pgdog/src/frontend/router/parser/cache/ast.rs b/pgdog/src/frontend/router/parser/cache/ast.rs index d5b67292f..a6c3f89e7 100644 --- a/pgdog/src/frontend/router/parser/cache/ast.rs +++ b/pgdog/src/frontend/router/parser/cache/ast.rs @@ -99,6 +99,7 @@ impl Ast { db_schema, user, search_path, + is_direct_to_shard: query.is_direct_to_shard, }) .maybe_rewrite()?; diff --git a/pgdog/src/frontend/router/parser/cache/cache_impl.rs b/pgdog/src/frontend/router/parser/cache/cache_impl.rs index b49580e56..4edc13ff2 100644 --- a/pgdog/src/frontend/router/parser/cache/cache_impl.rs +++ b/pgdog/src/frontend/router/parser/cache/cache_impl.rs @@ -143,6 +143,7 @@ impl Cache { &AstQuery { original_query: query, query_without_comment: query_and_comment.query, + is_direct_to_shard: query_and_comment.shard.is_direct(), }, ctx, prepared_statements, @@ -187,6 +188,7 @@ impl Cache { &AstQuery { original_query: query, query_without_comment: query_and_comment.query, + is_direct_to_shard: query_and_comment.shard.is_direct(), }, ctx, prepared_statements, diff --git a/pgdog/src/frontend/router/parser/cache/context.rs b/pgdog/src/frontend/router/parser/cache/context.rs index a8ad44f85..048f23e26 100644 --- a/pgdog/src/frontend/router/parser/cache/context.rs +++ b/pgdog/src/frontend/router/parser/cache/context.rs @@ -52,14 +52,18 @@ pub struct AstQuery<'a> { pub original_query: &'a BufferedQuery, /// Query without comments and other noise. pub query_without_comment: &'a str, + /// Is this query direct-to-shard? + pub is_direct_to_shard: bool, } impl<'a> AstQuery<'a> { /// Create an AstQuery using the raw query text as the cache key. + #[cfg(test)] pub fn from_query(query: &'a BufferedQuery) -> Self { Self { query_without_comment: query.query(), original_query: query, + is_direct_to_shard: false, } } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs index 84146a2ae..f66c3a793 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs @@ -15,8 +15,9 @@ impl StatementRewrite<'_> { &mut self, plan: &mut RewritePlan, schema: &Schema, + is_direct_to_shard: bool, ) -> Result<(), Error> { - if self.schema.shards == 1 { + if self.schema.shards == 1 || is_direct_to_shard { return Ok(()); } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs index 45eab440a..91fc3389f 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs @@ -47,6 +47,8 @@ pub struct StatementRewriteContext<'a> { pub user: &'a str, /// Search path for table lookups. pub search_path: Option<&'a ParameterValue>, + /// Is the query being run direct-to-shard? + pub is_direct_to_shard: bool, } #[derive(Debug)] @@ -72,6 +74,8 @@ pub struct StatementRewrite<'a> { user: &'a str, /// Search path for table lookups. search_path: Option<&'a ParameterValue>, + /// Are we rewriting a direct-to-shard query? + is_direct_to_shard: bool, } impl<'a> StatementRewrite<'a> { @@ -90,6 +94,7 @@ impl<'a> StatementRewrite<'a> { db_schema: ctx.db_schema, user: ctx.user, search_path: ctx.search_path, + is_direct_to_shard: ctx.is_direct_to_shard, } } @@ -140,7 +145,7 @@ impl<'a> StatementRewrite<'a> { } })?; - self.rewrite_aggregates(&mut plan, self.db_schema)?; + self.rewrite_aggregates(&mut plan, self.db_schema, self.is_direct_to_shard)?; self.limit_offset(&mut plan)?; if self.rewritten {