diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index 51bc0cb72..0edbba462 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -35,3 +35,4 @@ pub mod timestamp_sorting; pub mod tls_enforced; pub mod tls_reload; pub mod transaction_state; +pub mod unrecognized_aggregate; diff --git a/integration/rust/tests/integration/unrecognized_aggregate.rs b/integration/rust/tests/integration/unrecognized_aggregate.rs new file mode 100644 index 000000000..feeb62599 --- /dev/null +++ b/integration/rust/tests/integration/unrecognized_aggregate.rs @@ -0,0 +1,51 @@ +use rust::setup::{admin_sqlx, connections_sqlx}; +use sqlx::Executor; +use std::assert_matches; + +async fn define_custom_aggregate_fn() { + for connection in connections_sqlx().await { + connection + .execute("DROP AGGREGATE IF EXISTS pgdog_sum (int4)") + .await + .unwrap(); + connection + .execute("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)") + .await + .unwrap(); + connection + .execute("DROP TABLE IF EXISTS unrecognized_agg_test") + .await + .unwrap(); + connection + .execute("CREATE TABLE unrecognized_agg_test (lol int4)") + .await + .unwrap(); + } + admin_sqlx().await.execute("RELOAD").await.unwrap(); +} + +#[tokio::test] +async fn unrecognized_aggregate_function_errors_only_on_cross_shard_queries() { + define_custom_aggregate_fn().await; + let mut connections = connections_sqlx().await; + let sharded = connections.pop().unwrap(); + let unsharded = connections.pop().unwrap(); + + let unsharded_query = unsharded + .fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + assert_matches!(unsharded_query, Ok(_)); + + let sharded_query = sharded + .fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + let err = sharded_query + .err() + .expect("unrecognized aggregate executed successfully"); + assert!(err.to_string().contains("pgdog_sum is not yet supported")); + + let direct_to_shard_query = sharded + .fetch_one("/* pgdog_shard: 0 */ SELECT pgdog_sum(lol) FROM unrecognized_agg_test") + .await; + assert_matches!(direct_to_shard_query, Ok(_)); +} diff --git a/pgdog-stats/src/schema.rs b/pgdog-stats/src/schema.rs index 1ceb88278..ee1844919 100644 --- a/pgdog-stats/src/schema.rs +++ b/pgdog-stats/src/schema.rs @@ -1,6 +1,11 @@ use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, hash::Hash, ops::Deref, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, + ops::Deref, + sync::Arc, +}; /// Schema name -> Table name -> Relation pub type Relations = HashMap>; @@ -134,6 +139,7 @@ impl Relation { pub struct SchemaInner { pub search_path: Vec, pub relations: Relations, + pub aggregate_functions: HashSet, } impl Hash for SchemaInner { diff --git a/pgdog/src/backend/pool/connection/aggregate.rs b/pgdog/src/backend/pool/connection/aggregate.rs index e7c5ef814..03e830055 100644 --- a/pgdog/src/backend/pool/connection/aggregate.rs +++ b/pgdog/src/backend/pool/connection/aggregate.rs @@ -803,7 +803,7 @@ mod test { } fn parse(stmt: &str) -> Aggregate { - Aggregate::parse(&select(stmt)) + Aggregate::parse(&select(stmt), &Default::default()).unwrap() } #[test] diff --git a/pgdog/src/backend/schema/mod.rs b/pgdog/src/backend/schema/mod.rs index b836fcdf8..cda7da25c 100644 --- a/pgdog/src/backend/schema/mod.rs +++ b/pgdog/src/backend/schema/mod.rs @@ -60,9 +60,18 @@ impl Schema { .map(|p| p.trim().replace("\"", "")) .collect(); + let aggregate_functions = server + .fetch_all::( + "SELECT DISTINCT proname FROM pg_proc INNER JOIN pg_aggregate ON oid = aggfnoid", + ) + .await? + .into_iter() + .collect(); + let inner = SchemaInner { search_path, relations, + aggregate_functions, }; Ok(Self { @@ -79,6 +88,15 @@ impl Schema { pub(crate) fn from_parts( search_path: Vec, relations: HashMap<(String, String), Relation>, + ) -> Self { + Self::from_parts_with_agg(search_path, relations, Vec::new()) + } + + #[cfg(test)] + pub(crate) fn from_parts_with_agg( + search_path: Vec, + relations: HashMap<(String, String), Relation>, + aggregate_functions: Vec, ) -> Self { let mut nested: StatsRelations = HashMap::new(); for ((schema, name), relation) in relations { @@ -91,6 +109,7 @@ impl Schema { inner: StatsSchema::new(SchemaInner { search_path, relations: nested, + aggregate_functions: aggregate_functions.into_iter().collect(), }), } } @@ -555,4 +574,24 @@ mod test { ); } } + + #[tokio::test] + async fn test_loading_aggregate_functions() { + let mut server = test_server().await; + server.execute_checked("BEGIN").await.unwrap(); + + let schema = Schema::load(&mut server).await.unwrap(); + assert!(!schema + .aggregate_functions + .contains(&String::from("pgdog_sum"))); + + server + .execute_checked("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)") + .await + .unwrap(); + let schema = Schema::load(&mut server).await.unwrap(); + assert!(schema + .aggregate_functions + .contains(&String::from("pgdog_sum"))); + } } diff --git a/pgdog/src/frontend/router/parser/aggregate.rs b/pgdog/src/frontend/router/parser/aggregate.rs index fd88a4410..cfa322340 100644 --- a/pgdog/src/frontend/router/parser/aggregate.rs +++ b/pgdog/src/frontend/router/parser/aggregate.rs @@ -2,7 +2,8 @@ use pg_query::protobuf::Integer; use pg_query::protobuf::{a_const::Val, Node, SelectStmt, String as PgQueryString}; use pg_query::NodeEnum; -use crate::frontend::router::parser::{ExpressionRegistry, Function}; +use super::{rewrite::statement::Error, ExpressionRegistry, Function}; +use crate::backend::schema::Schema; #[derive(Debug, Clone, PartialEq)] pub struct AggregateTarget { @@ -121,7 +122,7 @@ fn columns_match(group_by_names: &[&String], select_names: &[&String]) -> bool { impl Aggregate { /// Figure out what aggregates are present and which ones PgDog supports. - pub fn parse(stmt: &SelectStmt) -> Self { + pub fn parse(stmt: &SelectStmt, schema: &Schema) -> Result { let mut targets = vec![]; let mut registry = ExpressionRegistry::new(); let group_by = stmt @@ -168,7 +169,13 @@ impl Aggregate { "stddev_pop" => Some(AggregateFunction::StddevPop), "variance" | "var_samp" => Some(AggregateFunction::VarSamp), "var_pop" => Some(AggregateFunction::VarPop), - _ => None, + fname => { + if schema.aggregate_functions.contains(fname) { + return Err(Error::UnsupportedAggregate(String::from(fname))); + } else { + None + } + } }; if let Some(function) = function { @@ -196,7 +203,7 @@ impl Aggregate { } } - Self { targets, group_by } + Ok(Self { targets, group_by }) } pub fn targets(&self) -> &[AggregateTarget] { @@ -259,7 +266,7 @@ mod test { } fn parse(stmt: &str) -> Aggregate { - Aggregate::parse(&select(stmt)) + Aggregate::parse(&select(stmt), &Default::default()).unwrap() } #[test] @@ -455,4 +462,23 @@ mod test { assert_eq!(aggr.group_by(), &[0]); assert_eq!(aggr.targets().len(), 1); } + + #[test] + fn test_unrecognized_aggregate_function_errors() { + let schema_with_agg = Schema::from_parts_with_agg( + Vec::new(), + Default::default(), + vec![String::from("mysum")], + ); + let schema_without_agg = Default::default(); + let query = select("SELECT mysum(lol) FROM example"); + + // A random function that isn't listed as aggregate in the schema + // doesn't require special support on our end, so we should be fine. + assert!(Aggregate::parse(&query, &schema_without_agg).is_ok()); + // If we see an aggregate function we don't recognize, we can't + // process the query correctly, since we need to combine the + // results from each shard. + assert!(Aggregate::parse(&query, &schema_with_agg).is_err()); + } } diff --git a/pgdog/src/frontend/router/parser/cache/ast.rs b/pgdog/src/frontend/router/parser/cache/ast.rs index d5b67292f..a6c3f89e7 100644 --- a/pgdog/src/frontend/router/parser/cache/ast.rs +++ b/pgdog/src/frontend/router/parser/cache/ast.rs @@ -99,6 +99,7 @@ impl Ast { db_schema, user, search_path, + is_direct_to_shard: query.is_direct_to_shard, }) .maybe_rewrite()?; diff --git a/pgdog/src/frontend/router/parser/cache/cache_impl.rs b/pgdog/src/frontend/router/parser/cache/cache_impl.rs index 5d66d3430..4edc13ff2 100644 --- a/pgdog/src/frontend/router/parser/cache/cache_impl.rs +++ b/pgdog/src/frontend/router/parser/cache/cache_impl.rs @@ -10,8 +10,8 @@ use parking_lot::{Mutex, RawMutex}; use std::sync::Arc; use tracing::debug; -use super::super::{Error, Route}; -use super::{super::parse_edge_comment, Ast, AstContext, AstQuery}; +use super::super::{Error, Route, ShardOptionExt}; +use super::{super::parse_edge_comment, key_pair::KeyPair, Ast, AstContext, AstQuery}; use crate::frontend::{BufferedQuery, PreparedStatements}; static CACHE: Lazy = Lazy::new(Cache::new); @@ -47,7 +47,7 @@ impl Stats { #[derive(Debug)] pub(super) struct Inner { /// Least-recently-used cache. - queries: LruCache, Ast>, + queries: LruCache<(Arc, bool), Ast>, /// Cache global stats. pub(super) stats: Stats, } @@ -119,10 +119,16 @@ impl Cache { { let mut guard = self.inner.lock(); - let ast = guard.queries.get_mut(query_and_comment.query).map(|entry| { - entry.stats.lock().hits += 1; // No contention on this. - entry.clone() - }); + let ast = guard + .queries + .get_mut( + &(query_and_comment.query, query_and_comment.shard.is_direct()) + as &dyn KeyPair, + ) + .map(|entry| { + entry.stats.lock().hits += 1; // No contention on this. + entry.clone() + }); if let Some(mut ast) = ast { guard.stats.hits += 1; ast.comment_role = query_and_comment.role; @@ -137,6 +143,7 @@ impl Cache { &AstQuery { original_query: query, query_without_comment: query_and_comment.query, + is_direct_to_shard: query_and_comment.shard.is_direct(), }, ctx, prepared_statements, @@ -153,9 +160,13 @@ impl Cache { // (direct-shard) variant. let cacheable = entry.comment_shard.is_none() || entry.rewrite_plan.is_empty(); if cacheable { - guard - .queries - .put(entry.query_without_comment.clone(), entry.clone()); + guard.queries.put( + ( + entry.query_without_comment.clone(), + entry.comment_shard.is_direct(), + ), + entry.clone(), + ); } guard.stats.misses += 1; guard.stats.parse_time += parse_time; @@ -177,6 +188,7 @@ impl Cache { &AstQuery { original_query: query, query_without_comment: query_and_comment.query, + is_direct_to_shard: query_and_comment.shard.is_direct(), }, ctx, prepared_statements, @@ -208,7 +220,10 @@ impl Cache { { let mut guard = self.inner.lock(); - if let Some(entry) = guard.queries.get(normalized.as_str()) { + if let Some(entry) = guard + .queries + .get(&(normalized.as_str(), false) as &dyn KeyPair) + { entry.update_stats(route); guard.stats.hits += 1; return Ok(()); @@ -219,7 +234,7 @@ impl Cache { entry.update_stats(route); let mut guard = self.inner.lock(); - guard.queries.put(normalized.into(), entry); + guard.queries.put((normalized.into(), false), entry); guard.stats.misses += 1; Ok(()) @@ -259,7 +274,7 @@ impl Cache { .lock() .queries .iter() - .map(|i| (i.0.clone(), i.1.clone())) + .map(|i| (i.0 .0.clone(), i.1.clone())) .collect() } diff --git a/pgdog/src/frontend/router/parser/cache/context.rs b/pgdog/src/frontend/router/parser/cache/context.rs index a8ad44f85..048f23e26 100644 --- a/pgdog/src/frontend/router/parser/cache/context.rs +++ b/pgdog/src/frontend/router/parser/cache/context.rs @@ -52,14 +52,18 @@ pub struct AstQuery<'a> { pub original_query: &'a BufferedQuery, /// Query without comments and other noise. pub query_without_comment: &'a str, + /// Is this query direct-to-shard? + pub is_direct_to_shard: bool, } impl<'a> AstQuery<'a> { /// Create an AstQuery using the raw query text as the cache key. + #[cfg(test)] pub fn from_query(query: &'a BufferedQuery) -> Self { Self { query_without_comment: query.query(), original_query: query, + is_direct_to_shard: false, } } diff --git a/pgdog/src/frontend/router/parser/cache/key_pair.rs b/pgdog/src/frontend/router/parser/cache/key_pair.rs new file mode 100644 index 000000000..6278375f9 --- /dev/null +++ b/pgdog/src/frontend/router/parser/cache/key_pair.rs @@ -0,0 +1,106 @@ +//! A trait object used to index into a map-ish container using a tuple key. +//! This trait is needed when we want to use borrowed types to look up owned +//! data. There is no easy way to have a type which is semantically (String, +//! String) implement Borrow<(&str, &str)>. The signature of Borrow would +//! require us to manifest a tuple reference out of thin air. +//! +//! Instead we use a trait object which lets us erase the type entirely, +//! creating a reference to each element of the tuple when needed for comparison +//! and hashing. +use std::borrow::Borrow; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +pub(crate) trait KeyPair { + fn left(&self) -> &A; + fn right(&self) -> &B; +} + +impl KeyPair for (A, B) { + fn left(&self) -> &A { + &self.0 + } + + fn right(&self) -> &B { + &self.1 + } +} + +impl KeyPair for (&A, &B) { + fn left(&self) -> &A { + self.0 + } + + fn right(&self) -> &B { + self.1 + } +} + +impl KeyPair for (&A, B) { + fn left(&self) -> &A { + self.0 + } + + fn right(&self) -> &B { + &self.1 + } +} + +impl KeyPair for (A, &B) { + fn left(&self) -> &A { + &self.0 + } + + fn right(&self) -> &B { + self.1 + } +} + +impl KeyPair for (Arc, B) { + fn left(&self) -> &A { + &self.0 + } + + fn right(&self) -> &B { + &self.1 + } +} + +impl<'a, A, B, X, Y> Borrow + 'a> for (X, Y) +where + A: ?Sized, + B: ?Sized, + (X, Y): KeyPair + 'a, +{ + fn borrow(&self) -> &(dyn KeyPair + 'a) { + self + } +} + +impl<'a, A, B> PartialEq for dyn KeyPair + 'a +where + A: PartialEq + ?Sized + 'a, + B: PartialEq + ?Sized + 'a, +{ + fn eq(&self, other: &Self) -> bool { + self.left() == other.left() && self.right() == other.right() + } +} + +impl<'a, A, B> Eq for dyn KeyPair + 'a +where + A: Eq + ?Sized + 'a, + B: Eq + ?Sized + 'a, +{ +} + +impl<'a, A, B> Hash for dyn KeyPair + 'a +where + A: ?Sized + 'a, + B: ?Sized + 'a, + for<'b, 'c> (&'b A, &'c B): Hash, +{ + fn hash(&self, state: &mut H) { + (self.left(), self.right()).hash(state) + } +} diff --git a/pgdog/src/frontend/router/parser/cache/mod.rs b/pgdog/src/frontend/router/parser/cache/mod.rs index 0a3437813..5eee7c1bc 100644 --- a/pgdog/src/frontend/router/parser/cache/mod.rs +++ b/pgdog/src/frontend/router/parser/cache/mod.rs @@ -6,6 +6,7 @@ pub mod ast; pub mod cache_impl; pub mod context; pub mod fingerprint; +pub mod key_pair; pub use ast::*; pub use cache_impl::*; diff --git a/pgdog/src/frontend/router/parser/mod.rs b/pgdog/src/frontend/router/parser/mod.rs index ccd86044a..f8973b5f9 100644 --- a/pgdog/src/frontend/router/parser/mod.rs +++ b/pgdog/src/frontend/router/parser/mod.rs @@ -55,6 +55,7 @@ pub use query::QueryParser; pub use rewrite::{ Assignment, AssignmentValue, ShardKeyRewritePlan, StatementRewrite, StatementRewriteContext, }; +pub(crate) use route::ShardOptionExt; pub use route::{Route, Shard, ShardWithPriority, ShardsWithPriority}; pub use schema::Schema; pub use sequence::{OwnedSequence, Sequence}; diff --git a/pgdog/src/frontend/router/parser/query/select.rs b/pgdog/src/frontend/router/parser/query/select.rs index bcfc38370..583ff0cd8 100644 --- a/pgdog/src/frontend/router/parser/query/select.rs +++ b/pgdog/src/frontend/router/parser/query/select.rs @@ -138,7 +138,7 @@ impl QueryParser { } let shard = Self::converge(&shards, ConvergeAlgorithm::default()); - let aggregates = Aggregate::parse(stmt); + let aggregates = Aggregate::parse(stmt, &context.router_context.schema)?; let limit = LimitClause::new(stmt, context.router_context.bind).limit_offset()?; let distinct = Distinct::new(stmt).distinct()?; diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs index c7af4978e..631400860 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs @@ -301,7 +301,7 @@ mod tests { fn rewrite(sql: &str) -> (ParseResult, RewriteOutput) { let mut parsed = pg_query::parse(sql).unwrap().protobuf; let stmt = select(&mut parsed); - let aggregate = Aggregate::parse(stmt); + let aggregate = Aggregate::parse(stmt, &Default::default()).unwrap(); let output = AggregatesRewrite.rewrite_select(stmt, &aggregate); (parsed, output) } @@ -326,7 +326,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()).unwrap(); assert_eq!(aggregate.targets().len(), 2); assert!(aggregate .targets() @@ -353,7 +353,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()).unwrap(); assert_eq!(aggregate.targets().len(), 3); assert!( aggregate @@ -381,7 +381,7 @@ mod tests { assert_eq!(helper_discount.helper_column, 3); assert!(matches!(helper_discount.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&mut ast)); + let aggregate = Aggregate::parse(select(&mut ast), &Default::default()).unwrap(); assert_eq!(aggregate.targets().len(), 4); assert_eq!( aggregate diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs index 0dca2c6ee..f66c3a793 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs @@ -2,6 +2,7 @@ pub mod engine; pub mod plan; use super::{Error, RewritePlan, StatementRewrite}; +use crate::backend::schema::Schema; use crate::frontend::router::parser::aggregate::Aggregate; use pg_query::NodeEnum; @@ -10,8 +11,13 @@ pub use plan::{AggregateRewritePlan, HelperKind, HelperMapping, RewriteOutput}; impl StatementRewrite<'_> { /// Add missing COUNT(*) and other helps when using aggregates. - pub(super) fn rewrite_aggregates(&mut self, plan: &mut RewritePlan) -> Result<(), Error> { - if self.schema.shards == 1 { + pub(super) fn rewrite_aggregates( + &mut self, + plan: &mut RewritePlan, + schema: &Schema, + is_direct_to_shard: bool, + ) -> Result<(), Error> { + if self.schema.shards == 1 || is_direct_to_shard { return Ok(()); } @@ -27,7 +33,7 @@ impl StatementRewrite<'_> { return Ok(()); }; - let aggregate = Aggregate::parse(&select); + let aggregate = Aggregate::parse(&select, schema)?; if aggregate.is_empty() { return Ok(()); } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/error.rs b/pgdog/src/frontend/router/parser/rewrite/statement/error.rs index 1e0858159..7ea21e229 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/error.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/error.rs @@ -37,4 +37,7 @@ pub enum Error { #[error("missing AST on request")] MissingAst, + + #[error("aggregate function {0} is not yet supported")] + UnsupportedAggregate(String), } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs index b03fab315..91fc3389f 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/mod.rs @@ -47,6 +47,8 @@ pub struct StatementRewriteContext<'a> { pub user: &'a str, /// Search path for table lookups. pub search_path: Option<&'a ParameterValue>, + /// Is the query being run direct-to-shard? + pub is_direct_to_shard: bool, } #[derive(Debug)] @@ -72,6 +74,8 @@ pub struct StatementRewrite<'a> { user: &'a str, /// Search path for table lookups. search_path: Option<&'a ParameterValue>, + /// Are we rewriting a direct-to-shard query? + is_direct_to_shard: bool, } impl<'a> StatementRewrite<'a> { @@ -90,6 +94,7 @@ impl<'a> StatementRewrite<'a> { db_schema: ctx.db_schema, user: ctx.user, search_path: ctx.search_path, + is_direct_to_shard: ctx.is_direct_to_shard, } } @@ -140,7 +145,7 @@ impl<'a> StatementRewrite<'a> { } })?; - self.rewrite_aggregates(&mut plan)?; + self.rewrite_aggregates(&mut plan, self.db_schema, self.is_direct_to_shard)?; self.limit_offset(&mut plan)?; if self.rewritten { diff --git a/pgdog/src/frontend/router/parser/route.rs b/pgdog/src/frontend/router/parser/route.rs index 06ad11d35..b63a1866e 100644 --- a/pgdog/src/frontend/router/parser/route.rs +++ b/pgdog/src/frontend/router/parser/route.rs @@ -50,6 +50,16 @@ impl Shard { } } +pub(crate) trait ShardOptionExt { + fn is_direct(&self) -> bool; +} + +impl ShardOptionExt for Option { + fn is_direct(&self) -> bool { + self.as_ref().map(Shard::is_direct).unwrap_or(false) + } +} + impl From> for Shard { fn from(value: Option) -> Self { if let Some(value) = value {